Compare commits
4 Commits
architectu
...
thought-si
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a219e178a1 | ||
|
|
e06a15b3ab | ||
|
|
349e37de0a | ||
|
|
31c733383b |
@@ -1,55 +0,0 @@
|
||||
# Agents
|
||||
|
||||
Agents can be viewed as an FSM using an LLM to generate inputs into the system that operates over a DAG.
|
||||
|
||||
What this really means is that the agent is just a function without memory that uses text inputs and outputs in a
|
||||
defined order.
|
||||
|
||||
```python
|
||||
def my_agent(*args, **kwargs) -> str:
|
||||
# do whatever you want!
|
||||
return "Hi I'm an agent!"
|
||||
```
|
||||
|
||||
Now obviously, that's like saying water's wet, but we're going to be using that definition to inform our design of the
|
||||
library, namely, that we should *not* store agent state outside the function call.
|
||||
|
||||
## The Agent Class
|
||||
|
||||
So we don't have state, why are we using a class?
|
||||
|
||||
Well, we want to initialize things, we want to have some configuration, and we want to have some helper functions.
|
||||
Preferably all in a single place.
|
||||
|
||||
```python
|
||||
class BaseAgent:
|
||||
def agent_primitives(self) -> list[BaseAgent]:
|
||||
# Returns a list of Agents that are utilized by this agent to generate inputs
|
||||
# We use agent primitives here instead of subagents because these are going to be part
|
||||
# of the message graph, not a subagent tool call.
|
||||
raise NotImplementedError
|
||||
|
||||
def tools(self) -> list[BaseTool]:
|
||||
# Returns a list of tools that the agent needs to run
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def run(self, config, *args, **kwargs) -> ConversationGraph:
|
||||
llm = get_llm(config)
|
||||
tools = self.tools()
|
||||
for agent in self.agent_primitives():
|
||||
tools.extend(agent.tools())
|
||||
tools = remove_duplicates(tools)
|
||||
tools = initialize_tools(tools, config)
|
||||
return self(llm, tools, config, *args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def __call__(self, llm, tools, config, *args, **kwargs) -> ConversationGraph:
|
||||
# Returns a ConversationGraph that can be parsed to get the output of the agent
|
||||
# Use w/e args/kwargs you want, as long as llm/tools/config are satisfied.
|
||||
raise NotImplementedError
|
||||
```
|
||||
|
||||
Doesn't seem too bad (I hope), it is a bit annoying that we don't initialize everything in the constructor, but
|
||||
hopefully we all kinda like it :)
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
# LLM Client
|
||||
|
||||
A quick wrapper over openai apis
|
||||
|
||||
## Responsibilities
|
||||
|
||||
- Transform "normal" chat/completions requests into graphs
|
||||
- Translate graphs into LLM requests
|
||||
- Keep a history of graphs parsed by it
|
||||
- On Policy Data
|
||||
- Deduplicating graphs, so we don't keep previous history as separate graphs
|
||||
|
||||
## How to use
|
||||
Exactly the same as the openai api! Just with the additional support of graph inputs and outputs.
|
||||
@@ -1,114 +0,0 @@
|
||||
# Message Graph
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
%% Message nodes
|
||||
SystemMsg["📋 System Message<br/>Role: System<br/>Content: Messages are nodes in a graph"]
|
||||
UserMsg["👤 User Message<br/>Role: User<br/>Content: But messages aren't the only thing in the graph"]
|
||||
subgraph PrevMessages["Previous Messages"]
|
||||
PrevSystemMsg["📋 System Message<br/>Role: System<br/>Content: Edits are kept in the graph as context"]
|
||||
PrevUserMsg["👤 User Message<br/>Role: User<br/>Content: So we can ensure they're immutable while keeping them editable"]
|
||||
end
|
||||
|
||||
%% Chat Response as a subgraph
|
||||
subgraph ChatResponseBox["💬 Chat Response"]
|
||||
ChatMetadata["📊 Metadata<br/>Temp: 1.0<br/>..."]
|
||||
ChatResponseText["📝 Response<br/>Hello, Here's a subagent call: <tool>subagent</tool>"]
|
||||
ChatContent["Content: Hello, Here's a subagent call..."]
|
||||
end
|
||||
|
||||
%% Tool Response as a subgraph
|
||||
subgraph ToolResponseBox["🔧 Tool Response"]
|
||||
subgraph ToolMetadata["📊 Tool Metadata"]
|
||||
ToolMetadataLength["Length: 3"]
|
||||
subgraph ToolChat["💭 Subagent Chat"]
|
||||
SubagentSystem["📋 System<br/>Content: Subagent call received"]
|
||||
SubagentUser["👤 User<br/>Content: Process this request"]
|
||||
SubagentAssistant["🤖 Assistant<br/>Content: Processing..."]
|
||||
SubagentSystem --> SubagentUser
|
||||
SubagentUser --> SubagentAssistant
|
||||
end
|
||||
end
|
||||
ToolContent["Content: Subagent call output"]
|
||||
end
|
||||
|
||||
%% Graph flow connections
|
||||
SystemMsg --> UserMsg
|
||||
PrevSystemMsg --> PrevUserMsg
|
||||
PrevMessages -.-> UserMsg
|
||||
UserMsg --> ChatResponseBox
|
||||
ChatResponseBox --> ToolResponseBox
|
||||
|
||||
class SystemMsg,UserMsg messageNode
|
||||
class ChatResponseBox responseNode
|
||||
class ToolResponseBox responseNode
|
||||
class ChatMetadata,ChatResponseText,ChatContent,ToolMetadata,ToolChat,ToolContent,ToolMetadataLength metadataNode
|
||||
```
|
||||
|
||||
Messages should be a graph (DAG, specifically) of immutable elements.
|
||||
|
||||
## Why immutable elements?
|
||||
We want to train on policy
|
||||
- This means the context cannot change after we call a response.
|
||||
|
||||
## Why a graph?
|
||||
Nodes and connections are a natural way to represent the flow of information in an agent conversation.
|
||||
|
||||
## Will this be annoying to deal with?
|
||||
|
||||
It shouldn't be! While there will be internal stuff that may look ???, for the interface, it should be as simple as your
|
||||
normal context window edits, so `message_history[2]['content'] = my_edit`, but internally we'll deal with the recordkeeping
|
||||
and how this ends up parsing into on policy training data, if requested.
|
||||
|
||||
## Edges
|
||||
|
||||
Edges are the connections between nodes, and there are two types we are concerned with:
|
||||
- **Sequential edges**: These represent the flow of conversation, connecting messages in the order they were sent. For example, a user message followed by an assistant response.
|
||||
- **Parallel edges**: These represent versioning, e.g. edit history, context squishing, etc.
|
||||
We, however, are only concerned about parallel edges when we break the prefix, and ignore any other parallel edges.
|
||||
|
||||
## So what does this look like in practice?
|
||||
|
||||
```python
|
||||
import copy
|
||||
|
||||
|
||||
class MessageGraph:
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self.prev_graph = None
|
||||
|
||||
def append(self, message):
|
||||
self.messages.append(message)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.messages[index]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# check if an assistant message is after this indx
|
||||
needs_new_graph = False
|
||||
first_idx = -1
|
||||
for i in range(key, len(self.messages)):
|
||||
if (i == key) and (value['role'] == 'assistant') and (value['content'] == self.messages[i]['content']):
|
||||
# no op
|
||||
return
|
||||
needs_new_graph = needs_new_graph or (self.messages[i]['role'] == 'assistant')
|
||||
if needs_new_graph and first_idx == -1:
|
||||
first_idx = i
|
||||
if needs_new_graph:
|
||||
self.prev_graph = copy.deepcopy(self)
|
||||
self.messages[key] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.messages)
|
||||
|
||||
def __eq__(self, other):
|
||||
return "\n\n".join(f"{msg['role']}: {msg['content']}" for msg in self) == "\n\n".join(
|
||||
f"{msg['role']}: {msg['content']}" for msg in other)
|
||||
|
||||
|
||||
# in use
|
||||
messages = MessageGraph()
|
||||
messages.append({'role': 'system', 'content': 'Hello, I am a system message'})
|
||||
messages[0] = {'role': 'user', 'content': 'Hello, I am a user message'}
|
||||
```
|
||||
@@ -1,16 +0,0 @@
|
||||
# Tools
|
||||
|
||||
Not much on this, yet. Tools are just a stateful wrapper around a function, so we can do things like:
|
||||
- Keep a docker container running
|
||||
- Keep a game online
|
||||
|
||||
```python
|
||||
class BaseTool:
|
||||
def definitions(self) -> List[Dict[str, Any]]:
|
||||
# OpenAI API compatible definitions
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
# Returns at minimum {'role': 'tool', 'content': '...'}
|
||||
raise NotImplementedError
|
||||
```
|
||||
789
batch_runner.py
789
batch_runner.py
File diff suppressed because it is too large
Load Diff
12
gemini_nothinking.sh
Normal file
12
gemini_nothinking.sh
Normal file
@@ -0,0 +1,12 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/agent_tasks_eval.jsonl" \
|
||||
--batch_size=1 \
|
||||
--run_name="agenttasks_eval_gemini-4.5-3-nothinking" \
|
||||
--distribution="science" \
|
||||
--model="gemini-3-pro-preview" \
|
||||
--base_url="https://generativelanguage.googleapis.com/v1beta/openai/" \
|
||||
--api_key="${GEMINI_API_KEY}" \
|
||||
--num_workers=10 \
|
||||
--max_turns=60 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. If you require API keys please check which ones already exist in your environment variables in a way that does not read them."
|
||||
381
profiling.py
Normal file
381
profiling.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Profiling module for tracking timing statistics of tools and LLM API calls.
|
||||
|
||||
This module provides a centralized way to track timing information for various
|
||||
operations in the agent system, including:
|
||||
- Individual tool executions
|
||||
- OpenAI API calls
|
||||
- Aggregate statistics (min, max, median, mean, total)
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from collections import defaultdict
|
||||
import statistics
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfilingStats:
|
||||
"""Statistics for a particular operation type."""
|
||||
call_count: int = 0
|
||||
total_time: float = 0.0
|
||||
min_time: float = float('inf')
|
||||
max_time: float = 0.0
|
||||
times: List[float] = field(default_factory=list)
|
||||
|
||||
def add_timing(self, duration: float):
|
||||
"""Add a timing measurement."""
|
||||
self.call_count += 1
|
||||
self.total_time += duration
|
||||
self.min_time = min(self.min_time, duration)
|
||||
self.max_time = max(self.max_time, duration)
|
||||
self.times.append(duration)
|
||||
|
||||
@property
|
||||
def mean_time(self) -> float:
|
||||
"""Calculate mean time."""
|
||||
return self.total_time / self.call_count if self.call_count > 0 else 0.0
|
||||
|
||||
@property
|
||||
def median_time(self) -> float:
|
||||
"""Calculate median time."""
|
||||
return statistics.median(self.times) if self.times else 0.0
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"call_count": self.call_count,
|
||||
"total_time": self.total_time,
|
||||
"min_time": self.min_time if self.min_time != float('inf') else 0.0,
|
||||
"max_time": self.max_time,
|
||||
"mean_time": self.mean_time,
|
||||
"median_time": self.median_time
|
||||
}
|
||||
|
||||
|
||||
class Profiler:
|
||||
"""
|
||||
Global profiler for tracking timing statistics across tools and API calls.
|
||||
|
||||
Usage:
|
||||
profiler = Profiler()
|
||||
|
||||
# Time a tool execution
|
||||
with profiler.time_tool("web_search"):
|
||||
# ... tool execution code ...
|
||||
pass
|
||||
|
||||
# Time an API call
|
||||
with profiler.time_api_call():
|
||||
# ... API call code ...
|
||||
pass
|
||||
|
||||
# Get statistics
|
||||
stats = profiler.get_statistics()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the profiler."""
|
||||
self.tool_stats: Dict[str, ProfilingStats] = defaultdict(ProfilingStats)
|
||||
self.api_stats: ProfilingStats = ProfilingStats()
|
||||
self._enabled = True
|
||||
|
||||
def enable(self):
|
||||
"""Enable profiling."""
|
||||
self._enabled = True
|
||||
|
||||
def disable(self):
|
||||
"""Disable profiling."""
|
||||
self._enabled = False
|
||||
|
||||
def reset(self):
|
||||
"""Reset all profiling data."""
|
||||
self.tool_stats.clear()
|
||||
self.api_stats = ProfilingStats()
|
||||
|
||||
def record_tool_timing(self, tool_name: str, duration: float):
|
||||
"""Record timing for a tool execution."""
|
||||
if self._enabled:
|
||||
self.tool_stats[tool_name].add_timing(duration)
|
||||
|
||||
def record_api_timing(self, duration: float):
|
||||
"""Record timing for an API call."""
|
||||
if self._enabled:
|
||||
self.api_stats.add_timing(duration)
|
||||
|
||||
def get_statistics(self) -> Dict:
|
||||
"""
|
||||
Get all profiling statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary containing tool and API statistics
|
||||
"""
|
||||
return {
|
||||
"tools": {
|
||||
tool_name: stats.to_dict()
|
||||
for tool_name, stats in sorted(self.tool_stats.items())
|
||||
},
|
||||
"api_calls": self.api_stats.to_dict()
|
||||
}
|
||||
|
||||
def print_statistics(self, detailed: bool = True):
|
||||
"""
|
||||
Print profiling statistics in a readable format.
|
||||
|
||||
Args:
|
||||
detailed: If True, show per-tool breakdown. If False, show summary only.
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("📊 PROFILING STATISTICS")
|
||||
print("="*80)
|
||||
|
||||
# API Call Statistics
|
||||
print("\n🔷 OpenAI API Calls:")
|
||||
if self.api_stats.call_count > 0:
|
||||
api_dict = self.api_stats.to_dict()
|
||||
print(f" Total Calls: {api_dict['call_count']}")
|
||||
print(f" Total Time: {api_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {api_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {api_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {api_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {api_dict['median_time']:.2f}s")
|
||||
else:
|
||||
print(" No API calls recorded")
|
||||
|
||||
# Tool Statistics
|
||||
print("\n🔧 Tool Executions:")
|
||||
if self.tool_stats:
|
||||
if detailed:
|
||||
for tool_name in sorted(self.tool_stats.keys()):
|
||||
stats_dict = self.tool_stats[tool_name].to_dict()
|
||||
print(f"\n 📌 {tool_name}:")
|
||||
print(f" Total Calls: {stats_dict['call_count']}")
|
||||
print(f" Total Time: {stats_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {stats_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {stats_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {stats_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {stats_dict['median_time']:.2f}s")
|
||||
|
||||
# Summary
|
||||
total_tool_calls = sum(s.call_count for s in self.tool_stats.values())
|
||||
total_tool_time = sum(s.total_time for s in self.tool_stats.values())
|
||||
print(f"\n 📊 Summary:")
|
||||
print(f" Total Tool Calls: {total_tool_calls}")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Unique Tools Used: {len(self.tool_stats)}")
|
||||
else:
|
||||
print(" No tool executions recorded")
|
||||
|
||||
# Overall Summary
|
||||
total_api_time = self.api_stats.total_time
|
||||
total_tool_time = sum(s.total_time for s in self.tool_stats.values())
|
||||
print(f"\n📈 Overall Summary:")
|
||||
print(f" Total API Time: {total_api_time:.2f}s")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Total Time: {total_api_time + total_tool_time:.2f}s")
|
||||
print("="*80 + "\n")
|
||||
|
||||
def export_to_json(self) -> str:
|
||||
"""Export statistics as JSON string."""
|
||||
import json
|
||||
return json.dumps(self.get_statistics(), indent=2)
|
||||
|
||||
def export_to_file(self, filepath: str):
|
||||
"""
|
||||
Export statistics to a JSON file.
|
||||
|
||||
Args:
|
||||
filepath: Path to output file
|
||||
"""
|
||||
import json
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(self.get_statistics(), f, indent=2)
|
||||
print(f"📁 Profiling statistics exported to: {filepath}")
|
||||
|
||||
|
||||
# Global profiler instance
|
||||
_global_profiler: Optional[Profiler] = None
|
||||
|
||||
|
||||
def get_profiler() -> Profiler:
|
||||
"""Get or create the global profiler instance."""
|
||||
global _global_profiler
|
||||
if _global_profiler is None:
|
||||
_global_profiler = Profiler()
|
||||
return _global_profiler
|
||||
|
||||
|
||||
def reset_profiler():
|
||||
"""Reset the global profiler."""
|
||||
global _global_profiler
|
||||
if _global_profiler is not None:
|
||||
_global_profiler.reset()
|
||||
|
||||
|
||||
class TimingContext:
|
||||
"""Context manager for timing operations."""
|
||||
|
||||
def __init__(self, profiler: Profiler, operation_type: str, operation_name: Optional[str] = None):
|
||||
"""
|
||||
Initialize timing context.
|
||||
|
||||
Args:
|
||||
profiler: Profiler instance to record timing
|
||||
operation_type: 'tool' or 'api'
|
||||
operation_name: Name of the operation (required for tools)
|
||||
"""
|
||||
self.profiler = profiler
|
||||
self.operation_type = operation_type
|
||||
self.operation_name = operation_name
|
||||
self.start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
"""Start timing."""
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop timing and record."""
|
||||
duration = time.time() - self.start_time
|
||||
|
||||
if self.operation_type == 'tool':
|
||||
self.profiler.record_tool_timing(self.operation_name, duration)
|
||||
elif self.operation_type == 'api':
|
||||
self.profiler.record_api_timing(duration)
|
||||
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
|
||||
def aggregate_profiling_stats(stats_list: List[Dict]) -> Dict:
|
||||
"""
|
||||
Aggregate multiple profiling statistics dictionaries into one.
|
||||
|
||||
This is useful for batch processing where each worker process has its own
|
||||
profiler instance that needs to be combined.
|
||||
|
||||
Args:
|
||||
stats_list: List of statistics dictionaries from get_statistics()
|
||||
|
||||
Returns:
|
||||
Dict: Aggregated statistics with combined tool and API call data
|
||||
"""
|
||||
aggregated = {
|
||||
"tools": defaultdict(lambda: {"times": []}),
|
||||
"api_calls": {"times": []}
|
||||
}
|
||||
|
||||
# Aggregate tool statistics
|
||||
for stats in stats_list:
|
||||
# Aggregate tool timings
|
||||
for tool_name, tool_stats in stats.get("tools", {}).items():
|
||||
# Reconstruct individual timings from aggregated stats
|
||||
# Since we have mean_time and call_count, we approximate
|
||||
aggregated["tools"][tool_name]["times"].extend(
|
||||
[tool_stats.get("mean_time", 0.0)] * tool_stats.get("call_count", 0)
|
||||
)
|
||||
|
||||
# Aggregate API call timings
|
||||
api_stats = stats.get("api_calls", {})
|
||||
if api_stats.get("call_count", 0) > 0:
|
||||
aggregated["api_calls"]["times"].extend(
|
||||
[api_stats.get("mean_time", 0.0)] * api_stats.get("call_count", 0)
|
||||
)
|
||||
|
||||
# Calculate final statistics for tools
|
||||
final_stats = {"tools": {}, "api_calls": {}}
|
||||
|
||||
for tool_name, data in aggregated["tools"].items():
|
||||
times = data["times"]
|
||||
if times:
|
||||
final_stats["tools"][tool_name] = {
|
||||
"call_count": len(times),
|
||||
"total_time": sum(times),
|
||||
"min_time": min(times),
|
||||
"max_time": max(times),
|
||||
"mean_time": statistics.mean(times),
|
||||
"median_time": statistics.median(times)
|
||||
}
|
||||
|
||||
# Calculate final statistics for API calls
|
||||
api_times = aggregated["api_calls"]["times"]
|
||||
if api_times:
|
||||
final_stats["api_calls"] = {
|
||||
"call_count": len(api_times),
|
||||
"total_time": sum(api_times),
|
||||
"min_time": min(api_times),
|
||||
"max_time": max(api_times),
|
||||
"mean_time": statistics.mean(api_times),
|
||||
"median_time": statistics.median(api_times)
|
||||
}
|
||||
else:
|
||||
final_stats["api_calls"] = {
|
||||
"call_count": 0,
|
||||
"total_time": 0.0,
|
||||
"min_time": 0.0,
|
||||
"max_time": 0.0,
|
||||
"mean_time": 0.0,
|
||||
"median_time": 0.0
|
||||
}
|
||||
|
||||
return final_stats
|
||||
|
||||
|
||||
def print_aggregated_statistics(stats: Dict, detailed: bool = True):
|
||||
"""
|
||||
Print aggregated profiling statistics in a readable format.
|
||||
|
||||
Args:
|
||||
stats: Aggregated statistics dictionary from aggregate_profiling_stats()
|
||||
detailed: If True, show per-tool breakdown. If False, show summary only.
|
||||
"""
|
||||
print("\n" + "="*80)
|
||||
print("📊 AGGREGATED PROFILING STATISTICS")
|
||||
print("="*80)
|
||||
|
||||
# API Call Statistics
|
||||
print("\n🔷 OpenAI API Calls:")
|
||||
api_stats = stats.get("api_calls", {})
|
||||
if api_stats.get("call_count", 0) > 0:
|
||||
print(f" Total Calls: {api_stats['call_count']}")
|
||||
print(f" Total Time: {api_stats['total_time']:.2f}s")
|
||||
print(f" Min Time: {api_stats['min_time']:.2f}s")
|
||||
print(f" Max Time: {api_stats['max_time']:.2f}s")
|
||||
print(f" Mean Time: {api_stats['mean_time']:.2f}s")
|
||||
print(f" Median Time: {api_stats['median_time']:.2f}s")
|
||||
else:
|
||||
print(" No API calls recorded")
|
||||
|
||||
# Tool Statistics
|
||||
print("\n🔧 Tool Executions:")
|
||||
tool_stats = stats.get("tools", {})
|
||||
if tool_stats:
|
||||
if detailed:
|
||||
for tool_name in sorted(tool_stats.keys()):
|
||||
stats_dict = tool_stats[tool_name]
|
||||
print(f"\n 📌 {tool_name}:")
|
||||
print(f" Total Calls: {stats_dict['call_count']}")
|
||||
print(f" Total Time: {stats_dict['total_time']:.2f}s")
|
||||
print(f" Min Time: {stats_dict['min_time']:.2f}s")
|
||||
print(f" Max Time: {stats_dict['max_time']:.2f}s")
|
||||
print(f" Mean Time: {stats_dict['mean_time']:.2f}s")
|
||||
print(f" Median Time: {stats_dict['median_time']:.2f}s")
|
||||
|
||||
# Summary
|
||||
total_tool_calls = sum(s["call_count"] for s in tool_stats.values())
|
||||
total_tool_time = sum(s["total_time"] for s in tool_stats.values())
|
||||
print(f"\n 📊 Summary:")
|
||||
print(f" Total Tool Calls: {total_tool_calls}")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Unique Tools Used: {len(tool_stats)}")
|
||||
else:
|
||||
print(" No tool executions recorded")
|
||||
|
||||
# Overall Summary
|
||||
total_api_time = api_stats.get("total_time", 0.0)
|
||||
total_tool_time = sum(s["total_time"] for s in tool_stats.values())
|
||||
print(f"\n📈 Overall Summary:")
|
||||
print(f" Total API Time: {total_api_time:.2f}s")
|
||||
print(f" Total Tool Time: {total_tool_time:.2f}s")
|
||||
print(f" Total Time: {total_api_time + total_tool_time:.2f}s")
|
||||
print("="*80 + "\n")
|
||||
@@ -1,9 +1,6 @@
|
||||
firecrawl-py
|
||||
openai
|
||||
fal-client
|
||||
fire
|
||||
git@github.com:NousResearch/hecate.git
|
||||
tenacity
|
||||
python-dotenv
|
||||
fire
|
||||
httpx
|
||||
httpx
|
||||
90
run_agent.py
90
run_agent.py
@@ -45,6 +45,9 @@ else:
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
|
||||
# Import profiling
|
||||
from profiling import get_profiler
|
||||
|
||||
|
||||
class AIAgent:
|
||||
"""
|
||||
@@ -364,6 +367,10 @@ class AIAgent:
|
||||
Returns:
|
||||
Dict: Complete conversation result with final response and message history
|
||||
"""
|
||||
# Reset profiler for this conversation to get fresh stats
|
||||
from profiling import reset_profiler as reset_prof
|
||||
reset_prof()
|
||||
|
||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||
import uuid
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
@@ -394,6 +401,8 @@ class AIAgent:
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"API Request - Model: {self.model}, Messages: {len(messages)}, Tools: {len(self.tools) if self.tools else 0}")
|
||||
logging.debug(f"Last message role: {messages[-1]['role'] if messages else 'none'}")
|
||||
# Log the last few messages to see if thought_signature is present
|
||||
logging.debug(f"Last message content: {json.dumps(messages[-1] if messages else {}, indent=2)}")
|
||||
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
@@ -419,6 +428,9 @@ class AIAgent:
|
||||
api_duration = time.time() - api_start_time
|
||||
print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s")
|
||||
|
||||
# Record API timing in profiler
|
||||
get_profiler().record_api_timing(api_duration)
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"API Response received - Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
|
||||
|
||||
@@ -449,22 +461,58 @@ class AIAgent:
|
||||
if self.verbose_logging:
|
||||
for tc in assistant_message.tool_calls:
|
||||
logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...")
|
||||
# Debug: Check what attributes are available on tool_call
|
||||
logging.debug(f"Tool call attributes: {dir(tc)}")
|
||||
# Try to dump the model to see all fields
|
||||
if hasattr(tc, 'model_dump'):
|
||||
logging.debug(f"Tool call data: {tc.model_dump()}")
|
||||
|
||||
# Add assistant message with tool calls to conversation
|
||||
# Extract thought_signature if present (required for Gemini models)
|
||||
tool_calls_data = []
|
||||
for tool_call in assistant_message.tool_calls:
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
# Try multiple ways to access thought_signature (Gemini-specific)
|
||||
# Gemini uses extra_content.google.thought_signature structure
|
||||
thought_sig = None
|
||||
|
||||
# Method 1: Check extra_content attribute
|
||||
if hasattr(tool_call, 'extra_content'):
|
||||
extra = tool_call.extra_content
|
||||
if isinstance(extra, dict) and 'google' in extra:
|
||||
thought_sig = extra['google'].get('thought_signature')
|
||||
|
||||
# Method 2: Check model_dump() if available (Pydantic v2)
|
||||
if thought_sig is None and hasattr(tool_call, 'model_dump'):
|
||||
dumped = tool_call.model_dump()
|
||||
if 'extra_content' in dumped and isinstance(dumped['extra_content'], dict):
|
||||
google_data = dumped['extra_content'].get('google', {})
|
||||
thought_sig = google_data.get('thought_signature')
|
||||
|
||||
if thought_sig is not None:
|
||||
tool_call_dict["extra_content"] = {
|
||||
"google": {
|
||||
"thought_signature": thought_sig
|
||||
}
|
||||
}
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Captured thought_signature for tool call {tool_call.id}")
|
||||
elif self.verbose_logging:
|
||||
logging.debug(f"No thought_signature found for tool call {tool_call.id}")
|
||||
|
||||
tool_calls_data.append(tool_call_dict)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": assistant_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call.id,
|
||||
"type": tool_call.type,
|
||||
"function": {
|
||||
"name": tool_call.function.name,
|
||||
"arguments": tool_call.function.arguments
|
||||
}
|
||||
}
|
||||
for tool_call in assistant_message.tool_calls
|
||||
]
|
||||
"tool_calls": tool_calls_data
|
||||
})
|
||||
|
||||
# Execute each tool call
|
||||
@@ -490,11 +538,15 @@ class AIAgent:
|
||||
tool_duration = time.time() - tool_start_time
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
|
||||
# Record tool timing in profiler
|
||||
get_profiler().record_tool_timing(function_name, tool_duration)
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result preview: {result_preview}...")
|
||||
|
||||
# Add tool result to conversation
|
||||
# Note: thought_signature should NOT be in tool responses, only in assistant messages
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": function_result,
|
||||
@@ -562,11 +614,15 @@ class AIAgent:
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
||||
|
||||
# Get profiling statistics for this conversation
|
||||
profiling_stats = get_profiler().get_statistics()
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": completed
|
||||
"completed": completed,
|
||||
"profiling_stats": profiling_stats
|
||||
}
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
@@ -594,7 +650,8 @@ def main(
|
||||
list_tools: bool = False,
|
||||
save_trajectories: bool = False,
|
||||
verbose: bool = False,
|
||||
log_prefix_chars: int = 20
|
||||
log_prefix_chars: int = 20,
|
||||
show_profiling: bool = True
|
||||
):
|
||||
"""
|
||||
Main function for running the agent directly.
|
||||
@@ -613,6 +670,7 @@ def main(
|
||||
save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False.
|
||||
verbose (bool): Enable verbose logging for debugging. Defaults to False.
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20.
|
||||
show_profiling (bool): Display profiling statistics after conversation. Defaults to True.
|
||||
|
||||
Toolset Examples:
|
||||
- "research": Web search, extract, crawl + vision tools
|
||||
@@ -763,7 +821,11 @@ def main(
|
||||
print(f"\n🎯 FINAL RESPONSE:")
|
||||
print("-" * 30)
|
||||
print(result['final_response'])
|
||||
|
||||
|
||||
# Display profiling statistics if enabled
|
||||
if show_profiling:
|
||||
get_profiler().print_statistics(detailed=True)
|
||||
|
||||
print("\n👋 Agent execution completed!")
|
||||
|
||||
|
||||
|
||||
20
safe_print.py
Normal file
20
safe_print.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Simple safe print that tries rich, falls back to regular print."""
|
||||
|
||||
try:
|
||||
from rich import print as rich_print
|
||||
RICH_AVAILABLE = True
|
||||
except ImportError:
|
||||
RICH_AVAILABLE = False
|
||||
|
||||
|
||||
def safe_print(*args, **kwargs):
|
||||
"""Try rich.print, fall back to regular print if it fails."""
|
||||
if RICH_AVAILABLE:
|
||||
try:
|
||||
rich_print(*args, **kwargs)
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
# Fallback to regular print
|
||||
print(*args, **kwargs)
|
||||
Reference in New Issue
Block a user