Compare commits
8 Commits
cluster-fa
...
thought-si
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a219e178a1 | ||
|
|
e06a15b3ab | ||
|
|
349e37de0a | ||
|
|
ab7293bed6 | ||
|
|
1614c15bb1 | ||
|
|
f813959750 | ||
|
|
f957ec2267 | ||
|
|
92e3074c10 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -25,3 +25,8 @@ hermes-*/*
|
||||
examples/
|
||||
tests/quick_test_dataset.jsonl
|
||||
tests/sample_dataset.jsonl
|
||||
run_datagen_kimik2-thinking.sh
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
run_datagen_sonnet.sh
|
||||
source-data/*
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
|
||||
431
batch_runner.py
431
batch_runner.py
@@ -23,7 +23,7 @@ Usage:
|
||||
|
||||
# Configure tool failure thresholds
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||
--max_tool_failures=20 --max_tool_failure_rate=0.3
|
||||
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -52,6 +52,74 @@ from safe_print import safe_print
|
||||
# Global configuration for worker processes
|
||||
_WORKER_CONFIG = {}
|
||||
|
||||
# Canonical names for the terminal tool (old & new implementations)
|
||||
_TERMINAL_TOOL_NAMES = {"terminal", "terminal_tool", "simple_terminal_tool"}
|
||||
|
||||
|
||||
def _is_terminal_tool_name(tool_name: Optional[str]) -> bool:
|
||||
"""Return True if the given tool name corresponds to a terminal tool."""
|
||||
return bool(tool_name) and tool_name.lower() in _TERMINAL_TOOL_NAMES
|
||||
|
||||
|
||||
def _terminal_tool_failed(content_json: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Determine whether the terminal tool itself failed (not the user command).
|
||||
|
||||
Terminal failures are indicated by explicit status flags or negative exit codes.
|
||||
Regular command failures (non-zero positive exit codes, stderr, timeouts) are not counted.
|
||||
"""
|
||||
if not isinstance(content_json, dict):
|
||||
return False
|
||||
|
||||
status = str(content_json.get("status", "")).lower()
|
||||
if status in {"error", "disabled"}:
|
||||
return True
|
||||
|
||||
exit_code = content_json.get("exit_code")
|
||||
if isinstance(exit_code, int) and exit_code < 0:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _categorize_error_type(error_message: str) -> str:
|
||||
"""
|
||||
Categorize an error message into a failure type.
|
||||
|
||||
Args:
|
||||
error_message (str): The error message to categorize
|
||||
|
||||
Returns:
|
||||
str: Category of the error
|
||||
"""
|
||||
error_lower = error_message.lower()
|
||||
|
||||
# Common error patterns
|
||||
if "timeout" in error_lower or "timed out" in error_lower:
|
||||
return "Timeout"
|
||||
elif "connection" in error_lower or "connect" in error_lower:
|
||||
return "Connection Error"
|
||||
elif "rate limit" in error_lower or "ratelimit" in error_lower or "429" in error_lower:
|
||||
return "Rate Limit"
|
||||
elif "authentication" in error_lower or "auth" in error_lower or "unauthorized" in error_lower or "401" in error_lower:
|
||||
return "Authentication"
|
||||
elif "not found" in error_lower or "404" in error_lower:
|
||||
return "Not Found"
|
||||
elif "permission" in error_lower or "forbidden" in error_lower or "403" in error_lower:
|
||||
return "Permission Denied"
|
||||
elif "invalid" in error_lower or "malformed" in error_lower or "bad request" in error_lower or "400" in error_lower:
|
||||
return "Invalid Input"
|
||||
elif "out of memory" in error_lower or "oom" in error_lower:
|
||||
return "Out of Memory"
|
||||
elif "network" in error_lower:
|
||||
return "Network Error"
|
||||
elif "server error" in error_lower or "500" in error_lower or "502" in error_lower or "503" in error_lower:
|
||||
return "Server Error"
|
||||
elif "vm" in error_lower and ("fail" in error_lower or "error" in error_lower):
|
||||
return "VM Error"
|
||||
else:
|
||||
return "Other"
|
||||
|
||||
|
||||
def _extract_tool_errors_from_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -61,7 +129,7 @@ def _extract_tool_errors_from_messages(messages: List[Dict[str, Any]]) -> List[D
|
||||
messages (List[Dict]): Message history
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of tool errors with tool name, error message, and context
|
||||
List[Dict]: List of tool errors with tool name, error message, error type, and context
|
||||
"""
|
||||
tool_errors = []
|
||||
tool_calls_map = {} # Map tool_call_id to tool name
|
||||
@@ -87,23 +155,37 @@ def _extract_tool_errors_from_messages(messages: List[Dict[str, Any]]) -> List[D
|
||||
content_json = json.loads(content) if isinstance(content, str) else content
|
||||
|
||||
if isinstance(content_json, dict):
|
||||
# Check if error field exists AND has a non-null value
|
||||
if "error" in content_json and content_json["error"] is not None:
|
||||
has_error = True
|
||||
error_msg = str(content_json["error"])
|
||||
# Get tool name for special handling
|
||||
tool_name = tool_calls_map.get(tool_call_id, "unknown")
|
||||
|
||||
# Special handling for terminal tool responses
|
||||
if "content" in content_json and isinstance(content_json["content"], dict):
|
||||
inner_content = content_json["content"]
|
||||
if inner_content.get("error") is not None or inner_content.get("exit_code", 0) != 0:
|
||||
# Special handling for terminal tool outputs
|
||||
if _is_terminal_tool_name(tool_name):
|
||||
if _terminal_tool_failed(content_json):
|
||||
has_error = True
|
||||
error_msg = inner_content.get("error") or f"Exit code: {inner_content.get('exit_code')}"
|
||||
# Prefer explicit error text, fall back to status or generic message
|
||||
error_msg = str(
|
||||
content_json.get("error")
|
||||
or content_json.get("status")
|
||||
or "Terminal tool failure"
|
||||
)
|
||||
else:
|
||||
# For other tools, check if error field exists AND has a non-null value
|
||||
if "error" in content_json and content_json["error"] is not None:
|
||||
has_error = True
|
||||
error_msg = str(content_json["error"])
|
||||
|
||||
# Check for "success": false pattern
|
||||
if content_json.get("success") is False:
|
||||
has_error = True
|
||||
if not error_msg:
|
||||
error_msg = str(content_json.get("message", content_json.get("error", "Unknown error")))
|
||||
# Check nested content structure (some tools wrap responses)
|
||||
if "content" in content_json and isinstance(content_json["content"], dict):
|
||||
inner_content = content_json["content"]
|
||||
if inner_content.get("error") is not None:
|
||||
has_error = True
|
||||
error_msg = inner_content.get("error")
|
||||
|
||||
# Check for "success": false pattern
|
||||
if content_json.get("success") is False:
|
||||
has_error = True
|
||||
if not error_msg:
|
||||
error_msg = str(content_json.get("message", content_json.get("error", "Unknown error")))
|
||||
|
||||
except:
|
||||
# If not JSON, check if content explicitly states an error
|
||||
@@ -114,9 +196,11 @@ def _extract_tool_errors_from_messages(messages: List[Dict[str, Any]]) -> List[D
|
||||
# Record error if found
|
||||
if has_error and tool_call_id in tool_calls_map:
|
||||
tool_name = tool_calls_map[tool_call_id]
|
||||
error_message = error_msg or "Unknown error"
|
||||
tool_errors.append({
|
||||
"tool_name": tool_name,
|
||||
"error_message": error_msg or "Unknown error",
|
||||
"error_message": error_message,
|
||||
"error_type": _categorize_error_type(error_message),
|
||||
"full_content": content[:500] # Keep first 500 chars of full response
|
||||
})
|
||||
|
||||
@@ -160,32 +244,37 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
|
||||
elif msg["role"] == "tool":
|
||||
tool_call_id = msg.get("tool_call_id", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
|
||||
# Determine if tool call was successful
|
||||
is_success = True
|
||||
try:
|
||||
# Try to parse as JSON and check for actual error values
|
||||
content_json = json.loads(content) if isinstance(content, str) else content
|
||||
|
||||
|
||||
if isinstance(content_json, dict):
|
||||
# Check if error field exists AND has a non-null value
|
||||
if "error" in content_json and content_json["error"] is not None:
|
||||
is_success = False
|
||||
|
||||
# Special handling for terminal tool responses
|
||||
# Terminal wraps its response in a "content" field
|
||||
if "content" in content_json and isinstance(content_json["content"], dict):
|
||||
inner_content = content_json["content"]
|
||||
# Check for actual error (non-null error field or non-zero exit code)
|
||||
has_error = (inner_content.get("error") is not None or
|
||||
inner_content.get("exit_code", 0) != 0)
|
||||
if has_error:
|
||||
# Get tool name for special handling
|
||||
tool_name = tool_calls_map.get(tool_call_id, "unknown")
|
||||
|
||||
# Special handling for terminal tool: only count as failure when the tool itself fails
|
||||
if _is_terminal_tool_name(tool_name):
|
||||
if _terminal_tool_failed(content_json):
|
||||
is_success = False
|
||||
|
||||
# Check for "success": false pattern used by some tools
|
||||
if content_json.get("success") is False:
|
||||
is_success = False
|
||||
|
||||
else:
|
||||
# For other tools, check if error field exists AND has a non-null value
|
||||
if "error" in content_json and content_json["error"] is not None:
|
||||
is_success = False
|
||||
|
||||
# Check nested content structure (some tools wrap responses)
|
||||
if "content" in content_json and isinstance(content_json["content"], dict):
|
||||
inner_content = content_json["content"]
|
||||
# Check for actual error (non-null error field)
|
||||
if inner_content.get("error") is not None:
|
||||
is_success = False
|
||||
|
||||
# Check for "success": false pattern used by some tools
|
||||
if content_json.get("success") is False:
|
||||
is_success = False
|
||||
|
||||
except:
|
||||
# If not JSON, check if content is empty or explicitly states an error
|
||||
# Note: We avoid simple substring matching to prevent false positives
|
||||
@@ -262,12 +351,16 @@ def _process_single_prompt(
|
||||
result["completed"]
|
||||
)
|
||||
|
||||
# Get profiling stats from the result
|
||||
profiling_stats = result.get("profiling_stats", {"tools": {}, "api_calls": {}})
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"prompt_index": prompt_index,
|
||||
"trajectory": trajectory,
|
||||
"tool_stats": tool_stats,
|
||||
"tool_errors": tool_errors,
|
||||
"profiling_stats": profiling_stats,
|
||||
"completed": result["completed"],
|
||||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": selected_toolsets,
|
||||
@@ -291,6 +384,7 @@ def _process_single_prompt(
|
||||
"error": error_msg,
|
||||
"traceback": tb,
|
||||
"tool_errors": [],
|
||||
"profiling_stats": {"tools": {}, "api_calls": {}},
|
||||
"trajectory": None,
|
||||
"tool_stats": {},
|
||||
"toolsets_used": [],
|
||||
@@ -339,6 +433,7 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
|
||||
# Initialize aggregated stats for this batch
|
||||
batch_tool_stats = {}
|
||||
batch_profiling_stats = [] # Collect profiling stats from each prompt
|
||||
completed_in_batch = []
|
||||
all_tool_errors = [] # Track all tool errors in this batch
|
||||
exception_errors = [] # Track top-level exceptions
|
||||
@@ -360,7 +455,8 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
"prompt_index": prompt_index,
|
||||
"tool_name": tool_error["tool_name"],
|
||||
"error_message": tool_error["error_message"],
|
||||
"full_content": tool_error.get("full_content", "")
|
||||
"full_content": tool_error.get("full_content", ""),
|
||||
"error_type": tool_error.get("error_type", "Other")
|
||||
})
|
||||
|
||||
# Track top-level exceptions (not tool errors)
|
||||
@@ -395,21 +491,26 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
"success": 0,
|
||||
"failure": 0
|
||||
}
|
||||
|
||||
|
||||
batch_tool_stats[tool_name]["count"] += stats["count"]
|
||||
batch_tool_stats[tool_name]["success"] += stats["success"]
|
||||
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
|
||||
# Collect profiling statistics
|
||||
if result.get("profiling_stats"):
|
||||
batch_profiling_stats.append(result["profiling_stats"])
|
||||
|
||||
completed_in_batch.append(prompt_index)
|
||||
print(f" ✅ Prompt {prompt_index} completed")
|
||||
|
||||
print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
|
||||
|
||||
|
||||
return {
|
||||
"batch_num": batch_num,
|
||||
"processed": len(prompts_to_process),
|
||||
"skipped": len(batch_data) - len(prompts_to_process),
|
||||
"tool_stats": batch_tool_stats,
|
||||
"profiling_stats": batch_profiling_stats,
|
||||
"completed_prompts": completed_in_batch,
|
||||
"tool_errors": all_tool_errors,
|
||||
"exception_errors": exception_errors
|
||||
@@ -438,6 +539,7 @@ class BatchRunner:
|
||||
max_tool_failures: int = 10,
|
||||
max_tool_failure_rate: float = 0.5,
|
||||
keep_recent_errors: int = 5,
|
||||
min_tool_calls_for_rate: int = 10,
|
||||
):
|
||||
"""
|
||||
Initialize the batch runner.
|
||||
@@ -458,6 +560,7 @@ class BatchRunner:
|
||||
max_tool_failures (int): Maximum number of tool failures before stopping (default: 10)
|
||||
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
|
||||
keep_recent_errors (int): Number of recent errors to keep per tool (default: 5)
|
||||
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
|
||||
"""
|
||||
self.dataset_file = Path(dataset_file)
|
||||
self.batch_size = batch_size
|
||||
@@ -474,6 +577,7 @@ class BatchRunner:
|
||||
self.max_tool_failures = max_tool_failures
|
||||
self.max_tool_failure_rate = max_tool_failure_rate
|
||||
self.keep_recent_errors = keep_recent_errors
|
||||
self.min_tool_calls_for_rate = min_tool_calls_for_rate
|
||||
|
||||
# Validate distribution
|
||||
if not validate_distribution(distribution):
|
||||
@@ -482,12 +586,15 @@ class BatchRunner:
|
||||
# Setup output directory
|
||||
self.output_dir = Path("data") / run_name
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Checkpoint file
|
||||
self.checkpoint_file = self.output_dir / "checkpoint.json"
|
||||
|
||||
|
||||
# Statistics file
|
||||
self.stats_file = self.output_dir / "statistics.json"
|
||||
|
||||
# Errors file
|
||||
self.errors_file = self.output_dir / "errors.json"
|
||||
|
||||
# Load dataset
|
||||
self.dataset = self._load_dataset()
|
||||
@@ -506,6 +613,7 @@ class BatchRunner:
|
||||
safe_print(f" [yellow]Tool failure limits:[/yellow]")
|
||||
safe_print(f" Max failures: {self.max_tool_failures}")
|
||||
safe_print(f" Max failure rate: {self.max_tool_failure_rate:.1%}")
|
||||
safe_print(f" Min tool calls for rate check: {self.min_tool_calls_for_rate}")
|
||||
safe_print(f" Keep recent errors: {self.keep_recent_errors}")
|
||||
if self.ephemeral_system_prompt:
|
||||
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
|
||||
@@ -604,7 +712,8 @@ class BatchRunner:
|
||||
|
||||
def _consolidate_data(self, num_batches: int, tool_stats: Dict[str, Dict[str, int]],
|
||||
start_time: float, tool_errors_by_tool: Dict[str, List[Dict]],
|
||||
exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None):
|
||||
exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None,
|
||||
profiling_stats_list: List[Dict] = None):
|
||||
"""
|
||||
Consolidate batch data into trajectories.jsonl and save statistics.
|
||||
|
||||
@@ -616,6 +725,7 @@ class BatchRunner:
|
||||
exception_errors (List): Top-level exceptions
|
||||
early_exit (bool): Whether this is an early exit
|
||||
exit_reason (str): Reason for early exit
|
||||
profiling_stats_list (List[Dict]): List of profiling statistics from each conversation
|
||||
"""
|
||||
# Combine all batch files into a single trajectories.jsonl file
|
||||
combined_file = self.output_dir / "trajectories.jsonl"
|
||||
@@ -644,7 +754,50 @@ class BatchRunner:
|
||||
stats["success_rate"] = 0.0
|
||||
stats["failure_rate"] = 0.0
|
||||
|
||||
# Save final statistics
|
||||
# Build failure type breakdown for each tool
|
||||
failure_type_breakdown = {}
|
||||
for tool_name, errors in tool_errors_by_tool.items():
|
||||
failure_types = {}
|
||||
for error in errors:
|
||||
error_type = error.get("error_type", "Other")
|
||||
if error_type not in failure_types:
|
||||
failure_types[error_type] = 0
|
||||
failure_types[error_type] += 1
|
||||
|
||||
# Calculate percentages
|
||||
total_failures = len(errors)
|
||||
failure_type_breakdown[tool_name] = {
|
||||
"total_failures": total_failures,
|
||||
"types": {
|
||||
error_type: {
|
||||
"count": count,
|
||||
"percentage": round((count / total_failures) * 100, 2)
|
||||
}
|
||||
for error_type, count in failure_types.items()
|
||||
}
|
||||
}
|
||||
|
||||
# Save error information to separate file
|
||||
error_data = {
|
||||
"run_name": self.run_name,
|
||||
"completed_at": datetime.now().isoformat(),
|
||||
"total_tool_errors": sum(len(errors) for errors in tool_errors_by_tool.values()),
|
||||
"total_exception_errors": len(exception_errors),
|
||||
"tool_errors": tool_errors_by_tool,
|
||||
"failure_type_breakdown": failure_type_breakdown,
|
||||
"exception_errors": exception_errors[:self.keep_recent_errors] # Keep k most recent
|
||||
}
|
||||
|
||||
with open(self.errors_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(error_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Aggregate profiling statistics if available
|
||||
aggregated_profiling_stats = None
|
||||
if profiling_stats_list:
|
||||
from profiling import aggregate_profiling_stats
|
||||
aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list)
|
||||
|
||||
# Save final statistics (without detailed errors)
|
||||
final_stats = {
|
||||
"run_name": self.run_name,
|
||||
"distribution": self.distribution,
|
||||
@@ -657,13 +810,17 @@ class BatchRunner:
|
||||
"duration_seconds": round(time.time() - start_time, 2),
|
||||
"early_exit": early_exit,
|
||||
"exit_reason": exit_reason,
|
||||
"tool_errors": tool_errors_by_tool,
|
||||
"exception_errors": exception_errors[:self.keep_recent_errors], # Keep k most recent
|
||||
"tool_statistics": tool_stats
|
||||
"tool_statistics": tool_stats,
|
||||
"profiling_statistics": aggregated_profiling_stats
|
||||
}
|
||||
|
||||
with open(self.stats_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(final_stats, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Display aggregated profiling statistics
|
||||
if aggregated_profiling_stats:
|
||||
from profiling import print_aggregated_statistics
|
||||
print_aggregated_statistics(aggregated_profiling_stats, detailed=True)
|
||||
|
||||
|
||||
def run(self, resume: bool = False):
|
||||
@@ -705,6 +862,7 @@ class BatchRunner:
|
||||
|
||||
# Aggregate statistics across all batches
|
||||
total_tool_stats = {}
|
||||
all_profiling_stats = [] # Collect all profiling stats for aggregation
|
||||
tool_errors_by_tool = {} # {tool_name: [list of k most recent errors]}
|
||||
all_exception_errors = []
|
||||
all_completed_prompts = list(completed_prompts_set)
|
||||
@@ -729,65 +887,81 @@ class BatchRunner:
|
||||
for batch_num, batch_data in enumerate(self.batches)
|
||||
]
|
||||
|
||||
# Process batches and check tool failure threshold after each batch
|
||||
for batch_num, task in enumerate(tasks):
|
||||
# Process single batch
|
||||
result = pool.apply(_process_batch_worker, (task,))
|
||||
# Process batches in parallel and check tool failure threshold as results come in
|
||||
# imap_unordered allows parallel processing while getting results as they complete
|
||||
batch_num = 0
|
||||
try:
|
||||
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
||||
# Update statistics
|
||||
all_completed_prompts.extend(result.get("completed_prompts", []))
|
||||
total_processed += result.get("processed", 0)
|
||||
|
||||
# Update statistics
|
||||
all_completed_prompts.extend(result.get("completed_prompts", []))
|
||||
total_processed += result.get("processed", 0)
|
||||
# Aggregate tool stats
|
||||
for tool_name, stats in result.get("tool_stats", {}).items():
|
||||
if tool_name not in total_tool_stats:
|
||||
total_tool_stats[tool_name] = {
|
||||
"count": 0,
|
||||
"success": 0,
|
||||
"failure": 0
|
||||
}
|
||||
|
||||
# Aggregate tool stats
|
||||
for tool_name, stats in result.get("tool_stats", {}).items():
|
||||
if tool_name not in total_tool_stats:
|
||||
total_tool_stats[tool_name] = {
|
||||
"count": 0,
|
||||
"success": 0,
|
||||
"failure": 0
|
||||
}
|
||||
total_tool_stats[tool_name]["count"] += stats["count"]
|
||||
total_tool_stats[tool_name]["success"] += stats["success"]
|
||||
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
total_tool_stats[tool_name]["count"] += stats["count"]
|
||||
total_tool_stats[tool_name]["success"] += stats["success"]
|
||||
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
# Collect profiling stats from this batch
|
||||
if result.get("profiling_stats"):
|
||||
all_profiling_stats.extend(result["profiling_stats"])
|
||||
|
||||
# Aggregate tool errors (keep k most recent per tool)
|
||||
for tool_error in result.get("tool_errors", []):
|
||||
tool_name = tool_error["tool_name"]
|
||||
if tool_name not in tool_errors_by_tool:
|
||||
tool_errors_by_tool[tool_name] = []
|
||||
# Aggregate tool errors (keep k most recent per tool)
|
||||
for tool_error in result.get("tool_errors", []):
|
||||
tool_name = tool_error["tool_name"]
|
||||
if tool_name not in tool_errors_by_tool:
|
||||
tool_errors_by_tool[tool_name] = []
|
||||
|
||||
# Add error and keep only k most recent
|
||||
tool_errors_by_tool[tool_name].append(tool_error)
|
||||
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
|
||||
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
|
||||
# Add error and keep only k most recent
|
||||
tool_errors_by_tool[tool_name].append(tool_error)
|
||||
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
|
||||
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
|
||||
|
||||
total_tool_errors += 1
|
||||
total_tool_errors += 1
|
||||
|
||||
# Track exception errors
|
||||
all_exception_errors.extend(result.get("exception_errors", []))
|
||||
# Track exception errors
|
||||
all_exception_errors.extend(result.get("exception_errors", []))
|
||||
|
||||
# Check tool failure thresholds
|
||||
if total_processed > 0:
|
||||
tool_failure_rate = total_tool_errors / total_processed
|
||||
# Check tool failure thresholds
|
||||
# Calculate total tool calls (not prompts)
|
||||
total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values())
|
||||
|
||||
# Check absolute count threshold
|
||||
if total_tool_errors >= self.max_tool_failures:
|
||||
early_exit = True
|
||||
exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})"
|
||||
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
||||
pool.terminate() # Stop all workers immediately
|
||||
break
|
||||
|
||||
# Check rate threshold
|
||||
if tool_failure_rate >= self.max_tool_failure_rate:
|
||||
early_exit = True
|
||||
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%})"
|
||||
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
||||
break
|
||||
# Check rate threshold (only if we have enough tool calls to trust the rate)
|
||||
if total_tool_calls >= self.min_tool_calls_for_rate:
|
||||
tool_failure_rate = total_tool_errors / total_tool_calls
|
||||
|
||||
# Update checkpoint after each batch
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
self._save_checkpoint(checkpoint_data)
|
||||
if tool_failure_rate >= self.max_tool_failure_rate:
|
||||
early_exit = True
|
||||
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%}, {total_tool_errors}/{total_tool_calls} tool calls)"
|
||||
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
||||
pool.terminate() # Stop all workers immediately
|
||||
break
|
||||
|
||||
# Update checkpoint after each batch completes
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
self._save_checkpoint(checkpoint_data)
|
||||
|
||||
batch_num += 1
|
||||
except KeyboardInterrupt:
|
||||
safe_print("\n[bold yellow]⚠️ Interrupted by user, stopping workers...[/bold yellow]")
|
||||
pool.terminate()
|
||||
early_exit = True
|
||||
exit_reason = "Interrupted by user"
|
||||
|
||||
# Save final checkpoint
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
@@ -802,7 +976,8 @@ class BatchRunner:
|
||||
tool_errors_by_tool,
|
||||
all_exception_errors,
|
||||
early_exit,
|
||||
exit_reason
|
||||
exit_reason,
|
||||
all_profiling_stats
|
||||
)
|
||||
|
||||
# Print summary
|
||||
@@ -846,10 +1021,20 @@ class BatchRunner:
|
||||
for idx, (error_msg, instances) in enumerate(list(unique_errors.items())[:3]):
|
||||
error_preview = error_msg if len(error_msg) <= 100 else error_msg[:97] + "..."
|
||||
safe_print(f" [{idx+1}] [dim]{error_preview}[/dim] (x{len(instances)})")
|
||||
# Show one example with prompt index
|
||||
|
||||
# Show one example with prompt index and full content prefix
|
||||
example = instances[-1] # Most recent
|
||||
safe_print(f" [dim]Prompt {example['prompt_index']}[/dim]")
|
||||
|
||||
# Show full content prefix (first 200 chars)
|
||||
full_content = example.get('full_content', '')
|
||||
if full_content and full_content != error_preview:
|
||||
content_preview = full_content[:200]
|
||||
if len(full_content) > 200:
|
||||
content_preview += "..."
|
||||
# Show with prefix indicator
|
||||
safe_print(f" [dim]Content: {content_preview}[/dim]")
|
||||
|
||||
if len(unique_errors) > 3:
|
||||
safe_print(f" [dim]... and {len(unique_errors) - 3} more error types[/dim]")
|
||||
|
||||
@@ -861,10 +1046,20 @@ class BatchRunner:
|
||||
safe_print(f"\n[bold red]💥 Top-level Exceptions: {len(all_exception_errors)}[/bold red]")
|
||||
safe_print("[red]-[/red]" * 70)
|
||||
for error in all_exception_errors[:self.keep_recent_errors]:
|
||||
error_preview = error["error"][:100]
|
||||
if len(error["error"]) > 100:
|
||||
error_msg = error["error"]
|
||||
error_preview = error_msg[:150]
|
||||
if len(error_msg) > 150:
|
||||
error_preview += "..."
|
||||
safe_print(f" Prompt {error['prompt_index']}: [dim]{error_preview}[/dim]")
|
||||
safe_print(f" [red]Prompt {error['prompt_index']}:[/red] [dim]{error_preview}[/dim]")
|
||||
|
||||
# Show traceback prefix if available
|
||||
traceback_text = error.get("traceback", "")
|
||||
if traceback_text:
|
||||
# Show last 3 lines of traceback for context
|
||||
tb_lines = traceback_text.strip().split('\n')
|
||||
relevant_lines = tb_lines[-3:] if len(tb_lines) > 3 else tb_lines
|
||||
for line in relevant_lines:
|
||||
safe_print(f" [dim]{line}[/dim]")
|
||||
|
||||
safe_print(f"\n[cyan]📈 Tool Usage Statistics:[/cyan]")
|
||||
safe_print("-" * 70)
|
||||
@@ -890,15 +1085,53 @@ class BatchRunner:
|
||||
else:
|
||||
safe_print("No tool calls were made during this run.")
|
||||
|
||||
# Display failure type breakdown for tools with failures
|
||||
if tool_errors_by_tool:
|
||||
safe_print(f"\n[cyan]📊 Failure Type Breakdown:[/cyan]")
|
||||
safe_print("-" * 70)
|
||||
|
||||
# Sort tools by total error count
|
||||
sorted_tools = sorted(
|
||||
tool_errors_by_tool.items(),
|
||||
key=lambda x: len(x[1]),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
for tool_name, errors in sorted_tools:
|
||||
# Count failure types for this tool
|
||||
failure_types = {}
|
||||
for error in errors:
|
||||
error_type = error.get("error_type", "Other")
|
||||
if error_type not in failure_types:
|
||||
failure_types[error_type] = 0
|
||||
failure_types[error_type] += 1
|
||||
|
||||
# Display tool name and total failures
|
||||
total_failures = len(errors)
|
||||
safe_print(f"\n[yellow]{tool_name}[/yellow] ({total_failures} failures):")
|
||||
|
||||
# Sort failure types by count
|
||||
sorted_types = sorted(
|
||||
failure_types.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Display each failure type with count and percentage
|
||||
for failure_type, count in sorted_types:
|
||||
percentage = (count / total_failures) * 100
|
||||
safe_print(f" • {failure_type:<20} {count:>4} ({percentage:>5.1f}%)")
|
||||
|
||||
safe_print(f"\n[cyan]💾 Results saved to:[/cyan] {self.output_dir}")
|
||||
safe_print(f" - Trajectories: trajectories.jsonl (combined)")
|
||||
safe_print(f" - Individual batches: batch_*.jsonl (for debugging)")
|
||||
safe_print(f" - Statistics: {self.stats_file.name}")
|
||||
safe_print(f" - Errors: {self.errors_file.name}")
|
||||
safe_print(f" - Checkpoint: {self.checkpoint_file.name}")
|
||||
|
||||
if early_exit:
|
||||
safe_print(f"\n[bold yellow]ℹ️ Run was stopped early due to tool failures.[/bold yellow]")
|
||||
safe_print(f"[yellow] Check {self.stats_file.name} for detailed error information including tracebacks.[/yellow]")
|
||||
safe_print(f"[yellow] Check {self.errors_file.name} for detailed error information including tracebacks.[/yellow]")
|
||||
safe_print(f"[yellow] You can resume this run later with --resume flag.[/yellow]")
|
||||
|
||||
|
||||
@@ -920,6 +1153,7 @@ def main(
|
||||
max_tool_failures: int = 10,
|
||||
max_tool_failure_rate: float = 0.5,
|
||||
keep_recent_errors: int = 5,
|
||||
min_tool_calls_for_rate: int = 10,
|
||||
):
|
||||
"""
|
||||
Run batch processing of agent prompts from a dataset.
|
||||
@@ -942,6 +1176,7 @@ def main(
|
||||
max_tool_failures (int): Maximum number of tool failures before stopping (default: 10)
|
||||
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
|
||||
keep_recent_errors (int): Number of recent errors to keep per tool for reporting (default: 5)
|
||||
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
|
||||
|
||||
Examples:
|
||||
# Basic usage
|
||||
@@ -959,7 +1194,7 @@ def main(
|
||||
|
||||
# With custom tool failure thresholds
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||
--max_tool_failures=20 --max_tool_failure_rate=0.3 --keep_recent_errors=10
|
||||
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10 --keep_recent_errors=10
|
||||
|
||||
# List available distributions
|
||||
python batch_runner.py --list_distributions
|
||||
@@ -1010,7 +1245,8 @@ def main(
|
||||
log_prefix_chars=log_prefix_chars,
|
||||
max_tool_failures=max_tool_failures,
|
||||
max_tool_failure_rate=max_tool_failure_rate,
|
||||
keep_recent_errors=keep_recent_errors
|
||||
keep_recent_errors=keep_recent_errors,
|
||||
min_tool_calls_for_rate=min_tool_calls_for_rate
|
||||
)
|
||||
|
||||
runner.run(resume=resume)
|
||||
@@ -1024,4 +1260,3 @@ def main(
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
|
||||
|
||||
12
gemini_nothinking.sh
Normal file
12
gemini_nothinking.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/agent_tasks_eval.jsonl" \
|
||||
--batch_size=1 \
|
||||
--run_name="agenttasks_eval_gemini-4.5-3-nothinking" \
|
||||
--distribution="science" \
|
||||
--model="gemini-3-pro-preview" \
|
||||
--base_url="https://generativelanguage.googleapis.com/v1beta/openai/" \
|
||||
--api_key="${GEMINI_API_KEY}" \
|
||||
--num_workers=10 \
|
||||
--max_turns=60 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. If you require API keys please check which ones already exist in your environment variables in a way that does not read them."
|
||||
@@ -31,7 +31,9 @@ import asyncio
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from tools.web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key
|
||||
from tools.terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION
|
||||
from tools.simple_terminal_tool import simple_terminal_tool, check_requirements as check_simple_terminal_requirements, SIMPLE_TERMINAL_TOOL_DESCRIPTION
|
||||
# Keep old terminal tool for backwards compatibility if needed
|
||||
# from tools.terminal_tool import terminal_tool, check_hecate_requirements, TERMINAL_TOOL_DESCRIPTION
|
||||
from tools.vision_tools import vision_analyze_tool, check_vision_requirements
|
||||
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
|
||||
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
|
||||
@@ -111,7 +113,7 @@ def get_web_tool_definitions() -> List[Dict[str, Any]]:
|
||||
def get_terminal_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for terminal tools in OpenAI's expected format.
|
||||
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of terminal tool definitions compatible with OpenAI API
|
||||
"""
|
||||
@@ -120,7 +122,7 @@ def get_terminal_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": TERMINAL_TOOL_DESCRIPTION,
|
||||
"description": SIMPLE_TERMINAL_TOOL_DESCRIPTION,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -128,28 +130,18 @@ def get_terminal_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "string",
|
||||
"description": "The command to execute on the VM"
|
||||
},
|
||||
"input_keys": {
|
||||
"type": "string",
|
||||
"description": "Keystrokes to send to the most recent interactive session (e.g., 'hello\\n' for typing hello + Enter). If no active session exists, this will be ignored."
|
||||
},
|
||||
"background": {
|
||||
"type": "boolean",
|
||||
"description": "Whether to run the command in the background (default: false)",
|
||||
"default": False
|
||||
},
|
||||
"idle_threshold": {
|
||||
"type": "number",
|
||||
"description": "Seconds to wait for output before considering session idle (default: 5.0)",
|
||||
"default": 5.0,
|
||||
"minimum": 0.1
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (optional)",
|
||||
"minimum": 1
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -262,11 +254,11 @@ def get_all_tool_names() -> List[str]:
|
||||
# Web tools
|
||||
if check_firecrawl_api_key():
|
||||
tool_names.extend(["web_search", "web_extract", "web_crawl"])
|
||||
|
||||
# Terminal tools
|
||||
if check_hecate_requirements():
|
||||
|
||||
# Terminal tools
|
||||
if check_simple_terminal_requirements():
|
||||
tool_names.extend(["terminal"])
|
||||
|
||||
|
||||
# Vision tools
|
||||
if check_vision_requirements():
|
||||
tool_names.extend(["vision_analyze"])
|
||||
@@ -346,11 +338,11 @@ def get_tool_definitions(
|
||||
if check_firecrawl_api_key():
|
||||
for tool in get_web_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
if check_hecate_requirements():
|
||||
|
||||
if check_simple_terminal_requirements():
|
||||
for tool in get_terminal_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
|
||||
if check_vision_requirements():
|
||||
for tool in get_vision_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
@@ -494,12 +486,10 @@ def handle_terminal_function_call(function_name: str, function_args: Dict[str, A
|
||||
"""
|
||||
if function_name == "terminal":
|
||||
command = function_args.get("command")
|
||||
input_keys = function_args.get("input_keys")
|
||||
background = function_args.get("background", False)
|
||||
idle_threshold = function_args.get("idle_threshold", 5.0)
|
||||
timeout = function_args.get("timeout")
|
||||
|
||||
return terminal_tool(command, input_keys, None, background, idle_threshold, timeout, task_id)
|
||||
return simple_terminal_tool(command=command, background=background, timeout=timeout, task_id=task_id)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown terminal function: {function_name}"}, ensure_ascii=False)
|
||||
@@ -681,10 +671,10 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
"requirements": ["FIRECRAWL_API_KEY environment variable"]
|
||||
},
|
||||
"terminal_tools": {
|
||||
"available": check_hecate_requirements(),
|
||||
"tools": ["terminal_tool"],
|
||||
"description": "Execute commands with optional interactive session support on Linux VMs",
|
||||
"requirements": ["MORPH_API_KEY environment variable", "hecate package"]
|
||||
"available": check_simple_terminal_requirements(),
|
||||
"tools": ["simple_terminal_tool"],
|
||||
"description": "Execute commands on secure Linux VMs without session persistence",
|
||||
"requirements": ["MORPH_API_KEY environment variable"]
|
||||
},
|
||||
"vision_tools": {
|
||||
"available": check_vision_requirements(),
|
||||
@@ -711,13 +701,13 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
def check_toolset_requirements() -> Dict[str, bool]:
|
||||
"""
|
||||
Check if all requirements for available toolsets are met.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict: Status of each toolset's requirements
|
||||
"""
|
||||
return {
|
||||
"web_tools": check_firecrawl_api_key(),
|
||||
"terminal_tools": check_hecate_requirements(),
|
||||
"terminal_tools": check_simple_terminal_requirements(),
|
||||
"vision_tools": check_vision_requirements(),
|
||||
"moa_tools": check_moa_requirements(),
|
||||
"image_tools": check_image_generation_requirements()
|
||||
|
||||
381
profiling.py
Normal file
381
profiling.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Profiling module for tracking timing statistics of tools and LLM API calls.
|
||||
|
||||
This module provides a centralized way to track timing information for various
|
||||
operations in the agent system, including:
|
||||
- Individual tool executions
|
||||
- OpenAI API calls
|
||||
- Aggregate statistics (min, max, median, mean, total)
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
import statistics
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfilingStats:
|
||||
"""Statistics for a particular operation type."""
|
||||
call_count: int = 0
|
||||
total_time: float = 0.0
|
||||
min_time: float = float('inf')
|
||||
max_time: float = 0.0
|
||||
times: List[float] = field(default_factory=list)
|
||||
|
||||
def add_timing(self, duration: float):
|
||||
"""Add a timing measurement."""
|
||||
self.call_count += 1
|
||||
self.total_time += duration
|
||||
self.min_time = min(self.min_time, duration)
|
||||
self.max_time = max(self.max_time, duration)
|
||||
self.times.append(duration)
|
||||
|
||||
@property
|
||||
def mean_time(self) -> float:
|
||||
"""Calculate mean time."""
|
||||
return self.total_time / self.call_count if self.call_count > 0 else 0.0
|
||||
|
||||
@property
|
||||
def median_time(self) -> float:
|
||||
"""Calculate median time."""
|
||||
return statistics.median(self.times) if self.times else 0.0
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"call_count": self.call_count,
|
||||
"total_time": self.total_time,
|
||||
"min_time": self.min_time if self.min_time != float('inf') else 0.0,
|
||||
"max_time": self.max_time,
|
||||
"mean_time": self.mean_time,
|
||||
"median_time": self.median_time
|
||||
}
|
||||
|
||||
|
||||
class Profiler:
|
||||
"""
|
||||
Global profiler for tracking timing statistics across tools and API calls.
|
||||
|
||||
Usage:
|
||||
profiler = Profiler()
|
||||
|
||||
# Time a tool execution
|
||||
with profiler.time_tool("web_search"):
|
||||
# ... tool execution code ...
|
||||
pass
|
||||
|
||||
# Time an API call
|
||||
with profiler.time_api_call():
|
||||
# ... API call code ...
|
||||
pass
|
||||
|
||||
# Get statistics
|
||||
stats = profiler.get_statistics()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the profiler."""
|
||||
self.tool_stats: Dict[str, ProfilingStats] = defaultdict(ProfilingStats)
|
||||
self.api_stats: ProfilingStats = ProfilingStats()
|
||||
self._enabled = True
|
||||
|
||||
def enable(self):
|
||||
"""Enable profiling."""
|
||||
self._enabled = True
|
||||
|
||||
def disable(self):
|
||||
"""Disable profiling."""
|
||||
self._enabled = False
|
||||
|
||||
def reset(self):
|
||||
"""Reset all profiling data."""
|
||||
self.tool_stats.clear()
|
||||
self.api_stats = ProfilingStats()
|
||||
|
||||
def record_tool_timing(self, tool_name: str, duration: float):
|
||||
"""Record timing for a tool execution."""
|
||||
if self._enabled:
|
||||
self.tool_stats[tool_name].add_timing(duration)
|
||||
|
||||
def record_api_timing(self, duration: float):
|
||||
"""Record timing for an API call."""
|
||||
if self._enabled:
|
||||
self.api_stats.add_timing(duration)
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""
|
||||
Get all profiling statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tool and API statistics
|
||||
"""
|
||||
return {
|
||||
"tools": {
|
||||
tool_name: stats.to_dict()
|
||||
for tool_name, stats in sorted(self.tool_stats.items())
|
||||
},
|
||||
"api_calls": self.api_stats.to_dict()
|
||||
}
|
||||
|
||||
def print_statistics(self, detailed: bool = True):
|
||||
"""
|
||||
Print profiling statistics in a readable format.
|
||||
|
||||
Args:
|
||||
detailed: If True, show per-tool breakdown. If False, show summary only.
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("📊 PROFILING STATISTICS")
|
||||
print("="*80)
|
||||
|
||||
# API Call Statistics
|
||||
print("\n🔷 OpenAI API Calls:")
|
||||
if self.api_stats.call_count > 0:
|
||||
api_dict = self.api_stats.to_dict()
|
||||
print(f" Total Calls: {api_dict['call_count']}")
|
||||
print(f" Total Time: {api_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {api_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {api_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {api_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {api_dict['median_time']:.2f}s")
|
||||
else:
|
||||
print(" No API calls recorded")
|
||||
|
||||
# Tool Statistics
|
||||
print("\n🔧 Tool Executions:")
|
||||
if self.tool_stats:
|
||||
if detailed:
|
||||
for tool_name in sorted(self.tool_stats.keys()):
|
||||
stats_dict = self.tool_stats[tool_name].to_dict()
|
||||
print(f"\n 📌 {tool_name}:")
|
||||
print(f" Total Calls: {stats_dict['call_count']}")
|
||||
print(f" Total Time: {stats_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {stats_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {stats_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {stats_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {stats_dict['median_time']:.2f}s")
|
||||
|
||||
# Summary
|
||||
total_tool_calls = sum(s.call_count for s in self.tool_stats.values())
|
||||
total_tool_time = sum(s.total_time for s in self.tool_stats.values())
|
||||
print(f"\n 📊 Summary:")
|
||||
print(f" Total Tool Calls: {total_tool_calls}")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Unique Tools Used: {len(self.tool_stats)}")
|
||||
else:
|
||||
print(" No tool executions recorded")
|
||||
|
||||
# Overall Summary
|
||||
total_api_time = self.api_stats.total_time
|
||||
total_tool_time = sum(s.total_time for s in self.tool_stats.values())
|
||||
print(f"\n📈 Overall Summary:")
|
||||
print(f" Total API Time: {total_api_time:.2f}s")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Total Time: {total_api_time + total_tool_time:.2f}s")
|
||||
print("="*80 + "\n")
|
||||
|
||||
def export_to_json(self) -> str:
|
||||
"""Export statistics as JSON string."""
|
||||
import json
|
||||
return json.dumps(self.get_statistics(), indent=2)
|
||||
|
||||
def export_to_file(self, filepath: str):
|
||||
"""
|
||||
Export statistics to a JSON file.
|
||||
|
||||
Args:
|
||||
filepath: Path to output file
|
||||
"""
|
||||
import json
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(self.get_statistics(), f, indent=2)
|
||||
print(f"📁 Profiling statistics exported to: {filepath}")
|
||||
|
||||
|
||||
# Global profiler instance
|
||||
_global_profiler: Optional[Profiler] = None
|
||||
|
||||
|
||||
def get_profiler() -> Profiler:
|
||||
"""Get or create the global profiler instance."""
|
||||
global _global_profiler
|
||||
if _global_profiler is None:
|
||||
_global_profiler = Profiler()
|
||||
return _global_profiler
|
||||
|
||||
|
||||
def reset_profiler():
|
||||
"""Reset the global profiler."""
|
||||
global _global_profiler
|
||||
if _global_profiler is not None:
|
||||
_global_profiler.reset()
|
||||
|
||||
|
||||
class TimingContext:
|
||||
"""Context manager for timing operations."""
|
||||
|
||||
def __init__(self, profiler: Profiler, operation_type: str, operation_name: Optional[str] = None):
|
||||
"""
|
||||
Initialize timing context.
|
||||
|
||||
Args:
|
||||
profiler: Profiler instance to record timing
|
||||
operation_type: 'tool' or 'api'
|
||||
operation_name: Name of the operation (required for tools)
|
||||
"""
|
||||
self.profiler = profiler
|
||||
self.operation_type = operation_type
|
||||
self.operation_name = operation_name
|
||||
self.start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Start timing."""
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop timing and record."""
|
||||
duration = time.time() - self.start_time
|
||||
|
||||
if self.operation_type == 'tool':
|
||||
self.profiler.record_tool_timing(self.operation_name, duration)
|
||||
elif self.operation_type == 'api':
|
||||
self.profiler.record_api_timing(duration)
|
||||
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
|
||||
def aggregate_profiling_stats(stats_list: List[Dict]) -> Dict:
|
||||
"""
|
||||
Aggregate multiple profiling statistics dictionaries into one.
|
||||
|
||||
This is useful for batch processing where each worker process has its own
|
||||
profiler instance that needs to be combined.
|
||||
|
||||
Args:
|
||||
stats_list: List of statistics dictionaries from get_statistics()
|
||||
|
||||
Returns:
|
||||
Dict: Aggregated statistics with combined tool and API call data
|
||||
"""
|
||||
aggregated = {
|
||||
"tools": defaultdict(lambda: {"times": []}),
|
||||
"api_calls": {"times": []}
|
||||
}
|
||||
|
||||
# Aggregate tool statistics
|
||||
for stats in stats_list:
|
||||
# Aggregate tool timings
|
||||
for tool_name, tool_stats in stats.get("tools", {}).items():
|
||||
# Reconstruct individual timings from aggregated stats
|
||||
# Since we have mean_time and call_count, we approximate
|
||||
aggregated["tools"][tool_name]["times"].extend(
|
||||
[tool_stats.get("mean_time", 0.0)] * tool_stats.get("call_count", 0)
|
||||
)
|
||||
|
||||
# Aggregate API call timings
|
||||
api_stats = stats.get("api_calls", {})
|
||||
if api_stats.get("call_count", 0) > 0:
|
||||
aggregated["api_calls"]["times"].extend(
|
||||
[api_stats.get("mean_time", 0.0)] * api_stats.get("call_count", 0)
|
||||
)
|
||||
|
||||
# Calculate final statistics for tools
|
||||
final_stats = {"tools": {}, "api_calls": {}}
|
||||
|
||||
for tool_name, data in aggregated["tools"].items():
|
||||
times = data["times"]
|
||||
if times:
|
||||
final_stats["tools"][tool_name] = {
|
||||
"call_count": len(times),
|
||||
"total_time": sum(times),
|
||||
"min_time": min(times),
|
||||
"max_time": max(times),
|
||||
"mean_time": statistics.mean(times),
|
||||
"median_time": statistics.median(times)
|
||||
}
|
||||
|
||||
# Calculate final statistics for API calls
|
||||
api_times = aggregated["api_calls"]["times"]
|
||||
if api_times:
|
||||
final_stats["api_calls"] = {
|
||||
"call_count": len(api_times),
|
||||
"total_time": sum(api_times),
|
||||
"min_time": min(api_times),
|
||||
"max_time": max(api_times),
|
||||
"mean_time": statistics.mean(api_times),
|
||||
"median_time": statistics.median(api_times)
|
||||
}
|
||||
else:
|
||||
final_stats["api_calls"] = {
|
||||
"call_count": 0,
|
||||
"total_time": 0.0,
|
||||
"min_time": 0.0,
|
||||
"max_time": 0.0,
|
||||
"mean_time": 0.0,
|
||||
"median_time": 0.0
|
||||
}
|
||||
|
||||
return final_stats
|
||||
|
||||
|
||||
def print_aggregated_statistics(stats: Dict, detailed: bool = True):
|
||||
"""
|
||||
Print aggregated profiling statistics in a readable format.
|
||||
|
||||
Args:
|
||||
stats: Aggregated statistics dictionary from aggregate_profiling_stats()
|
||||
detailed: If True, show per-tool breakdown. If False, show summary only.
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("📊 AGGREGATED PROFILING STATISTICS")
|
||||
print("="*80)
|
||||
|
||||
# API Call Statistics
|
||||
print("\n🔷 OpenAI API Calls:")
|
||||
api_stats = stats.get("api_calls", {})
|
||||
if api_stats.get("call_count", 0) > 0:
|
||||
print(f" Total Calls: {api_stats['call_count']}")
|
||||
print(f" Total Time: {api_stats['total_time']:.2f}s")
|
||||
print(f" Min Time: {api_stats['min_time']:.2f}s")
|
||||
print(f" Max Time: {api_stats['max_time']:.2f}s")
|
||||
print(f" Mean Time: {api_stats['mean_time']:.2f}s")
|
||||
print(f" Median Time: {api_stats['median_time']:.2f}s")
|
||||
else:
|
||||
print(" No API calls recorded")
|
||||
|
||||
# Tool Statistics
|
||||
print("\n🔧 Tool Executions:")
|
||||
tool_stats = stats.get("tools", {})
|
||||
if tool_stats:
|
||||
if detailed:
|
||||
for tool_name in sorted(tool_stats.keys()):
|
||||
stats_dict = tool_stats[tool_name]
|
||||
print(f"\n 📌 {tool_name}:")
|
||||
print(f" Total Calls: {stats_dict['call_count']}")
|
||||
print(f" Total Time: {stats_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {stats_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {stats_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {stats_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {stats_dict['median_time']:.2f}s")
|
||||
|
||||
# Summary
|
||||
total_tool_calls = sum(s["call_count"] for s in tool_stats.values())
|
||||
total_tool_time = sum(s["total_time"] for s in tool_stats.values())
|
||||
print(f"\n 📊 Summary:")
|
||||
print(f" Total Tool Calls: {total_tool_calls}")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Unique Tools Used: {len(tool_stats)}")
|
||||
else:
|
||||
print(" No tool executions recorded")
|
||||
|
||||
# Overall Summary
|
||||
total_api_time = api_stats.get("total_time", 0.0)
|
||||
total_tool_time = sum(s["total_time"] for s in tool_stats.values())
|
||||
print(f"\n📈 Overall Summary:")
|
||||
print(f" Total API Time: {total_api_time:.2f}s")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Total Time: {total_api_time + total_tool_time:.2f}s")
|
||||
print("="*80 + "\n")
|
||||
120
run_agent.py
120
run_agent.py
@@ -45,6 +45,9 @@ else:
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
|
||||
# Import profiling
|
||||
from profiling import get_profiler
|
||||
|
||||
|
||||
class AIAgent:
|
||||
"""
|
||||
@@ -364,6 +367,10 @@ class AIAgent:
|
||||
Returns:
|
||||
Dict: Complete conversation result with final response and message history
|
||||
"""
|
||||
# Reset profiler for this conversation to get fresh stats
|
||||
from profiling import reset_profiler as reset_prof
|
||||
reset_prof()
|
||||
|
||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||
import uuid
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
@@ -388,17 +395,19 @@ class AIAgent:
|
||||
|
||||
while api_call_count < self.max_iterations:
|
||||
api_call_count += 1
|
||||
print(f"\n🔄 Making API call #{api_call_count}...")
|
||||
print(f"\n🔄 Making OpenAI-compatible API call #{api_call_count}...")
|
||||
|
||||
# Log request details if verbose
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}")
|
||||
logging.debug(f"Last message role: {messages[-1]['role'] if messages else 'none'}")
|
||||
# Log the last few messages to see if thought_signature is present
|
||||
logging.debug(f"Last message content: {json.dumps(messages[-1] if messages else {}, indent=2)}")
|
||||
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
|
||||
max_retries = 6 # Increased to allow longer backoff periods
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# Prepare messages for API call
|
||||
@@ -407,30 +416,33 @@ class AIAgent:
|
||||
if active_system_prompt:
|
||||
# Insert system message at the beginning
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
|
||||
# Make API call with tools
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
timeout=60.0 # Add explicit timeout
|
||||
timeout=300.0 # 5 minute timeout for long-running agent tasks
|
||||
)
|
||||
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
print(f"⏱️ API call completed in {api_duration:.2f}s")
|
||||
|
||||
print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s")
|
||||
|
||||
# Record API timing in profiler
|
||||
get_profiler().record_api_timing(api_duration)
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"API Response received - Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
|
||||
|
||||
|
||||
break # Success, exit retry loop
|
||||
|
||||
|
||||
except Exception as api_error:
|
||||
retry_count += 1
|
||||
if retry_count > max_retries:
|
||||
raise api_error
|
||||
|
||||
wait_time = min(2 ** retry_count, 10) # Exponential backoff, max 10s
|
||||
print(f"⚠️ API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
||||
|
||||
wait_time = min(2 ** retry_count, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s, 60s
|
||||
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
||||
print(f"⏳ Retrying in {wait_time}s...")
|
||||
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
||||
time.sleep(wait_time)
|
||||
@@ -449,22 +461,58 @@ class AIAgent:
|
||||
if self.verbose_logging:
|
||||
for tc in assistant_message.tool_calls:
|
||||
logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...")
|
||||
# Debug: Check what attributes are available on tool_call
|
||||
logging.debug(f"Tool call attributes: {dir(tc)}")
|
||||
# Try to dump the model to see all fields
|
||||
if hasattr(tc, 'model_dump'):
|
||||
logging.debug(f"Tool call data: {tc.model_dump()}")
|
||||
|
||||
# Add assistant message with tool calls to conversation
|
||||
# Extract thought_signature if present (required for Gemini models)
|
||||
tool_calls_data = []
|
||||
for tool_call in assistant_message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
# Try multiple ways to access thought_signature (Gemini-specific)
|
||||
# Gemini uses extra_content.google.thought_signature structure
|
||||
thought_sig = None
|
||||
|
||||
# Method 1: Check extra_content attribute
|
||||
if hasattr(tool_call, 'extra_content'):
|
||||
extra = tool_call.extra_content
|
||||
if isinstance(extra, dict) and 'google' in extra:
|
||||
thought_sig = extra['google'].get('thought_signature')
|
||||
|
||||
# Method 2: Check model_dump() if available (Pydantic v2)
|
||||
if thought_sig is None and hasattr(tool_call, 'model_dump'):
|
||||
dumped = tool_call.model_dump()
|
||||
if 'extra_content' in dumped and isinstance(dumped['extra_content'], dict):
|
||||
google_data = dumped['extra_content'].get('google', {})
|
||||
thought_sig = google_data.get('thought_signature')
|
||||
|
||||
if thought_sig is not None:
|
||||
tool_call_dict["extra_content"] = {
|
||||
"google": {
|
||||
"thought_signature": thought_sig
|
||||
}
|
||||
}
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Captured thought_signature for tool call {tool_call.id}")
|
||||
elif self.verbose_logging:
|
||||
logging.debug(f"No thought_signature found for tool call {tool_call.id}")
|
||||
|
||||
tool_calls_data.append(tool_call_dict)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": assistant_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
for tool_call in assistant_message.tool_calls
|
||||
]
|
||||
"tool_calls": tool_calls_data
|
||||
})
|
||||
|
||||
# Execute each tool call
|
||||
@@ -490,11 +538,15 @@ class AIAgent:
|
||||
tool_duration = time.time() - tool_start_time
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
|
||||
# Record tool timing in profiler
|
||||
get_profiler().record_tool_timing(function_name, tool_duration)
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result preview: {result_preview}...")
|
||||
|
||||
# Add tool result to conversation
|
||||
# Note: thought_signature should NOT be in tool responses, only in assistant messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": function_result,
|
||||
@@ -522,11 +574,11 @@ class AIAgent:
|
||||
"content": final_response
|
||||
})
|
||||
|
||||
print(f"🎉 Conversation completed after {api_call_count} API call(s)")
|
||||
print(f"🎉 Conversation completed after {api_call_count} OpenAI-compatible API call(s)")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error during API call #{api_call_count}: {str(e)}"
|
||||
error_msg = f"Error during OpenAI-compatible API call #{api_call_count}: {str(e)}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
||||
if self.verbose_logging:
|
||||
@@ -562,11 +614,15 @@ class AIAgent:
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
||||
|
||||
# Get profiling statistics for this conversation
|
||||
profiling_stats = get_profiler().get_statistics()
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": completed
|
||||
"completed": completed,
|
||||
"profiling_stats": profiling_stats
|
||||
}
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
@@ -594,7 +650,8 @@ def main(
|
||||
list_tools: bool = False,
|
||||
save_trajectories: bool = False,
|
||||
verbose: bool = False,
|
||||
log_prefix_chars: int = 20
|
||||
log_prefix_chars: int = 20,
|
||||
show_profiling: bool = True
|
||||
):
|
||||
"""
|
||||
Main function for running the agent directly.
|
||||
@@ -613,6 +670,7 @@ def main(
|
||||
save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False.
|
||||
verbose (bool): Enable verbose logging for debugging. Defaults to False.
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20.
|
||||
show_profiling (bool): Display profiling statistics after conversation. Defaults to True.
|
||||
|
||||
Toolset Examples:
|
||||
- "research": Web search, extract, crawl + vision tools
|
||||
@@ -763,7 +821,11 @@ def main(
|
||||
print(f"\n🎯 FINAL RESPONSE:")
|
||||
print("-" * 30)
|
||||
print(result['final_response'])
|
||||
|
||||
|
||||
# Display profiling statistics if enabled
|
||||
if show_profiling:
|
||||
get_profiler().print_statistics(detailed=True)
|
||||
|
||||
print("\n👋 Agent execution completed!")
|
||||
|
||||
|
||||
|
||||
@@ -161,11 +161,11 @@ def _construct_aggregator_prompt(system_prompt: str, responses: List[str]) -> st
|
||||
|
||||
|
||||
async def _run_reference_model_safe(
|
||||
model: str,
|
||||
user_prompt: str,
|
||||
model: str,
|
||||
user_prompt: str,
|
||||
temperature: float = REFERENCE_TEMPERATURE,
|
||||
max_tokens: int = 32000,
|
||||
max_retries: int = 3
|
||||
max_retries: int = 6
|
||||
) -> tuple[str, str, bool]:
|
||||
"""
|
||||
Run a single reference model with retry logic and graceful failure handling.
|
||||
@@ -212,8 +212,8 @@ async def _run_reference_model_safe(
|
||||
print(f"⚠️ {model} unknown error (attempt {attempt + 1}): {error_str}")
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Exponential backoff for rate limiting
|
||||
sleep_time = 2 ** attempt
|
||||
# Exponential backoff for rate limiting: 2s, 4s, 8s, 16s, 32s, 60s
|
||||
sleep_time = min(2 ** (attempt + 1), 60)
|
||||
print(f" Retrying in {sleep_time}s...")
|
||||
await asyncio.sleep(sleep_time)
|
||||
else:
|
||||
|
||||
395
tools/simple_terminal_tool.py
Normal file
395
tools/simple_terminal_tool.py
Normal file
@@ -0,0 +1,395 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple Terminal Tool Module
|
||||
|
||||
A simplified terminal tool that executes commands on MorphCloud VMs without tmux.
|
||||
No session persistence, no interactive app support - just simple command execution.
|
||||
|
||||
Features:
|
||||
- Direct SSH command execution
|
||||
- Background task support
|
||||
- VM lifecycle management with TTL
|
||||
- Automatic cleanup after inactivity
|
||||
|
||||
Usage:
|
||||
from simple_terminal_tool import simple_terminal_tool
|
||||
|
||||
# Execute a simple command
|
||||
result = simple_terminal_tool("ls -la")
|
||||
|
||||
# Execute in background
|
||||
result = simple_terminal_tool("python server.py", background=True)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import atexit
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
# Tool description for LLM
|
||||
SIMPLE_TERMINAL_TOOL_DESCRIPTION = """Execute commands on a secure Linux VM environment.
|
||||
|
||||
**Environment:**
|
||||
- Minimal Debian-based OS with internet access
|
||||
- Automatic VM lifecycle management (creates on-demand, reuses, cleans up)
|
||||
- Filesystem is persisted between tool calls but environment variables, venvs, etc are reset.
|
||||
|
||||
**Command Execution:**
|
||||
- Simple commands: Just provide the 'command' parameter
|
||||
- Background processes: Set 'background': True for servers/long-running tasks
|
||||
- Command timeout: Optional 'timeout' parameter in seconds
|
||||
|
||||
**Examples:**
|
||||
- Run command: `{"command": "ls -la"}`
|
||||
- Background task: `{"command": "source path/to/my/venv/bin/activate && python server.py", "background": True}`
|
||||
- With timeout: `{"command": "long_task.sh", "timeout": 300}`
|
||||
|
||||
**Best Practices:**
|
||||
- Run servers/long processes in background
|
||||
- Monitor disk usage for large tasks
|
||||
- Install whatever tools you need with sudo apt-get
|
||||
- Do not be afraid to run pip with --break-system-packages
|
||||
|
||||
**Things to avoid**
|
||||
- Do NOT use interactive tools such as tmux, vim, nano, python repl - you will get stuck. Even git sometimes becomes interactive if the output is large. If you're not sure pipe to cat.
|
||||
"""
|
||||
|
||||
# Global state for VM lifecycle management
|
||||
_active_instances: Dict[str, Any] = {}
|
||||
_last_activity: Dict[str, float] = {}
|
||||
_instance_lock = threading.Lock()
|
||||
_cleanup_thread = None
|
||||
_cleanup_running = False
|
||||
|
||||
|
||||
def _cleanup_inactive_vms(vm_lifetime_seconds: int = 300):
|
||||
"""Clean up VMs that have been inactive for longer than vm_lifetime_seconds."""
|
||||
global _active_instances, _last_activity
|
||||
|
||||
current_time = time.time()
|
||||
tasks_to_cleanup = []
|
||||
|
||||
with _instance_lock:
|
||||
for task_id, last_time in list(_last_activity.items()):
|
||||
if current_time - last_time > vm_lifetime_seconds:
|
||||
tasks_to_cleanup.append(task_id)
|
||||
|
||||
for task_id in tasks_to_cleanup:
|
||||
try:
|
||||
if task_id in _active_instances:
|
||||
instance = _active_instances[task_id]
|
||||
if hasattr(instance, 'terminate'):
|
||||
instance.terminate()
|
||||
elif hasattr(instance, 'stop'):
|
||||
instance.stop()
|
||||
elif hasattr(instance, 'delete'):
|
||||
instance.delete()
|
||||
|
||||
del _active_instances[task_id]
|
||||
print(f"[VM Cleanup] Terminated inactive VM for task: {task_id}")
|
||||
|
||||
if task_id in _last_activity:
|
||||
del _last_activity[task_id]
|
||||
|
||||
except Exception as e:
|
||||
# 404 errors are benign - VM already cleaned up by TTL
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "InstanceNotFoundError" in error_str or "not found" in error_str.lower():
|
||||
print(f"[VM Cleanup] VM for task {task_id} already cleaned up (likely TTL expiration)")
|
||||
else:
|
||||
print(f"[VM Cleanup] Error cleaning up VM for task {task_id}: {e}")
|
||||
|
||||
|
||||
def _cleanup_thread_worker():
|
||||
"""Background thread worker that periodically cleans up inactive VMs."""
|
||||
global _cleanup_running
|
||||
|
||||
while _cleanup_running:
|
||||
try:
|
||||
vm_lifetime = int(os.getenv("HECATE_VM_LIFETIME_SECONDS", "300"))
|
||||
_cleanup_inactive_vms(vm_lifetime)
|
||||
except Exception as e:
|
||||
print(f"[VM Cleanup] Error in cleanup thread: {e}")
|
||||
|
||||
for _ in range(60):
|
||||
if not _cleanup_running:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _start_cleanup_thread():
|
||||
"""Start the background cleanup thread if not already running."""
|
||||
global _cleanup_thread, _cleanup_running
|
||||
|
||||
with _instance_lock:
|
||||
if _cleanup_thread is None or not _cleanup_thread.is_alive():
|
||||
_cleanup_running = True
|
||||
_cleanup_thread = threading.Thread(target=_cleanup_thread_worker, daemon=True)
|
||||
_cleanup_thread.start()
|
||||
|
||||
|
||||
def _stop_cleanup_thread():
|
||||
"""Stop the background cleanup thread."""
|
||||
global _cleanup_running
|
||||
_cleanup_running = False
|
||||
if _cleanup_thread is not None:
|
||||
_cleanup_thread.join(timeout=5)
|
||||
|
||||
|
||||
def cleanup_vm(task_id: str):
|
||||
"""Manually clean up a specific VM by task_id."""
|
||||
global _active_instances, _last_activity
|
||||
|
||||
with _instance_lock:
|
||||
try:
|
||||
if task_id in _active_instances:
|
||||
instance = _active_instances[task_id]
|
||||
if hasattr(instance, 'terminate'):
|
||||
instance.terminate()
|
||||
elif hasattr(instance, 'stop'):
|
||||
instance.stop()
|
||||
elif hasattr(instance, 'delete'):
|
||||
instance.delete()
|
||||
|
||||
del _active_instances[task_id]
|
||||
print(f"[VM Cleanup] Manually terminated VM for task: {task_id}")
|
||||
|
||||
if task_id in _last_activity:
|
||||
del _last_activity[task_id]
|
||||
|
||||
except Exception as e:
|
||||
# 404 errors are benign - VM already cleaned up by TTL
|
||||
error_str = str(e)
|
||||
if "404" in error_str or "InstanceNotFoundError" in error_str or "not found" in error_str.lower():
|
||||
print(f"[VM Cleanup] VM for task {task_id} already cleaned up (likely TTL expiration)")
|
||||
else:
|
||||
print(f"[VM Cleanup] Error manually cleaning up VM for task {task_id}: {e}")
|
||||
|
||||
|
||||
atexit.register(_stop_cleanup_thread)
|
||||
|
||||
|
||||
def _execute_ssh_command(instance, command: str, timeout: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a command via SSH on the VM instance.
|
||||
|
||||
Args:
|
||||
instance: MorphVM instance
|
||||
command: Command to execute
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
dict with stdout, stderr, returncode
|
||||
"""
|
||||
ssh_context_manager = None
|
||||
try:
|
||||
# Use the instance's SSH context manager
|
||||
ssh_context_manager = instance.ssh()
|
||||
ssh_context = ssh_context_manager.__enter__()
|
||||
|
||||
# Execute the command
|
||||
result = ssh_context.run(command, get_pty=False, timeout=timeout or 120)
|
||||
|
||||
# Close the SSH connection
|
||||
if ssh_context_manager:
|
||||
try:
|
||||
ssh_context_manager.__exit__(None, None, None)
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"stdout": result.stdout or "",
|
||||
"stderr": result.stderr or "",
|
||||
"returncode": result.returncode
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Close connection on error
|
||||
if ssh_context_manager:
|
||||
try:
|
||||
ssh_context_manager.__exit__(None, None, None)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check if it's a timeout
|
||||
error_str = str(e).lower()
|
||||
if "timeout" in error_str:
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"Command timed out after {timeout or 120} seconds",
|
||||
"returncode": 124
|
||||
}
|
||||
|
||||
return {
|
||||
"stdout": "",
|
||||
"stderr": f"SSH execution failed: {str(e)}",
|
||||
"returncode": -1
|
||||
}
|
||||
|
||||
|
||||
def simple_terminal_tool(
|
||||
command: str,
|
||||
background: bool = False,
|
||||
timeout: Optional[int] = None,
|
||||
task_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Execute a command on a MorphCloud VM without session persistence.
|
||||
|
||||
Args:
|
||||
command: The command to execute
|
||||
background: Whether to run in background (default: False)
|
||||
timeout: Command timeout in seconds (default: 120)
|
||||
task_id: Unique identifier for VM isolation (optional)
|
||||
|
||||
Returns:
|
||||
str: JSON string with output, exit_code, and error fields
|
||||
|
||||
Examples:
|
||||
# Execute a simple command
|
||||
>>> result = simple_terminal_tool(command="ls -la /tmp")
|
||||
|
||||
# Run a background task
|
||||
>>> result = simple_terminal_tool(command="python server.py", background=True)
|
||||
|
||||
# With custom timeout
|
||||
>>> result = simple_terminal_tool(command="long_task.sh", timeout=300)
|
||||
"""
|
||||
global _active_instances, _last_activity
|
||||
|
||||
try:
|
||||
# Import required modules
|
||||
try:
|
||||
from morphcloud.api import MorphCloudClient
|
||||
except ImportError as import_error:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Terminal tool disabled: {import_error}",
|
||||
"status": "disabled"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Get configuration
|
||||
vm_ttl_seconds = int(os.getenv("HECATE_VM_TTL_SECONDS", "1200"))
|
||||
snapshot_id = os.getenv("HECATE_DEFAULT_SNAPSHOT_ID", "snapshot_defv9tjg")
|
||||
|
||||
# Check API key
|
||||
morph_api_key = os.getenv("MORPH_API_KEY")
|
||||
if not morph_api_key:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": "MORPH_API_KEY environment variable not set",
|
||||
"status": "disabled"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Use task_id for VM isolation
|
||||
effective_task_id = task_id or "default"
|
||||
|
||||
# Start cleanup thread
|
||||
_start_cleanup_thread()
|
||||
|
||||
# Get or create VM instance
|
||||
with _instance_lock:
|
||||
if effective_task_id not in _active_instances:
|
||||
morph_client = MorphCloudClient(api_key=morph_api_key)
|
||||
_active_instances[effective_task_id] = morph_client.instances.start(
|
||||
snapshot_id=snapshot_id,
|
||||
ttl_seconds=vm_ttl_seconds,
|
||||
ttl_action="stop"
|
||||
)
|
||||
|
||||
# Update last activity time
|
||||
_last_activity[effective_task_id] = time.time()
|
||||
instance = _active_instances[effective_task_id]
|
||||
|
||||
# Wait for instance to be ready
|
||||
instance.wait_until_ready()
|
||||
|
||||
# Prepare command for execution
|
||||
if background:
|
||||
# Run in background with nohup and redirect output
|
||||
exec_command = f"nohup {command} > /tmp/bg_output.log 2>&1 &"
|
||||
result = _execute_ssh_command(instance, exec_command, timeout=10)
|
||||
|
||||
# For background tasks, return immediately with info
|
||||
if result["returncode"] == 0:
|
||||
return json.dumps({
|
||||
"output": "Background task started successfully",
|
||||
"exit_code": 0,
|
||||
"error": None
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
return json.dumps({
|
||||
"output": result["stdout"],
|
||||
"exit_code": result["returncode"],
|
||||
"error": result["stderr"]
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
# Run foreground command
|
||||
result = _execute_ssh_command(instance, command, timeout=timeout)
|
||||
|
||||
# Combine stdout and stderr for output
|
||||
output = result["stdout"]
|
||||
if result["stderr"] and result["returncode"] != 0:
|
||||
output = f"{output}\n{result['stderr']}" if output else result["stderr"]
|
||||
|
||||
return json.dumps({
|
||||
"output": output.strip(),
|
||||
"exit_code": result["returncode"],
|
||||
"error": result["stderr"] if result["returncode"] != 0 else None
|
||||
}, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps({
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": f"Failed to execute command: {str(e)}",
|
||||
"status": "error"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def check_requirements() -> bool:
|
||||
"""Check if all requirements for the simple terminal tool are met."""
|
||||
required_vars = ["MORPH_API_KEY"]
|
||||
missing_required = [var for var in required_vars if not os.getenv(var)]
|
||||
|
||||
if missing_required:
|
||||
print(f"Missing required environment variables: {', '.join(missing_required)}")
|
||||
return False
|
||||
|
||||
try:
|
||||
from morphcloud.api import MorphCloudClient
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"MorphCloud not available: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""Simple test when run directly."""
|
||||
print("Simple Terminal Tool Module")
|
||||
print("=" * 40)
|
||||
|
||||
if not check_requirements():
|
||||
print("Requirements not met. Please check the messages above.")
|
||||
exit(1)
|
||||
|
||||
print("All requirements met!")
|
||||
print("\nAvailable Tool:")
|
||||
print(" - simple_terminal_tool: Execute commands without session persistence")
|
||||
|
||||
print("\nUsage Examples:")
|
||||
print(" # Execute a command")
|
||||
print(" result = simple_terminal_tool(command='ls -la')")
|
||||
print(" ")
|
||||
print(" # Run a background task")
|
||||
print(" result = simple_terminal_tool(command='python server.py', background=True)")
|
||||
|
||||
print("\nEnvironment Variables:")
|
||||
print(f" MORPH_API_KEY: {'Set' if os.getenv('MORPH_API_KEY') else 'Not set'}")
|
||||
print(f" HECATE_VM_TTL_SECONDS: {os.getenv('HECATE_VM_TTL_SECONDS', '1200')} (default: 1200 / 20 minutes)")
|
||||
print(f" HECATE_VM_LIFETIME_SECONDS: {os.getenv('HECATE_VM_LIFETIME_SECONDS', '300')} (default: 300 / 5 minutes)")
|
||||
print(f" HECATE_DEFAULT_SNAPSHOT_ID: {os.getenv('HECATE_DEFAULT_SNAPSHOT_ID', 'snapshot_defv9tjg')}")
|
||||
@@ -184,10 +184,10 @@ Your goal is to preserve ALL important information while reducing length. Never
|
||||
Create a markdown summary that captures all key information in a well-organized, scannable format. Include important quotes and code snippets in their original formatting. Focus on actionable information, specific details, and unique insights."""
|
||||
|
||||
# Call the LLM asynchronously with retry logic for flaky API
|
||||
max_retries = 3
|
||||
max_retries = 6
|
||||
retry_delay = 2 # Start with 2 seconds
|
||||
last_error = None
|
||||
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = await nous_client.chat.completions.create(
|
||||
@@ -206,7 +206,7 @@ Create a markdown summary that captures all key information in a well-organized,
|
||||
print(f"⚠️ LLM API call failed (attempt {attempt + 1}/{max_retries}): {str(api_error)[:100]}")
|
||||
print(f" Retrying in {retry_delay}s...")
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay *= 2 # Exponential backoff: 2s, 4s, 8s
|
||||
retry_delay = min(retry_delay * 2, 60) # Exponential backoff: 2s, 4s, 8s, 16s, 32s, 60s
|
||||
else:
|
||||
# All retries exhausted
|
||||
raise last_error
|
||||
|
||||
@@ -67,7 +67,7 @@ DISTRIBUTIONS = {
|
||||
"description": "Web research with vision analysis and reasoning",
|
||||
"toolsets": {
|
||||
"web": 94, # 90% chance of web tools
|
||||
"vision": 50, # 50% chance of vision tools
|
||||
"vision": 65, # 50% chance of vision tools
|
||||
"moa": 10, # 40% chance of reasoning tools
|
||||
"terminal": 94, # 10% chance of terminal tools
|
||||
"image_gen": 15 # 80% chance of image generation tools
|
||||
|
||||
Reference in New Issue
Block a user