Compare commits
13 Commits
terminal
...
add-morph-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7019d98bf | ||
|
|
587d1cf720 | ||
|
|
4ece87efb0 | ||
|
|
96cff78335 | ||
|
|
58d5fa1e4c | ||
|
|
f4ff1f496b | ||
|
|
e1710378b7 | ||
|
|
bc71dffd4c | ||
|
|
ebb46ba0e6 | ||
|
|
3078053795 | ||
|
|
cde7e64418 | ||
|
|
bf4223f381 | ||
|
|
1dacd941f6 |
18
.gitignore
vendored
18
.gitignore
vendored
@@ -1,2 +1,18 @@
|
||||
/venv/
|
||||
/_pycache/
|
||||
/_pycache/
|
||||
hecate/
|
||||
hecate-lib/
|
||||
*.pyc*
|
||||
__pycache__/
|
||||
.venv/
|
||||
.vscode/
|
||||
.env
|
||||
.env.local
|
||||
.env.development.local
|
||||
.env.test.local
|
||||
.env.production.local
|
||||
.env.development
|
||||
.env.test
|
||||
export*
|
||||
__pycache__/model_tools.cpython-310.pyc
|
||||
__pycache__/web_tools.cpython-310.pyc
|
||||
|
||||
566
image_generation_tool.py
Normal file
566
image_generation_tool.py
Normal file
@@ -0,0 +1,566 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Image Generation Tools Module
|
||||
|
||||
This module provides image generation tools using FAL.ai's FLUX.1 Krea model with
|
||||
automatic upscaling via FAL.ai's Clarity Upscaler for enhanced image quality.
|
||||
|
||||
Available tools:
|
||||
- image_generate_tool: Generate images from text prompts with automatic upscaling
|
||||
|
||||
Features:
|
||||
- High-quality image generation using FLUX.1 Krea model
|
||||
- Automatic 2x upscaling using Clarity Upscaler for enhanced quality
|
||||
- Comprehensive parameter control (size, steps, guidance, etc.)
|
||||
- Proper error handling and validation with fallback to original images
|
||||
- Debug logging support
|
||||
- Sync mode for immediate results
|
||||
|
||||
Usage:
|
||||
from image_generation_tool import image_generate_tool
|
||||
import asyncio
|
||||
|
||||
# Generate and automatically upscale an image
|
||||
result = await image_generate_tool(
|
||||
prompt="A serene mountain landscape with cherry blossoms",
|
||||
image_size="landscape_4_3",
|
||||
num_images=1
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, Union
|
||||
import fal_client
|
||||
|
||||
# Configuration for image generation
|
||||
DEFAULT_MODEL = "fal-ai/flux/krea"
|
||||
DEFAULT_IMAGE_SIZE = "landscape_4_3"
|
||||
DEFAULT_NUM_INFERENCE_STEPS = 50
|
||||
DEFAULT_GUIDANCE_SCALE = 4.5
|
||||
DEFAULT_NUM_IMAGES = 1
|
||||
DEFAULT_OUTPUT_FORMAT = "png"
|
||||
|
||||
# Configuration for automatic upscaling
|
||||
UPSCALER_MODEL = "fal-ai/clarity-upscaler"
|
||||
UPSCALER_FACTOR = 2
|
||||
UPSCALER_SAFETY_CHECKER = False
|
||||
UPSCALER_DEFAULT_PROMPT = "masterpiece, best quality, highres"
|
||||
UPSCALER_NEGATIVE_PROMPT = "(worst quality, low quality, normal quality:2)"
|
||||
UPSCALER_CREATIVITY = 0.35
|
||||
UPSCALER_RESEMBLANCE = 0.6
|
||||
UPSCALER_GUIDANCE_SCALE = 4
|
||||
UPSCALER_NUM_INFERENCE_STEPS = 18
|
||||
|
||||
# Valid parameter values for validation based on FLUX Krea documentation
|
||||
VALID_IMAGE_SIZES = [
|
||||
"square_hd", "square", "portrait_4_3", "portrait_16_9", "landscape_4_3", "landscape_16_9"
|
||||
]
|
||||
VALID_OUTPUT_FORMATS = ["jpeg", "png"]
|
||||
VALID_ACCELERATION_MODES = ["none", "regular", "high"]
|
||||
|
||||
# Debug mode configuration
|
||||
DEBUG_MODE = os.getenv("IMAGE_TOOLS_DEBUG", "false").lower() == "true"
|
||||
DEBUG_SESSION_ID = str(uuid.uuid4())
|
||||
DEBUG_LOG_PATH = Path("./logs")
|
||||
DEBUG_DATA = {
|
||||
"session_id": DEBUG_SESSION_ID,
|
||||
"start_time": datetime.datetime.now().isoformat(),
|
||||
"debug_enabled": DEBUG_MODE,
|
||||
"tool_calls": []
|
||||
} if DEBUG_MODE else None
|
||||
|
||||
# Create logs directory if debug mode is enabled
|
||||
if DEBUG_MODE:
|
||||
DEBUG_LOG_PATH.mkdir(exist_ok=True)
|
||||
print(f"🐛 Image generation debug mode enabled - Session ID: {DEBUG_SESSION_ID}")
|
||||
|
||||
|
||||
def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log a debug call entry to the global debug data structure.
|
||||
|
||||
Args:
|
||||
tool_name (str): Name of the tool being called
|
||||
call_data (Dict[str, Any]): Data about the call including parameters and results
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return
|
||||
|
||||
call_entry = {
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": tool_name,
|
||||
**call_data
|
||||
}
|
||||
|
||||
DEBUG_DATA["tool_calls"].append(call_entry)
|
||||
|
||||
|
||||
def _save_debug_log() -> None:
|
||||
"""
|
||||
Save the current debug data to a JSON file in the logs directory.
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return
|
||||
|
||||
try:
|
||||
debug_filename = f"image_tools_debug_{DEBUG_SESSION_ID}.json"
|
||||
debug_filepath = DEBUG_LOG_PATH / debug_filename
|
||||
|
||||
# Update end time
|
||||
DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat()
|
||||
DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"])
|
||||
|
||||
with open(debug_filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"🐛 Image generation debug log saved: {debug_filepath}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error saving image generation debug log: {str(e)}")
|
||||
|
||||
|
||||
def _validate_parameters(
|
||||
image_size: Union[str, Dict[str, int]],
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
num_images: int,
|
||||
output_format: str,
|
||||
acceleration: str = "none"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate and normalize image generation parameters for FLUX Krea model.
|
||||
|
||||
Args:
|
||||
image_size: Either a preset string or custom size dict
|
||||
num_inference_steps: Number of inference steps
|
||||
guidance_scale: Guidance scale value
|
||||
num_images: Number of images to generate
|
||||
output_format: Output format for images
|
||||
acceleration: Acceleration mode for generation speed
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Validated and normalized parameters
|
||||
|
||||
Raises:
|
||||
ValueError: If any parameter is invalid
|
||||
"""
|
||||
validated = {}
|
||||
|
||||
# Validate image_size
|
||||
if isinstance(image_size, str):
|
||||
if image_size not in VALID_IMAGE_SIZES:
|
||||
raise ValueError(f"Invalid image_size '{image_size}'. Must be one of: {VALID_IMAGE_SIZES}")
|
||||
validated["image_size"] = image_size
|
||||
elif isinstance(image_size, dict):
|
||||
if "width" not in image_size or "height" not in image_size:
|
||||
raise ValueError("Custom image_size must contain 'width' and 'height' keys")
|
||||
if not isinstance(image_size["width"], int) or not isinstance(image_size["height"], int):
|
||||
raise ValueError("Custom image_size width and height must be integers")
|
||||
if image_size["width"] < 64 or image_size["height"] < 64:
|
||||
raise ValueError("Custom image_size dimensions must be at least 64x64")
|
||||
if image_size["width"] > 2048 or image_size["height"] > 2048:
|
||||
raise ValueError("Custom image_size dimensions must not exceed 2048x2048")
|
||||
validated["image_size"] = image_size
|
||||
else:
|
||||
raise ValueError("image_size must be either a preset string or a dict with width/height")
|
||||
|
||||
# Validate num_inference_steps
|
||||
if not isinstance(num_inference_steps, int) or num_inference_steps < 1 or num_inference_steps > 100:
|
||||
raise ValueError("num_inference_steps must be an integer between 1 and 100")
|
||||
validated["num_inference_steps"] = num_inference_steps
|
||||
|
||||
# Validate guidance_scale (FLUX Krea default is 4.5)
|
||||
if not isinstance(guidance_scale, (int, float)) or guidance_scale < 0.1 or guidance_scale > 20.0:
|
||||
raise ValueError("guidance_scale must be a number between 0.1 and 20.0")
|
||||
validated["guidance_scale"] = float(guidance_scale)
|
||||
|
||||
# Validate num_images
|
||||
if not isinstance(num_images, int) or num_images < 1 or num_images > 4:
|
||||
raise ValueError("num_images must be an integer between 1 and 4")
|
||||
validated["num_images"] = num_images
|
||||
|
||||
# Validate output_format
|
||||
if output_format not in VALID_OUTPUT_FORMATS:
|
||||
raise ValueError(f"Invalid output_format '{output_format}'. Must be one of: {VALID_OUTPUT_FORMATS}")
|
||||
validated["output_format"] = output_format
|
||||
|
||||
# Validate acceleration
|
||||
if acceleration not in VALID_ACCELERATION_MODES:
|
||||
raise ValueError(f"Invalid acceleration '{acceleration}'. Must be one of: {VALID_ACCELERATION_MODES}")
|
||||
validated["acceleration"] = acceleration
|
||||
|
||||
return validated
|
||||
|
||||
|
||||
async def _upscale_image(image_url: str, original_prompt: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Upscale an image using FAL.ai's Clarity Upscaler.
|
||||
|
||||
Args:
|
||||
image_url (str): URL of the image to upscale
|
||||
original_prompt (str): Original prompt used to generate the image
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Upscaled image data or None if upscaling fails
|
||||
"""
|
||||
try:
|
||||
print(f"🔍 Upscaling image with Clarity Upscaler...")
|
||||
|
||||
# Prepare arguments for upscaler
|
||||
upscaler_arguments = {
|
||||
"image_url": image_url,
|
||||
"prompt": f"{UPSCALER_DEFAULT_PROMPT}, {original_prompt}",
|
||||
"upscale_factor": UPSCALER_FACTOR,
|
||||
"negative_prompt": UPSCALER_NEGATIVE_PROMPT,
|
||||
"creativity": UPSCALER_CREATIVITY,
|
||||
"resemblance": UPSCALER_RESEMBLANCE,
|
||||
"guidance_scale": UPSCALER_GUIDANCE_SCALE,
|
||||
"num_inference_steps": UPSCALER_NUM_INFERENCE_STEPS,
|
||||
"enable_safety_checker": UPSCALER_SAFETY_CHECKER
|
||||
}
|
||||
|
||||
# Submit upscaler request
|
||||
handler = await fal_client.submit_async(
|
||||
UPSCALER_MODEL,
|
||||
arguments=upscaler_arguments
|
||||
)
|
||||
|
||||
# Get the upscaled result
|
||||
result = await handler.get()
|
||||
|
||||
if result and "image" in result:
|
||||
upscaled_image = result["image"]
|
||||
print(f"✅ Image upscaled successfully to {upscaled_image.get('width', 'unknown')}x{upscaled_image.get('height', 'unknown')}")
|
||||
return {
|
||||
"url": upscaled_image["url"],
|
||||
"width": upscaled_image.get("width", 0),
|
||||
"height": upscaled_image.get("height", 0),
|
||||
"upscaled": True,
|
||||
"upscale_factor": UPSCALER_FACTOR
|
||||
}
|
||||
else:
|
||||
print("❌ Upscaler returned invalid response")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error upscaling image: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def image_generate_tool(
|
||||
prompt: str,
|
||||
image_size: Union[str, Dict[str, int]] = DEFAULT_IMAGE_SIZE,
|
||||
num_inference_steps: int = DEFAULT_NUM_INFERENCE_STEPS,
|
||||
guidance_scale: float = DEFAULT_GUIDANCE_SCALE,
|
||||
num_images: int = DEFAULT_NUM_IMAGES,
|
||||
enable_safety_checker: bool = True,
|
||||
output_format: str = DEFAULT_OUTPUT_FORMAT,
|
||||
acceleration: str = "none",
|
||||
allow_nsfw_images: bool = True,
|
||||
seed: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate images from text prompts using FAL.ai's FLUX.1 Krea model with automatic upscaling.
|
||||
|
||||
This tool uses FAL.ai's FLUX.1 Krea model for high-quality text-to-image generation
|
||||
with extensive customization options. Generated images are automatically upscaled 2x
|
||||
using FAL.ai's Clarity Upscaler for enhanced quality. The final upscaled images are
|
||||
returned as URLs that can be displayed using <img src="{URL}"></img> tags.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt describing the desired image
|
||||
image_size (Union[str, Dict[str, int]]): Preset size or custom {"width": int, "height": int}
|
||||
num_inference_steps (int): Number of denoising steps (1-50, default: 28)
|
||||
guidance_scale (float): How closely to follow prompt (0.1-20.0, default: 4.5)
|
||||
num_images (int): Number of images to generate (1-4, default: 1)
|
||||
enable_safety_checker (bool): Enable content safety filtering (default: True)
|
||||
output_format (str): Image format "jpeg" or "png" (default: "png")
|
||||
acceleration (str): Generation speed "none", "regular", or "high" (default: "none")
|
||||
allow_nsfw_images (bool): Allow generation of NSFW content (default: True)
|
||||
seed (Optional[int]): Random seed for reproducible results (optional)
|
||||
|
||||
Returns:
|
||||
str: JSON string containing minimal generation results:
|
||||
{
|
||||
"success": bool,
|
||||
"image": str or None # URL of the upscaled image, or None if failed
|
||||
}
|
||||
"""
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"prompt": prompt,
|
||||
"image_size": image_size,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"guidance_scale": guidance_scale,
|
||||
"num_images": num_images,
|
||||
"enable_safety_checker": enable_safety_checker,
|
||||
"output_format": output_format,
|
||||
"acceleration": acceleration,
|
||||
"allow_nsfw_images": allow_nsfw_images,
|
||||
"seed": seed
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"images_generated": 0,
|
||||
"generation_time": 0
|
||||
}
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
try:
|
||||
print(f"🎨 Generating {num_images} image(s) with FLUX Krea: {prompt[:80]}{'...' if len(prompt) > 80 else ''}")
|
||||
|
||||
# Validate prompt
|
||||
if not prompt or not isinstance(prompt, str) or len(prompt.strip()) == 0:
|
||||
raise ValueError("Prompt is required and must be a non-empty string")
|
||||
|
||||
if len(prompt) > 1000:
|
||||
raise ValueError("Prompt must be 1000 characters or less")
|
||||
|
||||
# Check API key availability
|
||||
if not os.getenv("FAL_KEY"):
|
||||
raise ValueError("FAL_KEY environment variable not set")
|
||||
|
||||
# Validate parameters
|
||||
validated_params = _validate_parameters(
|
||||
image_size, num_inference_steps, guidance_scale, num_images, output_format, acceleration
|
||||
)
|
||||
|
||||
# Prepare arguments for FAL.ai FLUX Krea API
|
||||
arguments = {
|
||||
"prompt": prompt.strip(),
|
||||
"image_size": validated_params["image_size"],
|
||||
"num_inference_steps": validated_params["num_inference_steps"],
|
||||
"guidance_scale": validated_params["guidance_scale"],
|
||||
"num_images": validated_params["num_images"],
|
||||
"enable_safety_checker": enable_safety_checker,
|
||||
"output_format": validated_params["output_format"],
|
||||
"acceleration": validated_params["acceleration"],
|
||||
"allow_nsfw_images": allow_nsfw_images,
|
||||
"sync_mode": True # Use sync mode for immediate results
|
||||
}
|
||||
|
||||
# Add seed if provided
|
||||
if seed is not None and isinstance(seed, int):
|
||||
arguments["seed"] = seed
|
||||
|
||||
print(f"🚀 Submitting generation request to FAL.ai FLUX Krea...")
|
||||
print(f" Model: {DEFAULT_MODEL}")
|
||||
print(f" Size: {validated_params['image_size']}")
|
||||
print(f" Steps: {validated_params['num_inference_steps']}")
|
||||
print(f" Guidance: {validated_params['guidance_scale']}")
|
||||
print(f" Acceleration: {validated_params['acceleration']}")
|
||||
|
||||
# Submit request to FAL.ai
|
||||
handler = await fal_client.submit_async(
|
||||
DEFAULT_MODEL,
|
||||
arguments=arguments
|
||||
)
|
||||
|
||||
# Get the result
|
||||
result = await handler.get()
|
||||
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
|
||||
# Process the response
|
||||
if not result or "images" not in result:
|
||||
raise ValueError("Invalid response from FAL.ai API - no images returned")
|
||||
|
||||
images = result.get("images", [])
|
||||
if not images:
|
||||
raise ValueError("No images were generated")
|
||||
|
||||
# Format image data and upscale images
|
||||
formatted_images = []
|
||||
for img in images:
|
||||
if isinstance(img, dict) and "url" in img:
|
||||
original_image = {
|
||||
"url": img["url"],
|
||||
"width": img.get("width", 0),
|
||||
"height": img.get("height", 0)
|
||||
}
|
||||
|
||||
# Attempt to upscale the image
|
||||
upscaled_image = await _upscale_image(img["url"], prompt.strip())
|
||||
|
||||
if upscaled_image:
|
||||
# Use upscaled image if successful
|
||||
formatted_images.append(upscaled_image)
|
||||
else:
|
||||
# Fall back to original image if upscaling fails
|
||||
print(f"⚠️ Using original image as fallback")
|
||||
original_image["upscaled"] = False
|
||||
formatted_images.append(original_image)
|
||||
|
||||
if not formatted_images:
|
||||
raise ValueError("No valid image URLs returned from API")
|
||||
|
||||
upscaled_count = sum(1 for img in formatted_images if img.get("upscaled", False))
|
||||
print(f"✅ Generated {len(formatted_images)} image(s) in {generation_time:.1f}s ({upscaled_count} upscaled)")
|
||||
|
||||
# Prepare successful response - minimal format
|
||||
response_data = {
|
||||
"success": True,
|
||||
"image": formatted_images[0]["url"] if formatted_images else None
|
||||
}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["images_generated"] = len(formatted_images)
|
||||
debug_call_data["generation_time"] = generation_time
|
||||
|
||||
# Log debug information
|
||||
_log_debug_call("image_generate_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(response_data, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
generation_time = (datetime.datetime.now() - start_time).total_seconds()
|
||||
error_msg = f"Error generating image: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
||||
# Prepare error response - minimal format
|
||||
response_data = {
|
||||
"success": False,
|
||||
"image": None
|
||||
}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
debug_call_data["generation_time"] = generation_time
|
||||
_log_debug_call("image_generate_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(response_data, indent=2)
|
||||
|
||||
|
||||
def check_fal_api_key() -> bool:
|
||||
"""
|
||||
Check if the FAL.ai API key is available in environment variables.
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
return bool(os.getenv("FAL_KEY"))
|
||||
|
||||
|
||||
def check_image_generation_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for image generation tools are met.
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check API key
|
||||
if not check_fal_api_key():
|
||||
return False
|
||||
|
||||
# Check if fal_client is available
|
||||
import fal_client
|
||||
return True
|
||||
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return {
|
||||
"enabled": False,
|
||||
"session_id": None,
|
||||
"log_path": None,
|
||||
"total_calls": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"session_id": DEBUG_SESSION_ID,
|
||||
"log_path": str(DEBUG_LOG_PATH / f"image_tools_debug_{DEBUG_SESSION_ID}.json"),
|
||||
"total_calls": len(DEBUG_DATA["tool_calls"])
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Simple test/demo when run directly
|
||||
"""
|
||||
print("🎨 Image Generation Tools Module - FLUX.1 Krea + Auto Upscaling")
|
||||
print("=" * 60)
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_fal_api_key()
|
||||
|
||||
if not api_available:
|
||||
print("❌ FAL_KEY environment variable not set")
|
||||
print("Please set your API key: export FAL_KEY='your-key-here'")
|
||||
print("Get API key at: https://fal.ai/")
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ FAL.ai API key found")
|
||||
|
||||
# Check if fal_client is available
|
||||
try:
|
||||
import fal_client
|
||||
print("✅ fal_client library available")
|
||||
except ImportError:
|
||||
print("❌ fal_client library not found")
|
||||
print("Please install: pip install fal-client")
|
||||
exit(1)
|
||||
|
||||
print("🛠️ Image generation tools ready for use!")
|
||||
print(f"🤖 Using model: {DEFAULT_MODEL}")
|
||||
print(f"🔍 Auto-upscaling with: {UPSCALER_MODEL} ({UPSCALER_FACTOR}x)")
|
||||
|
||||
# Show debug mode status
|
||||
if DEBUG_MODE:
|
||||
print(f"🐛 Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}")
|
||||
print(f" Debug logs will be saved to: ./logs/image_tools_debug_{DEBUG_SESSION_ID}.json")
|
||||
else:
|
||||
print("🐛 Debug mode disabled (set IMAGE_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from image_generation_tool import image_generate_tool")
|
||||
print(" import asyncio")
|
||||
print("")
|
||||
print(" async def main():")
|
||||
print(" # Generate image with automatic 2x upscaling")
|
||||
print(" result = await image_generate_tool(")
|
||||
print(" prompt='A serene mountain landscape with cherry blossoms',")
|
||||
print(" image_size='landscape_4_3',")
|
||||
print(" num_images=1")
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
print("\nSupported image sizes:")
|
||||
for size in VALID_IMAGE_SIZES:
|
||||
print(f" - {size}")
|
||||
print(" - Custom: {'width': 512, 'height': 768} (if needed)")
|
||||
|
||||
print("\nAcceleration modes:")
|
||||
for mode in VALID_ACCELERATION_MODES:
|
||||
print(f" - {mode}")
|
||||
|
||||
print("\nExample prompts:")
|
||||
print(" - 'A candid street photo of a woman with a pink bob and bold eyeliner'")
|
||||
print(" - 'Modern architecture building with glass facade, sunset lighting'")
|
||||
print(" - 'Abstract art with vibrant colors and geometric patterns'")
|
||||
print(" - 'Portrait of a wise old owl perched on ancient tree branch'")
|
||||
print(" - 'Futuristic cityscape with flying cars and neon lights'")
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export IMAGE_TOOLS_DEBUG=true")
|
||||
print(" # Debug logs capture all image generation calls and results")
|
||||
print(" # Logs saved to: ./logs/image_tools_debug_UUID.json")
|
||||
586
mixture_of_agents_tool.py
Normal file
586
mixture_of_agents_tool.py
Normal file
@@ -0,0 +1,586 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Mixture-of-Agents Tool Module
|
||||
|
||||
This module implements the Mixture-of-Agents (MoA) methodology that leverages
|
||||
the collective strengths of multiple LLMs through a layered architecture to
|
||||
achieve state-of-the-art performance on complex reasoning tasks.
|
||||
|
||||
Based on the research paper: "Mixture-of-Agents Enhances Large Language Model Capabilities"
|
||||
by Junlin Wang et al. (arXiv:2406.04692v1)
|
||||
|
||||
Key Features:
|
||||
- Multi-layer LLM collaboration for enhanced reasoning
|
||||
- Parallel processing of reference models for efficiency
|
||||
- Intelligent aggregation and synthesis of diverse responses
|
||||
- Specialized for extremely difficult problems requiring intense reasoning
|
||||
- Optimized for coding, mathematics, and complex analytical tasks
|
||||
|
||||
Available Tool:
|
||||
- mixture_of_agents_tool: Process complex queries using multiple frontier models
|
||||
|
||||
Architecture:
|
||||
1. Reference models generate diverse initial responses in parallel
|
||||
2. Aggregator model synthesizes responses into a high-quality output
|
||||
3. Multiple layers can be used for iterative refinement (future enhancement)
|
||||
|
||||
Models Used:
|
||||
- Reference Models: claude-opus-4-20250514, gemini-2.5-pro, o4-mini, deepseek-r1
|
||||
- Aggregator Model: claude-opus-4-20250514 (highest capability for synthesis)
|
||||
|
||||
Configuration:
|
||||
To customize the MoA setup, modify the configuration constants at the top of this file:
|
||||
- REFERENCE_MODELS: List of models for generating diverse initial responses
|
||||
- AGGREGATOR_MODEL: Model used to synthesize the final response
|
||||
- REFERENCE_TEMPERATURE/AGGREGATOR_TEMPERATURE: Sampling temperatures
|
||||
- MIN_SUCCESSFUL_REFERENCES: Minimum successful models needed to proceed
|
||||
|
||||
Usage:
|
||||
from mixture_of_agents_tool import mixture_of_agents_tool
|
||||
import asyncio
|
||||
|
||||
# Process a complex query
|
||||
result = await mixture_of_agents_tool(
|
||||
user_prompt="Solve this complex mathematical proof..."
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Initialize Nous Research API client for MoA processing
|
||||
nous_client = AsyncOpenAI(
|
||||
api_key=os.getenv("NOUS_API_KEY"),
|
||||
base_url="https://inference-api.nousresearch.com/v1"
|
||||
)
|
||||
|
||||
# Configuration for MoA processing
|
||||
# Reference models - these generate diverse initial responses in parallel
|
||||
REFERENCE_MODELS = [
|
||||
"claude-opus-4-20250514",
|
||||
"gemini-2.5-pro",
|
||||
"gpt-5",
|
||||
"deepseek-r1"
|
||||
]
|
||||
|
||||
# Aggregator model - synthesizes reference responses into final output
|
||||
AGGREGATOR_MODEL = "claude-opus-4-20250514" # Use highest capability model for aggregation
|
||||
|
||||
# Temperature settings optimized for MoA performance
|
||||
REFERENCE_TEMPERATURE = 0.6 # Balanced creativity for diverse perspectives
|
||||
AGGREGATOR_TEMPERATURE = 0.4 # Focused synthesis for consistency
|
||||
|
||||
# Failure handling configuration
|
||||
MIN_SUCCESSFUL_REFERENCES = 1 # Minimum successful reference models needed to proceed
|
||||
|
||||
# System prompt for the aggregator model (from the research paper)
|
||||
AGGREGATOR_SYSTEM_PROMPT = """You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.
|
||||
|
||||
Responses from models:"""
|
||||
|
||||
# Debug mode configuration
|
||||
DEBUG_MODE = os.getenv("MOA_TOOLS_DEBUG", "false").lower() == "true"
|
||||
DEBUG_SESSION_ID = str(uuid.uuid4())
|
||||
DEBUG_LOG_PATH = Path("./logs")
|
||||
DEBUG_DATA = {
|
||||
"session_id": DEBUG_SESSION_ID,
|
||||
"start_time": datetime.datetime.now().isoformat(),
|
||||
"debug_enabled": DEBUG_MODE,
|
||||
"tool_calls": []
|
||||
} if DEBUG_MODE else None
|
||||
|
||||
# Create logs directory if debug mode is enabled
|
||||
if DEBUG_MODE:
|
||||
DEBUG_LOG_PATH.mkdir(exist_ok=True)
|
||||
print(f"🐛 MoA debug mode enabled - Session ID: {DEBUG_SESSION_ID}")
|
||||
|
||||
|
||||
def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log a debug call entry to the global debug data structure.
|
||||
|
||||
Args:
|
||||
tool_name (str): Name of the tool being called
|
||||
call_data (Dict[str, Any]): Data about the call including parameters and results
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return
|
||||
|
||||
call_entry = {
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": tool_name,
|
||||
**call_data
|
||||
}
|
||||
|
||||
DEBUG_DATA["tool_calls"].append(call_entry)
|
||||
|
||||
|
||||
def _save_debug_log() -> None:
|
||||
"""
|
||||
Save the current debug data to a JSON file in the logs directory.
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return
|
||||
|
||||
try:
|
||||
debug_filename = f"moa_tools_debug_{DEBUG_SESSION_ID}.json"
|
||||
debug_filepath = DEBUG_LOG_PATH / debug_filename
|
||||
|
||||
# Update end time
|
||||
DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat()
|
||||
DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"])
|
||||
|
||||
with open(debug_filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"🐛 MoA debug log saved: {debug_filepath}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error saving MoA debug log: {str(e)}")
|
||||
|
||||
|
||||
def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> str:
|
||||
"""
|
||||
Construct the final system prompt for the aggregator including all model responses.
|
||||
|
||||
Args:
|
||||
system_prompt (str): Base system prompt for aggregation
|
||||
responses (List[str]): List of responses from reference models
|
||||
|
||||
Returns:
|
||||
str: Complete system prompt with enumerated responses
|
||||
"""
|
||||
response_text = "\n".join([f"{i+1}. {response}" for i, response in enumerate(responses)])
|
||||
return f"{system_prompt}\n\n{response_text}"
|
||||
|
||||
|
||||
async def _run_reference_model_safe(
|
||||
model: str,
|
||||
user_prompt: str,
|
||||
temperature: float = REFERENCE_TEMPERATURE,
|
||||
max_tokens: int = 32000,
|
||||
max_retries: int = 3
|
||||
) -> tuple[str, str, bool]:
|
||||
"""
|
||||
Run a single reference model with retry logic and graceful failure handling.
|
||||
|
||||
Args:
|
||||
model (str): Model identifier to use
|
||||
user_prompt (str): The user's query
|
||||
temperature (float): Sampling temperature for response generation
|
||||
max_tokens (int): Maximum tokens in response
|
||||
max_retries (int): Maximum number of retry attempts
|
||||
|
||||
Returns:
|
||||
tuple[str, str, bool]: (model_name, response_content_or_error, success_flag)
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
print(f"🤖 Querying {model} (attempt {attempt + 1}/{max_retries})")
|
||||
|
||||
# Build parameters for the API call
|
||||
api_params = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": user_prompt}]
|
||||
}
|
||||
|
||||
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||
# Only include temperature for non-GPT models
|
||||
if not model.lower().startswith('gpt-'):
|
||||
api_params["temperature"] = temperature
|
||||
|
||||
response = await nous_client.chat.completions.create(**api_params)
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
print(f"✅ {model} responded ({len(content)} characters)")
|
||||
return model, content, True
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# Log more detailed error information for debugging
|
||||
if "invalid" in error_str.lower():
|
||||
print(f"⚠️ {model} invalid request error (attempt {attempt + 1}): {error_str}")
|
||||
elif "rate" in error_str.lower() or "limit" in error_str.lower():
|
||||
print(f"⚠️ {model} rate limit error (attempt {attempt + 1}): {error_str}")
|
||||
else:
|
||||
print(f"⚠️ {model} unknown error (attempt {attempt + 1}): {error_str}")
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Exponential backoff for rate limiting
|
||||
sleep_time = 2 ** attempt
|
||||
print(f" Retrying in {sleep_time}s...")
|
||||
await asyncio.sleep(sleep_time)
|
||||
else:
|
||||
error_msg = f"{model} failed after {max_retries} attempts: {error_str}"
|
||||
print(f"❌ {error_msg}")
|
||||
return model, error_msg, False
|
||||
|
||||
|
||||
async def _run_aggregator_model(
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
temperature: float = AGGREGATOR_TEMPERATURE,
|
||||
max_tokens: int = None
|
||||
) -> str:
|
||||
"""
|
||||
Run the aggregator model to synthesize the final response.
|
||||
|
||||
Args:
|
||||
system_prompt (str): System prompt with all reference responses
|
||||
user_prompt (str): Original user query
|
||||
temperature (float): Focused temperature for consistent aggregation
|
||||
max_tokens (int): Maximum tokens in final response
|
||||
|
||||
Returns:
|
||||
str: Synthesized final response
|
||||
"""
|
||||
print(f"🧠 Running aggregator model: {AGGREGATOR_MODEL}")
|
||||
|
||||
# Build parameters for the API call
|
||||
api_params = {
|
||||
"model": AGGREGATOR_MODEL,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
}
|
||||
|
||||
# GPT models (especially gpt-4o-mini) don't support custom temperature values
|
||||
# Only include temperature for non-GPT models
|
||||
if not AGGREGATOR_MODEL.lower().startswith('gpt-'):
|
||||
api_params["temperature"] = temperature
|
||||
|
||||
response = await nous_client.chat.completions.create(**api_params)
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
print(f"✅ Aggregation complete ({len(content)} characters)")
|
||||
return content
|
||||
|
||||
|
||||
async def mixture_of_agents_tool(
|
||||
user_prompt: str,
|
||||
reference_models: Optional[List[str]] = None,
|
||||
aggregator_model: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Process a complex query using the Mixture-of-Agents methodology.
|
||||
|
||||
This tool leverages multiple frontier language models to collaboratively solve
|
||||
extremely difficult problems requiring intense reasoning. It's particularly
|
||||
effective for:
|
||||
- Complex mathematical proofs and calculations
|
||||
- Advanced coding problems and algorithm design
|
||||
- Multi-step analytical reasoning tasks
|
||||
- Problems requiring diverse domain expertise
|
||||
- Tasks where single models show limitations
|
||||
|
||||
The MoA approach uses a fixed 2-layer architecture:
|
||||
1. Layer 1: Multiple reference models generate diverse responses in parallel (temp=0.6)
|
||||
2. Layer 2: Aggregator model synthesizes the best elements into final response (temp=0.4)
|
||||
|
||||
Args:
|
||||
user_prompt (str): The complex query or problem to solve
|
||||
reference_models (Optional[List[str]]): Custom reference models to use
|
||||
aggregator_model (Optional[str]): Custom aggregator model to use
|
||||
|
||||
Returns:
|
||||
str: JSON string containing the MoA results with the following structure:
|
||||
{
|
||||
"success": bool,
|
||||
"response": str,
|
||||
"models_used": {
|
||||
"reference_models": List[str],
|
||||
"aggregator_model": str
|
||||
},
|
||||
"processing_time": float
|
||||
}
|
||||
|
||||
Raises:
|
||||
Exception: If MoA processing fails or API key is not set
|
||||
"""
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"user_prompt": user_prompt[:200] + "..." if len(user_prompt) > 200 else user_prompt,
|
||||
"reference_models": reference_models or REFERENCE_MODELS,
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL,
|
||||
"reference_temperature": REFERENCE_TEMPERATURE,
|
||||
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"reference_responses_count": 0,
|
||||
"failed_models_count": 0,
|
||||
"failed_models": [],
|
||||
"final_response_length": 0,
|
||||
"processing_time_seconds": 0,
|
||||
"models_used": {}
|
||||
}
|
||||
|
||||
try:
|
||||
print(f"🚀 Starting Mixture-of-Agents processing...")
|
||||
print(f"📝 Query: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}")
|
||||
|
||||
# Validate API key availability
|
||||
if not os.getenv("NOUS_API_KEY"):
|
||||
raise ValueError("NOUS_API_KEY environment variable not set")
|
||||
|
||||
# Use provided models or defaults
|
||||
ref_models = reference_models or REFERENCE_MODELS
|
||||
agg_model = aggregator_model or AGGREGATOR_MODEL
|
||||
|
||||
print(f"🔄 Using {len(ref_models)} reference models in 2-layer MoA architecture")
|
||||
|
||||
# Layer 1: Generate diverse responses from reference models (with failure handling)
|
||||
print("📡 Layer 1: Generating reference responses...")
|
||||
model_results = await asyncio.gather(*[
|
||||
_run_reference_model_safe(model, user_prompt, REFERENCE_TEMPERATURE)
|
||||
for model in ref_models
|
||||
])
|
||||
|
||||
# Separate successful and failed responses
|
||||
successful_responses = []
|
||||
failed_models = []
|
||||
|
||||
for model_name, content, success in model_results:
|
||||
if success:
|
||||
successful_responses.append(content)
|
||||
else:
|
||||
failed_models.append(model_name)
|
||||
|
||||
successful_count = len(successful_responses)
|
||||
failed_count = len(failed_models)
|
||||
|
||||
print(f"📊 Reference model results: {successful_count} successful, {failed_count} failed")
|
||||
|
||||
if failed_models:
|
||||
print(f"⚠️ Failed models: {', '.join(failed_models)}")
|
||||
|
||||
# Check if we have enough successful responses to proceed
|
||||
if successful_count < MIN_SUCCESSFUL_REFERENCES:
|
||||
raise ValueError(f"Insufficient successful reference models ({successful_count}/{len(ref_models)}). Need at least {MIN_SUCCESSFUL_REFERENCES} successful responses.")
|
||||
|
||||
debug_call_data["reference_responses_count"] = successful_count
|
||||
debug_call_data["failed_models_count"] = failed_count
|
||||
debug_call_data["failed_models"] = failed_models
|
||||
|
||||
# Layer 2: Aggregate responses using the aggregator model
|
||||
print("🧠 Layer 2: Synthesizing final response...")
|
||||
aggregator_system_prompt = _construct_aggregator_prompt(
|
||||
AGGREGATOR_SYSTEM_PROMPT,
|
||||
successful_responses
|
||||
)
|
||||
|
||||
final_response = await _run_aggregator_model(
|
||||
aggregator_system_prompt,
|
||||
user_prompt,
|
||||
AGGREGATOR_TEMPERATURE
|
||||
)
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.datetime.now()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
print(f"✅ MoA processing completed in {processing_time:.2f} seconds")
|
||||
|
||||
# Prepare successful response (only final aggregated result, minimal fields)
|
||||
result = {
|
||||
"success": True,
|
||||
"response": final_response,
|
||||
"models_used": {
|
||||
"reference_models": ref_models,
|
||||
"aggregator_model": agg_model
|
||||
}
|
||||
}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["final_response_length"] = len(final_response)
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
debug_call_data["models_used"] = result["models_used"]
|
||||
|
||||
# Log debug information
|
||||
_log_debug_call("mixture_of_agents_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error in MoA processing: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
||||
# Calculate processing time even for errors
|
||||
end_time = datetime.datetime.now()
|
||||
processing_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# Prepare error response (minimal fields)
|
||||
result = {
|
||||
"success": False,
|
||||
"response": "MoA processing failed. Please try again or use a single model for this query.",
|
||||
"models_used": {
|
||||
"reference_models": reference_models or REFERENCE_MODELS,
|
||||
"aggregator_model": aggregator_model or AGGREGATOR_MODEL
|
||||
},
|
||||
"error": error_msg
|
||||
}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
debug_call_data["processing_time_seconds"] = processing_time
|
||||
_log_debug_call("mixture_of_agents_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
def check_nous_api_key() -> bool:
|
||||
"""
|
||||
Check if the Nous Research API key is available in environment variables.
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
return bool(os.getenv("NOUS_API_KEY"))
|
||||
|
||||
|
||||
def check_moa_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for MoA tools are met.
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
return check_nous_api_key()
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return {
|
||||
"enabled": False,
|
||||
"session_id": None,
|
||||
"log_path": None,
|
||||
"total_calls": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"session_id": DEBUG_SESSION_ID,
|
||||
"log_path": str(DEBUG_LOG_PATH / f"moa_tools_debug_{DEBUG_SESSION_ID}.json"),
|
||||
"total_calls": len(DEBUG_DATA["tool_calls"])
|
||||
}
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, List[str]]:
|
||||
"""
|
||||
Get information about available models for MoA processing.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: Dictionary with reference and aggregator models
|
||||
"""
|
||||
return {
|
||||
"reference_models": REFERENCE_MODELS,
|
||||
"aggregator_models": [AGGREGATOR_MODEL],
|
||||
"supported_models": REFERENCE_MODELS + [AGGREGATOR_MODEL]
|
||||
}
|
||||
|
||||
|
||||
def get_moa_configuration() -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current MoA configuration settings.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing all configuration parameters
|
||||
"""
|
||||
return {
|
||||
"reference_models": REFERENCE_MODELS,
|
||||
"aggregator_model": AGGREGATOR_MODEL,
|
||||
"reference_temperature": REFERENCE_TEMPERATURE,
|
||||
"aggregator_temperature": AGGREGATOR_TEMPERATURE,
|
||||
"min_successful_references": MIN_SUCCESSFUL_REFERENCES,
|
||||
"total_reference_models": len(REFERENCE_MODELS),
|
||||
"failure_tolerance": f"{len(REFERENCE_MODELS) - MIN_SUCCESSFUL_REFERENCES}/{len(REFERENCE_MODELS)} models can fail"
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Simple test/demo when run directly
|
||||
"""
|
||||
print("🤖 Mixture-of-Agents Tool Module")
|
||||
print("=" * 50)
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_nous_api_key()
|
||||
|
||||
if not api_available:
|
||||
print("❌ NOUS_API_KEY environment variable not set")
|
||||
print("Please set your API key: export NOUS_API_KEY='your-key-here'")
|
||||
print("Get API key at: https://inference-api.nousresearch.com/")
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ Nous Research API key found")
|
||||
|
||||
print("🛠️ MoA tools ready for use!")
|
||||
|
||||
# Show current configuration
|
||||
config = get_moa_configuration()
|
||||
print(f"\n⚙️ Current Configuration:")
|
||||
print(f" 🤖 Reference models ({len(config['reference_models'])}): {', '.join(config['reference_models'])}")
|
||||
print(f" 🧠 Aggregator model: {config['aggregator_model']}")
|
||||
print(f" 🌡️ Reference temperature: {config['reference_temperature']}")
|
||||
print(f" 🌡️ Aggregator temperature: {config['aggregator_temperature']}")
|
||||
print(f" 🛡️ Failure tolerance: {config['failure_tolerance']}")
|
||||
print(f" 📊 Minimum successful models: {config['min_successful_references']}")
|
||||
|
||||
# Show debug mode status
|
||||
if DEBUG_MODE:
|
||||
print(f"\n🐛 Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}")
|
||||
print(f" Debug logs will be saved to: ./logs/moa_tools_debug_{DEBUG_SESSION_ID}.json")
|
||||
else:
|
||||
print("\n🐛 Debug mode disabled (set MOA_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from mixture_of_agents_tool import mixture_of_agents_tool")
|
||||
print(" import asyncio")
|
||||
print("")
|
||||
print(" async def main():")
|
||||
print(" result = await mixture_of_agents_tool(")
|
||||
print(" user_prompt='Solve this complex mathematical proof...'")
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
print("\nBest use cases:")
|
||||
print(" - Complex mathematical proofs and calculations")
|
||||
print(" - Advanced coding problems and algorithm design")
|
||||
print(" - Multi-step analytical reasoning tasks")
|
||||
print(" - Problems requiring diverse domain expertise")
|
||||
print(" - Tasks where single models show limitations")
|
||||
|
||||
print("\nPerformance characteristics:")
|
||||
print(" - Higher latency due to multiple model calls")
|
||||
print(" - Significantly improved quality for complex tasks")
|
||||
print(" - Parallel processing for efficiency")
|
||||
print(f" - Optimized temperatures: {REFERENCE_TEMPERATURE} for reference models, {AGGREGATOR_TEMPERATURE} for aggregation")
|
||||
print(" - Token-efficient: only returns final aggregated response")
|
||||
print(" - Resilient: continues with partial model failures")
|
||||
print(f" - Configurable: easy to modify models and settings at top of file")
|
||||
print(" - State-of-the-art results on challenging benchmarks")
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export MOA_TOOLS_DEBUG=true")
|
||||
print(" # Debug logs capture all MoA processing steps and metrics")
|
||||
print(" # Logs saved to: ./logs/moa_tools_debug_UUID.json")
|
||||
1157
model_tools.py
1157
model_tools.py
File diff suppressed because it is too large
Load Diff
@@ -1,2 +1,3 @@
|
||||
tavily-python
|
||||
openai
|
||||
firecrawl-py
|
||||
openai
|
||||
fal-client
|
||||
347
run_agent.py
347
run_agent.py
@@ -26,6 +26,7 @@ import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI
|
||||
import fire
|
||||
from datetime import datetime
|
||||
|
||||
# Import our tool system
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
@@ -45,7 +46,13 @@ class AIAgent:
|
||||
api_key: str = None,
|
||||
model: str = "gpt-4",
|
||||
max_iterations: int = 10,
|
||||
tool_delay: float = 1.0
|
||||
tool_delay: float = 1.0,
|
||||
enabled_tools: List[str] = None,
|
||||
disabled_tools: List[str] = None,
|
||||
enabled_toolsets: List[str] = None,
|
||||
disabled_toolsets: List[str] = None,
|
||||
save_trajectories: bool = False,
|
||||
morph_snapshot_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -56,10 +63,24 @@ class AIAgent:
|
||||
model (str): Model name to use (default: "gpt-4")
|
||||
max_iterations (int): Maximum number of tool calling iterations (default: 10)
|
||||
tool_delay (float): Delay between tool calls in seconds (default: 1.0)
|
||||
enabled_tools (List[str]): Only enable these specific tools (optional)
|
||||
disabled_tools (List[str]): Disable these specific tools (optional)
|
||||
enabled_toolsets (List[str]): Only enable tools from these toolsets (optional)
|
||||
disabled_toolsets (List[str]): Disable tools from these toolsets (optional)
|
||||
save_trajectories (bool): Whether to save conversation trajectories to JSONL files (default: False)
|
||||
morph_snapshot_id (str | None): Morph Cloud snapshot id from which to start terminal tool
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
self.tool_delay = tool_delay
|
||||
self.save_trajectories = save_trajectories
|
||||
self.morph_snapshot_id = morph_snapshot_id
|
||||
|
||||
# Store tool filtering options
|
||||
self.enabled_tools = enabled_tools
|
||||
self.disabled_tools = disabled_tools
|
||||
self.enabled_toolsets = enabled_toolsets
|
||||
self.disabled_toolsets = disabled_toolsets
|
||||
|
||||
# Initialize OpenAI client
|
||||
client_kwargs = {}
|
||||
@@ -78,40 +99,215 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize OpenAI client: {e}")
|
||||
|
||||
# Get available tools
|
||||
self.tools = get_tool_definitions()
|
||||
print(f"🛠️ Loaded {len(self.tools)} tools")
|
||||
# Get available tools with filtering
|
||||
self.tools = get_tool_definitions(
|
||||
enabled_tools=enabled_tools,
|
||||
disabled_tools=disabled_tools,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
disabled_toolsets=disabled_toolsets
|
||||
)
|
||||
|
||||
# Show tool configuration
|
||||
if self.tools:
|
||||
tool_names = [tool["function"]["name"] for tool in self.tools]
|
||||
print(f"🛠️ Loaded {len(self.tools)} tools: {', '.join(tool_names)}")
|
||||
|
||||
# Show filtering info if applied
|
||||
if enabled_tools:
|
||||
print(f" ✅ Enabled tools: {', '.join(enabled_tools)}")
|
||||
if disabled_tools:
|
||||
print(f" ❌ Disabled tools: {', '.join(disabled_tools)}")
|
||||
if enabled_toolsets:
|
||||
print(f" ✅ Enabled toolsets: {', '.join(enabled_toolsets)}")
|
||||
if disabled_toolsets:
|
||||
print(f" ❌ Disabled toolsets: {', '.join(disabled_toolsets)}")
|
||||
else:
|
||||
print("🛠️ No tools loaded (all tools filtered out or unavailable)")
|
||||
|
||||
# Check tool requirements
|
||||
requirements = check_toolset_requirements()
|
||||
missing_reqs = [name for name, available in requirements.items() if not available]
|
||||
if missing_reqs:
|
||||
print(f"⚠️ Some tools may not work due to missing requirements: {missing_reqs}")
|
||||
if self.tools:
|
||||
requirements = check_toolset_requirements()
|
||||
missing_reqs = [name for name, available in requirements.items() if not available]
|
||||
if missing_reqs:
|
||||
print(f"⚠️ Some tools may not work due to missing requirements: {missing_reqs}")
|
||||
|
||||
# Show trajectory saving status
|
||||
if self.save_trajectories:
|
||||
print("📝 Trajectory saving enabled")
|
||||
|
||||
def create_system_message(self, custom_system: str = None) -> str:
|
||||
def _format_tools_for_system_message(self) -> str:
|
||||
"""
|
||||
Create the system message for the agent.
|
||||
Format tool definitions for the system message in the trajectory format.
|
||||
|
||||
Returns:
|
||||
str: JSON string representation of tool definitions
|
||||
"""
|
||||
if not self.tools:
|
||||
return "[]"
|
||||
|
||||
# Convert tool definitions to the format expected in trajectories
|
||||
formatted_tools = []
|
||||
for tool in self.tools:
|
||||
func = tool["function"]
|
||||
formatted_tool = {
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
"required": None # Match the format in the example
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return json.dumps(formatted_tools)
|
||||
|
||||
def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert internal message format to trajectory format for saving.
|
||||
|
||||
Args:
|
||||
custom_system (str): Custom system message (optional)
|
||||
messages (List[Dict]): Internal message history
|
||||
user_query (str): Original user query
|
||||
completed (bool): Whether the conversation completed successfully
|
||||
|
||||
Returns:
|
||||
str: System message content
|
||||
List[Dict]: Messages in trajectory format
|
||||
"""
|
||||
if custom_system:
|
||||
return custom_system
|
||||
trajectory = []
|
||||
|
||||
return (
|
||||
"You are an AI assistant that provides helpful responses. You may use extremely long chains of thought "
|
||||
"to deeply consider the problem and deliberate with yourself via systematic reasoning processes to help "
|
||||
"come to a correct solution prior to answering. You should enclose your thoughts and internal monologue "
|
||||
"inside <thinking> tags.\n\n"
|
||||
"You are equipped with web research tools that allow you to search the web, extract content from web pages, "
|
||||
"and crawl websites. Use these tools to gather current information and provide accurate, well-researched responses. "
|
||||
"You can call multiple tools in parallel if they are not reliant on each other's results. You can also use "
|
||||
"sequential tool calls to build on data you've collected from previous tool calls. Continue using tools until "
|
||||
"you feel confident you have enough information to provide a comprehensive answer."
|
||||
# Add system message with tool definitions
|
||||
system_msg = (
|
||||
"You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. "
|
||||
"You may call one or more functions to assist with the user query. If available tools are not relevant in assisting "
|
||||
"with user query, just respond in natural conversational language. Don't make assumptions about what values to plug "
|
||||
"into functions. After calling & executing the functions, you will be provided with function results within "
|
||||
"<tool_response> </tool_response> XML tags. Here are the available tools:\n"
|
||||
f"<tools>\n{self._format_tools_for_system_message()}\n</tools>\n"
|
||||
"For each function call return a JSON object, with the following pydantic model json schema for each:\n"
|
||||
"{'title': 'FunctionCall', 'type': 'object', 'properties': {'name': {'title': 'Name', 'type': 'string'}, "
|
||||
"'arguments': {'title': 'Arguments', 'type': 'object'}}, 'required': ['name', 'arguments']}\n"
|
||||
"Each function call should be enclosed within <tool_call> </tool_call> XML tags.\n"
|
||||
"Example:\n<tool_call>\n{'name': <function-name>,'arguments': <args-dict>}\n</tool_call>"
|
||||
)
|
||||
|
||||
trajectory.append({
|
||||
"from": "system",
|
||||
"value": system_msg
|
||||
})
|
||||
|
||||
# Add the initial user message
|
||||
trajectory.append({
|
||||
"from": "human",
|
||||
"value": user_query
|
||||
})
|
||||
|
||||
# Process remaining messages
|
||||
i = 1 # Skip the first user message as we already added it
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
if msg["role"] == "assistant":
|
||||
# Check if this message has tool calls
|
||||
if "tool_calls" in msg and msg["tool_calls"]:
|
||||
# Format assistant message with tool calls
|
||||
content = ""
|
||||
if msg.get("content") and msg["content"].strip():
|
||||
content = msg["content"] + "\n"
|
||||
|
||||
# Add tool calls wrapped in XML tags
|
||||
for tool_call in msg["tool_calls"]:
|
||||
tool_call_json = {
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": json.loads(tool_call["function"]["arguments"]) if isinstance(tool_call["function"]["arguments"], str) else tool_call["function"]["arguments"]
|
||||
}
|
||||
content += f"<tool_call>\n{json.dumps(tool_call_json)}\n</tool_call>\n"
|
||||
|
||||
trajectory.append({
|
||||
"from": "gpt",
|
||||
"value": content.rstrip()
|
||||
})
|
||||
|
||||
# Collect all subsequent tool responses
|
||||
tool_responses = []
|
||||
j = i + 1
|
||||
while j < len(messages) and messages[j]["role"] == "tool":
|
||||
tool_msg = messages[j]
|
||||
# Format tool response with XML tags
|
||||
tool_response = f"<tool_response>\n"
|
||||
|
||||
# Try to parse tool content as JSON if it looks like JSON
|
||||
tool_content = tool_msg["content"]
|
||||
try:
|
||||
if tool_content.strip().startswith(("{", "[")):
|
||||
tool_content = json.loads(tool_content)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass # Keep as string if not valid JSON
|
||||
|
||||
tool_response += json.dumps({
|
||||
"tool_call_id": tool_msg.get("tool_call_id", ""),
|
||||
"name": msg["tool_calls"][len(tool_responses)]["function"]["name"] if len(tool_responses) < len(msg["tool_calls"]) else "unknown",
|
||||
"content": tool_content
|
||||
})
|
||||
tool_response += "\n</tool_response>"
|
||||
tool_responses.append(tool_response)
|
||||
j += 1
|
||||
|
||||
# Add all tool responses as a single message
|
||||
if tool_responses:
|
||||
trajectory.append({
|
||||
"from": "tool",
|
||||
"value": "\n".join(tool_responses)
|
||||
})
|
||||
i = j - 1 # Skip the tool messages we just processed
|
||||
|
||||
else:
|
||||
# Regular assistant message without tool calls
|
||||
trajectory.append({
|
||||
"from": "gpt",
|
||||
"value": msg["content"] or ""
|
||||
})
|
||||
|
||||
elif msg["role"] == "user":
|
||||
trajectory.append({
|
||||
"from": "human",
|
||||
"value": msg["content"]
|
||||
})
|
||||
|
||||
i += 1
|
||||
|
||||
return trajectory
|
||||
|
||||
def _save_trajectory(self, messages: List[Dict[str, Any]], user_query: str, completed: bool):
|
||||
"""
|
||||
Save conversation trajectory to JSONL file.
|
||||
|
||||
Args:
|
||||
messages (List[Dict]): Complete message history
|
||||
user_query (str): Original user query
|
||||
completed (bool): Whether the conversation completed successfully
|
||||
"""
|
||||
if not self.save_trajectories:
|
||||
return
|
||||
|
||||
# Convert messages to trajectory format
|
||||
trajectory = self._convert_to_trajectory_format(messages, user_query, completed)
|
||||
|
||||
# Determine which file to save to
|
||||
filename = "trajectory_samples.jsonl" if completed else "failed_trajectories.jsonl"
|
||||
|
||||
# Create trajectory entry
|
||||
entry = {
|
||||
"conversations": trajectory,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": self.model,
|
||||
"completed": completed
|
||||
}
|
||||
|
||||
# Append to JSONL file
|
||||
try:
|
||||
with open(filename, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||
print(f"💾 Trajectory saved to {filename}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to save trajectory: {e}")
|
||||
|
||||
def run_conversation(
|
||||
self,
|
||||
@@ -133,13 +329,6 @@ class AIAgent:
|
||||
# Initialize conversation
|
||||
messages = conversation_history or []
|
||||
|
||||
# Add system message if not already present
|
||||
if not messages or messages[0]["role"] != "system":
|
||||
messages.insert(0, {
|
||||
"role": "system",
|
||||
"content": self.create_system_message(system_message)
|
||||
})
|
||||
|
||||
# Add user message
|
||||
messages.append({
|
||||
"role": "user",
|
||||
@@ -202,6 +391,9 @@ class AIAgent:
|
||||
function_args = {}
|
||||
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())})")
|
||||
|
||||
if function_name == "terminal" and self.morph_snapshot_id is not None:
|
||||
function_args["snapshot_id"] = self.morph_snapshot_id
|
||||
|
||||
# Execute the tool
|
||||
function_result = handle_function_call(function_name, function_args)
|
||||
@@ -256,11 +448,17 @@ class AIAgent:
|
||||
if final_response is None:
|
||||
final_response = "I've reached the maximum number of iterations. Here's what I found so far."
|
||||
|
||||
# Determine if conversation completed successfully
|
||||
completed = final_response is not None and api_call_count < self.max_iterations
|
||||
|
||||
# Save trajectory if enabled
|
||||
self._save_trajectory(messages, user_message, completed)
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": final_response is not None
|
||||
"completed": completed
|
||||
}
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
@@ -282,7 +480,14 @@ def main(
|
||||
model: str = "claude-opus-4-20250514",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://api.anthropic.com/v1/",
|
||||
max_turns: int = 10
|
||||
max_turns: int = 10,
|
||||
enabled_tools: str = None,
|
||||
disabled_tools: str = None,
|
||||
enabled_toolsets: str = None,
|
||||
disabled_toolsets: str = None,
|
||||
list_tools: bool = False,
|
||||
save_trajectories: bool = False,
|
||||
morph_snapshot_id: str | None = None,
|
||||
):
|
||||
"""
|
||||
Main function for running the agent directly.
|
||||
@@ -293,17 +498,91 @@ def main(
|
||||
api_key (str): API key for authentication. Uses ANTHROPIC_API_KEY env var if not provided.
|
||||
base_url (str): Base URL for the model API. Defaults to https://api.anthropic.com/v1/
|
||||
max_turns (int): Maximum number of API call iterations. Defaults to 10.
|
||||
enabled_tools (str): Comma-separated list of tools to enable (e.g., "web_search,terminal")
|
||||
disabled_tools (str): Comma-separated list of tools to disable (e.g., "terminal")
|
||||
enabled_toolsets (str): Comma-separated list of toolsets to enable (e.g., "web_tools")
|
||||
disabled_toolsets (str): Comma-separated list of toolsets to disable (e.g., "terminal_tools")
|
||||
list_tools (bool): Just list available tools and exit
|
||||
save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False.
|
||||
morph_snapshot_id (str | None): Morph Cloud snapshot id to start terminal tool from
|
||||
"""
|
||||
print("🤖 AI Agent with Tool Calling")
|
||||
print("=" * 50)
|
||||
|
||||
# Handle tool listing
|
||||
if list_tools:
|
||||
from model_tools import get_all_tool_names, get_toolset_for_tool, get_available_toolsets
|
||||
|
||||
print("📋 Available Tools & Toolsets:")
|
||||
print("-" * 30)
|
||||
|
||||
# Show toolsets
|
||||
toolsets = get_available_toolsets()
|
||||
print("📦 Toolsets:")
|
||||
for name, info in toolsets.items():
|
||||
status = "✅" if info["available"] else "❌"
|
||||
print(f" {status} {name}: {info['description']}")
|
||||
if not info["available"]:
|
||||
print(f" Requirements: {', '.join(info['requirements'])}")
|
||||
|
||||
# Show individual tools
|
||||
all_tools = get_all_tool_names()
|
||||
print(f"\n🔧 Individual Tools ({len(all_tools)} available):")
|
||||
for tool_name in all_tools:
|
||||
toolset = get_toolset_for_tool(tool_name)
|
||||
print(f" 📌 {tool_name} (from {toolset})")
|
||||
|
||||
print(f"\n💡 Usage Examples:")
|
||||
print(f" # Run with only web tools")
|
||||
print(f" python run_agent.py --enabled_toolsets=web_tools --query='search for Python news'")
|
||||
print(f" # Run with specific tools only")
|
||||
print(f" python run_agent.py --enabled_tools=web_search,web_extract --query='research topic'")
|
||||
print(f" # Run without terminal tools")
|
||||
print(f" python run_agent.py --disabled_tools=terminal --query='web research only'")
|
||||
print(f" # Run with trajectory saving enabled")
|
||||
print(f" python run_agent.py --save_trajectories --query='your question here'")
|
||||
return
|
||||
|
||||
# Parse tool selection arguments
|
||||
enabled_tools_list = None
|
||||
disabled_tools_list = None
|
||||
enabled_toolsets_list = None
|
||||
disabled_toolsets_list = None
|
||||
|
||||
if enabled_tools:
|
||||
enabled_tools_list = [t.strip() for t in enabled_tools.split(",")]
|
||||
print(f"🎯 Enabled tools: {enabled_tools_list}")
|
||||
|
||||
if disabled_tools:
|
||||
disabled_tools_list = [t.strip() for t in disabled_tools.split(",")]
|
||||
print(f"🚫 Disabled tools: {disabled_tools_list}")
|
||||
|
||||
if enabled_toolsets:
|
||||
enabled_toolsets_list = [t.strip() for t in enabled_toolsets.split(",")]
|
||||
print(f"🎯 Enabled toolsets: {enabled_toolsets_list}")
|
||||
|
||||
if disabled_toolsets:
|
||||
disabled_toolsets_list = [t.strip() for t in disabled_toolsets.split(",")]
|
||||
print(f"🚫 Disabled toolsets: {disabled_toolsets_list}")
|
||||
|
||||
if save_trajectories:
|
||||
print(f"💾 Trajectory saving: ENABLED")
|
||||
print(f" - Successful conversations → trajectory_samples.jsonl")
|
||||
print(f" - Failed conversations → failed_trajectories.jsonl")
|
||||
|
||||
# Initialize agent with provided parameters
|
||||
try:
|
||||
agent = AIAgent(
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
max_iterations=max_turns
|
||||
max_iterations=max_turns,
|
||||
enabled_tools=enabled_tools_list,
|
||||
disabled_tools=disabled_tools_list,
|
||||
enabled_toolsets=enabled_toolsets_list,
|
||||
disabled_toolsets=disabled_toolsets_list,
|
||||
save_trajectories=save_trajectories,
|
||||
morph_snapshot_id=morph_snapshot_id
|
||||
)
|
||||
except RuntimeError as e:
|
||||
print(f"❌ Failed to initialize agent: {e}")
|
||||
|
||||
@@ -78,7 +78,8 @@ def terminal_tool(
|
||||
session_id: Optional[str] = None,
|
||||
background: bool = False,
|
||||
idle_threshold: float = 5.0,
|
||||
timeout: Optional[int] = None
|
||||
timeout: Optional[int] = None,
|
||||
snapshot_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Execute a command on a Morph VM with optional interactive session support.
|
||||
@@ -136,7 +137,7 @@ def terminal_tool(
|
||||
)
|
||||
|
||||
# Execute with lifecycle management
|
||||
result = run_tool_with_lifecycle_management(tool_call)
|
||||
result = run_tool_with_lifecycle_management(tool_call, snapshot_id=snapshot_id)
|
||||
|
||||
# Format the result with all possible fields
|
||||
# Map hecate's "stdout" to "output" for compatibility
|
||||
@@ -231,4 +232,4 @@ if __name__ == "__main__":
|
||||
print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}")
|
||||
print(f" OPENAI_API_KEY: {'Set' if os.getenv('OPENAI_API_KEY') else 'Not set (optional)'}")
|
||||
print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300)")
|
||||
print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_p5294qxt')} (default: snapshot_p5294qxt)")
|
||||
print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_p5294qxt')} (default: snapshot_p5294qxt)")
|
||||
|
||||
14
test_run.sh
Normal file
14
test_run.sh
Normal file
@@ -0,0 +1,14 @@
|
||||
export WEB_TOOLS_DEBUG=true
|
||||
|
||||
python run_agent.py \
|
||||
--query "Tell me about this animal pictured: https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQi1nkrYXY-ijQv5aCxkwooyg2roNFxj0ewJA&s" \
|
||||
--max_turns 30 \
|
||||
--model claude-sonnet-4-20250514 \
|
||||
--base_url https://api.anthropic.com/v1/ \
|
||||
--api_key $ANTHROPIC_API_KEY \
|
||||
--enabled_toolsets=vision_tools
|
||||
|
||||
#Possible Toolsets:
|
||||
#web_tools
|
||||
#vision_tools
|
||||
#terminal_tools
|
||||
620
test_web_tools.py
Normal file
620
test_web_tools.py
Normal file
@@ -0,0 +1,620 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive Test Suite for Web Tools Module
|
||||
|
||||
This script tests all web tools functionality to ensure they work correctly.
|
||||
Run this after any updates to the web_tools.py module or Firecrawl library.
|
||||
|
||||
Usage:
|
||||
python test_web_tools.py # Run all tests
|
||||
python test_web_tools.py --no-llm # Skip LLM processing tests
|
||||
python test_web_tools.py --verbose # Show detailed output
|
||||
|
||||
Requirements:
|
||||
- FIRECRAWL_API_KEY environment variable must be set
|
||||
- NOUS_API_KEY environment vitinariable (optional, for LLM tests)
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# Import the web tools to test
|
||||
from web_tools import (
|
||||
web_search_tool,
|
||||
web_extract_tool,
|
||||
web_crawl_tool,
|
||||
check_firecrawl_api_key,
|
||||
check_nous_api_key,
|
||||
get_debug_session_info
|
||||
)
|
||||
|
||||
|
||||
class Colors:
|
||||
"""ANSI color codes for terminal output"""
|
||||
HEADER = '\033[95m'
|
||||
BLUE = '\033[94m'
|
||||
CYAN = '\033[96m'
|
||||
GREEN = '\033[92m'
|
||||
WARNING = '\033[93m'
|
||||
FAIL = '\033[91m'
|
||||
ENDC = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
UNDERLINE = '\033[4m'
|
||||
|
||||
|
||||
def print_header(text: str):
|
||||
"""Print a formatted header"""
|
||||
print(f"\n{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}")
|
||||
print(f"{Colors.HEADER}{Colors.BOLD}{text}{Colors.ENDC}")
|
||||
print(f"{Colors.HEADER}{Colors.BOLD}{'='*60}{Colors.ENDC}")
|
||||
|
||||
|
||||
def print_section(text: str):
|
||||
"""Print a formatted section header"""
|
||||
print(f"\n{Colors.CYAN}{Colors.BOLD}📌 {text}{Colors.ENDC}")
|
||||
print(f"{Colors.CYAN}{'-'*50}{Colors.ENDC}")
|
||||
|
||||
|
||||
def print_success(text: str):
|
||||
"""Print success message"""
|
||||
print(f"{Colors.GREEN}✅ {text}{Colors.ENDC}")
|
||||
|
||||
|
||||
def print_error(text: str):
|
||||
"""Print error message"""
|
||||
print(f"{Colors.FAIL}❌ {text}{Colors.ENDC}")
|
||||
|
||||
|
||||
def print_warning(text: str):
|
||||
"""Print warning message"""
|
||||
print(f"{Colors.WARNING}⚠️ {text}{Colors.ENDC}")
|
||||
|
||||
|
||||
def print_info(text: str, indent: int = 0):
|
||||
"""Print info message"""
|
||||
indent_str = " " * indent
|
||||
print(f"{indent_str}{Colors.BLUE}ℹ️ {text}{Colors.ENDC}")
|
||||
|
||||
|
||||
class WebToolsTester:
|
||||
"""Test suite for web tools"""
|
||||
|
||||
def __init__(self, verbose: bool = False, test_llm: bool = True):
|
||||
self.verbose = verbose
|
||||
self.test_llm = test_llm
|
||||
self.test_results = {
|
||||
"passed": [],
|
||||
"failed": [],
|
||||
"skipped": []
|
||||
}
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
|
||||
def log_result(self, test_name: str, status: str, details: str = ""):
|
||||
"""Log test result"""
|
||||
result = {
|
||||
"test": test_name,
|
||||
"status": status,
|
||||
"details": details,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if status == "passed":
|
||||
self.test_results["passed"].append(result)
|
||||
print_success(f"{test_name}: {details}" if details else test_name)
|
||||
elif status == "failed":
|
||||
self.test_results["failed"].append(result)
|
||||
print_error(f"{test_name}: {details}" if details else test_name)
|
||||
elif status == "skipped":
|
||||
self.test_results["skipped"].append(result)
|
||||
print_warning(f"{test_name} skipped: {details}" if details else f"{test_name} skipped")
|
||||
|
||||
def test_environment(self) -> bool:
|
||||
"""Test environment setup and API keys"""
|
||||
print_section("Environment Check")
|
||||
|
||||
# Check Firecrawl API key
|
||||
if not check_firecrawl_api_key():
|
||||
self.log_result("Firecrawl API Key", "failed", "FIRECRAWL_API_KEY not set")
|
||||
return False
|
||||
else:
|
||||
self.log_result("Firecrawl API Key", "passed", "Found")
|
||||
|
||||
# Check Nous API key (optional)
|
||||
if not check_nous_api_key():
|
||||
self.log_result("Nous API Key", "skipped", "NOUS_API_KEY not set (LLM tests will be skipped)")
|
||||
self.test_llm = False
|
||||
else:
|
||||
self.log_result("Nous API Key", "passed", "Found")
|
||||
|
||||
# Check debug mode
|
||||
debug_info = get_debug_session_info()
|
||||
if debug_info["enabled"]:
|
||||
print_info(f"Debug mode enabled - Session: {debug_info['session_id']}")
|
||||
print_info(f"Debug log: {debug_info['log_path']}")
|
||||
|
||||
return True
|
||||
|
||||
def test_web_search(self) -> List[str]:
|
||||
"""Test web search functionality"""
|
||||
print_section("Test 1: Web Search")
|
||||
|
||||
test_queries = [
|
||||
("Python web scraping tutorial", 5),
|
||||
("Firecrawl API documentation", 3),
|
||||
("inflammatory arthritis symptoms treatment", 8) # Test medical query from your example
|
||||
]
|
||||
|
||||
extracted_urls = []
|
||||
|
||||
for query, limit in test_queries:
|
||||
try:
|
||||
print(f"\n Testing search: '{query}' (limit={limit})")
|
||||
|
||||
if self.verbose:
|
||||
print(f" Calling web_search_tool(query='{query}', limit={limit})")
|
||||
|
||||
# Perform search
|
||||
result = web_search_tool(query, limit)
|
||||
|
||||
# Parse result
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except json.JSONDecodeError as e:
|
||||
self.log_result(f"Search: {query[:30]}...", "failed", f"Invalid JSON: {e}")
|
||||
if self.verbose:
|
||||
print(f" Raw response (first 500 chars): {result[:500]}...")
|
||||
continue
|
||||
|
||||
if "error" in data:
|
||||
self.log_result(f"Search: {query[:30]}...", "failed", f"API error: {data['error']}")
|
||||
continue
|
||||
|
||||
# Check structure
|
||||
if "success" not in data or "data" not in data:
|
||||
self.log_result(f"Search: {query[:30]}...", "failed", "Missing success or data fields")
|
||||
if self.verbose:
|
||||
print(f" Response keys: {list(data.keys())}")
|
||||
continue
|
||||
|
||||
web_results = data.get("data", {}).get("web", [])
|
||||
|
||||
if not web_results:
|
||||
self.log_result(f"Search: {query[:30]}...", "failed", "Empty web results array")
|
||||
if self.verbose:
|
||||
print(f" data.web content: {data.get('data', {}).get('web')}")
|
||||
continue
|
||||
|
||||
# Validate each result
|
||||
valid_results = 0
|
||||
missing_fields = []
|
||||
|
||||
for i, result in enumerate(web_results):
|
||||
required_fields = ["url", "title", "description"]
|
||||
has_all_fields = all(key in result for key in required_fields)
|
||||
|
||||
if has_all_fields:
|
||||
valid_results += 1
|
||||
# Collect URLs for extraction test
|
||||
if len(extracted_urls) < 3:
|
||||
extracted_urls.append(result["url"])
|
||||
|
||||
if self.verbose:
|
||||
print(f" Result {i+1}: ✓ {result['title'][:50]}...")
|
||||
print(f" URL: {result['url'][:60]}...")
|
||||
else:
|
||||
missing = [f for f in required_fields if f not in result]
|
||||
missing_fields.append(f"Result {i+1} missing: {missing}")
|
||||
if self.verbose:
|
||||
print(f" Result {i+1}: ✗ Missing fields: {missing}")
|
||||
|
||||
# Log results
|
||||
if valid_results == len(web_results):
|
||||
self.log_result(
|
||||
f"Search: {query[:30]}...",
|
||||
"passed",
|
||||
f"All {valid_results} results valid"
|
||||
)
|
||||
else:
|
||||
self.log_result(
|
||||
f"Search: {query[:30]}...",
|
||||
"failed",
|
||||
f"Only {valid_results}/{len(web_results)} valid. Issues: {'; '.join(missing_fields[:3])}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.log_result(f"Search: {query[:30]}...", "failed", f"Exception: {type(e).__name__}: {str(e)}")
|
||||
if self.verbose:
|
||||
import traceback
|
||||
print(f" Traceback: {traceback.format_exc()}")
|
||||
|
||||
if self.verbose and extracted_urls:
|
||||
print(f"\n URLs collected for extraction test: {len(extracted_urls)}")
|
||||
for url in extracted_urls:
|
||||
print(f" - {url}")
|
||||
|
||||
return extracted_urls
|
||||
|
||||
async def test_web_extract(self, urls: List[str] = None):
|
||||
"""Test web content extraction"""
|
||||
print_section("Test 2: Web Extract (without LLM)")
|
||||
|
||||
# Use provided URLs or defaults
|
||||
if not urls:
|
||||
urls = [
|
||||
"https://docs.firecrawl.dev/introduction",
|
||||
"https://www.python.org/about/"
|
||||
]
|
||||
print(f" Using default URLs for testing")
|
||||
else:
|
||||
print(f" Using {len(urls)} URLs from search results")
|
||||
|
||||
# Test extraction
|
||||
if urls:
|
||||
try:
|
||||
test_urls = urls[:2] # Test with max 2 URLs
|
||||
print(f"\n Extracting content from {len(test_urls)} URL(s)...")
|
||||
for url in test_urls:
|
||||
print(f" - {url}")
|
||||
|
||||
if self.verbose:
|
||||
print(f" Calling web_extract_tool(urls={test_urls}, format='markdown', use_llm_processing=False)")
|
||||
|
||||
result = await web_extract_tool(
|
||||
test_urls,
|
||||
format="markdown",
|
||||
use_llm_processing=False
|
||||
)
|
||||
|
||||
# Parse result
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except json.JSONDecodeError as e:
|
||||
self.log_result("Extract (no LLM)", "failed", f"Invalid JSON: {e}")
|
||||
if self.verbose:
|
||||
print(f" Raw response (first 500 chars): {result[:500]}...")
|
||||
return
|
||||
|
||||
if "error" in data:
|
||||
self.log_result("Extract (no LLM)", "failed", f"API error: {data['error']}")
|
||||
return
|
||||
|
||||
results = data.get("results", [])
|
||||
|
||||
if not results:
|
||||
self.log_result("Extract (no LLM)", "failed", "No results in response")
|
||||
if self.verbose:
|
||||
print(f" Response keys: {list(data.keys())}")
|
||||
return
|
||||
|
||||
# Validate each result
|
||||
valid_results = 0
|
||||
failed_results = 0
|
||||
total_content_length = 0
|
||||
extraction_details = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
title = result.get("title", "No title")
|
||||
content = result.get("content", "")
|
||||
error = result.get("error")
|
||||
|
||||
if error:
|
||||
failed_results += 1
|
||||
extraction_details.append(f"Page {i+1}: ERROR - {error}")
|
||||
if self.verbose:
|
||||
print(f" Page {i+1}: ✗ Error - {error}")
|
||||
elif content:
|
||||
content_len = len(content)
|
||||
total_content_length += content_len
|
||||
valid_results += 1
|
||||
extraction_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)")
|
||||
if self.verbose:
|
||||
print(f" Page {i+1}: ✓ {title[:50]}... - {content_len} characters")
|
||||
print(f" First 100 chars: {content[:100]}...")
|
||||
else:
|
||||
extraction_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)")
|
||||
if self.verbose:
|
||||
print(f" Page {i+1}: ⚠ {title[:50]}... - Empty content")
|
||||
|
||||
# Log results
|
||||
if valid_results > 0:
|
||||
self.log_result(
|
||||
"Extract (no LLM)",
|
||||
"passed",
|
||||
f"{valid_results}/{len(results)} pages extracted, {total_content_length} total chars"
|
||||
)
|
||||
else:
|
||||
self.log_result(
|
||||
"Extract (no LLM)",
|
||||
"failed",
|
||||
f"No valid content. {failed_results} errors, {len(results) - failed_results} empty"
|
||||
)
|
||||
if self.verbose:
|
||||
print(f"\n Extraction details:")
|
||||
for detail in extraction_details:
|
||||
print(f" {detail}")
|
||||
|
||||
except Exception as e:
|
||||
self.log_result("Extract (no LLM)", "failed", f"Exception: {type(e).__name__}: {str(e)}")
|
||||
if self.verbose:
|
||||
import traceback
|
||||
print(f" Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def test_web_extract_with_llm(self, urls: List[str] = None):
|
||||
"""Test web extraction with LLM processing"""
|
||||
print_section("Test 3: Web Extract (with Gemini LLM)")
|
||||
|
||||
if not self.test_llm:
|
||||
self.log_result("Extract (with LLM)", "skipped", "LLM testing disabled")
|
||||
return
|
||||
|
||||
# Use a URL likely to have substantial content
|
||||
test_url = urls[0] if urls else "https://docs.firecrawl.dev/features/scrape"
|
||||
|
||||
try:
|
||||
print(f"\n Extracting and processing: {test_url}")
|
||||
|
||||
result = await web_extract_tool(
|
||||
[test_url],
|
||||
format="markdown",
|
||||
use_llm_processing=True,
|
||||
min_length=1000 # Lower threshold for testing
|
||||
)
|
||||
|
||||
data = json.loads(result)
|
||||
|
||||
if "error" in data:
|
||||
self.log_result("Extract (with LLM)", "failed", data["error"])
|
||||
return
|
||||
|
||||
results = data.get("results", [])
|
||||
|
||||
if not results:
|
||||
self.log_result("Extract (with LLM)", "failed", "No results returned")
|
||||
return
|
||||
|
||||
result = results[0]
|
||||
content = result.get("content", "")
|
||||
|
||||
if content:
|
||||
content_len = len(content)
|
||||
|
||||
# Check if content was actually processed (should be shorter than typical raw content)
|
||||
if content_len > 0:
|
||||
self.log_result(
|
||||
"Extract (with LLM)",
|
||||
"passed",
|
||||
f"Content processed: {content_len} chars"
|
||||
)
|
||||
|
||||
if self.verbose:
|
||||
print(f"\n First 300 chars of processed content:")
|
||||
print(f" {content[:300]}...")
|
||||
else:
|
||||
self.log_result("Extract (with LLM)", "failed", "No content after processing")
|
||||
else:
|
||||
self.log_result("Extract (with LLM)", "failed", "No content field in result")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
self.log_result("Extract (with LLM)", "failed", f"Invalid JSON: {e}")
|
||||
except Exception as e:
|
||||
self.log_result("Extract (with LLM)", "failed", str(e))
|
||||
|
||||
async def test_web_crawl(self):
|
||||
"""Test web crawling functionality"""
|
||||
print_section("Test 4: Web Crawl")
|
||||
|
||||
test_sites = [
|
||||
("https://docs.firecrawl.dev", None, 2), # Test docs site
|
||||
("https://firecrawl.dev", None, 3), # Test main site
|
||||
]
|
||||
|
||||
for url, instructions, expected_min_pages in test_sites:
|
||||
try:
|
||||
print(f"\n Testing crawl of: {url}")
|
||||
if instructions:
|
||||
print(f" Instructions: {instructions}")
|
||||
else:
|
||||
print(f" No instructions (general crawl)")
|
||||
print(f" Expected minimum pages: {expected_min_pages}")
|
||||
|
||||
# Show what's being called
|
||||
if self.verbose:
|
||||
print(f" Calling web_crawl_tool(url='{url}', instructions={instructions}, use_llm_processing=False)")
|
||||
|
||||
result = await web_crawl_tool(
|
||||
url,
|
||||
instructions=instructions,
|
||||
use_llm_processing=False # Disable LLM for faster testing
|
||||
)
|
||||
|
||||
# Check if result is valid JSON
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except json.JSONDecodeError as e:
|
||||
self.log_result(f"Crawl: {url}", "failed", f"Invalid JSON response: {e}")
|
||||
if self.verbose:
|
||||
print(f" Raw response (first 500 chars): {result[:500]}...")
|
||||
continue
|
||||
|
||||
# Check for errors
|
||||
if "error" in data:
|
||||
self.log_result(f"Crawl: {url}", "failed", f"API error: {data['error']}")
|
||||
continue
|
||||
|
||||
# Get results
|
||||
results = data.get("results", [])
|
||||
|
||||
if not results:
|
||||
self.log_result(f"Crawl: {url}", "failed", "No pages in results array")
|
||||
if self.verbose:
|
||||
print(f" Full response: {json.dumps(data, indent=2)[:1000]}...")
|
||||
continue
|
||||
|
||||
# Analyze pages
|
||||
valid_pages = 0
|
||||
empty_pages = 0
|
||||
total_content = 0
|
||||
page_details = []
|
||||
|
||||
for i, page in enumerate(results):
|
||||
content = page.get("content", "")
|
||||
title = page.get("title", "Untitled")
|
||||
error = page.get("error")
|
||||
|
||||
if error:
|
||||
page_details.append(f"Page {i+1}: ERROR - {error}")
|
||||
elif content:
|
||||
valid_pages += 1
|
||||
content_len = len(content)
|
||||
total_content += content_len
|
||||
page_details.append(f"Page {i+1}: {title[:40]}... ({content_len} chars)")
|
||||
else:
|
||||
empty_pages += 1
|
||||
page_details.append(f"Page {i+1}: {title[:40]}... (EMPTY)")
|
||||
|
||||
# Show detailed results if verbose
|
||||
if self.verbose:
|
||||
print(f"\n Crawl Results:")
|
||||
print(f" Total pages returned: {len(results)}")
|
||||
print(f" Valid pages (with content): {valid_pages}")
|
||||
print(f" Empty pages: {empty_pages}")
|
||||
print(f" Total content size: {total_content} characters")
|
||||
print(f"\n Page Details:")
|
||||
for detail in page_details[:10]: # Show first 10 pages
|
||||
print(f" - {detail}")
|
||||
if len(page_details) > 10:
|
||||
print(f" ... and {len(page_details) - 10} more pages")
|
||||
|
||||
# Determine pass/fail
|
||||
if valid_pages >= expected_min_pages:
|
||||
self.log_result(
|
||||
f"Crawl: {url}",
|
||||
"passed",
|
||||
f"{valid_pages}/{len(results)} valid pages, {total_content} chars total"
|
||||
)
|
||||
else:
|
||||
self.log_result(
|
||||
f"Crawl: {url}",
|
||||
"failed",
|
||||
f"Only {valid_pages} valid pages (expected >= {expected_min_pages}), {empty_pages} empty, {len(results)} total"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.log_result(f"Crawl: {url}", "failed", f"Exception: {type(e).__name__}: {str(e)}")
|
||||
if self.verbose:
|
||||
import traceback
|
||||
print(f" Traceback:")
|
||||
print(" " + "\n ".join(traceback.format_exc().split("\n")))
|
||||
|
||||
async def run_all_tests(self):
|
||||
"""Run all tests"""
|
||||
self.start_time = datetime.now()
|
||||
|
||||
print_header("WEB TOOLS TEST SUITE")
|
||||
print(f"Started at: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
# Test environment
|
||||
if not self.test_environment():
|
||||
print_error("\nCannot proceed without required API keys!")
|
||||
return False
|
||||
|
||||
# Test search and collect URLs
|
||||
urls = self.test_web_search()
|
||||
|
||||
# Test extraction
|
||||
await self.test_web_extract(urls if urls else None)
|
||||
|
||||
# Test extraction with LLM
|
||||
if self.test_llm:
|
||||
await self.test_web_extract_with_llm(urls if urls else None)
|
||||
|
||||
# Test crawling
|
||||
await self.test_web_crawl()
|
||||
|
||||
# Print summary
|
||||
self.end_time = datetime.now()
|
||||
duration = (self.end_time - self.start_time).total_seconds()
|
||||
|
||||
print_header("TEST SUMMARY")
|
||||
print(f"Duration: {duration:.2f} seconds")
|
||||
print(f"\n{Colors.GREEN}Passed: {len(self.test_results['passed'])}{Colors.ENDC}")
|
||||
print(f"{Colors.FAIL}Failed: {len(self.test_results['failed'])}{Colors.ENDC}")
|
||||
print(f"{Colors.WARNING}Skipped: {len(self.test_results['skipped'])}{Colors.ENDC}")
|
||||
|
||||
# List failed tests
|
||||
if self.test_results["failed"]:
|
||||
print(f"\n{Colors.FAIL}{Colors.BOLD}Failed Tests:{Colors.ENDC}")
|
||||
for test in self.test_results["failed"]:
|
||||
print(f" - {test['test']}: {test['details']}")
|
||||
|
||||
# Save results to file
|
||||
self.save_results()
|
||||
|
||||
return len(self.test_results["failed"]) == 0
|
||||
|
||||
def save_results(self):
|
||||
"""Save test results to a JSON file"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"test_results_web_tools_{timestamp}.json"
|
||||
|
||||
results = {
|
||||
"test_suite": "Web Tools",
|
||||
"start_time": self.start_time.isoformat() if self.start_time else None,
|
||||
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||
"duration_seconds": (self.end_time - self.start_time).total_seconds() if self.start_time and self.end_time else None,
|
||||
"summary": {
|
||||
"passed": len(self.test_results["passed"]),
|
||||
"failed": len(self.test_results["failed"]),
|
||||
"skipped": len(self.test_results["skipped"])
|
||||
},
|
||||
"results": self.test_results,
|
||||
"environment": {
|
||||
"firecrawl_api_key": check_firecrawl_api_key(),
|
||||
"nous_api_key": check_nous_api_key(),
|
||||
"debug_mode": get_debug_session_info()["enabled"]
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print_info(f"Test results saved to: {filename}")
|
||||
except Exception as e:
|
||||
print_warning(f"Failed to save results: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(description="Test Web Tools Module")
|
||||
parser.add_argument("--no-llm", action="store_true", help="Skip LLM processing tests")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Show detailed output")
|
||||
parser.add_argument("--debug", action="store_true", help="Enable debug mode for web tools")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set debug mode if requested
|
||||
if args.debug:
|
||||
os.environ["WEB_TOOLS_DEBUG"] = "true"
|
||||
print_info("Debug mode enabled for web tools")
|
||||
|
||||
# Create tester
|
||||
tester = WebToolsTester(
|
||||
verbose=args.verbose,
|
||||
test_llm=not args.no_llm
|
||||
)
|
||||
|
||||
# Run tests
|
||||
success = await tester.run_all_tests()
|
||||
|
||||
# Exit with appropriate code
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
346
vision_tools.py
Normal file
346
vision_tools.py
Normal file
@@ -0,0 +1,346 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Vision Tools Module
|
||||
|
||||
This module provides vision analysis tools that work with image URLs.
|
||||
Uses Gemini Flash via Nous Research API for intelligent image understanding.
|
||||
|
||||
Available tools:
|
||||
- vision_analyze_tool: Analyze images from URLs with custom prompts
|
||||
|
||||
Features:
|
||||
- Comprehensive image description
|
||||
- Context-aware analysis based on user queries
|
||||
- Proper error handling and validation
|
||||
- Debug logging support
|
||||
|
||||
Usage:
|
||||
from vision_tools import vision_analyze_tool
|
||||
import asyncio
|
||||
|
||||
# Analyze an image
|
||||
result = await vision_analyze_tool(
|
||||
image_url="https://example.com/image.jpg",
|
||||
user_prompt="What architectural style is this building?"
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Initialize Nous Research API client for vision processing
|
||||
nous_client = AsyncOpenAI(
|
||||
api_key=os.getenv("NOUS_API_KEY"),
|
||||
base_url="https://inference-api.nousresearch.com/v1"
|
||||
)
|
||||
|
||||
# Configuration for vision processing
|
||||
DEFAULT_VISION_MODEL = "gemini-2.5-flash"
|
||||
|
||||
# Debug mode configuration
|
||||
DEBUG_MODE = os.getenv("VISION_TOOLS_DEBUG", "false").lower() == "true"
|
||||
DEBUG_SESSION_ID = str(uuid.uuid4())
|
||||
DEBUG_LOG_PATH = Path("./logs")
|
||||
DEBUG_DATA = {
|
||||
"session_id": DEBUG_SESSION_ID,
|
||||
"start_time": datetime.datetime.now().isoformat(),
|
||||
"debug_enabled": DEBUG_MODE,
|
||||
"tool_calls": []
|
||||
} if DEBUG_MODE else None
|
||||
|
||||
# Create logs directory if debug mode is enabled
|
||||
if DEBUG_MODE:
|
||||
DEBUG_LOG_PATH.mkdir(exist_ok=True)
|
||||
print(f"🐛 Vision debug mode enabled - Session ID: {DEBUG_SESSION_ID}")
|
||||
|
||||
|
||||
def _log_debug_call(tool_name: str, call_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Log a debug call entry to the global debug data structure.
|
||||
|
||||
Args:
|
||||
tool_name (str): Name of the tool being called
|
||||
call_data (Dict[str, Any]): Data about the call including parameters and results
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return
|
||||
|
||||
call_entry = {
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"tool_name": tool_name,
|
||||
**call_data
|
||||
}
|
||||
|
||||
DEBUG_DATA["tool_calls"].append(call_entry)
|
||||
|
||||
|
||||
def _save_debug_log() -> None:
|
||||
"""
|
||||
Save the current debug data to a JSON file in the logs directory.
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return
|
||||
|
||||
try:
|
||||
debug_filename = f"vision_tools_debug_{DEBUG_SESSION_ID}.json"
|
||||
debug_filepath = DEBUG_LOG_PATH / debug_filename
|
||||
|
||||
# Update end time
|
||||
DEBUG_DATA["end_time"] = datetime.datetime.now().isoformat()
|
||||
DEBUG_DATA["total_calls"] = len(DEBUG_DATA["tool_calls"])
|
||||
|
||||
with open(debug_filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(DEBUG_DATA, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"🐛 Vision debug log saved: {debug_filepath}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error saving vision debug log: {str(e)}")
|
||||
|
||||
|
||||
def _validate_image_url(url: str) -> bool:
|
||||
"""
|
||||
Basic validation of image URL format.
|
||||
|
||||
Args:
|
||||
url (str): The URL to validate
|
||||
|
||||
Returns:
|
||||
bool: True if URL appears to be valid, False otherwise
|
||||
"""
|
||||
if not url or not isinstance(url, str):
|
||||
return False
|
||||
|
||||
# Check if it's a valid URL format
|
||||
if not (url.startswith('http://') or url.startswith('https://')):
|
||||
return False
|
||||
|
||||
# Check for common image extensions (optional, as URLs may not have extensions)
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp', '.svg']
|
||||
|
||||
return True # Allow all HTTP/HTTPS URLs for flexibility
|
||||
|
||||
|
||||
async def vision_analyze_tool(
|
||||
image_url: str,
|
||||
user_prompt: str,
|
||||
model: str = DEFAULT_VISION_MODEL
|
||||
) -> str:
|
||||
"""
|
||||
Analyze an image from a URL using vision AI.
|
||||
|
||||
This tool processes images using Gemini Flash via Nous Research API.
|
||||
The user_prompt parameter is expected to be pre-formatted by the calling
|
||||
function (typically model_tools.py) to include both full description
|
||||
requests and specific questions.
|
||||
|
||||
Args:
|
||||
image_url (str): The URL of the image to analyze
|
||||
user_prompt (str): The pre-formatted prompt for the vision model
|
||||
model (str): The vision model to use (default: gemini-2.5-flash)
|
||||
|
||||
Returns:
|
||||
str: JSON string containing the analysis results with the following structure:
|
||||
{
|
||||
"success": bool,
|
||||
"analysis": str (defaults to error message if None)
|
||||
}
|
||||
|
||||
Raises:
|
||||
Exception: If analysis fails or API key is not set
|
||||
"""
|
||||
debug_call_data = {
|
||||
"parameters": {
|
||||
"image_url": image_url,
|
||||
"user_prompt": user_prompt,
|
||||
"model": model
|
||||
},
|
||||
"error": None,
|
||||
"success": False,
|
||||
"analysis_length": 0,
|
||||
"model_used": model
|
||||
}
|
||||
|
||||
try:
|
||||
print(f"🔍 Analyzing image from URL: {image_url[:60]}{'...' if len(image_url) > 60 else ''}")
|
||||
print(f"📝 User prompt: {user_prompt[:100]}{'...' if len(user_prompt) > 100 else ''}")
|
||||
|
||||
# Validate image URL
|
||||
if not _validate_image_url(image_url):
|
||||
raise ValueError("Invalid image URL format. Must start with http:// or https://")
|
||||
|
||||
# Check API key availability
|
||||
if not os.getenv("NOUS_API_KEY"):
|
||||
raise ValueError("NOUS_API_KEY environment variable not set")
|
||||
|
||||
# Use the prompt as provided (model_tools.py now handles full description formatting)
|
||||
comprehensive_prompt = user_prompt
|
||||
|
||||
# Prepare the message with image URL format
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": comprehensive_prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
print(f"🧠 Processing image with {model}...")
|
||||
|
||||
# Call the vision API
|
||||
response = await nous_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=0.1, # Low temperature for consistent analysis
|
||||
max_tokens=2000 # Generous limit for detailed analysis
|
||||
)
|
||||
|
||||
# Extract the analysis
|
||||
analysis = response.choices[0].message.content.strip()
|
||||
analysis_length = len(analysis)
|
||||
|
||||
print(f"✅ Image analysis completed ({analysis_length} characters)")
|
||||
|
||||
# Prepare successful response
|
||||
result = {
|
||||
"success": True,
|
||||
"analysis": analysis or "There was a problem with the request and the image could not be analyzed."
|
||||
}
|
||||
|
||||
debug_call_data["success"] = True
|
||||
debug_call_data["analysis_length"] = analysis_length
|
||||
|
||||
# Log debug information
|
||||
_log_debug_call("vision_analyze_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing image: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
||||
# Prepare error response
|
||||
result = {
|
||||
"success": False,
|
||||
"analysis": "There was a problem with the request and the image could not be analyzed."
|
||||
}
|
||||
|
||||
debug_call_data["error"] = error_msg
|
||||
_log_debug_call("vision_analyze_tool", debug_call_data)
|
||||
_save_debug_log()
|
||||
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
def check_nous_api_key() -> bool:
|
||||
"""
|
||||
Check if the Nous Research API key is available in environment variables.
|
||||
|
||||
Returns:
|
||||
bool: True if API key is set, False otherwise
|
||||
"""
|
||||
return bool(os.getenv("NOUS_API_KEY"))
|
||||
|
||||
|
||||
def check_vision_requirements() -> bool:
|
||||
"""
|
||||
Check if all requirements for vision tools are met.
|
||||
|
||||
Returns:
|
||||
bool: True if requirements are met, False otherwise
|
||||
"""
|
||||
return check_nous_api_key()
|
||||
|
||||
|
||||
def get_debug_session_info() -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the current debug session.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing debug session information
|
||||
"""
|
||||
if not DEBUG_MODE or not DEBUG_DATA:
|
||||
return {
|
||||
"enabled": False,
|
||||
"session_id": None,
|
||||
"log_path": None,
|
||||
"total_calls": 0
|
||||
}
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"session_id": DEBUG_SESSION_ID,
|
||||
"log_path": str(DEBUG_LOG_PATH / f"vision_tools_debug_{DEBUG_SESSION_ID}.json"),
|
||||
"total_calls": len(DEBUG_DATA["tool_calls"])
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Simple test/demo when run directly
|
||||
"""
|
||||
print("👁️ Vision Tools Module")
|
||||
print("=" * 40)
|
||||
|
||||
# Check if API key is available
|
||||
api_available = check_nous_api_key()
|
||||
|
||||
if not api_available:
|
||||
print("❌ NOUS_API_KEY environment variable not set")
|
||||
print("Please set your API key: export NOUS_API_KEY='your-key-here'")
|
||||
print("Get API key at: https://inference-api.nousresearch.com/")
|
||||
exit(1)
|
||||
else:
|
||||
print("✅ Nous Research API key found")
|
||||
|
||||
print("🛠️ Vision tools ready for use!")
|
||||
print(f"🧠 Using model: {DEFAULT_VISION_MODEL}")
|
||||
|
||||
# Show debug mode status
|
||||
if DEBUG_MODE:
|
||||
print(f"🐛 Debug mode ENABLED - Session ID: {DEBUG_SESSION_ID}")
|
||||
print(f" Debug logs will be saved to: ./logs/vision_tools_debug_{DEBUG_SESSION_ID}.json")
|
||||
else:
|
||||
print("🐛 Debug mode disabled (set VISION_TOOLS_DEBUG=true to enable)")
|
||||
|
||||
print("\nBasic usage:")
|
||||
print(" from vision_tools import vision_analyze_tool")
|
||||
print(" import asyncio")
|
||||
print("")
|
||||
print(" async def main():")
|
||||
print(" result = await vision_analyze_tool(")
|
||||
print(" image_url='https://example.com/image.jpg',")
|
||||
print(" user_prompt='What do you see in this image?'")
|
||||
print(" )")
|
||||
print(" print(result)")
|
||||
print(" asyncio.run(main())")
|
||||
|
||||
print("\nExample prompts:")
|
||||
print(" - 'What architectural style is this building?'")
|
||||
print(" - 'Describe the emotions and mood in this image'")
|
||||
print(" - 'What text can you read in this image?'")
|
||||
print(" - 'Identify any safety hazards visible'")
|
||||
print(" - 'What products or brands are shown?'")
|
||||
|
||||
print("\nDebug mode:")
|
||||
print(" # Enable debug logging")
|
||||
print(" export VISION_TOOLS_DEBUG=true")
|
||||
print(" # Debug logs capture all vision analysis calls and results")
|
||||
print(" # Logs saved to: ./logs/vision_tools_debug_UUID.json")
|
||||
922
web_tools.py
922
web_tools.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user