Compare commits
66 Commits
rl-capabil
...
atropos-in
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
24c13bc412 | ||
|
|
06e9422324 | ||
|
|
907616a692 | ||
|
|
33a00d9b8e | ||
|
|
a2312076da | ||
|
|
499490d06a | ||
|
|
35b2250b36 | ||
|
|
395392e5de | ||
|
|
2041b354a9 | ||
|
|
3951eab399 | ||
|
|
62001e3bf5 | ||
|
|
c8b30e9efa | ||
|
|
f82c3081f2 | ||
|
|
a69924631c | ||
|
|
4619d1c8ef | ||
|
|
98d945f6de | ||
|
|
507b77c4ac | ||
|
|
b99c2a2644 | ||
|
|
975c849308 | ||
|
|
9dc27880cd | ||
|
|
3b9c53e6db | ||
|
|
05dd31131f | ||
|
|
36ea883d45 | ||
|
|
6be8cdeeca | ||
|
|
192ce958c3 | ||
|
|
c441681dc2 | ||
|
|
dd70d57b9b | ||
|
|
f12ea1bc02 | ||
|
|
fa76a331b0 | ||
|
|
d999d9876d | ||
|
|
578a5fb6a9 | ||
|
|
a8809bbd3e | ||
|
|
a478e44585 | ||
|
|
c0494b3558 | ||
|
|
7f1cd014f2 | ||
|
|
07b615e96e | ||
|
|
ab387a6120 | ||
|
|
ac79725923 | ||
|
|
0bc914b00c | ||
|
|
411e7f8ff4 | ||
|
|
eb2e6b73fe | ||
|
|
664acf7426 | ||
|
|
fd1c3da305 | ||
|
|
8dd38318fc | ||
|
|
4d619bcd21 | ||
|
|
beac2ee06a | ||
|
|
487487406d | ||
|
|
87464821d8 | ||
|
|
661d8f4d6c | ||
|
|
bf13a848ef | ||
|
|
88286f6da3 | ||
|
|
5b82190460 | ||
|
|
8380895ae3 | ||
|
|
ea7aa0b0d4 | ||
|
|
7130fa50cb | ||
|
|
5a9c98a771 | ||
|
|
6cb4fe948a | ||
|
|
30221d8c20 | ||
|
|
b5b1fef20a | ||
|
|
16fb41f9cc | ||
|
|
4939130485 | ||
|
|
8dccd6569e | ||
|
|
db348dc467 | ||
|
|
88722e230d | ||
|
|
68fb0efe0e | ||
|
|
e38c274f8d |
115
.clinerules
Normal file
115
.clinerules
Normal file
@@ -0,0 +1,115 @@
|
||||
# Cline's Memory Bank
|
||||
|
||||
I am Cline, an expert software engineer with a unique characteristic: my memory resets completely between sessions. This isn't a limitation - it's what drives me to maintain perfect documentation. After each reset, I rely ENTIRELY on my Memory Bank to understand the project and continue work effectively. I MUST read ALL memory bank files at the start of EVERY task - this is not optional.
|
||||
|
||||
## Memory Bank Structure
|
||||
|
||||
The Memory Bank consists of core files and optional context files, all in Markdown format. Files build upon each other in a clear hierarchy:
|
||||
|
||||
flowchart TD
|
||||
PB[projectbrief.md] --> PC[productContext.md]
|
||||
PB --> SP[systemPatterns.md]
|
||||
PB --> TC[techContext.md]
|
||||
|
||||
PC --> AC[activeContext.md]
|
||||
SP --> AC
|
||||
TC --> AC
|
||||
|
||||
AC --> P[progress.md]
|
||||
|
||||
### Core Files (Required)
|
||||
1. `projectbrief.md`
|
||||
- Foundation document that shapes all other files
|
||||
- Created at project start if it doesn't exist
|
||||
- Defines core requirements and goals
|
||||
- Source of truth for project scope
|
||||
|
||||
2. `productContext.md`
|
||||
- Why this project exists
|
||||
- Problems it solves
|
||||
- How it should work
|
||||
- User experience goals
|
||||
|
||||
3. `activeContext.md`
|
||||
- Current work focus
|
||||
- Recent changes
|
||||
- Next steps
|
||||
- Active decisions and considerations
|
||||
- Important patterns and preferences
|
||||
- Learnings and project insights
|
||||
|
||||
4. `systemPatterns.md`
|
||||
- System architecture
|
||||
- Key technical decisions
|
||||
- Design patterns in use
|
||||
- Component relationships
|
||||
- Critical implementation paths
|
||||
|
||||
5. `techContext.md`
|
||||
- Technologies used
|
||||
- Development setup
|
||||
- Technical constraints
|
||||
- Dependencies
|
||||
- Tool usage patterns
|
||||
|
||||
6. `progress.md`
|
||||
- What works
|
||||
- What's left to build
|
||||
- Current status
|
||||
- Known issues
|
||||
- Evolution of project decisions
|
||||
|
||||
### Additional Context
|
||||
Create additional files/folders within memory-bank/ when they help organize:
|
||||
- Complex feature documentation
|
||||
- Integration specifications
|
||||
- API documentation
|
||||
- Testing strategies
|
||||
- Deployment procedures
|
||||
|
||||
## Core Workflows
|
||||
|
||||
### Plan Mode
|
||||
flowchart TD
|
||||
Start[Start] --> ReadFiles[Read Memory Bank]
|
||||
ReadFiles --> CheckFiles{Files Complete?}
|
||||
|
||||
CheckFiles -->|No| Plan[Create Plan]
|
||||
Plan --> Document[Document in Chat]
|
||||
|
||||
CheckFiles -->|Yes| Verify[Verify Context]
|
||||
Verify --> Strategy[Develop Strategy]
|
||||
Strategy --> Present[Present Approach]
|
||||
|
||||
### Act Mode
|
||||
flowchart TD
|
||||
Start[Start] --> Context[Check Memory Bank]
|
||||
Context --> Update[Update Documentation]
|
||||
Update --> Execute[Execute Task]
|
||||
Execute --> Document[Document Changes]
|
||||
|
||||
## Documentation Updates
|
||||
|
||||
Memory Bank updates occur when:
|
||||
1. Discovering new project patterns
|
||||
2. After implementing significant changes
|
||||
3. When user requests with **update memory bank** (MUST review ALL files)
|
||||
4. When context needs clarification
|
||||
|
||||
flowchart TD
|
||||
Start[Update Process]
|
||||
|
||||
subgraph Process
|
||||
P1[Review ALL Files]
|
||||
P2[Document Current State]
|
||||
P3[Clarify Next Steps]
|
||||
P4[Document Insights & Patterns]
|
||||
|
||||
P1 --> P2 --> P3 --> P4
|
||||
end
|
||||
|
||||
Start --> Process
|
||||
|
||||
Note: When triggered by **update memory bank**, I MUST review every memory bank file, even if some don't require updates. Focus particularly on activeContext.md and progress.md as they track current state.
|
||||
|
||||
REMEMBER: After every memory reset, I begin completely fresh. The Memory Bank is my only link to previous work. It must be maintained with precision and clarity, as my effectiveness depends entirely on its accuracy.
|
||||
152
.env.example
152
.env.example
@@ -1,17 +1,73 @@
|
||||
# Hermes Agent Environment Configuration
|
||||
# Copy this file to .env and fill in your API keys
|
||||
|
||||
# =============================================================================
|
||||
# CORE SETTINGS
|
||||
# =============================================================================
|
||||
# Agent backend:
|
||||
# - openai : default Hermes-Agent loop (OpenAI function-calling via OpenAI SDK)
|
||||
# - atropos : Atroposlib ServerManager/ManagedServer-backed loop (training/env integration)
|
||||
HERMES_BACKEND=openai
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LOCAL / SELF-HOSTED OPENAI-COMPATIBLE ENDPOINTS (vLLM, SGLang, llama.cpp, etc.)
|
||||
# =============================================================================
|
||||
# For local development (matches the Atropos test env defaults):
|
||||
# ATROPOS_SERVER_BASE_URL=http://127.0.0.1:8080
|
||||
# ATROPOS_SERVER_MODEL=hermes-4-36b
|
||||
# For hosted inference (Nous Research inference API):
|
||||
ATROPOS_SERVER_BASE_URL=
|
||||
ATROPOS_SERVER_MODEL=
|
||||
ATROPOS_TOKENIZER_NAME=
|
||||
# Set this to your Nous API key (Bearer token).
|
||||
ATROPOS_SERVER_API_KEY=
|
||||
|
||||
# Debugging (prints to stdout; use with care)
|
||||
# HERMES_DEBUG_ATROPOS_REQUEST=1
|
||||
# HERMES_DEBUG_ATROPOS_RESPONSE=1
|
||||
# HERMES_DEBUG_OPENAI_REQUEST=1
|
||||
# HERMES_DEBUG_OPENAI_RESPONSE=1
|
||||
|
||||
# =============================================================================
|
||||
# LOCAL / SELF-HOSTED OPENAI-COMPATIBLE ENDPOINTS (vLLM, SGLang, llama.cpp, etc.)
|
||||
# =============================================================================
|
||||
# If you set ATROPOS_SERVER_BASE_URL or OPENAI_BASE_URL, Hermes will use it instead
|
||||
# of OpenRouter.
|
||||
#
|
||||
# Local server convenience (base URL without /v1):
|
||||
# llama.cpp example (see `Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh`):
|
||||
# ATROPOS_SERVER_BASE_URL=http://127.0.0.1:8080
|
||||
# ATROPOS_SERVER_MODEL=hermes-4-36b
|
||||
# ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B
|
||||
# ATROPOS_SERVER_API_KEY=local
|
||||
#
|
||||
# Hosted Nous inference API:
|
||||
# ATROPOS_SERVER_BASE_URL=https://inference-api.nousresearch.com
|
||||
# ATROPOS_SERVER_MODEL=Hermes-4.3-36B
|
||||
# ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B
|
||||
# ATROPOS_SERVER_API_KEY=sk-... (Bearer token)
|
||||
#
|
||||
# If you plan to run GRPO-style group sampling (e.g. `--env.group_size 4`) against
|
||||
# llama.cpp, start the server with at least that many slots, e.g.:
|
||||
# LLAMA_CPP_PARALLEL=4 Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
#
|
||||
# Generic OpenAI-compatible (base URL should include /v1):
|
||||
# OPENAI_BASE_URL=http://127.0.0.1:8080/v1
|
||||
# OPENAI_API_KEY=local
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (OpenRouter)
|
||||
# =============================================================================
|
||||
# OpenRouter provides access to many models through one API
|
||||
# All LLM calls go through OpenRouter - no direct provider keys needed
|
||||
# Get your key at: https://openrouter.ai/keys
|
||||
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
|
||||
OPENROUTER_API_KEY=
|
||||
|
||||
# Default model to use (OpenRouter format: provider/model)
|
||||
# Examples: anthropic/claude-sonnet-4, openai/gpt-4o, google/gemini-2.0-flash, zhipuai/glm-4-plus
|
||||
LLM_MODEL=anthropic/claude-sonnet-4
|
||||
# Examples: anthropic/claude-opus-4.6, openai/gpt-4o, google/gemini-2.0-flash, zhipuai/glm-4-plus
|
||||
LLM_MODEL=anthropic/claude-opus-4.6
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
@@ -46,8 +102,11 @@ TERMINAL_DOCKER_IMAGE=python:3.11
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://python:3.11
|
||||
TERMINAL_MODAL_IMAGE=python:3.11
|
||||
|
||||
# Working directory inside the container
|
||||
TERMINAL_CWD=/tmp
|
||||
# Working directory for terminal commands
|
||||
# For CLI: "." means current directory (resolved automatically from config.yaml)
|
||||
# For containers (docker/singularity/modal): absolute path inside the container
|
||||
# Usually managed by config.yaml (terminal.cwd) — uncomment to override
|
||||
# TERMINAL_CWD=.
|
||||
|
||||
# Default command timeout in seconds
|
||||
TERMINAL_TIMEOUT=60
|
||||
@@ -89,12 +148,87 @@ TERMINAL_LIFETIME_SECONDS=300
|
||||
# SUDO_PASSWORD=your_password_here
|
||||
|
||||
# =============================================================================
|
||||
# MODAL CLOUD BACKEND (Optional - for TERMINAL_ENV=modal)
|
||||
# MODAL CLOUD BACKEND (for TERMINAL_ENV=modal)
|
||||
# =============================================================================
|
||||
# Modal uses CLI authentication, not environment variables.
|
||||
# Run: pip install modal && modal setup
|
||||
# This will authenticate via browser and store credentials locally.
|
||||
# No API key needed in .env - Modal handles auth automatically.
|
||||
# Modal provides cloud sandboxes with per-second billing and auto-scaling.
|
||||
# This implementation uses a warm pool of sandboxes for cost efficiency.
|
||||
#
|
||||
# SETUP:
|
||||
# pip install modal && modal setup
|
||||
# (Authenticates via browser, stores credentials locally)
|
||||
#
|
||||
# FEATURES:
|
||||
# - Auto-scaling warm sandbox pool (no cold start after first use)
|
||||
# - Named sandbox recovery (reconnects after restart)
|
||||
# - Profile-based heterogeneous environments (CPU, GPU, different images)
|
||||
# - Server-side idle_timeout protection against orphaned sandboxes
|
||||
|
||||
# Modal app name (groups all sandboxes, used for recovery)
|
||||
TERMINAL_MODAL_APP_NAME=hermes-sandbox
|
||||
|
||||
# Default profile when none specified
|
||||
TERMINAL_MODAL_DEFAULT_PROFILE=default
|
||||
|
||||
# Profile config file (optional - YAML format, see modal_profiles.yaml)
|
||||
# TERMINAL_MODAL_PROFILES_FILE=modal_profiles.yaml
|
||||
|
||||
# --- Default Profile Settings (used if no YAML file) ---
|
||||
# These apply when no profile is specified or for the "default" profile
|
||||
TERMINAL_MODAL_IMAGE=python:3.11
|
||||
TERMINAL_MODAL_MIN_POOL=1
|
||||
TERMINAL_MODAL_MAX_POOL=5
|
||||
TERMINAL_MODAL_IDLE_TIMEOUT=120
|
||||
TERMINAL_MODAL_MAX_LIFETIME=3600
|
||||
TERMINAL_MODAL_SCALE_DOWN_IDLE=180
|
||||
|
||||
# --- Custom Profile Example: pytorch-gpu ---
|
||||
# Uncomment to enable a GPU profile for ML tasks
|
||||
# Usage: terminal_tool("python train.py", profile="pytorch-gpu")
|
||||
#
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_IMAGE=pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_GPU=T4
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MEMORY=16384
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MIN_POOL=0
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MAX_POOL=2
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_IDLE_TIMEOUT=60
|
||||
|
||||
# --- Custom Profile Example: node ---
|
||||
# Uncomment to enable a Node.js profile
|
||||
# Usage: terminal_tool("npm test", profile="node")
|
||||
#
|
||||
# TERMINAL_MODAL_PROFILE_node_IMAGE=node:18
|
||||
# TERMINAL_MODAL_PROFILE_node_MIN_POOL=0
|
||||
# TERMINAL_MODAL_PROFILE_node_MAX_POOL=3
|
||||
|
||||
# =============================================================================
|
||||
# MODAL SECRETS (Secure credential injection)
|
||||
# =============================================================================
|
||||
# Modal Secrets allow you to securely pass API keys, passwords, and other
|
||||
# sensitive data to your sandboxes without exposing them in code or logs.
|
||||
#
|
||||
# SETUP SECRETS:
|
||||
# 1. Via Dashboard: https://modal.com/secrets
|
||||
# 2. Via CLI: modal secret create my-secret KEY1=value1 KEY2=value2
|
||||
# 3. Via CLI with env: modal secret create my-secret API_KEY="$API_KEY"
|
||||
#
|
||||
# LIST SECRETS:
|
||||
# modal secret list
|
||||
#
|
||||
# DELETE SECRETS:
|
||||
# modal secret delete my-secret
|
||||
|
||||
# Global secrets applied to ALL profiles (comma-separated secret names)
|
||||
# These secrets must be created on Modal dashboard or via CLI first
|
||||
# TERMINAL_MODAL_SECRETS=my-api-keys,database-creds
|
||||
|
||||
# Per-profile secrets (comma-separated secret names)
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_SECRETS=huggingface-token,wandb-key
|
||||
|
||||
# Per-profile environment variables (semicolon-separated KEY=VALUE pairs)
|
||||
# TERMINAL_MODAL_PROFILE_default_ENV_VARS=DEBUG=1;LOG_LEVEL=info
|
||||
|
||||
# Load local .env file into sandbox (useful for development)
|
||||
# TERMINAL_MODAL_PROFILE_default_USE_DOTENV=true
|
||||
|
||||
# =============================================================================
|
||||
# BROWSER TOOL CONFIGURATION (agent-browser + Browserbase)
|
||||
|
||||
28
.gitignore
vendored
28
.gitignore
vendored
@@ -39,6 +39,34 @@ agent-browser/
|
||||
*.pem
|
||||
privvy*
|
||||
images/
|
||||
__pycache__/
|
||||
hermes_agent.egg-info/
|
||||
wandb/
|
||||
testlogs
|
||||
|
||||
# CLI config (may contain sensitive SSH paths)
|
||||
cli-config.yaml
|
||||
|
||||
.DS_Store
|
||||
|
||||
# artifacts
|
||||
*.jsonl
|
||||
*.html
|
||||
*.json
|
||||
*.log
|
||||
*.csv
|
||||
|
||||
# Singularity/Apptainer images (large binary files)
|
||||
*.sif
|
||||
|
||||
# Test files
|
||||
test_singularity_*.py
|
||||
test_*.py
|
||||
!tests/test_*.py
|
||||
|
||||
# Nomad data
|
||||
/tmp/NomadClient*/
|
||||
|
||||
*.egg-info*
|
||||
wandb
|
||||
logs
|
||||
574
README.md
574
README.md
@@ -15,11 +15,13 @@ irm https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/ins
|
||||
```
|
||||
|
||||
The installer will:
|
||||
- Clone to `~/.hermes-agent` (with submodules: mini-swe-agent, tinker-atropos)
|
||||
- Create a virtual environment
|
||||
- Install all dependencies
|
||||
- Install [uv](https://docs.astral.sh/uv/) (fast Python package manager) if not present
|
||||
- Install Python 3.11 via uv if not already available (no sudo needed)
|
||||
- Clone to `~/.hermes/hermes-agent` (with submodules: mini-swe-agent, tinker-atropos)
|
||||
- Create a virtual environment with Python 3.11
|
||||
- Install all dependencies and submodule packages
|
||||
- Symlink `hermes` into `~/.local/bin` so it works globally (no venv activation needed)
|
||||
- Run the interactive setup wizard
|
||||
- Add `hermes` to your PATH
|
||||
|
||||
After installation, reload your shell and run:
|
||||
```bash
|
||||
@@ -64,13 +66,13 @@ You need at least one LLM provider:
|
||||
| Provider | Get Key | Env Variable |
|
||||
|----------|---------|--------------|
|
||||
| **OpenRouter** (recommended) | [openrouter.ai/keys](https://openrouter.ai/keys) | `OPENROUTER_API_KEY` |
|
||||
| Anthropic | [console.anthropic.com](https://console.anthropic.com/) | `ANTHROPIC_API_KEY` |
|
||||
| OpenAI | [platform.openai.com](https://platform.openai.com/api-keys) | `OPENAI_API_KEY` |
|
||||
|
||||
|
||||
### Optional API Keys
|
||||
|
||||
| Feature | Provider | Env Variable |
|
||||
|---------|----------|--------------|
|
||||
| Custom OpenAI Endpoint (OAI or VLLM/SGLANG) | [platform.openai.com](https://platform.openai.com/api-keys) | `OPENAI_API_KEY` |
|
||||
| Web scraping | [Firecrawl](https://firecrawl.dev/) | `FIRECRAWL_API_KEY` |
|
||||
| Browser automation | [Browserbase](https://browserbase.com/) | `BROWSERBASE_API_KEY`, `BROWSERBASE_PROJECT_ID` |
|
||||
| Image generation | [FAL](https://fal.ai/) | `FAL_KEY` |
|
||||
@@ -179,8 +181,8 @@ hermes config set terminal.singularity_image ~/python.sif
|
||||
|
||||
**Modal** (serverless cloud):
|
||||
```bash
|
||||
pip install modal boto3
|
||||
modal setup # Authenticate
|
||||
uv pip install "swe-rex[modal]" # Installs swe-rex + modal + boto3
|
||||
modal setup # Authenticate with Modal
|
||||
hermes config set terminal.backend modal
|
||||
```
|
||||
|
||||
@@ -275,16 +277,19 @@ See [docs/messaging.md](docs/messaging.md) for WhatsApp and advanced setup.
|
||||
|
||||
Train language models with reinforcement learning using the Tinker API and Atropos framework.
|
||||
|
||||
> **Note:** RL training tools require **Python 3.11+** (the upstream `tinker` package has this requirement). On Python 3.10, the RL toolset will be automatically disabled — all other features work fine.
|
||||
|
||||
#### Requirements
|
||||
|
||||
1. **API Keys:** Add to `~/.hermes/.env`:
|
||||
1. **Python 3.11+** (check with `python3 --version`)
|
||||
2. **API Keys:** Add to `~/.hermes/.env`:
|
||||
```bash
|
||||
TINKER_API_KEY=your-tinker-key # Get from https://tinker-console.thinkingmachines.ai/keys
|
||||
WANDB_API_KEY=your-wandb-key # Get from https://wandb.ai/authorize
|
||||
OPENROUTER_API_KEY=your-key # Optional: for rl_test_inference
|
||||
```
|
||||
|
||||
2. **That's it!** tinker-atropos is included as a submodule - no separate installation needed.
|
||||
3. **That's it!** tinker-atropos is included as a submodule — the installer handles it automatically.
|
||||
|
||||
#### Using RL Tools
|
||||
|
||||
@@ -320,6 +325,94 @@ For extended RL workflows with longer timeouts:
|
||||
python rl_cli.py --model "anthropic/claude-sonnet-4-20250514"
|
||||
```
|
||||
|
||||
### 🧪 Atropos RL Environments
|
||||
|
||||
Hermes-Agent integrates with the [Atropos](https://github.com/NousResearch/atropos) RL framework through a layered environment system. This allows training models with reinforcement learning on agentic tasks using hermes-agent's tools.
|
||||
|
||||
#### Architecture
|
||||
|
||||
The integration has three layers:
|
||||
|
||||
| Layer | File | Purpose |
|
||||
|-------|------|---------|
|
||||
| **Agent Loop** | `environments/agent_loop.py` | Reusable multi-turn tool-calling engine (standard OpenAI spec) |
|
||||
| **Base Environment** | `environments/hermes_base_env.py` | Abstract Atropos `BaseEnv` subclass with toolset resolution, ToolContext, scoring |
|
||||
| **Concrete Envs** | `environments/terminal_test_env.py`, `environments/hermes_swe_env.py` | Task-specific environments |
|
||||
|
||||
#### Two-Phase Operation
|
||||
|
||||
- **Phase 1 (OpenAI server type)**: Works with any OpenAI-compatible endpoint (VLLM, SGLang, OpenRouter, OpenAI API). The server handles tool call parsing natively. Good for **SFT data generation**, **verifier testing**, and **evaluation**.
|
||||
- **Phase 2 (VLLM server type)**: Uses ManagedServer for exact token IDs + logprobs via `/generate`. Client-side tool call parser registry reconstructs structured `tool_calls` from raw output. Required for **full RL training**.
|
||||
|
||||
#### Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Launch VLLM with tool parser
|
||||
vllm serve YourModel --tool-parser hermes
|
||||
|
||||
# 2. Start the Atropos API server
|
||||
run-api
|
||||
|
||||
# 3. Run an environment
|
||||
python environments/terminal_test_env.py serve \
|
||||
--openai.base_url http://localhost:8000/v1 \
|
||||
--openai.model_name YourModel \
|
||||
--openai.server_type openai
|
||||
```
|
||||
|
||||
#### ToolContext (Reward Functions)
|
||||
|
||||
Reward functions receive a `ToolContext` with unrestricted access to all hermes-agent tools, scoped to the rollout's sandbox:
|
||||
|
||||
```python
|
||||
async def compute_reward(self, item, result, ctx: ToolContext) -> float:
|
||||
# Run tests in the model's terminal sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
if test["exit_code"] == 0:
|
||||
return 1.0
|
||||
# Or check a file, search the web, navigate a browser...
|
||||
return 0.0
|
||||
```
|
||||
|
||||
#### Creating Custom Environments
|
||||
|
||||
Subclass `HermesAgentBaseEnv` and implement 5 methods:
|
||||
|
||||
```python
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv
|
||||
|
||||
class MyEnv(HermesAgentBaseEnv):
|
||||
name = "my-env"
|
||||
async def setup(self): ... # Load data
|
||||
async def get_next_item(self): ... # Return next item
|
||||
def format_prompt(self, item): ... # Item -> prompt string
|
||||
async def compute_reward(self, item, result, ctx): ... # Score with ToolContext
|
||||
async def evaluate(self, *args, **kwargs): ... # Periodic eval
|
||||
|
||||
if __name__ == "__main__":
|
||||
MyEnv.cli()
|
||||
```
|
||||
|
||||
#### Toolset Distributions
|
||||
|
||||
Configure which tools are available per group, either explicitly or probabilistically:
|
||||
|
||||
```bash
|
||||
# Explicit toolsets
|
||||
--env.enabled_toolsets '["terminal","file","web"]'
|
||||
|
||||
# Probabilistic distribution (sampled per group)
|
||||
--env.distribution development
|
||||
```
|
||||
|
||||
#### Tool Call Parsers (Phase 2)
|
||||
|
||||
For VLLM server type, a parser registry extracts structured `tool_calls` from raw model output. Supported parsers: `hermes`, `mistral`, `llama3_json`, `qwen`, `deepseek_v3`, `deepseek_v3_1`, `kimi_k2`, `longcat`, `glm45`, `glm47`, `qwen3_coder`.
|
||||
|
||||
```bash
|
||||
--env.tool_call_parser hermes # Match your VLLM --tool-parser flag
|
||||
```
|
||||
|
||||
### ⏰ Scheduled Tasks (Cron)
|
||||
|
||||
Schedule tasks to run automatically:
|
||||
@@ -425,26 +518,332 @@ skills/
|
||||
|
||||
## Manual Installation
|
||||
|
||||
If you prefer not to use the installer:
|
||||
If you prefer full control over the installation process (or the quick-install script doesn't suit your environment), follow these steps to set everything up by hand.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
| Requirement | Minimum Version | Check Command | Notes |
|
||||
|-------------|----------------|---------------|-------|
|
||||
| **Git** | Any recent | `git --version` | Required |
|
||||
| **Node.js** | 18+ | `node --version` | Optional — needed for browser automation tools |
|
||||
| **ripgrep** | Any | `rg --version` | Optional — faster file search in terminal tool (falls back to grep) |
|
||||
|
||||
> **Note:** Python and pip are **not** prerequisites. The installer uses [uv](https://docs.astral.sh/uv/) to provision Python 3.11 automatically (no sudo needed). If you already have Python 3.11+ installed, uv will use it.
|
||||
|
||||
<details>
|
||||
<summary><strong>Installing prerequisites by platform</strong></summary>
|
||||
|
||||
**Ubuntu / Debian:**
|
||||
```bash
|
||||
sudo apt update && sudo apt install git
|
||||
# Optional:
|
||||
sudo apt install ripgrep nodejs npm
|
||||
```
|
||||
|
||||
**macOS (Homebrew):**
|
||||
```bash
|
||||
brew install git
|
||||
# Optional:
|
||||
brew install ripgrep node
|
||||
```
|
||||
|
||||
**Windows (WSL recommended):**
|
||||
Use the [Windows Subsystem for Linux](https://learn.microsoft.com/en-us/windows/wsl/install) and follow the Ubuntu instructions above. Alternatively, use the PowerShell quick-install script at the top of this README.
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
### Step 1: Clone the Repository
|
||||
|
||||
Clone with `--recurse-submodules` to pull the required submodules ([mini-swe-agent](https://github.com/SWE-agent/mini-swe-agent) for the terminal tool backend and [tinker-atropos](https://github.com/nousresearch/tinker-atropos) for RL training):
|
||||
|
||||
```bash
|
||||
# Clone the repository (with submodules)
|
||||
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
```
|
||||
|
||||
If you already cloned without `--recurse-submodules`, initialize them manually:
|
||||
```bash
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 2: Install uv & Create Virtual Environment
|
||||
|
||||
[uv](https://docs.astral.sh/uv/) is a fast Python package manager that can also provision Python itself. Install it and create the venv in one go:
|
||||
|
||||
```bash
|
||||
# Install uv (if not already installed)
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Create venv with Python 3.11 (uv downloads it if not present — no sudo needed)
|
||||
uv venv venv --python 3.11
|
||||
```
|
||||
|
||||
> **Tip:** You do **not** need to activate the venv to use `hermes`. The entry point has a hardcoded shebang pointing to the venv Python, so it works globally once symlinked (see Step 8). For installing packages, uv can target the venv directly via `VIRTUAL_ENV`.
|
||||
|
||||
---
|
||||
|
||||
### Step 3: Install Python Dependencies
|
||||
|
||||
Install the main package in editable mode with all optional extras (messaging, cron, CLI menus, modal):
|
||||
|
||||
```bash
|
||||
# Tell uv which venv to install into
|
||||
export VIRTUAL_ENV="$(pwd)/venv"
|
||||
|
||||
# Install with all extras
|
||||
uv pip install -e ".[all]"
|
||||
```
|
||||
|
||||
If you only want the core agent (no Telegram/Discord/cron support):
|
||||
```bash
|
||||
uv pip install -e "."
|
||||
```
|
||||
|
||||
<details>
|
||||
<summary><strong>Optional extras breakdown</strong></summary>
|
||||
|
||||
| Extra | What it adds | Install command |
|
||||
|-------|-------------|-----------------|
|
||||
| `all` | Everything below | `uv pip install -e ".[all]"` |
|
||||
| `messaging` | Telegram & Discord gateway | `uv pip install -e ".[messaging]"` |
|
||||
| `cron` | Cron expression parsing for scheduled tasks | `uv pip install -e ".[cron]"` |
|
||||
| `cli` | Terminal menu UI for setup wizard | `uv pip install -e ".[cli]"` |
|
||||
| `modal` | Modal cloud execution backend (swe-rex + modal + boto3) | `uv pip install -e ".[modal]"` |
|
||||
| `dev` | pytest & test utilities | `uv pip install -e ".[dev]"` |
|
||||
|
||||
You can combine extras: `uv pip install -e ".[messaging,cron]"`
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
### Step 4: Install Submodule Packages
|
||||
|
||||
These are local packages checked out as Git submodules. Install them in editable mode:
|
||||
|
||||
```bash
|
||||
# Terminal tool backend (required for the terminal/command-execution tool)
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
|
||||
# RL training backend
|
||||
uv pip install -e "./tinker-atropos"
|
||||
```
|
||||
|
||||
Both are optional — if you skip them, the corresponding toolsets simply won't be available.
|
||||
|
||||
---
|
||||
|
||||
### Step 5: Install Node.js Dependencies (Optional)
|
||||
|
||||
Only needed if you plan to use the **browser automation** toolset (Browserbase-powered):
|
||||
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
This installs the `agent-browser` package defined in `package.json`. Skip this step if you don't need browser tools.
|
||||
|
||||
---
|
||||
|
||||
### Step 6: Create the Configuration Directory
|
||||
|
||||
Hermes stores all user configuration in `~/.hermes/`:
|
||||
|
||||
```bash
|
||||
# Create the directory structure
|
||||
mkdir -p ~/.hermes/{cron,sessions,logs}
|
||||
|
||||
# Copy the example config file
|
||||
cp cli-config.yaml.example ~/.hermes/config.yaml
|
||||
|
||||
# Create an empty .env file for API keys
|
||||
touch ~/.hermes/.env
|
||||
```
|
||||
|
||||
Your `~/.hermes/` directory should now look like:
|
||||
```
|
||||
~/.hermes/
|
||||
├── config.yaml # Agent settings (model, terminal, toolsets, compression, etc.)
|
||||
├── .env # API keys and secrets (one per line: KEY=value)
|
||||
├── cron/ # Scheduled job data
|
||||
├── sessions/ # Messaging gateway sessions
|
||||
└── logs/ # Conversation logs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 7: Add Your API Keys
|
||||
|
||||
Open `~/.hermes/.env` in your editor and add at minimum an LLM provider key:
|
||||
|
||||
```bash
|
||||
# Required — at least one LLM provider:
|
||||
OPENROUTER_API_KEY=sk-or-v1-your-key-here
|
||||
|
||||
# Optional — enable additional tools:
|
||||
FIRECRAWL_API_KEY=fc-your-key # Web search & scraping
|
||||
BROWSERBASE_API_KEY=bb-your-key # Browser automation
|
||||
BROWSERBASE_PROJECT_ID=your-project-id # Browser automation
|
||||
FAL_KEY=your-fal-key # Image generation (FLUX)
|
||||
TINKER_API_KEY=your-tinker-key # RL training
|
||||
WANDB_API_KEY=your-wandb-key # RL training metrics
|
||||
|
||||
# Optional — messaging gateway:
|
||||
TELEGRAM_BOT_TOKEN=123456:ABC-DEF # From @BotFather
|
||||
TELEGRAM_ALLOWED_USERS=your-user-id # Comma-separated
|
||||
DISCORD_BOT_TOKEN=MTIz... # From Developer Portal
|
||||
DISCORD_ALLOWED_USERS=your-user-id # Comma-separated
|
||||
```
|
||||
|
||||
Or set them one at a time via the CLI:
|
||||
```bash
|
||||
hermes config set OPENROUTER_API_KEY sk-or-v1-your-key-here
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 8: Add `hermes` to Your PATH
|
||||
|
||||
The `hermes` entry point at `venv/bin/hermes` has a hardcoded shebang pointing to the venv's Python, so it works **without activating the venv**. The recommended approach is a symlink into `~/.local/bin` (most distributions already have this on PATH):
|
||||
|
||||
```bash
|
||||
mkdir -p ~/.local/bin
|
||||
ln -sf "$(pwd)/venv/bin/hermes" ~/.local/bin/hermes
|
||||
```
|
||||
|
||||
If `~/.local/bin` isn't on your PATH yet, add it:
|
||||
|
||||
**Bash** (`~/.bashrc`):
|
||||
```bash
|
||||
echo '' >> ~/.bashrc
|
||||
echo '# Hermes Agent' >> ~/.bashrc
|
||||
echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.bashrc
|
||||
source ~/.bashrc
|
||||
```
|
||||
|
||||
**Zsh** (`~/.zshrc`):
|
||||
```bash
|
||||
echo '' >> ~/.zshrc
|
||||
echo '# Hermes Agent' >> ~/.zshrc
|
||||
echo 'export PATH="$HOME/.local/bin:$PATH"' >> ~/.zshrc
|
||||
source ~/.zshrc
|
||||
```
|
||||
|
||||
**Fish** (`~/.config/fish/config.fish`):
|
||||
```fish
|
||||
fish_add_path $HOME/.local/bin
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Step 9: Run the Setup Wizard (Optional)
|
||||
|
||||
The interactive setup wizard walks you through configuring your API keys and preferences:
|
||||
|
||||
```bash
|
||||
hermes setup
|
||||
```
|
||||
|
||||
This is optional if you already configured `~/.hermes/.env` and `~/.hermes/config.yaml` manually in the steps above.
|
||||
|
||||
---
|
||||
|
||||
### Step 10: Verify the Installation
|
||||
|
||||
```bash
|
||||
# Check that the command is available
|
||||
hermes version
|
||||
|
||||
# Run diagnostics to verify everything is working
|
||||
hermes doctor
|
||||
|
||||
# Check your configuration
|
||||
hermes status
|
||||
|
||||
# Test with a quick query
|
||||
hermes chat -q "Hello! What tools do you have available?"
|
||||
```
|
||||
|
||||
If `hermes doctor` reports issues, it will tell you exactly what's missing and how to fix it.
|
||||
|
||||
---
|
||||
|
||||
### Quick-Reference: Manual Install (Condensed)
|
||||
|
||||
For those who just want the commands without the explanations:
|
||||
|
||||
```bash
|
||||
# Install uv (if not already installed)
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
# Clone & enter
|
||||
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
|
||||
# Run setup script
|
||||
./setup-hermes.sh
|
||||
# Create venv with Python 3.11 (uv downloads it if needed)
|
||||
uv venv venv --python 3.11
|
||||
export VIRTUAL_ENV="$(pwd)/venv"
|
||||
|
||||
# Or manually:
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -e ".[all]"
|
||||
# Install everything
|
||||
uv pip install -e ".[all]"
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
uv pip install -e "./tinker-atropos"
|
||||
npm install # optional, for browser tools
|
||||
|
||||
# Install submodules (required for terminal and RL tools)
|
||||
pip install -e "./mini-swe-agent" # Terminal tool backend
|
||||
pip install -e "./tinker-atropos" # RL training backend
|
||||
# Configure
|
||||
mkdir -p ~/.hermes/{cron,sessions,logs}
|
||||
cp cli-config.yaml.example ~/.hermes/config.yaml
|
||||
touch ~/.hermes/.env
|
||||
echo 'OPENROUTER_API_KEY=sk-or-v1-your-key' >> ~/.hermes/.env
|
||||
|
||||
hermes setup
|
||||
# Make hermes available globally (no venv activation needed)
|
||||
mkdir -p ~/.local/bin
|
||||
ln -sf "$(pwd)/venv/bin/hermes" ~/.local/bin/hermes
|
||||
|
||||
# Verify
|
||||
hermes doctor
|
||||
hermes
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Updating a Manual Installation
|
||||
|
||||
To update an existing manual install to the latest version:
|
||||
|
||||
```bash
|
||||
cd /path/to/hermes-agent
|
||||
export VIRTUAL_ENV="$(pwd)/venv"
|
||||
|
||||
# Pull latest code and submodules
|
||||
git pull origin main
|
||||
git submodule update --init --recursive
|
||||
|
||||
# Reinstall (picks up new dependencies)
|
||||
uv pip install -e ".[all]"
|
||||
uv pip install -e "./mini-swe-agent"
|
||||
uv pip install -e "./tinker-atropos"
|
||||
|
||||
# Check for new config options added since your last update
|
||||
hermes config check
|
||||
hermes config migrate # Interactively add any missing options
|
||||
```
|
||||
|
||||
### Uninstalling a Manual Installation
|
||||
|
||||
```bash
|
||||
# Remove the hermes symlink
|
||||
rm -f ~/.local/bin/hermes
|
||||
|
||||
# Remove the cloned repository
|
||||
rm -rf /path/to/hermes-agent
|
||||
|
||||
# Remove user configuration (optional — keep if you plan to reinstall)
|
||||
rm -rf ~/.hermes
|
||||
```
|
||||
|
||||
---
|
||||
@@ -596,6 +995,137 @@ All variables go in `~/.hermes/.env`. Run `hermes config set VAR value` to set t
|
||||
|
||||
---
|
||||
|
||||
## RL Training with Tinker
|
||||
|
||||
Hermes-Agent includes an RL training integration with [Tinker](https://thinkingmachines.ai/tinker/) (Thinking Machines) and [Atropos](https://github.com/NousResearch/atropos) for training language models with reinforcement learning from agent trajectories.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. **Install with Atropos extras** (includes Tinker SDK, atroposlib, torch, wandb):
|
||||
```bash
|
||||
pip install -e ".[atropos]"
|
||||
```
|
||||
|
||||
2. **Initialize the tinker-atropos submodule**:
|
||||
```bash
|
||||
git submodule update --init
|
||||
pip install -e ./tinker-atropos
|
||||
```
|
||||
|
||||
3. **Get API keys**:
|
||||
- `TINKER_API_KEY` from [Tinker Console](https://tinker-console.thinkingmachines.ai/keys) (requires billing setup)
|
||||
- `WANDB_API_KEY` from [Weights & Biases](https://wandb.ai/settings) (for metrics tracking)
|
||||
|
||||
4. **Add keys to your `.env` file**:
|
||||
```bash
|
||||
# Add to .env or ~/.hermes/.env
|
||||
TINKER_API_KEY=your_tinker_key
|
||||
WANDB_API_KEY=your_wandb_key
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
The RL training pipeline uses three processes that communicate over HTTP:
|
||||
|
||||
```
|
||||
┌──────────────────────┐ ┌─────────────────────┐ ┌────────────────────────┐
|
||||
│ Atropos Rollout API │ │ Tinker Trainer │ │ Environment │
|
||||
│ (port 8000) │◄──│ (port 8001) │◄──│ (worker) │
|
||||
│ │ │ │ │ │
|
||||
│ • Collects batches │ │ • LoRA training │ │ • Generates prompts │
|
||||
│ • Coordinates env │ │ • Inference server │ │ • Calls inference API │
|
||||
│ and trainer │ │ • Weight updates │ │ • Scores responses │
|
||||
│ │ │ • WandB logging │ │ • Sends scored batches │
|
||||
└──────────────────────┘ └─────────────────────┘ └────────────────────────┘
|
||||
```
|
||||
|
||||
### Quick Start: GSM8k Agent Training
|
||||
|
||||
This example trains a model on math problems using a Python REPL tool — the model learns to write and execute Python code to solve math:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start Atropos Rollout API
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
run-api
|
||||
|
||||
# Terminal 2: Start Tinker Trainer + Inference Server
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
python launch_training.py --config configs/gsm8k_agent.yaml
|
||||
|
||||
# Terminal 3: Start GSM8k Agent Environment
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
python tinker_atropos/environments/gsm8k_agent.py serve --config configs/gsm8k_agent.yaml
|
||||
```
|
||||
|
||||
### Available Environments
|
||||
|
||||
| Environment | File | Description |
|
||||
|------------|------|-------------|
|
||||
| `gsm8k` | `gsm8k_tinker.py` | Standard GSM8k math (no tools) |
|
||||
| `gsm8k_agent` | `gsm8k_agent.py` | GSM8k with Python REPL tool calling |
|
||||
|
||||
### Configuration
|
||||
|
||||
Configs are YAML files in `tinker-atropos/configs/` with three sections:
|
||||
|
||||
```yaml
|
||||
env: # Atropos environment settings
|
||||
group_size: 4 # Parallel rollouts per problem
|
||||
batch_size: 16 # Training batch size
|
||||
tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
max_token_length: 2048 # Max generation length
|
||||
total_steps: 20 # Training steps
|
||||
|
||||
openai: # Inference server (served by Tinker trainer)
|
||||
- model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
base_url: "http://localhost:8001/v1"
|
||||
|
||||
tinker: # Tinker training parameters
|
||||
lora_rank: 16 # LoRA rank (lower = faster, less capacity)
|
||||
learning_rate: 0.00005 # Learning rate
|
||||
max_token_trainer_length: 4096 # Max tokens for training
|
||||
wandb_project: "hermes-agent-rl"
|
||||
```
|
||||
|
||||
### RL CLI (Agent-Driven Training)
|
||||
|
||||
For interactive training management via the Hermes agent:
|
||||
|
||||
```bash
|
||||
# Interactive mode - let the agent manage training
|
||||
python rl_cli.py --interactive
|
||||
|
||||
# List available environments
|
||||
python rl_cli.py --list-environments
|
||||
|
||||
# Direct task
|
||||
python rl_cli.py "Train a model on GSM8k with tool use"
|
||||
```
|
||||
|
||||
### Sandbox Backends for Agent Environments
|
||||
|
||||
For agent environments that need isolated tool execution (e.g., SWE tasks), Hermes-Agent supports multiple sandbox backends:
|
||||
|
||||
| Backend | Use Case | Command |
|
||||
|---------|----------|---------|
|
||||
| **Nomad + Docker** | Default, local development | `--env.tool_pool_mode nomad` |
|
||||
| **Nomad + Singularity** | HPC clusters without Docker | `--env.tool_pool_mode nomad --env.driver singularity` |
|
||||
| **Modal** | Cloud-based, auto-scaling | `--env.tool_pool_mode modal` |
|
||||
|
||||
See [docs/MODAL_BACKEND.md](docs/MODAL_BACKEND.md) for Modal backend details.
|
||||
|
||||
### Cost
|
||||
|
||||
Check the [Tinker Rate Card](https://tinker-console.thinkingmachines.ai/rate-card) for available models and pricing.
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
```bash
|
||||
|
||||
Binary file not shown.
Binary file not shown.
41
atropos/Dockerfile
Normal file
41
atropos/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# Dockerfile for atropos-agent sandbox server
|
||||
# Runs inside Nomad containers to handle tool execution
|
||||
# Includes bubblewrap for namespace-based slot isolation
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# Bubblewrap for namespace isolation
|
||||
bubblewrap \
|
||||
# `script` for PTY allocation (used for stable tmux+asciinema startup)
|
||||
util-linux \
|
||||
# Git for SWE-style tasks (cloning repos)
|
||||
git \
|
||||
# tmux for stateful terminal sessions (Phase 4.7+)
|
||||
tmux \
|
||||
# Common tools agents might need
|
||||
curl \
|
||||
wget \
|
||||
jq \
|
||||
# Cleanup
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies (sandbox server + optional terminal recording)
|
||||
RUN pip install --no-cache-dir aiohttp asciinema
|
||||
|
||||
# Copy the sandbox server
|
||||
COPY sandbox_server.py /app/sandbox_server.py
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Create data directory for slot workspaces
|
||||
RUN mkdir -p /data
|
||||
|
||||
# Verify bubblewrap is installed and working
|
||||
RUN bwrap --version
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
# Default command - can be overridden by Nomad job spec
|
||||
CMD ["python", "sandbox_server.py", "--port", "8080", "--slots", "10", "--data-dir", "/data"]
|
||||
47
atropos/__init__.py
Normal file
47
atropos/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
Atropos integration for Hermes-Agent.
|
||||
|
||||
This package is intentionally optional: Hermes-Agent should work without Atropos.
|
||||
If you import anything from `atropos.*` without having `atroposlib` installed,
|
||||
we raise a clear error with install instructions.
|
||||
|
||||
Install (recommended, from repo checkout):
|
||||
uv sync --extra atropos
|
||||
|
||||
Or (pip / editable):
|
||||
pip install -e '.[atropos]'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _require_atroposlib() -> None:
|
||||
try:
|
||||
import atroposlib # noqa: F401
|
||||
except ModuleNotFoundError as exc: # pragma: no cover
|
||||
raise ModuleNotFoundError(
|
||||
"Hermes-Agent Atropos integration requires `atroposlib`, but it is not installed.\n"
|
||||
"Install it with:\n"
|
||||
" uv sync --extra atropos\n"
|
||||
"or:\n"
|
||||
" pip install -e '.[atropos]'\n"
|
||||
) from exc
|
||||
|
||||
|
||||
_require_atroposlib()
|
||||
|
||||
# Re-export the most commonly used pieces for convenience.
|
||||
# Agent imports are eager (always available).
|
||||
from .agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData # noqa: E402
|
||||
|
||||
# Env imports are lazy to avoid pulling in deleted atropos.tools dependencies.
|
||||
# Use: from atropos.envs import AgentEnv, AgentEnvConfig (if needed)
|
||||
|
||||
__all__ = [
|
||||
"AtroposAgent",
|
||||
"AgentConfig",
|
||||
"AgentResult",
|
||||
"AgentStep",
|
||||
"SequenceData",
|
||||
]
|
||||
|
||||
15
atropos/agent/__init__.py
Normal file
15
atropos/agent/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Agent abstractions for atropos-agent.
|
||||
|
||||
Provides the core AtroposAgent class for running ReACT-style agent loops.
|
||||
"""
|
||||
|
||||
from .atropos_agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData
|
||||
|
||||
__all__ = [
|
||||
"AtroposAgent",
|
||||
"AgentConfig",
|
||||
"AgentResult",
|
||||
"AgentStep",
|
||||
"SequenceData",
|
||||
]
|
||||
850
atropos/agent/atropos_agent.py
Normal file
850
atropos/agent/atropos_agent.py
Normal file
@@ -0,0 +1,850 @@
|
||||
"""
|
||||
ReACT-style agent implementation for atropos-agent.
|
||||
|
||||
This module provides the core AtroposAgent class that implements a basic
|
||||
Reason-Act-Observe loop with tool calling capabilities.
|
||||
|
||||
Uses ManagedServer from atroposlib for automatic token/logprob tracking,
|
||||
making trajectories ready for RL training.
|
||||
|
||||
The agent uses Hermes-style XML tags for tool calls:
|
||||
- <think>...</think> for reasoning
|
||||
- <tool_call>{"name": "...", "arguments": {...}}</tool_call> for actions
|
||||
- <tool_response>...</tool_response> for observations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import uuid4
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
|
||||
from ..tools import ToolCall, ToolRegistry, ToolResult
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Default system prompt with tool calling instructions.
|
||||
AGENT_SYSTEM_PROMPT = """You are a deep thinking AI. You MUST enclose your internal reasoning inside <think>...</think> tags.
|
||||
|
||||
You are a function calling AI model.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags.
|
||||
You must call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.
|
||||
You can ONLY respond without a tool call if you are totally certain you have the final answer to the user's question or task
|
||||
After calling & executing a function, you will be provided with function results within <tool_response></tool_response> XML tags.
|
||||
|
||||
Here are the available tools:
|
||||
<tools>
|
||||
{tools_json}
|
||||
</tools>
|
||||
|
||||
Use the following JSON schema for each tool call you will make:
|
||||
{"title": "FunctionCall", "type": "object", "properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"]}
|
||||
|
||||
## REQUIRED TOOL FORMAT
|
||||
|
||||
When you decide to call a tool, your assistant message MUST be:
|
||||
1) exactly one <think>...</think> block, followed by
|
||||
2) one or more <tool_call>...</tool_call> blocks,
|
||||
and NOTHING else in that message.
|
||||
|
||||
If you need to explain anything, put it inside <think>. Do NOT write natural language outside <think> or <tool_call>.
|
||||
|
||||
For each function call return a JSON object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": "<function-name>", "arguments": {"arg1": "value1"}}
|
||||
</tool_call>
|
||||
|
||||
Each <tool_call> must be on its own and contain ONLY the JSON object (no extra text).
|
||||
The JSON inside <tool_call> MUST be valid JSON with double quotes.
|
||||
|
||||
Do NOT output <tool_response> in an assistant message.
|
||||
|
||||
After you receive tool results, you may either call more tools (same required format) or provide the final answer.
|
||||
When providing the final answer, do NOT include any <tool_call> blocks.
|
||||
|
||||
## TERMINAL TOOL NOTES
|
||||
|
||||
- Commands execute under POSIX `/bin/sh` (not bash).
|
||||
- Each tool call runs in a fresh shell: environment changes (like `cd` or venv activation) do not persist across tool calls.
|
||||
- Avoid bash-only features like `source`, `[[ ... ]]`, or process substitution.
|
||||
- Prefer explicit venv usage:
|
||||
- `python -m venv .venv && . .venv/bin/activate && python -m pip install -e .` (POSIX `.` activation), or
|
||||
- `.venv/bin/python -m pip install -e .` (no activation required).
|
||||
|
||||
## ICL (examples)
|
||||
|
||||
User: Show the current directory.
|
||||
Assistant:
|
||||
<think>I should run pwd.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "pwd"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "/tmp\\n"}</tool_response>
|
||||
Assistant: /tmp
|
||||
|
||||
User: List files, then count them.
|
||||
Assistant:
|
||||
<think>I should count files.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "ls -1 | wc -l"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "3\\n"}</tool_response>
|
||||
Assistant: 3
|
||||
|
||||
User: Run pwd, then print ok (two tool calls).
|
||||
Assistant:
|
||||
<think>I should run two commands.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "pwd"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "echo ok"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "/tmp\\n"}</tool_response>
|
||||
User: <tool_response>{"success": true, "output": "ok\\n"}</tool_response>
|
||||
Assistant: ok
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Configuration for the AtroposAgent."""
|
||||
|
||||
# Generation parameters
|
||||
temperature: Optional[float] = 0.7
|
||||
# Default to "let the backend decide" (important for tool-tag completions that may be longer).
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
# Agent behavior
|
||||
max_steps: int = 50
|
||||
system_prompt: Optional[str] = None
|
||||
tool_delay_s: float = 0.0
|
||||
|
||||
# Working directory for tools
|
||||
working_dir: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceData:
|
||||
"""Token/logprob data from a single completion."""
|
||||
|
||||
full_text: str
|
||||
tokens: List[int]
|
||||
masked_tokens: List[int] # -100 for prompt, actual IDs for completion
|
||||
logprobs: List[float] # 1.0 for prompt, actual values for completion
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_sequence_node(cls, node) -> "SequenceData":
|
||||
"""Create from a ManagedServer SequenceNode."""
|
||||
return cls(
|
||||
full_text=node.full_text,
|
||||
tokens=node.tokens,
|
||||
masked_tokens=node.masked_tokens,
|
||||
logprobs=node.logprobs,
|
||||
metadata=getattr(node, "metadata", None),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStep:
|
||||
"""A single step in the agent's trajectory."""
|
||||
|
||||
step_number: int
|
||||
assistant_message: str
|
||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||||
tool_results: List[ToolResult] = field(default_factory=list)
|
||||
sequence_data: Optional[SequenceData] = None # Token data from this step
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""Result of running an agent trajectory."""
|
||||
|
||||
success: bool
|
||||
final_response: str
|
||||
steps: List[AgentStep] = field(default_factory=list)
|
||||
total_tokens: int = 0
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Full trajectory token data for RL training
|
||||
trajectory_data: Optional[SequenceData] = None
|
||||
|
||||
@property
|
||||
def num_steps(self) -> int:
|
||||
return len(self.steps)
|
||||
|
||||
@property
|
||||
def total_tool_calls(self) -> int:
|
||||
return sum(len(step.tool_calls) for step in self.steps)
|
||||
|
||||
def to_messages(self) -> List[Dict[str, str]]:
|
||||
"""Convert trajectory to messages format for logging."""
|
||||
messages = []
|
||||
for step in self.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
# Combine all tool responses
|
||||
responses = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": responses})
|
||||
return messages
|
||||
|
||||
def to_scored_data(self, score: float) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Convert to format suitable for ScoredDataGroup.
|
||||
|
||||
Args:
|
||||
score: The score for this trajectory
|
||||
|
||||
Returns:
|
||||
Dict with tokens, masks, scores suitable for training, or None if no data
|
||||
"""
|
||||
if self.trajectory_data is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"tokens": self.trajectory_data.tokens,
|
||||
"masks": self.trajectory_data.masked_tokens,
|
||||
"scores": score,
|
||||
"logprobs": self.trajectory_data.logprobs,
|
||||
}
|
||||
|
||||
|
||||
class AtroposAgent:
|
||||
"""
|
||||
A ReACT-style agent that uses LLMs with tool calling.
|
||||
|
||||
This implementation wraps ManagedServer for automatic token/logprob tracking,
|
||||
making trajectories ready for RL training.
|
||||
|
||||
Example:
|
||||
# `server` may be an Atropos `ServerManager` (recommended) or a single `APIServer`.
|
||||
# In practice, environments usually construct this via `BaseEnv`.
|
||||
server = ...
|
||||
tools = ToolRegistry()
|
||||
tools.register(BashTool())
|
||||
|
||||
agent = AtroposAgent(server=server, tools=tools)
|
||||
result = await agent.run("List the files in the current directory")
|
||||
|
||||
# Access token data for training
|
||||
if result.trajectory_data:
|
||||
print(f"Tokens: {result.trajectory_data.tokens}")
|
||||
print(f"Masked: {result.trajectory_data.masked_tokens}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server, # ServerManager or APIServer
|
||||
tools: Optional[ToolRegistry] = None,
|
||||
config: Optional[AgentConfig] = None,
|
||||
tokenizer: Optional[Any] = None,
|
||||
execute_tool: Optional[Callable[[ToolCall], Awaitable[ToolResult]]] = None,
|
||||
):
|
||||
self.server = server
|
||||
self.tools = tools or ToolRegistry()
|
||||
self.config = config or AgentConfig()
|
||||
self.tokenizer = tokenizer or getattr(server, "tokenizer", None)
|
||||
self.execute_tool = execute_tool or self.tools.execute
|
||||
|
||||
@asynccontextmanager
|
||||
async def _managed(self) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
Yield a ManagedServer-like object.
|
||||
|
||||
- If `self.server` is a ServerManager, use its `managed_server()` context manager.
|
||||
- If `self.server` is a single APIServer, wrap it in `ManagedServer` directly.
|
||||
"""
|
||||
if os.getenv("ATROPOS_BYPASS_MANAGED_SERVER") == "1":
|
||||
yield _DirectChatCompletionClient(server=self.server)
|
||||
return
|
||||
if hasattr(self.server, "managed_server"):
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
yield managed
|
||||
else:
|
||||
managed = ManagedServer(server=self.server, tokenizer=self.tokenizer)
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
managed.reset()
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build the system prompt with tool descriptions."""
|
||||
if self.config.system_prompt:
|
||||
return self.config.system_prompt
|
||||
|
||||
tools_json = self.tools.get_prompt_tool_definitions_json()
|
||||
# Avoid `str.format()` here because the prompt contains many literal `{}` braces
|
||||
# in JSON examples; we only want to substitute the single `{tools_json}` token.
|
||||
return AGENT_SYSTEM_PROMPT.replace("{tools_json}", tools_json)
|
||||
|
||||
def _infer_server_model_for_debug(self) -> Optional[str]:
|
||||
"""
|
||||
Best-effort inference of the configured model name for debug payload saving.
|
||||
|
||||
ManagedServer/server_manager typically injects `model` internally, so `chat_kwargs`
|
||||
may not contain it. For replaying saved payloads via curl, it's useful to persist it.
|
||||
"""
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
model = getattr(cfg, "model_name", None) or getattr(s0, "model_name", None)
|
||||
if isinstance(model, str) and model:
|
||||
return model
|
||||
model = getattr(self.server, "model_name", None) or getattr(self.server, "model", None)
|
||||
if isinstance(model, str) and model:
|
||||
return model
|
||||
return None
|
||||
|
||||
def _infer_server_base_url_for_debug(self) -> Optional[str]:
|
||||
"""
|
||||
Best-effort inference of the configured base_url for debug logging.
|
||||
|
||||
This is helpful when diagnosing hangs / retries at the transport layer.
|
||||
"""
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
|
||||
if isinstance(base_url, str) and base_url:
|
||||
return base_url
|
||||
base_url = getattr(self.server, "base_url", None)
|
||||
if isinstance(base_url, str) and base_url:
|
||||
return base_url
|
||||
return None
|
||||
|
||||
def _extract_response_metadata(self, response: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract lightweight, JSON-serializable metadata from an OpenAI-style response.
|
||||
|
||||
This is useful for debugging training runs, especially when ManagedServer state
|
||||
tracking is unavailable (e.g. OpenAI-compatible chat endpoints).
|
||||
"""
|
||||
meta: Dict[str, Any] = {}
|
||||
try:
|
||||
rid = getattr(response, "id", None)
|
||||
if isinstance(rid, str) and rid:
|
||||
meta["id"] = rid
|
||||
model = getattr(response, "model", None)
|
||||
if isinstance(model, str) and model:
|
||||
meta["model"] = model
|
||||
created = getattr(response, "created", None)
|
||||
if isinstance(created, int):
|
||||
meta["created"] = created
|
||||
system_fingerprint = getattr(response, "system_fingerprint", None)
|
||||
if isinstance(system_fingerprint, str) and system_fingerprint:
|
||||
meta["system_fingerprint"] = system_fingerprint
|
||||
|
||||
choices = getattr(response, "choices", None)
|
||||
if isinstance(choices, list) and choices:
|
||||
fr = getattr(choices[0], "finish_reason", None)
|
||||
if isinstance(fr, str) and fr:
|
||||
meta["finish_reason"] = fr
|
||||
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage is not None:
|
||||
if hasattr(usage, "model_dump"):
|
||||
meta["usage"] = usage.model_dump()
|
||||
elif isinstance(usage, dict):
|
||||
meta["usage"] = usage
|
||||
except Exception:
|
||||
pass
|
||||
return meta
|
||||
|
||||
def _debug_dump_request(self, *, step_num: int, chat_kwargs: Dict[str, Any]) -> None:
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST") != "1":
|
||||
return
|
||||
try:
|
||||
# Avoid dumping megabytes by default; messages can be huge.
|
||||
meta = {
|
||||
"step": step_num,
|
||||
"base_url": self._infer_server_base_url_for_debug(),
|
||||
"model": chat_kwargs.get("model") or self._infer_server_model_for_debug(),
|
||||
"chat_kwargs_keys": sorted(list(chat_kwargs.keys())),
|
||||
"n": chat_kwargs.get("n"),
|
||||
"max_tokens": chat_kwargs.get("max_tokens"),
|
||||
"temperature": chat_kwargs.get("temperature"),
|
||||
"num_messages": len(chat_kwargs.get("messages") or []),
|
||||
}
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_REQUEST ===", flush=True)
|
||||
print(meta, flush=True)
|
||||
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST_FULL") == "1":
|
||||
payload = dict(chat_kwargs)
|
||||
# Make the payload more legible and less huge.
|
||||
try:
|
||||
dumped = json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
dumped = repr(payload)
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_REQUEST_FULL ===", flush=True)
|
||||
print(dumped[:200_000], flush=True)
|
||||
|
||||
# Optional: save the FULL request payload to disk (no truncation).
|
||||
save_dir = os.getenv("ATROPOS_DEBUG_AGENT_REQUEST_SAVE_DIR")
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
payload: Dict[str, Any] = dict(chat_kwargs)
|
||||
if "model" not in payload:
|
||||
model = self._infer_server_model_for_debug()
|
||||
if model:
|
||||
payload["model"] = model
|
||||
# Use a unique filename so parallel trajectories don't clobber each other.
|
||||
fname = os.path.join(
|
||||
save_dir,
|
||||
f"atropos_agent_request_step{step_num}_{int(time.time()*1000)}_{os.getpid()}_{uuid4().hex}.json",
|
||||
)
|
||||
with open(fname, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
print(f"[AtroposAgent] saved request payload: {fname}", flush=True)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _debug_dump_response(self, *, step_num: int, response: Any) -> None:
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_RESPONSE") != "1":
|
||||
return
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_RESPONSE ===", flush=True)
|
||||
print({"step": step_num, "type": type(response).__name__}, flush=True)
|
||||
try:
|
||||
dumped = response.model_dump() # openai pydantic model
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
# Keep the dump bounded; we only need enough to see the assistant message content.
|
||||
text = str(dumped)
|
||||
print(text[:200_000], flush=True)
|
||||
|
||||
async def _chat_completion_with_debug(
|
||||
self, *, managed: Any, step_num: int, chat_kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Call `managed.chat_completion()` with optional timeout + richer failure logging.
|
||||
|
||||
Debug env vars:
|
||||
- `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`.
|
||||
- `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting.
|
||||
"""
|
||||
# Hard guardrail: never allow a single chat completion to block for too long.
|
||||
# This is essential for RL data-gen stability; long hangs should be treated as failures (score=0).
|
||||
timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S")
|
||||
timeout_s_default = 240.0
|
||||
timeout_s = float(timeout_s_raw) if timeout_s_raw else timeout_s_default
|
||||
timeout_s = min(timeout_s, 240.0)
|
||||
|
||||
wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S")
|
||||
wait_every_s = float(wait_every_raw) if wait_every_raw else None
|
||||
|
||||
async def _await_call() -> Any:
|
||||
if not wait_every_s or wait_every_s <= 0:
|
||||
return await managed.chat_completion(**chat_kwargs)
|
||||
|
||||
# Heartbeat mode: wait in chunks without cancelling the underlying request.
|
||||
# NOTE: do NOT use `asyncio.wait_for(task, timeout=...)` here, because a timeout
|
||||
# will cancel the task and surface as `CancelledError` on the next loop.
|
||||
task = asyncio.create_task(managed.chat_completion(**chat_kwargs))
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
while True:
|
||||
done, _pending = await asyncio.wait({task}, timeout=wait_every_s)
|
||||
if task in done:
|
||||
return task.result()
|
||||
|
||||
waited = time.perf_counter() - t0
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num} still waiting for chat_completion... ({waited:.1f}s)",
|
||||
flush=True,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
task.cancel()
|
||||
raise
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(_await_call(), timeout=timeout_s)
|
||||
except asyncio.TimeoutError as e:
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_TIMEOUT ===", flush=True)
|
||||
print({"step": step_num, "timeout_s": timeout_s}, flush=True)
|
||||
raise RuntimeError(f"chat_completion timed out after {timeout_s:.1f}s") from e
|
||||
except asyncio.CancelledError:
|
||||
# Treat cancellation as a hard failure rather than crashing the whole env run.
|
||||
# (Atropos/BaseEnv may cancel tasks during shutdown or retries.)
|
||||
raise RuntimeError("chat_completion cancelled") from None
|
||||
except Exception as e:
|
||||
detail: Dict[str, Any] = {
|
||||
"step": step_num,
|
||||
"exc_type": type(e).__name__,
|
||||
"exc_str": str(e),
|
||||
}
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
try:
|
||||
detail["status_code"] = e.response.status_code
|
||||
detail["response_text"] = e.response.text[:20_000]
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(e, httpx.RequestError):
|
||||
detail["request"] = repr(getattr(e, "request", None))
|
||||
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_FAILURE ===", flush=True)
|
||||
print(detail, flush=True)
|
||||
raise
|
||||
|
||||
async def run(
|
||||
self,
|
||||
task: str,
|
||||
initial_messages: Optional[List[Dict[str, str]]] = None,
|
||||
) -> AgentResult:
|
||||
"""
|
||||
Run the agent on a task using ManagedServer for token tracking.
|
||||
|
||||
Args:
|
||||
task: The task/prompt for the agent
|
||||
initial_messages: Optional additional context messages
|
||||
|
||||
Returns:
|
||||
AgentResult with the trajectory, final response, and token data
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": self._build_system_prompt()},
|
||||
]
|
||||
|
||||
if initial_messages:
|
||||
messages.extend(initial_messages)
|
||||
|
||||
messages.append({"role": "user", "content": task})
|
||||
|
||||
steps = []
|
||||
final_response = ""
|
||||
final_node = None
|
||||
final_prompt_messages: Optional[List[Dict[str, str]]] = None
|
||||
last_node = None
|
||||
last_prompt_messages: Optional[List[Dict[str, str]]] = None
|
||||
last_response_text: str = ""
|
||||
|
||||
# Use ManagedServer for automatic token tracking
|
||||
async with self._managed() as managed:
|
||||
for step_num in range(self.config.max_steps):
|
||||
# ReACT loop iteration here, just call -> tools -> observe until done (no tools called)
|
||||
try:
|
||||
# Keep a copy of the prompt messages used for this completion.
|
||||
# Useful for reconstructing tokens/masks when state tracking is unavailable.
|
||||
prompt_messages = list(messages)
|
||||
chat_kwargs: Dict[str, Any] = {"messages": messages, "n": 1}
|
||||
if self.config.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.config.max_tokens
|
||||
if self.config.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.config.temperature
|
||||
|
||||
t_req = time.perf_counter()
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num+1} chat_completion start "
|
||||
f"(messages={len(messages)}, max_tokens={self.config.max_tokens}, temp={self.config.temperature})",
|
||||
flush=True,
|
||||
)
|
||||
self._debug_dump_request(step_num=step_num + 1, chat_kwargs=chat_kwargs)
|
||||
response = await self._chat_completion_with_debug(
|
||||
managed=managed, step_num=step_num + 1, chat_kwargs=chat_kwargs
|
||||
)
|
||||
self._debug_dump_response(step_num=step_num + 1, response=response)
|
||||
response_meta = self._extract_response_metadata(response)
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
current_node = None
|
||||
if hasattr(managed, "get_state"):
|
||||
state = managed.get_state()
|
||||
nodes = state.get("nodes", [])
|
||||
current_node = nodes[-1] if nodes else None
|
||||
|
||||
except Exception as e:
|
||||
return AgentResult(
|
||||
success=False,
|
||||
final_response="",
|
||||
steps=steps,
|
||||
error=f"Generation error: {str(e)}",
|
||||
)
|
||||
|
||||
msg = response.choices[0].message
|
||||
# Some OpenAI-compatible servers populate `message.reasoning` and leave `content=""`.
|
||||
response_text = (msg.content or "") or (getattr(msg, "reasoning", None) or "")
|
||||
tool_calls = ToolCall.parse_from_text(response_text)
|
||||
last_node = current_node
|
||||
last_prompt_messages = prompt_messages
|
||||
last_response_text = response_text
|
||||
|
||||
step_sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
|
||||
if step_sequence_data is None:
|
||||
if response_meta:
|
||||
# We still want metadata for debugging even if token/logprob state tracking is unavailable.
|
||||
step_sequence_data = SequenceData(
|
||||
full_text=response_text,
|
||||
tokens=[],
|
||||
masked_tokens=[],
|
||||
logprobs=[],
|
||||
metadata=response_meta,
|
||||
)
|
||||
else:
|
||||
merged = dict(response_meta)
|
||||
node_meta = step_sequence_data.metadata
|
||||
if isinstance(node_meta, dict):
|
||||
merged.update(node_meta)
|
||||
step_sequence_data.metadata = merged or step_sequence_data.metadata
|
||||
|
||||
step = AgentStep(
|
||||
step_number=step_num + 1,
|
||||
assistant_message=response_text,
|
||||
tool_calls=tool_calls,
|
||||
sequence_data=step_sequence_data,
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
steps.append(step)
|
||||
final_response = response_text
|
||||
final_node = current_node
|
||||
final_prompt_messages = prompt_messages
|
||||
break
|
||||
|
||||
messages.append({"role": "assistant", "content": response_text})
|
||||
|
||||
tool_responses = []
|
||||
for call in tool_calls:
|
||||
result = await self.execute_tool(call)
|
||||
step.tool_results.append(result)
|
||||
tool_responses.append(result.to_xml())
|
||||
if self.config.tool_delay_s > 0:
|
||||
await asyncio.sleep(self.config.tool_delay_s)
|
||||
|
||||
steps.append(step)
|
||||
|
||||
responses_text = "\n".join(tool_responses)
|
||||
# Tool observations are represented as user content with Hermes-style tags.
|
||||
# This is compatible with most OpenAI-compatible chat APIs and ensures
|
||||
# tokenizers/chat templates include tool outputs during training.
|
||||
messages.append({"role": "user", "content": responses_text})
|
||||
|
||||
else:
|
||||
# Reached max steps without completing
|
||||
# Return a failure result but include the last observed completion so callers can
|
||||
# record the trajectory (score=0) without triggering retries.
|
||||
final_response = last_response_text or final_response
|
||||
final_node = last_node
|
||||
final_prompt_messages = last_prompt_messages
|
||||
trajectory_data = None
|
||||
if final_node:
|
||||
trajectory_data = SequenceData.from_sequence_node(final_node)
|
||||
elif final_prompt_messages is not None and self.tokenizer is not None:
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
prompt_text = self.tokenizer.apply_chat_template(
|
||||
final_prompt_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
else:
|
||||
prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in final_prompt_messages])
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True)
|
||||
output_tokens = self.tokenizer.encode(final_response, add_special_tokens=False)
|
||||
tokens = prompt_tokens + output_tokens
|
||||
masked_tokens = ([-100] * len(prompt_tokens)) + output_tokens
|
||||
logprobs = ([1.0] * len(prompt_tokens)) + ([0.0] * len(output_tokens))
|
||||
trajectory_data = SequenceData(
|
||||
full_text=f"{prompt_text}{final_response}",
|
||||
tokens=tokens,
|
||||
masked_tokens=masked_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# Preserve response metadata (if any) even on failure trajectories.
|
||||
try:
|
||||
if trajectory_data is not None and steps:
|
||||
last_step = steps[-1]
|
||||
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
|
||||
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
|
||||
except Exception:
|
||||
pass
|
||||
return AgentResult(
|
||||
success=False,
|
||||
final_response=final_response,
|
||||
steps=steps,
|
||||
error=f"Reached maximum steps ({self.config.max_steps})",
|
||||
trajectory_data=trajectory_data,
|
||||
)
|
||||
|
||||
# Build result with trajectory data
|
||||
trajectory_data = None
|
||||
if final_node:
|
||||
trajectory_data = SequenceData.from_sequence_node(final_node)
|
||||
elif final_prompt_messages is not None and self.tokenizer is not None:
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
prompt_text = self.tokenizer.apply_chat_template(
|
||||
final_prompt_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
else:
|
||||
prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in final_prompt_messages])
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True)
|
||||
output_tokens = self.tokenizer.encode(final_response, add_special_tokens=False)
|
||||
tokens = prompt_tokens + output_tokens
|
||||
masked_tokens = ([-100] * len(prompt_tokens)) + output_tokens
|
||||
logprobs = ([1.0] * len(prompt_tokens)) + ([0.0] * len(output_tokens))
|
||||
trajectory_data = SequenceData(
|
||||
full_text=f"{prompt_text}{final_response}",
|
||||
tokens=tokens,
|
||||
masked_tokens=masked_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
# Ensure trajectory_data carries the most recent metadata we observed (if any).
|
||||
try:
|
||||
if trajectory_data is not None and steps:
|
||||
last_step = steps[-1]
|
||||
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
|
||||
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
final_response=final_response,
|
||||
steps=steps,
|
||||
trajectory_data=trajectory_data,
|
||||
)
|
||||
|
||||
async def run_single_turn(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
execute_tools: bool = True,
|
||||
) -> tuple[str, List[ToolResult], Optional[SequenceData]]:
|
||||
"""
|
||||
Run a single turn of the agent (one LLM call + tool execution).
|
||||
|
||||
This is useful for integration with BaseEnv where you want more
|
||||
control over the loop.
|
||||
|
||||
Args:
|
||||
messages: The conversation history
|
||||
execute_tools: Whether to execute parsed tool calls
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, tool_results, sequence_data)
|
||||
"""
|
||||
async with self._managed() as managed:
|
||||
chat_kwargs: Dict[str, Any] = {"messages": messages, "n": 1}
|
||||
if self.config.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.config.max_tokens
|
||||
if self.config.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.config.temperature
|
||||
|
||||
self._debug_dump_request(step_num=1, chat_kwargs=chat_kwargs)
|
||||
response = await self._chat_completion_with_debug(managed=managed, step_num=1, chat_kwargs=chat_kwargs)
|
||||
self._debug_dump_response(step_num=1, response=response)
|
||||
|
||||
current_node = None
|
||||
if hasattr(managed, "get_state"):
|
||||
state = managed.get_state()
|
||||
nodes = state.get("nodes", [])
|
||||
current_node = nodes[-1] if nodes else None
|
||||
|
||||
msg = response.choices[0].message
|
||||
response_text = (msg.content or "") or (getattr(msg, "reasoning", None) or "")
|
||||
tool_results = []
|
||||
|
||||
if execute_tools:
|
||||
tool_calls = ToolCall.parse_from_text(response_text)
|
||||
for call in tool_calls:
|
||||
result = await self.execute_tool(call)
|
||||
tool_results.append(result)
|
||||
|
||||
sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
|
||||
|
||||
return response_text, tool_results, sequence_data
|
||||
|
||||
|
||||
class _DirectChatCompletionClient:
|
||||
"""
|
||||
Minimal stand-in for ManagedServer that calls the OpenAI-compatible endpoint directly.
|
||||
|
||||
This is for isolating issues where `ManagedServer.chat_completion()` hangs or misbehaves.
|
||||
It intentionally does NOT do token/logprob tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, server: Any):
|
||||
self._server = server
|
||||
|
||||
def _server_config(self) -> tuple[str, str, str]:
|
||||
# ServerManager case: first configured server.
|
||||
servers = getattr(self._server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
|
||||
api_key = getattr(cfg, "api_key", None) or getattr(s0, "api_key", None)
|
||||
model = getattr(cfg, "model_name", None) or getattr(s0, "model_name", None)
|
||||
if isinstance(base_url, str) and isinstance(api_key, str) and isinstance(model, str):
|
||||
return base_url.rstrip("/"), api_key, model
|
||||
|
||||
# APIServer-like fallback.
|
||||
base_url = getattr(self._server, "base_url", None)
|
||||
api_key = getattr(self._server, "api_key", None)
|
||||
model = getattr(self._server, "model_name", None) or getattr(self._server, "model", None)
|
||||
if isinstance(base_url, str) and isinstance(api_key, str) and isinstance(model, str):
|
||||
return base_url.rstrip("/"), api_key, model
|
||||
|
||||
raise RuntimeError("Unable to resolve server base_url/api_key/model for direct chat completion")
|
||||
|
||||
async def chat_completion(self, *, messages: List[Dict[str, str]], n: int = 1, **kwargs: Any) -> Any:
|
||||
base_url, api_key, model = self._server_config()
|
||||
url = f"{base_url}/chat/completions"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"n": n,
|
||||
}
|
||||
# Pass through common generation kwargs.
|
||||
for k in ("max_tokens", "temperature", "top_p", "presence_penalty", "frequency_penalty", "stop"):
|
||||
if k in kwargs and kwargs[k] is not None:
|
||||
payload[k] = kwargs[k]
|
||||
|
||||
timeout_s = float(os.getenv("ATROPOS_DIRECT_REQUEST_TIMEOUT_S") or "120")
|
||||
print(f"[AtroposAgent] DIRECT chat_completion POST {url} (timeout={timeout_s}s)", flush=True)
|
||||
async with httpx.AsyncClient(timeout=timeout_s) as client:
|
||||
resp = await client.post(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Return a very small object compatible with the code paths that read
|
||||
# `response.choices[0].message.content`.
|
||||
class _Msg:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self.content = d.get("content")
|
||||
self.reasoning = d.get("reasoning")
|
||||
|
||||
class _Choice:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self.message = _Msg(d.get("message") or {})
|
||||
|
||||
class _Resp:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self._d = d
|
||||
self.choices = [_Choice(c) for c in (d.get("choices") or [])]
|
||||
|
||||
def model_dump(self) -> Dict[str, Any]:
|
||||
return self._d
|
||||
|
||||
return _Resp(data)
|
||||
6
atropos/api/__init__.py
Normal file
6
atropos/api/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
FastAPI services for atropos-agent.
|
||||
|
||||
- tool_executor_server: queued/batched sandbox tool execution (Phase 4)
|
||||
"""
|
||||
|
||||
254
atropos/api/tool_executor_server.py
Normal file
254
atropos/api/tool_executor_server.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Tool Executor API (Phase 4)
|
||||
|
||||
This service provides a queued, batched execution layer on top of a ToolBackend.
|
||||
It mirrors the stateful FastAPI + app.state pattern used in:
|
||||
atropos/atroposlib/api/server.py
|
||||
|
||||
Run (dev):
|
||||
uv run uvicorn atropos_agent.api.tool_executor_server:app --host 0.0.0.0 --port 9001
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..backends.nomad_backend import NomadBackendConfig, NomadToolBackend
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.base import (
|
||||
ArtifactArchiveRequestPayload,
|
||||
ArtifactArchiveResponsePayload,
|
||||
ArtifactListRequestPayload,
|
||||
ArtifactListResponsePayload,
|
||||
ArtifactReadRequestPayload,
|
||||
ArtifactReadResponsePayload,
|
||||
ToolExecutorExecuteRequest,
|
||||
ToolExecutorReleaseRequest,
|
||||
ToolResultPayload,
|
||||
)
|
||||
from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig
|
||||
|
||||
|
||||
class ToolExecutorServerConfig(BaseModel):
|
||||
nomad_address: str = Field(default="http://localhost:4646")
|
||||
job_id: str = Field(default="atropos-sandbox-tool-executor")
|
||||
image: str = Field(default="atropos-sandbox:local")
|
||||
slots_per_container: int = Field(default=10)
|
||||
min_containers: int = Field(default=1)
|
||||
max_containers: int = Field(default=10)
|
||||
privileged: bool = Field(default=False)
|
||||
acquire_timeout_s: float = Field(default=30.0)
|
||||
|
||||
batch_window_ms: int = Field(default=20)
|
||||
max_batch_size: int = Field(default=200)
|
||||
allow_network: bool = Field(default=True)
|
||||
|
||||
tool_server_url: Optional[str] = Field(default=None)
|
||||
tool_server_token: Optional[str] = Field(default=None)
|
||||
|
||||
token: Optional[str] = Field(default=None, description="Bearer token required for requests (optional in dev).")
|
||||
|
||||
purge_job_on_shutdown: bool = Field(default=True)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "ToolExecutorServerConfig":
|
||||
# In dev, prefer loading secrets/config from the repo-local `.env` (not committed).
|
||||
try:
|
||||
from dotenv import load_dotenv # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
load_dotenv = None # type: ignore[assignment]
|
||||
if load_dotenv is not None:
|
||||
env_path = Path(__file__).resolve().parents[2] / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
def _get_bool(name: str, default: bool) -> bool:
|
||||
raw = os.getenv(name)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in {"1", "true", "yes", "y", "on"}
|
||||
|
||||
return cls(
|
||||
nomad_address=os.getenv("TOOL_EXECUTOR_NOMAD_ADDRESS", "http://localhost:4646"),
|
||||
job_id=os.getenv("TOOL_EXECUTOR_JOB_ID", "atropos-sandbox-tool-executor"),
|
||||
image=os.getenv("TOOL_EXECUTOR_IMAGE", "atropos-sandbox:local"),
|
||||
slots_per_container=int(os.getenv("TOOL_EXECUTOR_SLOTS", "10")),
|
||||
min_containers=int(os.getenv("TOOL_EXECUTOR_MIN_CONTAINERS", "1")),
|
||||
max_containers=int(os.getenv("TOOL_EXECUTOR_MAX_CONTAINERS", "10")),
|
||||
privileged=_get_bool("TOOL_EXECUTOR_PRIVILEGED", False),
|
||||
acquire_timeout_s=float(os.getenv("TOOL_EXECUTOR_ACQUIRE_TIMEOUT_S", "30.0")),
|
||||
batch_window_ms=int(os.getenv("TOOL_EXECUTOR_BATCH_WINDOW_MS", "20")),
|
||||
max_batch_size=int(os.getenv("TOOL_EXECUTOR_MAX_BATCH_SIZE", "200")),
|
||||
allow_network=_get_bool("TOOL_EXECUTOR_ALLOW_NETWORK", True),
|
||||
tool_server_url=os.getenv("TOOL_EXECUTOR_TOOL_SERVER_URL") or None,
|
||||
tool_server_token=os.getenv("TOOL_EXECUTOR_TOOL_SERVER_TOKEN") or None,
|
||||
token=os.getenv("TOOL_EXECUTOR_TOKEN") or None,
|
||||
purge_job_on_shutdown=_get_bool("TOOL_EXECUTOR_PURGE_JOB_ON_SHUTDOWN", True),
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI(title="Atropos-Agent Tool Executor")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, str]:
|
||||
return {"message": "Atropos-Agent Tool Executor"}
|
||||
|
||||
|
||||
def _check_auth(cfg: ToolExecutorServerConfig, authorization: Optional[str]) -> None:
|
||||
if not cfg.token:
|
||||
return
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header")
|
||||
if not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header")
|
||||
token = authorization.split(" ", 1)[1].strip()
|
||||
if token != cfg.token:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup() -> None:
|
||||
cfg = ToolExecutorServerConfig.from_env()
|
||||
|
||||
# Default to Atropos "full" tool surface: sandbox + external (if tool_server_url provided).
|
||||
tools: ToolRegistry = build_tool_registry(
|
||||
enabled_toolsets=["full"],
|
||||
disabled_toolsets=None,
|
||||
tool_server_url=cfg.tool_server_url,
|
||||
)
|
||||
|
||||
backend = NomadToolBackend(
|
||||
NomadBackendConfig(
|
||||
nomad_address=cfg.nomad_address,
|
||||
sandbox_job_id=cfg.job_id,
|
||||
sandbox_image=cfg.image,
|
||||
slots_per_container=cfg.slots_per_container,
|
||||
min_containers=cfg.min_containers,
|
||||
max_containers=cfg.max_containers,
|
||||
privileged=cfg.privileged,
|
||||
acquire_timeout_s=cfg.acquire_timeout_s,
|
||||
purge_job_on_start=False,
|
||||
)
|
||||
)
|
||||
await backend.start()
|
||||
|
||||
executor = ToolExecutor(
|
||||
backend=backend,
|
||||
tools=tools,
|
||||
config=ToolExecutorConfig(
|
||||
batch_window_ms=cfg.batch_window_ms,
|
||||
max_batch_size=cfg.max_batch_size,
|
||||
allow_network=cfg.allow_network,
|
||||
tool_server_url=cfg.tool_server_url,
|
||||
tool_server_token=cfg.tool_server_token,
|
||||
),
|
||||
)
|
||||
await executor.start()
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.backend = backend
|
||||
app.state.executor = executor
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def _shutdown() -> None:
|
||||
executor: Optional[ToolExecutor] = getattr(app.state, "executor", None)
|
||||
backend: Optional[NomadToolBackend] = getattr(app.state, "backend", None)
|
||||
cfg: Optional[ToolExecutorServerConfig] = getattr(app.state, "cfg", None)
|
||||
|
||||
if executor is not None:
|
||||
await executor.close()
|
||||
|
||||
if backend is not None:
|
||||
await backend.stop(purge=bool(cfg.purge_job_on_shutdown) if cfg else False)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/status")
|
||||
async def status_endpoint() -> Dict[str, Any]:
|
||||
executor: ToolExecutor = app.state.executor
|
||||
backend: NomadToolBackend = app.state.backend
|
||||
|
||||
return {
|
||||
"queue_size": executor.queue_size(),
|
||||
"total_requests": executor.total_requests,
|
||||
"total_errors": executor.total_errors,
|
||||
"pool": backend.get_stats(),
|
||||
}
|
||||
|
||||
|
||||
@app.post("/execute", response_model=ToolResultPayload)
|
||||
async def execute_tool(
|
||||
req: ToolExecutorExecuteRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
status_code: int = status.HTTP_200_OK, # noqa: B008
|
||||
) -> ToolResultPayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
result = await executor.execute(
|
||||
trajectory_id=req.trajectory_id,
|
||||
call=req.tool.to_tool_call(),
|
||||
timeout_s=req.timeout_s,
|
||||
)
|
||||
return ToolResultPayload.from_tool_result(result)
|
||||
|
||||
|
||||
@app.post("/release")
|
||||
async def release_trajectory(
|
||||
req: ToolExecutorReleaseRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> Dict[str, Any]:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
await executor.release_trajectory(req.trajectory_id, reset_workspace=req.reset_workspace)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/artifacts/read", response_model=ArtifactReadResponsePayload)
|
||||
async def artifacts_read(
|
||||
req: ArtifactReadRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactReadResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.read_artifact(req)
|
||||
|
||||
|
||||
@app.post("/artifacts/list", response_model=ArtifactListResponsePayload)
|
||||
async def artifacts_list(
|
||||
req: ArtifactListRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactListResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.list_artifacts(req)
|
||||
|
||||
|
||||
@app.post("/artifacts/archive", response_model=ArtifactArchiveResponsePayload)
|
||||
async def artifacts_archive(
|
||||
req: ArtifactArchiveRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactArchiveResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.archive_artifacts(req)
|
||||
140
atropos/api/tool_server.py
Normal file
140
atropos/api/tool_server.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
External ToolServer (Phase 4.5+).
|
||||
|
||||
This server executes tools that must NOT run inside the sandbox, typically
|
||||
because they require credentials or access to external services.
|
||||
|
||||
Run (dev):
|
||||
uv run uvicorn atropos_agent.api.tool_server:app --host 0.0.0.0 --port 9002
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.base import ToolResultPayload, ToolServerExecuteRequest
|
||||
|
||||
|
||||
class ToolServerConfig(BaseModel):
|
||||
token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Bearer token required for requests (optional in dev).",
|
||||
)
|
||||
max_concurrency: int = Field(default=16, ge=1, description="Max concurrent tool executions.")
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "ToolServerConfig":
|
||||
# In dev, prefer loading secrets from the repo-local `.env` (not committed).
|
||||
try:
|
||||
from dotenv import load_dotenv # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
load_dotenv = None # type: ignore[assignment]
|
||||
if load_dotenv is not None:
|
||||
env_path = Path(__file__).resolve().parents[2] / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
token = os.getenv("TOOL_SERVER_TOKEN") or None
|
||||
max_concurrency = int(os.getenv("TOOL_SERVER_MAX_CONCURRENCY", "16"))
|
||||
return cls(token=token, max_concurrency=max_concurrency)
|
||||
|
||||
|
||||
app = FastAPI(title="Atropos-Agent Tool Server")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, str]:
|
||||
return {"message": "Atropos-Agent Tool Server"}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup() -> None:
|
||||
cfg = ToolServerConfig.from_env()
|
||||
|
||||
# External-only registry. It will only include tools that are enabled by toolsets and
|
||||
# whose Hermes requirements/keys are satisfied in this process.
|
||||
tools: ToolRegistry = build_tool_registry(
|
||||
enabled_toolsets=["all"],
|
||||
disabled_toolsets=["terminal", "sandbox", "filesystem", "terminal_stateful", "default"],
|
||||
tool_server_url="enabled",
|
||||
)
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.tools = tools
|
||||
app.state.semaphore = asyncio.Semaphore(cfg.max_concurrency)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/tools")
|
||||
async def list_tools() -> Dict[str, Any]:
|
||||
tools: ToolRegistry = app.state.tools
|
||||
return {"tools": [s.to_dict() for s in tools.get_schemas()]}
|
||||
|
||||
|
||||
def _check_auth(cfg: ToolServerConfig, authorization: Optional[str]) -> None:
|
||||
if not cfg.token:
|
||||
return
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header")
|
||||
if not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header")
|
||||
token = authorization.split(" ", 1)[1].strip()
|
||||
if token != cfg.token:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
|
||||
@app.post("/execute", response_model=ToolResultPayload)
|
||||
async def execute_tool(
|
||||
req: ToolServerExecuteRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ToolResultPayload:
|
||||
cfg: ToolServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
tools: ToolRegistry = app.state.tools
|
||||
sem: asyncio.Semaphore = app.state.semaphore
|
||||
|
||||
tool = tools.get(req.tool.name)
|
||||
if tool is None:
|
||||
return ToolResultPayload(
|
||||
success=False,
|
||||
error=f"Unknown tool: {req.tool.name}",
|
||||
uniq_id=req.tool.uniq_id,
|
||||
)
|
||||
|
||||
async with sem:
|
||||
try:
|
||||
kwargs = dict(req.tool.arguments)
|
||||
sig = inspect.signature(tool.execute).parameters
|
||||
# Some tools can benefit from extra context.
|
||||
if req.trajectory_id and "trajectory_id" in sig:
|
||||
kwargs["trajectory_id"] = req.trajectory_id
|
||||
if req.slot_id and "slot_id" in sig:
|
||||
kwargs["slot_id"] = req.slot_id
|
||||
if req.container_addr and "container_addr" in sig:
|
||||
kwargs["container_addr"] = req.container_addr
|
||||
if "task_id" in sig:
|
||||
kwargs["task_id"] = req.trajectory_id
|
||||
result = await tool.execute(**kwargs)
|
||||
except Exception as e:
|
||||
return ToolResultPayload(
|
||||
success=False,
|
||||
error=f"Tool execution error: {e}",
|
||||
uniq_id=req.tool.uniq_id,
|
||||
)
|
||||
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = req.tool.uniq_id
|
||||
return ToolResultPayload.from_tool_result(result)
|
||||
27
atropos/backends/__init__.py
Normal file
27
atropos/backends/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import ToolBackend
|
||||
from .modal_backend import ModalSandboxConfig, ModalToolBackend
|
||||
from .nomad_backend import NomadBackendConfig, NomadToolBackend
|
||||
|
||||
|
||||
def create_tool_backend(cfg: Any) -> ToolBackend:
|
||||
mode = str(getattr(cfg, "tool_pool_mode", "nomad")).strip().lower()
|
||||
if mode == "nomad":
|
||||
return NomadToolBackend(NomadBackendConfig.from_agent_env_config(cfg))
|
||||
if mode == "modal":
|
||||
return ModalToolBackend(ModalSandboxConfig.from_agent_env_config(cfg))
|
||||
raise ValueError(f"Unknown tool_pool_mode: {mode}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolBackend",
|
||||
"create_tool_backend",
|
||||
"NomadBackendConfig",
|
||||
"NomadToolBackend",
|
||||
"ModalSandboxConfig",
|
||||
"ModalToolBackend",
|
||||
]
|
||||
|
||||
89
atropos/backends/base.py
Normal file
89
atropos/backends/base.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Backend interfaces for AgentEnv tool execution.
|
||||
|
||||
The goal of this module is to decouple ToolExecutor / AgentEnv from any single
|
||||
execution backend (Nomad/Docker today; Modal later).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
from ..slots.executor import ExecutionResult
|
||||
from ..slots.slot import Slot
|
||||
|
||||
|
||||
class ToolBackend(Protocol):
|
||||
"""
|
||||
Minimal interface required by ToolExecutor.
|
||||
|
||||
Backends provide:
|
||||
- lifecycle (start/stop)
|
||||
- slot acquisition/release (workspace affinity)
|
||||
- batched tool execution across slots
|
||||
- optional artifact helpers (for env verification / demos)
|
||||
"""
|
||||
|
||||
@property
|
||||
def default_timeout_s(self) -> Optional[float]:
|
||||
"""Default sandbox execution timeout in seconds (if any)."""
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the backend (provision workers/containers, health checks, etc)."""
|
||||
|
||||
async def stop(self, *, purge: bool = False) -> None:
|
||||
"""Stop the backend and optionally purge remote resources."""
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
"""Acquire a slot for a trajectory (workspace affinity)."""
|
||||
|
||||
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None:
|
||||
"""Release a slot back to the pool."""
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
*,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""Execute a batch of sandbox tool calls and return results in order."""
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Optional artifact helpers (supported by the Nomad sandbox-server today)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
1179
atropos/backends/modal_backend.py
Normal file
1179
atropos/backends/modal_backend.py
Normal file
File diff suppressed because it is too large
Load Diff
156
atropos/backends/nomad_backend.py
Normal file
156
atropos/backends/nomad_backend.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Nomad/Docker tool backend.
|
||||
|
||||
This backend is the current default for AgentEnv: it provisions a Nomad job
|
||||
running `sandbox_server.py` and multiplexes stateless slots inside each container.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..slots import Slot, SlotPool, SlotPoolConfig
|
||||
from ..slots.executor import ExecutionResult
|
||||
from .base import ToolBackend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NomadBackendConfig:
|
||||
nomad_address: str
|
||||
sandbox_job_id: str
|
||||
sandbox_image: str
|
||||
slots_per_container: int
|
||||
min_containers: int
|
||||
max_containers: int
|
||||
privileged: bool
|
||||
acquire_timeout_s: float
|
||||
purge_job_on_start: bool
|
||||
# Driver selection: "docker" or "singularity"
|
||||
driver: str = "docker"
|
||||
# Path to .sif file for singularity driver (required if driver="singularity")
|
||||
singularity_image: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_agent_env_config(cls, cfg: Any) -> "NomadBackendConfig":
|
||||
return cls(
|
||||
nomad_address=str(getattr(cfg, "nomad_address")),
|
||||
sandbox_job_id=str(getattr(cfg, "sandbox_job_id")),
|
||||
sandbox_image=str(getattr(cfg, "sandbox_image")),
|
||||
slots_per_container=int(getattr(cfg, "slots_per_container")),
|
||||
min_containers=int(getattr(cfg, "min_containers")),
|
||||
max_containers=int(getattr(cfg, "max_containers")),
|
||||
privileged=bool(getattr(cfg, "privileged")),
|
||||
acquire_timeout_s=float(getattr(cfg, "acquire_timeout_s")),
|
||||
purge_job_on_start=bool(getattr(cfg, "purge_job_on_start", False)),
|
||||
driver=str(getattr(cfg, "driver", "docker")),
|
||||
singularity_image=getattr(cfg, "singularity_image", None),
|
||||
)
|
||||
|
||||
|
||||
class NomadToolBackend(ToolBackend):
|
||||
def __init__(self, config: NomadBackendConfig):
|
||||
self.config = config
|
||||
self.pool = SlotPool(
|
||||
SlotPoolConfig(
|
||||
nomad_address=config.nomad_address,
|
||||
job_id=config.sandbox_job_id,
|
||||
image=config.sandbox_image,
|
||||
slots_per_container=config.slots_per_container,
|
||||
min_containers=config.min_containers,
|
||||
max_containers=config.max_containers,
|
||||
privileged=config.privileged,
|
||||
acquire_timeout=config.acquire_timeout_s,
|
||||
purge_job_on_start=bool(config.purge_job_on_start),
|
||||
driver=config.driver,
|
||||
singularity_image=config.singularity_image,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def default_timeout_s(self) -> Optional[float]:
|
||||
t = getattr(self.pool.executor, "timeout", None)
|
||||
total = getattr(t, "total", None)
|
||||
try:
|
||||
return float(total) if total is not None else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.pool.start()
|
||||
|
||||
async def stop(self, *, purge: bool = False) -> None:
|
||||
await self.pool.stop(purge_job=purge)
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
return await self.pool.acquire(trajectory_id)
|
||||
|
||||
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None:
|
||||
await self.pool.release(slot, reset_workspace=reset_workspace)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
*,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
return await self.pool.execute_batch(requests, timeout=timeout_s)
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.read_artifact(
|
||||
slot,
|
||||
path,
|
||||
encoding=encoding,
|
||||
max_bytes=max_bytes,
|
||||
include_sha256=include_sha256,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.list_artifacts(
|
||||
slot,
|
||||
path,
|
||||
recursive=recursive,
|
||||
max_entries=max_entries,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.archive_artifacts(
|
||||
slot,
|
||||
path,
|
||||
archive_format=archive_format,
|
||||
max_bytes=max_bytes,
|
||||
max_entries=max_entries,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self.pool.get_stats()
|
||||
|
||||
18
atropos/envs/__init__.py
Normal file
18
atropos/envs/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Environment implementations for atropos-agent.
|
||||
|
||||
NOTE: AgentEnv is the OLD environment system, replaced by
|
||||
environments/hermes_base_env.py (HermesAgentBaseEnv).
|
||||
Import is lazy to avoid pulling in deleted dependencies.
|
||||
"""
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy import to avoid breaking when old dependencies are removed."""
|
||||
if name in ("AgentEnv", "AgentEnvConfig"):
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
return {"AgentEnv": AgentEnv, "AgentEnvConfig": AgentEnvConfig}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = ["AgentEnv", "AgentEnvConfig"]
|
||||
537
atropos/envs/agent_env.py
Normal file
537
atropos/envs/agent_env.py
Normal file
@@ -0,0 +1,537 @@
|
||||
"""
|
||||
AgentEnv - Atropos BaseEnv extension for agent/tool-call workloads.
|
||||
|
||||
AgentEnv is responsible for starting the sandbox tool execution backend and
|
||||
providing helpers for running agent trajectories with queued/batched tool calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Tuple, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, Item, ScoredDataGroup, ScoredDataItem
|
||||
from atroposlib.envs.server_handling.server_baseline import AsyncSemWithAdaptiveWeight
|
||||
|
||||
from ..agent import AgentConfig, AgentResult, AtroposAgent
|
||||
from ..backends import ToolBackend, create_tool_backend
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig
|
||||
|
||||
# Main BaseEnv child classes. Child class THESE to get agent+tooling functionality easily.
|
||||
|
||||
class AgentEnvConfig(BaseEnvConfig):
|
||||
tool_pool_mode: str = Field(default="nomad", description="Tool execution backend ('nomad' or 'modal')")
|
||||
|
||||
allow_network: bool = Field(
|
||||
default=True,
|
||||
description="Whether sandbox bash commands may access the network (env policy).",
|
||||
)
|
||||
require_sandbox: bool = Field(
|
||||
default=False,
|
||||
description="Fail closed if bubblewrap sandboxing is unavailable/unusable for stateless sandbox tools.",
|
||||
)
|
||||
require_stateful_sandbox: bool = Field(
|
||||
default=False,
|
||||
description="Fail closed if bubblewrap/PID isolation is unavailable for stateful terminal tools (tmux).",
|
||||
)
|
||||
tool_batch_window_ms: int = Field(default=20, description="ToolExecutor batching window (ms)")
|
||||
tool_max_batch_size: int = Field(default=200, description="ToolExecutor maximum batch size")
|
||||
|
||||
# nomad mode settings. TODO: Add Modal support, split this into own config
|
||||
nomad_address: str = Field(default="http://localhost:4646", description="Nomad API address")
|
||||
sandbox_job_id: str = Field(default="atropos-sandbox-agent-env", description="Nomad job id for sandbox containers")
|
||||
sandbox_image: str = Field(default="atropos-sandbox:local", description="Docker image for sandbox containers")
|
||||
slots_per_container: int = Field(default=10, description="Nomad mode: slots per container")
|
||||
min_containers: int = Field(default=1, description="Nomad mode: minimum containers")
|
||||
max_containers: int = Field(default=10, description="Nomad mode: maximum containers")
|
||||
privileged: bool = Field(default=False, description="Nomad mode: run container privileged")
|
||||
acquire_timeout_s: float = Field(default=30.0, description="Slot acquisition timeout (seconds)")
|
||||
purge_job_on_start: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Nomad mode: stop/purge the sandbox job on startup. This is helpful in local dev and training runs "
|
||||
"to recover from previous crashes that leave the job in a restart backoff state."
|
||||
),
|
||||
)
|
||||
purge_job_on_shutdown: bool = Field(default=True, description="Nomad mode: stop/purge job on shutdown")
|
||||
|
||||
# Nomad driver selection (docker or singularity)
|
||||
driver: str = Field(
|
||||
default="docker",
|
||||
description="Nomad task driver: 'docker' (default) or 'singularity' (for HPC without sudo Docker)",
|
||||
)
|
||||
singularity_image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to .sif file for Singularity driver (required if driver='singularity')",
|
||||
)
|
||||
|
||||
# Modal mode settings
|
||||
modal_app_name: str = Field(default="atropos-sandbox", description="Modal app name prefix")
|
||||
modal_image: str = Field(default="python:3.11", description="Modal: container image")
|
||||
modal_gpu: Optional[str] = Field(default=None, description="Modal: GPU type (None, 'T4', 'A10G', 'A100', 'H100')")
|
||||
modal_cpu: float = Field(default=1.0, description="Modal: CPU cores")
|
||||
modal_memory: int = Field(default=2048, description="Modal: memory in MB")
|
||||
modal_slots_per_sandbox: int = Field(default=10, description="Modal: slots per sandbox")
|
||||
modal_min_sandboxes: int = Field(default=1, description="Modal: minimum sandboxes")
|
||||
modal_max_sandboxes: int = Field(default=5, description="Modal: maximum sandboxes")
|
||||
modal_idle_timeout: int = Field(default=120, description="Modal: server-side idle timeout (seconds)")
|
||||
modal_max_lifetime: int = Field(default=3600, description="Modal: max sandbox lifetime (seconds)")
|
||||
modal_acquire_timeout: float = Field(default=60.0, description="Modal: slot acquisition timeout (seconds)")
|
||||
modal_execution_timeout: float = Field(default=30.0, description="Modal: default command execution timeout (seconds)")
|
||||
modal_secrets: str = Field(default="", description="Modal: comma-separated list of Modal Secret names")
|
||||
modal_env_vars: str = Field(default="", description="Modal: semicolon-separated KEY=VALUE pairs for env vars")
|
||||
modal_workspace_base: str = Field(default="/data", description="Modal: workspace base directory in sandbox")
|
||||
|
||||
# basic agent defaults
|
||||
agent_max_steps: int = Field(default=50, description="Max ReACT steps per trajectory")
|
||||
agent_temperature: float = Field(default=0.7, description="Sampling temperature")
|
||||
agent_max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Max tokens per model response (default: let backend decide)",
|
||||
)
|
||||
agent_tool_delay_s: float = Field(default=0.0, description="Delay between tool calls (seconds)")
|
||||
|
||||
# tool selection
|
||||
enabled_toolsets: List[str] = Field(
|
||||
default_factory=lambda: ["default"],
|
||||
description="Toolsets to enable (Hermes-style grouping).",
|
||||
)
|
||||
disabled_toolsets: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Toolsets to disable (applied after enabled_toolsets).",
|
||||
)
|
||||
|
||||
# external ToolServer routing (Phase 4.5+)
|
||||
tool_server_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Base URL for external ToolServer (enables external tools).",
|
||||
)
|
||||
tool_server_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Bearer token for ToolServer auth (optional in dev).",
|
||||
)
|
||||
|
||||
AgentEnvConfigT = TypeVar("AgentEnvConfigT", bound="AgentEnvConfig")
|
||||
|
||||
|
||||
class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||
env_config_cls = AgentEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AgentEnvConfigT,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: AgentEnvConfigT = config
|
||||
|
||||
self.tools: ToolRegistry = self.build_tools()
|
||||
|
||||
self._backend: Optional[ToolBackend] = None
|
||||
self._tool_executor: Optional[ToolExecutor] = None
|
||||
self._tool_server_inprocess: bool = False
|
||||
self._trajectory_workspace_meta: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def build_tools(self) -> ToolRegistry:
|
||||
"""Wraps original Hermes-Agent ToolRegistry for atropos AgentEnv use.
|
||||
See Hermes-Agent docs for toolsets and available tools etc.
|
||||
"""
|
||||
return build_tool_registry(
|
||||
enabled_toolsets=self.config.enabled_toolsets or ["default"],
|
||||
disabled_toolsets=self.config.disabled_toolsets or None,
|
||||
tool_server_url=self.config.tool_server_url,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def build_task(self, item: Item) -> str:
|
||||
"""Return the user-facing task string for the agent."""
|
||||
|
||||
@abstractmethod
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
"""Return a scalar score for this trajectory."""
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self,
|
||||
item: Item,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optional hook: prepare the sandbox workspace before the agent starts.
|
||||
|
||||
Examples:
|
||||
- clone a repo and checkout a commit
|
||||
- write fixture files (e.g. images) for external-tool demos
|
||||
- pre-install dependencies
|
||||
|
||||
Default: no-op.
|
||||
"""
|
||||
_ = (item, trajectory_id, exec_tool)
|
||||
return {}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]],
|
||||
agent_result: Optional[AgentResult] = None,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
Optional hook: run in-sandbox verification before scoring.
|
||||
|
||||
Many agent envs need to execute verification inside the same trajectory
|
||||
workspace (e.g. pytest) before releasing/resetting the slot.
|
||||
|
||||
Default: calls `score_trajectory()` and returns empty metadata.
|
||||
"""
|
||||
_ = (trajectory_id, exec_tool, agent_result, workspace_meta) # default ignores in-workspace verification
|
||||
score = await self.score_trajectory(item, final_response)
|
||||
return score, {}
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
tool_delay_s=self.config.agent_tool_delay_s,
|
||||
)
|
||||
|
||||
async def setup(self) -> None:
|
||||
print(f"[AgentEnv] setup(): starting tool backend ({self.config.tool_pool_mode})", flush=True)
|
||||
await self._start_tool_backend()
|
||||
print("[AgentEnv] setup(): configuring server concurrency", flush=True)
|
||||
self._configure_server_concurrency()
|
||||
print("[AgentEnv] setup(): running env-specific setup_agent_env()", flush=True)
|
||||
await self.setup_agent_env()
|
||||
print("[AgentEnv] setup(): done", flush=True)
|
||||
|
||||
def _configure_server_concurrency(self) -> None:
|
||||
"""
|
||||
Ensure the LLM server concurrency isn't accidentally capped below `group_size`.
|
||||
|
||||
In `BaseEnv process` mode, groups are collected concurrently and if the underlying
|
||||
ServerManager/OpenAIServer semaphore is left at 1, we serialize inference even
|
||||
when `--env.group_size` is > 1.
|
||||
"""
|
||||
desired = int(getattr(self.config, "group_size", 1) or 1)
|
||||
if desired <= 1:
|
||||
return
|
||||
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if not isinstance(servers, list) or not servers:
|
||||
return
|
||||
|
||||
for s in servers:
|
||||
sem = getattr(s, "sem", None)
|
||||
eval_sem = getattr(s, "eval_sem", None)
|
||||
# Only increase; never shrink.
|
||||
if sem is not None and getattr(sem, "max_val", 0) < desired:
|
||||
s.sem = AsyncSemWithAdaptiveWeight(desired)
|
||||
if hasattr(s, "config") and hasattr(s.config, "num_max_requests_at_once"):
|
||||
s.config.num_max_requests_at_once = desired
|
||||
if eval_sem is not None and getattr(eval_sem, "max_val", 0) < desired:
|
||||
s.eval_sem = AsyncSemWithAdaptiveWeight(desired)
|
||||
if hasattr(s, "config") and hasattr(s.config, "num_requests_for_eval"):
|
||||
s.config.num_requests_for_eval = desired
|
||||
|
||||
@abstractmethod
|
||||
async def setup_agent_env(self) -> None:
|
||||
"""Subclass hook for env-specific setup."""
|
||||
|
||||
async def evaluate(self, *args, **kwargs): # noqa: ARG002
|
||||
"""
|
||||
Default eval hook (no-op).
|
||||
|
||||
Atropos BaseEnv requires an `evaluate()` implementation. Many agent envs
|
||||
won't have a meaningful evaluation path during early PoC work; they can
|
||||
override this when needed.
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def env_manager(self):
|
||||
try:
|
||||
return await super().env_manager()
|
||||
finally:
|
||||
await self.shutdown_tool_backend()
|
||||
|
||||
async def process_manager(self):
|
||||
try:
|
||||
return await super().process_manager()
|
||||
finally:
|
||||
await self.shutdown_tool_backend()
|
||||
|
||||
async def _start_tool_backend(self) -> None:
|
||||
if self._tool_executor is not None:
|
||||
return
|
||||
|
||||
tool_server_url = self.config.tool_server_url
|
||||
tool_server_client = None
|
||||
if tool_server_url == "inprocess":
|
||||
import httpx
|
||||
from ..api.tool_server import app as tool_server_app
|
||||
|
||||
await tool_server_app.router.startup()
|
||||
tool_server_client = httpx.AsyncClient(
|
||||
transport=httpx.ASGITransport(app=tool_server_app),
|
||||
base_url="http://toolserver",
|
||||
)
|
||||
tool_server_url = "http://toolserver"
|
||||
self._tool_server_inprocess = True
|
||||
|
||||
backend = create_tool_backend(self.config)
|
||||
await backend.start()
|
||||
|
||||
executor = ToolExecutor(
|
||||
backend=backend,
|
||||
tools=self.tools,
|
||||
config=ToolExecutorConfig(
|
||||
batch_window_ms=self.config.tool_batch_window_ms,
|
||||
max_batch_size=self.config.tool_max_batch_size,
|
||||
allow_network=self.config.allow_network,
|
||||
require_sandbox=self.config.require_sandbox,
|
||||
require_stateful_sandbox=self.config.require_stateful_sandbox,
|
||||
tool_server_url=tool_server_url,
|
||||
tool_server_token=self.config.tool_server_token,
|
||||
),
|
||||
)
|
||||
await executor.start()
|
||||
if tool_server_client is not None:
|
||||
executor._tool_server_client = tool_server_client # type: ignore[attr-defined]
|
||||
|
||||
self._backend = backend
|
||||
self._tool_executor = executor
|
||||
|
||||
async def shutdown_tool_backend(self) -> None:
|
||||
executor = self._tool_executor
|
||||
backend = self._backend
|
||||
inprocess_tool_server = self._tool_server_inprocess
|
||||
self._tool_executor = None
|
||||
self._backend = None
|
||||
self._tool_server_inprocess = False
|
||||
|
||||
if executor is not None:
|
||||
await executor.close()
|
||||
if backend is not None:
|
||||
await backend.stop(purge=bool(self.config.purge_job_on_shutdown))
|
||||
if inprocess_tool_server:
|
||||
from ..api.tool_server import app as tool_server_app
|
||||
|
||||
await tool_server_app.router.shutdown()
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
|
||||
if self._tool_executor is None:
|
||||
raise RuntimeError("Tool backend not started")
|
||||
|
||||
trajectory_id = str(uuid.uuid4())
|
||||
t0 = time.perf_counter()
|
||||
print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} start", flush=True)
|
||||
task = self.build_task(item)
|
||||
agent_config = self.build_agent_config(item)
|
||||
if os.getenv("ATROPOS_DEBUG_PRINT_TASK") == "1":
|
||||
print(f"Starting trajectory {trajectory_id} with task: {task}", flush=True)
|
||||
else:
|
||||
# Avoid printing the full task prompt by default (can be huge/noisy).
|
||||
one_line = " ".join(str(task).splitlines()).strip()
|
||||
preview = one_line[:240] + ("…" if len(one_line) > 240 else "")
|
||||
print(f"Starting trajectory {trajectory_id} (task preview): {preview}", flush=True)
|
||||
|
||||
async def _exec(call):
|
||||
return await self._tool_executor.execute(trajectory_id, call)
|
||||
|
||||
agent = AtroposAgent(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
tools=self.tools,
|
||||
config=agent_config,
|
||||
execute_tool=_exec,
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"[AgentEnv] tid={trajectory_id} setup_trajectory_workspace() start", flush=True)
|
||||
workspace_meta = await self.setup_trajectory_workspace(item, trajectory_id=trajectory_id, exec_tool=_exec)
|
||||
if not isinstance(workspace_meta, dict):
|
||||
workspace_meta = {}
|
||||
self._trajectory_workspace_meta[trajectory_id] = workspace_meta
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} setup_trajectory_workspace() done in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print(f"[AgentEnv] tid={trajectory_id} agent.run() start", flush=True)
|
||||
result = await agent.run(task)
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} agent.run() done in {time.perf_counter() - t0:.2f}s "
|
||||
f"success={result.success} tool_calls={result.total_tool_calls}",
|
||||
flush=True,
|
||||
)
|
||||
if not result.success or result.trajectory_data is None:
|
||||
# Do not trigger BaseEnv retries for agent failures.
|
||||
# Record the trajectory with score 0.0 so training/eval can see the failure mode.
|
||||
messages = [{"role": "system", "content": agent._build_system_prompt()}] # noqa: SLF001
|
||||
messages.append({"role": "user", "content": task})
|
||||
for step in result.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
tool_text = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": tool_text})
|
||||
|
||||
scored: ScoredDataItem = {
|
||||
"tokens": (result.trajectory_data.tokens if result.trajectory_data else []),
|
||||
"masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []),
|
||||
"scores": 0.0,
|
||||
}
|
||||
if result.trajectory_data is not None:
|
||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||
if getattr(result.trajectory_data, "metadata", None):
|
||||
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
|
||||
if self.config.include_messages:
|
||||
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
|
||||
import json
|
||||
|
||||
err = result.error or "agent_failed"
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>{json.dumps({'success': False, 'error': err})}</tool_response>",
|
||||
}
|
||||
)
|
||||
scored["messages"] = messages
|
||||
return scored, []
|
||||
|
||||
print(f"[AgentEnv] tid={trajectory_id} verify_and_score_trajectory() start", flush=True)
|
||||
score, score_metadata = await self.verify_and_score_trajectory(
|
||||
item,
|
||||
result.final_response,
|
||||
trajectory_id=trajectory_id,
|
||||
exec_tool=_exec,
|
||||
agent_result=result,
|
||||
workspace_meta=workspace_meta,
|
||||
)
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} verify_and_score_trajectory() done in {time.perf_counter() - t0:.2f}s "
|
||||
f"score={score}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": agent._build_system_prompt()}] # noqa: SLF001
|
||||
messages.append({"role": "user", "content": task})
|
||||
for step in result.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
tool_text = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": tool_text})
|
||||
|
||||
# Optional: allow env verification to attach additional messages (e.g. install logs).
|
||||
if self.config.include_messages and isinstance(score_metadata, dict):
|
||||
extra = score_metadata.get("verification_messages")
|
||||
if isinstance(extra, list):
|
||||
for m in extra:
|
||||
if isinstance(m, dict) and isinstance(m.get("role"), str) and isinstance(m.get("content"), str):
|
||||
messages.append({"role": m["role"], "content": m["content"]})
|
||||
|
||||
scored: ScoredDataItem = {
|
||||
"tokens": result.trajectory_data.tokens,
|
||||
"masks": result.trajectory_data.masked_tokens,
|
||||
"scores": score,
|
||||
}
|
||||
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
|
||||
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
|
||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||
if getattr(result.trajectory_data, "metadata", None):
|
||||
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
|
||||
if self.config.include_messages:
|
||||
scored["messages"] = messages
|
||||
|
||||
return scored, []
|
||||
finally:
|
||||
self._trajectory_workspace_meta.pop(trajectory_id, None)
|
||||
print(f"[AgentEnv] tid={trajectory_id} release_trajectory(reset_workspace=True)", flush=True)
|
||||
await self._tool_executor.release_trajectory(trajectory_id, reset_workspace=True)
|
||||
print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} done in {time.perf_counter() - t0:.2f}s", flush=True)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
tasks = [self.collect_trajectory(item) for _ in range(self.config.group_size)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
backlog: List[Item] = []
|
||||
items: List[ScoredDataItem] = []
|
||||
for scored, b in results:
|
||||
backlog.extend(b)
|
||||
if scored is not None:
|
||||
items.append(scored)
|
||||
|
||||
if len(items) != self.config.group_size:
|
||||
return None, backlog
|
||||
|
||||
group: ScoredDataGroup = ScoredDataGroup(
|
||||
tokens=[],
|
||||
masks=[],
|
||||
scores=[],
|
||||
advantages=[],
|
||||
ref_logprobs=[],
|
||||
messages=[] if self.config.include_messages else None,
|
||||
inference_logprobs=[],
|
||||
group_overrides={},
|
||||
overrides=[],
|
||||
images=[],
|
||||
generation_params=None,
|
||||
)
|
||||
|
||||
for it in items:
|
||||
group["tokens"].append(it["tokens"])
|
||||
group["masks"].append(it["masks"])
|
||||
group["scores"].append(it["scores"])
|
||||
# policy logprobs (for PPO/GRPO training) if present
|
||||
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
|
||||
if lp is not None:
|
||||
group["inference_logprobs"].append(lp)
|
||||
group["overrides"].append(it.get("overrides") or {}) # type: ignore[typeddict-item]
|
||||
if group.get("messages") is not None and it.get("messages") is not None:
|
||||
group["messages"].append(it["messages"])
|
||||
|
||||
return group, backlog
|
||||
|
||||
async def run_agent(self, task: str, *, trajectory_id: Optional[str] = None) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Run the AtroposAgent on a single task and return (final_response, debug).
|
||||
|
||||
This is a helper intended for simple environments and tests.
|
||||
"""
|
||||
if self._tool_executor is None:
|
||||
raise RuntimeError("Tool backend not started")
|
||||
|
||||
tid = trajectory_id or str(uuid.uuid4())
|
||||
|
||||
async def _exec(call):
|
||||
return await self._tool_executor.execute(tid, call)
|
||||
|
||||
agent = AtroposAgent(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
tools=self.tools,
|
||||
config=AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
),
|
||||
execute_tool=_exec,
|
||||
)
|
||||
result = await agent.run(task)
|
||||
await self._tool_executor.release_trajectory(tid, reset_workspace=True)
|
||||
return result.final_response, {"success": result.success, "error": result.error, "tool_calls": result.total_tool_calls}
|
||||
171
atropos/envs/hermes_compat_test_env.py
Normal file
171
atropos/envs/hermes_compat_test_env.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Hermes-Agent + Atropos (Nomad sandbox) compatibility smoke environment.
|
||||
|
||||
This environment is intended to validate, end-to-end:
|
||||
BaseEnv.process -> AgentEnv -> ToolExecutor (batched) -> Nomad SlotPool -> sandbox_server
|
||||
|
||||
It forces the model to use a sandbox tool by asking it to run a command that
|
||||
generates a high-entropy token inside the sandbox, then repeat it exactly.
|
||||
|
||||
Run (process mode):
|
||||
uv run python -m atropos.envs.hermes_compat_test_env process --env.use_wandb false --env.total_steps 2 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _forced_tool_item() -> Item:
|
||||
# Use double quotes in the shell command and show JSON escaping explicitly.
|
||||
# This avoids invalid JSON escapes like `\\'` (not valid JSON) that some models produce.
|
||||
cmd = 'python -c "import secrets; print(secrets.token_hex(16))"'
|
||||
return {
|
||||
"command": cmd,
|
||||
"prompt": (
|
||||
"You are acting as an agent inside a sandboxed environment.\n"
|
||||
"You MUST use the terminal tool to execute commands.\n"
|
||||
"Run this exact command:\n"
|
||||
f"{cmd}\n"
|
||||
"When you call the tool, use valid JSON inside <tool_call>. Example:\n"
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": '
|
||||
'"python -c \\\\"import secrets; print(secrets.token_hex(16))\\\\""}}'
|
||||
"</tool_call>\n"
|
||||
"Then respond with EXACTLY what it printed (the hex token) and nothing else.\n"
|
||||
"Do not guess. Do not explain."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class HermesCompatTestEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class HermesCompatTestEnv(AgentEnv[HermesCompatTestEnvConfig]):
|
||||
name = "hermes_compat_test_env"
|
||||
env_config_cls = HermesCompatTestEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HermesCompatTestEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[HermesCompatTestEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = HermesCompatTestEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=2,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
# Tooling: sandbox-only terminal.
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
# Default to Nomad sandboxing; users can override via --env.* args.
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
# In local dev it's common for a previous crash to leave the job in backoff.
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return _forced_tool_item()
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# Avoid imposing max_tokens by default; tool-tag responses can be long for some models.
|
||||
return AgentConfig(
|
||||
max_steps=min(8, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Scoring happens in verify_and_score_trajectory so we can inspect tool results.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
observed: str = ""
|
||||
tool_ok = False
|
||||
for step in agent_result.steps:
|
||||
for res in step.tool_results:
|
||||
if not res.success:
|
||||
return 0.0, {"error": res.error, "output": res.output}
|
||||
out = (res.output or "").strip()
|
||||
if out:
|
||||
observed = out.splitlines()[-1].strip()
|
||||
tool_ok = True
|
||||
|
||||
final = (final_response or "").strip()
|
||||
score = 1.0 if tool_ok and agent_result.total_tool_calls > 0 and observed and final == observed else 0.0
|
||||
return score, {"observed": observed, "tool_calls": agent_result.total_tool_calls, "command": item.get("command")}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HermesCompatTestEnv.cli()
|
||||
172
atropos/envs/sandbox_terminal_smoke_env.py
Normal file
172
atropos/envs/sandbox_terminal_smoke_env.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Nomad sandbox terminal smoke environment (training-oriented).
|
||||
|
||||
Validates, end-to-end:
|
||||
BaseEnv.process -> AgentEnv -> ToolExecutor (batched) -> Nomad SlotPool -> sandbox_server
|
||||
|
||||
It forces the model to use a sandbox tool by asking it to run a command that
|
||||
generates a high-entropy token inside the sandbox, then repeat it exactly.
|
||||
|
||||
Run (process mode):
|
||||
uv run python -m atropos.envs.sandbox_terminal_smoke_env process --env.use_wandb false --env.total_steps 2 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
STRICT_TOOLCALL_SYSTEM_PROMPT = None
|
||||
|
||||
|
||||
def _forced_tool_item() -> Item:
|
||||
# Use double quotes in the shell command and show JSON escaping explicitly.
|
||||
# This avoids invalid JSON escapes like `\\'` (not valid JSON) that some models produce.
|
||||
cmd = 'python -c "import secrets; print(secrets.token_hex(16))"'
|
||||
return {
|
||||
"command": cmd,
|
||||
"prompt": (
|
||||
"You MUST use the terminal tool.\n"
|
||||
"Run this exact command:\n"
|
||||
f"{cmd}\n"
|
||||
"When you call the tool, use valid JSON inside <tool_call>. Example:\n"
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": '
|
||||
'"python -c \\\\"import secrets; print(secrets.token_hex(16))\\\\""}}'
|
||||
"</tool_call>\n"
|
||||
"Then respond with EXACTLY what it printed (the hex token) and nothing else.\n"
|
||||
"Do not guess. Do not explain."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class SandboxTerminalSmokeEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SandboxTerminalSmokeEnv(AgentEnv[SandboxTerminalSmokeEnvConfig]):
|
||||
name = "sandbox_terminal_smoke_env"
|
||||
env_config_cls = SandboxTerminalSmokeEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SandboxTerminalSmokeEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SandboxTerminalSmokeEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SandboxTerminalSmokeEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=2,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
# Tooling: sandbox-only terminal.
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
# Default to Nomad sandboxing; users can override via --env.* args.
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return _forced_tool_item()
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# Avoid imposing max_tokens by default; tool-tag responses can be long for some models.
|
||||
return AgentConfig(
|
||||
max_steps=min(8, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
system_prompt=STRICT_TOOLCALL_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Scoring happens in verify_and_score_trajectory so we can inspect tool results.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
observed: str = ""
|
||||
tool_ok = False
|
||||
for step in agent_result.steps:
|
||||
for res in step.tool_results:
|
||||
if not res.success:
|
||||
return 0.0, {"error": res.error, "output": res.output}
|
||||
out = (res.output or "").strip()
|
||||
if out:
|
||||
observed = out.splitlines()[-1].strip()
|
||||
tool_ok = True
|
||||
|
||||
final = (final_response or "").strip()
|
||||
score = 1.0 if tool_ok and agent_result.total_tool_calls > 0 and observed and final == observed else 0.0
|
||||
return score, {"observed": observed, "tool_calls": agent_result.total_tool_calls, "command": item.get("command")}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SandboxTerminalSmokeEnv.cli()
|
||||
418
atropos/envs/swe_smith_oracle_env.py
Normal file
418
atropos/envs/swe_smith_oracle_env.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
SWE-smith-oracle environment.
|
||||
|
||||
This environment is intentionally minimal:
|
||||
- prepares a sandbox workspace by cloning a public GitHub repo at `base_commit`
|
||||
- runs an AtroposAgent tool loop to apply a fix
|
||||
- verifies by running pytest nodeids from the dataset (reward = pass/fail)
|
||||
- Python only (no multi-language support currently, need to properly bauild & add to dropbox)
|
||||
- TODO: Get the other nonpython sandboxes up and running, then add a config knob to switch between them per row
|
||||
- oh and add to dockerhub
|
||||
|
||||
Dataset: NousResearch/SWE-smith-oracle (train; does NOT use SWE-bench eval set).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
|
||||
class SweSmithOracleEnvConfig(AgentEnvConfig):
|
||||
dataset_name: str = Field(default="NousResearch/SWE-smith-oracle")
|
||||
dataset_split: str = Field(default="train")
|
||||
max_items: int = Field(default=0, description="0 = no limit")
|
||||
shuffle: bool = Field(default=True)
|
||||
seed: int = Field(default=0)
|
||||
|
||||
python_only: bool = Field(default=True, description="Filter to Python-evaluable rows")
|
||||
score_include_fail_to_pass: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (default), score tests on PASS_TO_PASS ∪ FAIL_TO_PASS. "
|
||||
"Disable to only run PASS_TO_PASS (faster but weaker signal)."
|
||||
),
|
||||
)
|
||||
|
||||
prompt_mode: str = Field(
|
||||
default="problem_statement",
|
||||
description="Task prompt content: 'problem_statement' (fast) or 'problem_statement+text' (slower, includes dataset 'text').",
|
||||
)
|
||||
|
||||
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
||||
install_timeout_s: float = Field(default=600.0)
|
||||
test_timeout_s: float = Field(default=600.0)
|
||||
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||
"""
|
||||
SWE-smith-oracle AgentEnv.
|
||||
|
||||
This is designed for benchmarking multiplexed slot execution vs naive container-per-trajectory.
|
||||
"""
|
||||
|
||||
name = "swe_smith_oracle_env"
|
||||
env_config_cls = SweSmithOracleEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SweSmithOracleEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._dataset = None
|
||||
self._indices: List[int] = []
|
||||
self._cursor = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
||||
# Defaults for running the env via CLI in offline `process` mode.
|
||||
# Override via env vars or `--env.*` flags as needed.
|
||||
base_url_raw = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
base_url = base_url_raw.rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url}/v1"
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SweSmithOracleEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=1,
|
||||
max_token_length=8192,
|
||||
inference_weight=1.0,
|
||||
wandb_name="swe_smith_oracle",
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"),
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
from datasets import load_dataset
|
||||
|
||||
t0 = time.perf_counter()
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})",
|
||||
flush=True,
|
||||
)
|
||||
ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self._dataset = ds
|
||||
|
||||
indices: List[int] = []
|
||||
for idx in range(len(ds)):
|
||||
row = ds[idx]
|
||||
if self.config.python_only and not self._is_python_row(row):
|
||||
continue
|
||||
indices.append(idx)
|
||||
|
||||
if self.config.shuffle:
|
||||
rnd = random.Random(self.config.seed)
|
||||
rnd.shuffle(indices)
|
||||
|
||||
if self.config.max_items and self.config.max_items > 0:
|
||||
indices = indices[: self.config.max_items]
|
||||
|
||||
self._indices = indices
|
||||
self._cursor = 0
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loaded {len(self._indices)} items from {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _is_python_row(self, row: Dict[str, Any]) -> bool:
|
||||
nodeids = row.get("PASS_TO_PASS")
|
||||
if not isinstance(nodeids, list) or not nodeids:
|
||||
return False
|
||||
for nid in nodeids:
|
||||
if not isinstance(nid, str) or ".py::" not in nid:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
print(f"[SweSmithOracleEnv] get_next_item() cursor={self._cursor}/{len(self._indices)}", flush=True)
|
||||
if not self._dataset or not self._indices:
|
||||
raise RuntimeError("Dataset not initialized (did setup() run?)")
|
||||
if self._cursor >= len(self._indices):
|
||||
self._cursor = 0
|
||||
idx = self._indices[self._cursor]
|
||||
self._cursor += 1
|
||||
return dict(self._dataset[idx])
|
||||
|
||||
def _repo_name(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
if isinstance(repo, str) and "/" in repo:
|
||||
return repo.split("/")[-1]
|
||||
return "repo"
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
base_commit = item.get("base_commit") or ""
|
||||
problem = str(item.get("problem_statement") or "")
|
||||
context = str(item.get("text") or "")
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
tests_list = "\n".join(f"- {t}" for t in nodeids)
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
tests_block = (
|
||||
"Run these tests to verify:\n"
|
||||
f"{tests_list}\n\n"
|
||||
"When done, briefly describe what you changed and confirm tests pass."
|
||||
)
|
||||
|
||||
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
|
||||
if prompt_mode not in {"problem_statement", "problem_statement+text"}:
|
||||
raise ValueError(
|
||||
f"Invalid prompt_mode={self.config.prompt_mode!r}. "
|
||||
"Expected 'problem_statement' or 'problem_statement+text'."
|
||||
)
|
||||
|
||||
context_block = ""
|
||||
if prompt_mode == "problem_statement+text" and context:
|
||||
# Note: We intentionally do NOT truncate/cap here. This mode is for debugging / richer prompts and can be slow.
|
||||
context_block = f"\nAdditional context:\n{context}\n"
|
||||
|
||||
return (
|
||||
"You are a senior software engineer. Fix the repository so the specified tests pass.\n\n"
|
||||
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
||||
f"Workspace path: ./{repo_dir}\n\n"
|
||||
"Constraints:\n"
|
||||
"- You MUST use the terminal tool to inspect, edit, and verify the repository. Do not respond with a patch file.\n"
|
||||
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
|
||||
"- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n"
|
||||
"- Use non-interactive commands only.\n\n"
|
||||
"- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n"
|
||||
" Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n"
|
||||
"Problem statement:\n"
|
||||
f"{problem}\n\n"
|
||||
f"{context_block}\n"
|
||||
f"{tests_block}"
|
||||
)
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# SWE tasks are longer than the simple test env.
|
||||
return AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
tool_delay_s=self.config.agent_tool_delay_s,
|
||||
)
|
||||
|
||||
async def setup_trajectory_workspace(self, item: Item, *, trajectory_id: str, exec_tool) -> Dict[str, Any]:
|
||||
t0 = time.perf_counter()
|
||||
repo = item.get("repo")
|
||||
base_commit = item.get("base_commit")
|
||||
instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id")
|
||||
if not isinstance(repo, str) or not isinstance(base_commit, str):
|
||||
raise RuntimeError("Invalid dataset row: missing repo/base_commit")
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git"
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Repo setup strategy:
|
||||
# - Maintain a shared, per-container bare repo cache under /data/repo_cache
|
||||
# - For each trajectory, create an isolated git worktree under the slot workspace
|
||||
# This avoids cloning/fetching full repos per trajectory and is crucial for multiplexing.
|
||||
|
||||
def _repo_cache_slug(repo_name: str) -> str:
|
||||
return repo_name.replace("/", "__")
|
||||
|
||||
repo_slug = _repo_cache_slug(repo)
|
||||
cache_root = "/data/repo_cache"
|
||||
bare_repo = f"{cache_root}/{repo_slug}.git"
|
||||
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
|
||||
|
||||
# Use flock to serialize operations that mutate the shared bare repo (fetch/worktree).
|
||||
# util-linux (flock) is included in the sandbox image.
|
||||
worktree_cmd = (
|
||||
"set -e; "
|
||||
f"rm -rf {repo_dir}; "
|
||||
f"mkdir -p {cache_root}/.locks; "
|
||||
f": > {lock_file}; "
|
||||
f"flock -x {lock_file} sh -lc '"
|
||||
f"set -e; "
|
||||
"export GIT_TERMINAL_PROMPT=0; "
|
||||
"export GIT_LFS_SKIP_SMUDGE=1; "
|
||||
f"if [ ! -d \"{bare_repo}\" ]; then "
|
||||
f" git init --bare \"{bare_repo}\"; "
|
||||
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
|
||||
"fi; "
|
||||
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
|
||||
f"git -C \"{bare_repo}\" worktree prune || true; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
|
||||
"fi; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --prune origin; "
|
||||
"fi; "
|
||||
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
|
||||
"'"
|
||||
)
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
|
||||
res = await exec_tool(
|
||||
ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": worktree_cmd, "timeout": self.config.install_timeout_s},
|
||||
)
|
||||
)
|
||||
if not res.success:
|
||||
raise RuntimeError(
|
||||
"git worktree setup failed "
|
||||
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): worktree ready in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
||||
|
||||
def _tests_for_item(self, item: Item) -> List[str]:
|
||||
tests: List[str] = []
|
||||
if self.config.score_include_fail_to_pass:
|
||||
for key in ("PASS_TO_PASS", "FAIL_TO_PASS"):
|
||||
nodeids = item.get(key)
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
else:
|
||||
nodeids = item.get("PASS_TO_PASS")
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
# Stable order for reproducibility.
|
||||
return sorted(dict.fromkeys(tests))
|
||||
|
||||
def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]:
|
||||
chunks: List[List[str]] = []
|
||||
for i in range(0, len(nodeids), max_per_chunk):
|
||||
chunks.append(nodeids[i : i + max_per_chunk])
|
||||
return chunks
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str, # noqa: ARG002
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
agent_result=None,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
_ = trajectory_id
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
# Training correctness: do not reward trajectories that never actually used tools.
|
||||
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): no tool calls; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"error": "No tool calls were made by the agent",
|
||||
}
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
return 0.0, {"error": "No tests provided"}
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): ensuring venv + deps", flush=True)
|
||||
setup_cmd = (
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest"
|
||||
)
|
||||
setup_res = await exec_tool(
|
||||
ToolCall(name="terminal", arguments={"command": setup_cmd, "timeout": self.config.install_timeout_s})
|
||||
)
|
||||
verification_messages = [{"role": "user", "content": setup_res.to_xml()}]
|
||||
if not setup_res.success:
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"phase": "install",
|
||||
"error": setup_res.error,
|
||||
"output": setup_res.output,
|
||||
"verification_messages": verification_messages,
|
||||
}
|
||||
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}"
|
||||
res = await exec_tool(
|
||||
ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": cmd, "timeout": self.config.test_timeout_s},
|
||||
)
|
||||
)
|
||||
verification_messages.append({"role": "user", "content": res.to_xml()})
|
||||
if not res.success:
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"phase": "pytest",
|
||||
"failed_chunk": chunk_idx,
|
||||
"error": res.error,
|
||||
"output": res.output,
|
||||
"verification_messages": verification_messages,
|
||||
}
|
||||
|
||||
return 1.0, {"verification_mode": "dataset_tests", "passed": True, "verification_messages": verification_messages}
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Not used; scoring happens in verify_and_score_trajectory.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SweSmithOracleEnv.cli()
|
||||
217
atropos/envs/test_env.py
Normal file
217
atropos/envs/test_env.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Simple test environment for validating the atropos-agent setup.
|
||||
|
||||
This environment uses a local OpenAI-compatible server for LLM testing to verify:
|
||||
- BaseEnv extension works correctly
|
||||
- API communication via OpenAI-compatible endpoint
|
||||
- Basic trajectory collection
|
||||
|
||||
This is a minimal environment for testing, not production use.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
Item,
|
||||
)
|
||||
|
||||
from ..agent import AgentConfig
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Simple test prompts for validation
|
||||
TEST_PROMPTS = [
|
||||
{
|
||||
"prompt": "What is 2 + 2? Answer with just the number.",
|
||||
"expected": "4",
|
||||
},
|
||||
{
|
||||
"prompt": "What is the capital of France? Answer with just the city name.",
|
||||
"expected": "Paris",
|
||||
},
|
||||
{
|
||||
"prompt": "What color is the sky on a clear day? Answer with just the color.",
|
||||
"expected": "Blue",
|
||||
},
|
||||
{
|
||||
"prompt": "How many days are in a week? Answer with just the number.",
|
||||
"expected": "7",
|
||||
},
|
||||
{
|
||||
"prompt": "What is 10 * 5? Answer with just the number.",
|
||||
"expected": "50",
|
||||
},
|
||||
]
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a helpful assistant. Answer questions concisely and directly. "
|
||||
"When asked for a simple answer, provide just that answer without explanation."
|
||||
)
|
||||
|
||||
|
||||
class SimpleTestEnvConfig(AgentEnvConfig):
|
||||
"""Configuration for the simple test environment."""
|
||||
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible server (without /v1)",
|
||||
)
|
||||
server_model: str = Field(
|
||||
default="hermes-4-36b",
|
||||
description="Model name",
|
||||
)
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SimpleTestEnv(AgentEnv[SimpleTestEnvConfig]):
|
||||
"""
|
||||
A simple test environment to validate the atropos-agent setup.
|
||||
|
||||
Uses a local OpenAI-compatible LLM endpoint with basic question-answering tasks.
|
||||
Scoring is based on whether the response contains the expected answer.
|
||||
"""
|
||||
|
||||
name = "simple_test_env"
|
||||
env_config_cls = SimpleTestEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SimpleTestEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.iter = 0
|
||||
self.test_prompts = TEST_PROMPTS
|
||||
self.percent_correct_buffer: List[float] = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SimpleTestEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Initialize configuration with local server settings from environment variables.
|
||||
"""
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SimpleTestEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=4,
|
||||
use_wandb=False, # Disable wandb for simple testing
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=10,
|
||||
batch_size=16,
|
||||
steps_per_eval=5,
|
||||
max_token_length=2048,
|
||||
inference_weight=1.0,
|
||||
wandb_name="simple_test",
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
)
|
||||
|
||||
# OpenAI-compatible servers typically expose chat completions at /v1.
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=4,
|
||||
num_requests_for_eval=8,
|
||||
timeout=120, # Local models may be slower
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self):
|
||||
"""Setup the environment - load test data."""
|
||||
print(f"SimpleTestEnv setup complete. {len(self.test_prompts)} test prompts loaded.")
|
||||
print(f"Using server at: {self.config.server_base_url}")
|
||||
print(f"Model: {self.config.server_model}")
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""Get the next test prompt."""
|
||||
item = self.test_prompts[self.iter % len(self.test_prompts)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return item["prompt"]
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=5,
|
||||
temperature=0.7,
|
||||
max_tokens=256,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
expected = item["expected"].lower()
|
||||
response_lower = (final_response or "").lower()
|
||||
score = 1.0 if expected in response_lower else 0.0
|
||||
self.percent_correct_buffer.append(score)
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Simple evaluation - run through all test prompts once.
|
||||
"""
|
||||
correct = 0
|
||||
total = len(self.test_prompts)
|
||||
|
||||
for item in self.test_prompts:
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": item["prompt"]},
|
||||
]
|
||||
|
||||
response = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=256,
|
||||
temperature=0.0, # Greedy for eval
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_text = response.choices[0].message.content or ""
|
||||
expected = item["expected"].lower()
|
||||
|
||||
if expected in response_text.lower():
|
||||
correct += 1
|
||||
|
||||
accuracy = correct / total
|
||||
print(f"Evaluation: {correct}/{total} = {accuracy:.2%} accuracy")
|
||||
return {"eval_accuracy": accuracy}
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log metrics (simplified for testing)."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.percent_correct_buffer:
|
||||
avg_correct = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
|
||||
wandb_metrics["train/percent_correct"] = avg_correct
|
||||
print(f"Train accuracy: {avg_correct:.2%}")
|
||||
self.percent_correct_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Allow running as CLI
|
||||
SimpleTestEnv.cli()
|
||||
165
atropos/envs/toolserver_smoke_env.py
Normal file
165
atropos/envs/toolserver_smoke_env.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
ToolServer routing smoke environment.
|
||||
|
||||
Validates that:
|
||||
- sandbox tools run through Nomad SlotPool (terminal -> bash in sandbox)
|
||||
- external tools run through ToolServer (skills_list)
|
||||
|
||||
This env uses ToolServer in-process by default (`tool_server_url="inprocess"`),
|
||||
so it is self-contained for local testing.
|
||||
|
||||
Run:
|
||||
uv run python -m atropos.envs.toolserver_smoke_env process --env.use_wandb false --env.total_steps 1 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class ToolServerSmokeEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class ToolServerSmokeEnv(AgentEnv[ToolServerSmokeEnvConfig]):
|
||||
name = "toolserver_smoke_env"
|
||||
env_config_cls = ToolServerSmokeEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ToolServerSmokeEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[ToolServerSmokeEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = ToolServerSmokeEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
enabled_toolsets=["terminal", "skills"],
|
||||
disabled_toolsets=[],
|
||||
# Self-contained ToolServer for local smoke.
|
||||
tool_server_url="inprocess",
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return {
|
||||
"prompt": (
|
||||
"You MUST call exactly one tool per assistant message.\n"
|
||||
"\n"
|
||||
"Step 1) Call the skills_list tool (no arguments), then stop.\n"
|
||||
"Step 2) After you receive the tool response, call the terminal tool to run:\n"
|
||||
"python -c \"print('ok')\"\n"
|
||||
"Step 3) After you receive the terminal tool response, answer with just: ok\n"
|
||||
"\n"
|
||||
"Tool call format requirements:\n"
|
||||
"- Every tool call MUST be a complete XML block with a closing tag.\n"
|
||||
"- Do NOT emit a second <tool_call> in the same assistant message.\n"
|
||||
"\n"
|
||||
"Example:\n"
|
||||
"<tool_call>{\"name\": \"skills_list\", \"arguments\": {}}</tool_call>\n"
|
||||
"Do not include anything else in your final answer."
|
||||
)
|
||||
}
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=min(10, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
called = {c.name for s in agent_result.steps for c in s.tool_calls}
|
||||
need = {"skills_list", "terminal"}
|
||||
if not need.issubset(called):
|
||||
return 0.0, {"error": f"Missing tool calls: {sorted(need - called)}", "called": sorted(called)}
|
||||
|
||||
terminal_ok = False
|
||||
for step in agent_result.steps:
|
||||
for call, res in zip(step.tool_calls, step.tool_results):
|
||||
if call.name != "terminal":
|
||||
continue
|
||||
if res.success and (res.output or "").strip().splitlines()[-1].strip() == "ok":
|
||||
terminal_ok = True
|
||||
|
||||
score = 1.0 if terminal_ok and (final_response or "").strip() == "ok" else 0.0
|
||||
return score, {"called": sorted(called), "final": (final_response or "").strip()}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ToolServerSmokeEnv.cli()
|
||||
11
atropos/nomad/__init__.py
Normal file
11
atropos/nomad/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Nomad integration for atropos-agent.
|
||||
|
||||
Provides:
|
||||
- NomadClient: Client for Nomad HTTP API
|
||||
- Job templates for sandbox containers
|
||||
"""
|
||||
|
||||
from .client import NomadClient
|
||||
|
||||
__all__ = ["NomadClient"]
|
||||
500
atropos/nomad/client.py
Normal file
500
atropos/nomad/client.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
Nomad API Client for atropos-agent.
|
||||
|
||||
Provides a simple async client for interacting with the Nomad HTTP API:
|
||||
- Submit/stop jobs
|
||||
- Query allocations
|
||||
- Get allocation addresses
|
||||
- Scale jobs up/down
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class AllocationStatus(Enum):
|
||||
"""Nomad allocation status."""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETE = "complete"
|
||||
FAILED = "failed"
|
||||
LOST = "lost"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Allocation:
|
||||
"""Information about a Nomad allocation."""
|
||||
id: str
|
||||
job_id: str
|
||||
task_group: str
|
||||
node_id: str
|
||||
status: AllocationStatus
|
||||
# Network info for reaching the allocation
|
||||
address: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
|
||||
@property
|
||||
def http_address(self) -> Optional[str]:
|
||||
"""Get full HTTP address for the allocation."""
|
||||
if self.address and self.port:
|
||||
return f"http://{self.address}:{self.port}"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobStatus:
|
||||
"""Status of a Nomad job."""
|
||||
id: str
|
||||
name: str
|
||||
status: str
|
||||
allocations: List[Allocation] = field(default_factory=list)
|
||||
count: int = 0 # Number of task groups
|
||||
|
||||
|
||||
class NomadClient:
|
||||
"""
|
||||
Async client for Nomad HTTP API.
|
||||
|
||||
Usage:
|
||||
client = NomadClient(address="http://localhost:4646")
|
||||
|
||||
# Submit a job
|
||||
await client.submit_job(job_spec)
|
||||
|
||||
# Get allocations
|
||||
allocs = await client.get_job_allocations("sandbox-python")
|
||||
|
||||
# Scale job
|
||||
await client.scale_job("sandbox-python", count=5)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: str = "http://localhost:4646",
|
||||
token: Optional[str] = None,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
self.address = address.rstrip("/")
|
||||
self.token = token or os.environ.get("NOMAD_TOKEN")
|
||||
self.timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create HTTP session."""
|
||||
if self._session is None or self._session.closed:
|
||||
headers = {}
|
||||
if self.token:
|
||||
headers["X-Nomad-Token"] = self.token
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an HTTP request to Nomad API."""
|
||||
session = await self._get_session()
|
||||
url = f"{self.address}{path}"
|
||||
|
||||
try:
|
||||
async with session.request(method, url, json=data) as response:
|
||||
if response.status == 404:
|
||||
return {"error": "not_found", "status": 404}
|
||||
|
||||
text = await response.text()
|
||||
if not text:
|
||||
return {"status": response.status}
|
||||
|
||||
try:
|
||||
result = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return {"text": text, "status": response.status}
|
||||
|
||||
if response.status >= 400:
|
||||
return {"error": result, "status": response.status}
|
||||
|
||||
return result if isinstance(result, dict) else {"data": result, "status": response.status}
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return {"error": str(e), "status": 0}
|
||||
|
||||
# Job Operations
|
||||
|
||||
async def submit_job(self, job_spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Submit a job to Nomad.
|
||||
|
||||
Args:
|
||||
job_spec: Job specification dict (HCL converted to JSON)
|
||||
|
||||
Returns:
|
||||
Response with EvalID if successful
|
||||
"""
|
||||
return await self._request("POST", "/v1/jobs", {"Job": job_spec})
|
||||
|
||||
async def stop_job(self, job_id: str, purge: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Stop (and optionally purge) a job.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier
|
||||
purge: If True, completely remove the job
|
||||
"""
|
||||
path = f"/v1/job/{job_id}"
|
||||
if purge:
|
||||
path += "?purge=true"
|
||||
return await self._request("DELETE", path)
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get job details."""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}")
|
||||
if "error" in result and result.get("status") == 404:
|
||||
return None
|
||||
return result
|
||||
|
||||
async def get_job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
"""Get job status with allocations."""
|
||||
job = await self.get_job(job_id)
|
||||
if not job:
|
||||
return None
|
||||
|
||||
allocs = await self.get_job_allocations(job_id)
|
||||
|
||||
# Get count from task groups
|
||||
count = 0
|
||||
task_groups = job.get("TaskGroups", [])
|
||||
for tg in task_groups:
|
||||
count += tg.get("Count", 1)
|
||||
|
||||
return JobStatus(
|
||||
id=job_id,
|
||||
name=job.get("Name", job_id),
|
||||
status=job.get("Status", "unknown"),
|
||||
allocations=allocs,
|
||||
count=count,
|
||||
)
|
||||
|
||||
# Allocation Operations
|
||||
|
||||
async def get_job_allocations(self, job_id: str) -> List[Allocation]:
|
||||
"""Get all allocations for a job."""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}/allocations")
|
||||
|
||||
if "error" in result:
|
||||
return []
|
||||
|
||||
allocs_data = result.get("data", result) if isinstance(result, dict) else result
|
||||
if not isinstance(allocs_data, list):
|
||||
return []
|
||||
|
||||
allocations = []
|
||||
for alloc_data in allocs_data:
|
||||
# Parse allocation info
|
||||
alloc_id = alloc_data.get("ID", "")
|
||||
status_str = alloc_data.get("ClientStatus", "unknown")
|
||||
|
||||
try:
|
||||
status = AllocationStatus(status_str)
|
||||
except ValueError:
|
||||
status = AllocationStatus.PENDING
|
||||
|
||||
# Get network info - need to fetch detailed allocation for this
|
||||
address = None
|
||||
port = None
|
||||
|
||||
# First try the summary data
|
||||
resources = alloc_data.get("AllocatedResources") or {}
|
||||
shared = resources.get("Shared") or {}
|
||||
networks = shared.get("Networks") or []
|
||||
|
||||
# If no networks in summary, fetch detailed allocation
|
||||
if not networks and alloc_id:
|
||||
detailed = await self.get_allocation(alloc_id)
|
||||
if detailed:
|
||||
resources = detailed.get("AllocatedResources") or {}
|
||||
shared = resources.get("Shared") or {}
|
||||
networks = shared.get("Networks") or []
|
||||
|
||||
if networks:
|
||||
network = networks[0]
|
||||
address = network.get("IP")
|
||||
# Look for dynamic ports OR reserved ports (Singularity/raw_exec uses reserved)
|
||||
dyn_ports = network.get("DynamicPorts") or []
|
||||
reserved_ports = network.get("ReservedPorts") or []
|
||||
for dp in dyn_ports + reserved_ports:
|
||||
if dp.get("Label") == "http":
|
||||
port = dp.get("Value")
|
||||
break
|
||||
|
||||
allocations.append(Allocation(
|
||||
id=alloc_id,
|
||||
job_id=job_id,
|
||||
task_group=alloc_data.get("TaskGroup", ""),
|
||||
node_id=alloc_data.get("NodeID", ""),
|
||||
status=status,
|
||||
address=address,
|
||||
port=port,
|
||||
))
|
||||
|
||||
return allocations
|
||||
|
||||
async def get_allocation(self, alloc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get detailed allocation info."""
|
||||
result = await self._request("GET", f"/v1/allocation/{alloc_id}")
|
||||
if "error" in result and result.get("status") == 404:
|
||||
return None
|
||||
return result
|
||||
|
||||
# Scaling Operations
|
||||
|
||||
async def scale_job(self, job_id: str, count: int, task_group: str = "sandbox") -> Dict[str, Any]:
|
||||
"""
|
||||
Scale a job's task group to specified count.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier
|
||||
count: Desired number of allocations
|
||||
task_group: Name of task group to scale
|
||||
"""
|
||||
payload = {
|
||||
"Count": count,
|
||||
"Target": {
|
||||
"Group": task_group,
|
||||
},
|
||||
}
|
||||
return await self._request("POST", f"/v1/job/{job_id}/scale", payload)
|
||||
|
||||
async def get_job_scale_status(self, job_id: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get current scale status for a job.
|
||||
|
||||
Returns:
|
||||
Dict mapping task group name to count
|
||||
"""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}/scale")
|
||||
|
||||
if "error" in result:
|
||||
return {}
|
||||
|
||||
task_groups = result.get("TaskGroups", {})
|
||||
return {
|
||||
name: info.get("Running", 0)
|
||||
for name, info in task_groups.items()
|
||||
}
|
||||
|
||||
# Health Check
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
"""Check if Nomad is reachable and healthy."""
|
||||
try:
|
||||
result = await self._request("GET", "/v1/status/leader")
|
||||
return "error" not in result
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_leader(self) -> Optional[str]:
|
||||
"""Get current Nomad leader address."""
|
||||
result = await self._request("GET", "/v1/status/leader")
|
||||
if isinstance(result, dict) and "data" in result:
|
||||
return result["data"]
|
||||
return None
|
||||
|
||||
|
||||
def load_job_template(
|
||||
template_name: str = "sandbox",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load and configure a job template.
|
||||
|
||||
Args:
|
||||
template_name: Name of template (e.g., "sandbox")
|
||||
**kwargs: Template variables to substitute
|
||||
|
||||
Returns:
|
||||
Job specification dict ready for Nomad API
|
||||
"""
|
||||
# Default job template for sandbox container
|
||||
if template_name == "sandbox":
|
||||
return create_sandbox_job(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown template: {template_name}")
|
||||
|
||||
|
||||
def create_sandbox_job(
|
||||
job_id: str = "atropos-sandbox",
|
||||
image: str = "atropos-sandbox:local", # Use :local tag to avoid registry pull
|
||||
count: int = 1,
|
||||
slots_per_container: int = 10,
|
||||
privileged: bool = False,
|
||||
cpu: int = 500,
|
||||
memory: int = 512,
|
||||
port: int = 8080,
|
||||
datacenter: str = "dc1",
|
||||
driver: str = "docker", # "docker" or "singularity"
|
||||
singularity_image: str = None, # Path to .sif file for singularity driver
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a sandbox job specification.
|
||||
|
||||
This job runs the sandbox_server.py inside a container,
|
||||
with the specified number of slots for agent workspaces.
|
||||
|
||||
Args:
|
||||
job_id: Unique job identifier
|
||||
image: Docker image to use (for docker driver)
|
||||
count: Number of container instances
|
||||
slots_per_container: Number of slots per container
|
||||
privileged: Run container in privileged mode (recommended for bubblewrap)
|
||||
cpu: CPU allocation in MHz
|
||||
memory: Memory allocation in MB
|
||||
port: HTTP port for sandbox server
|
||||
datacenter: Nomad datacenter
|
||||
driver: Container driver - "docker" or "singularity"
|
||||
singularity_image: Path to .sif file (required if driver="singularity")
|
||||
|
||||
Returns:
|
||||
Job specification dict
|
||||
"""
|
||||
# Build task config based on driver
|
||||
if driver == "singularity":
|
||||
if not singularity_image:
|
||||
raise ValueError("singularity_image path required when driver='singularity'")
|
||||
|
||||
# Use raw_exec driver to run apptainer via shell for variable expansion
|
||||
# The container binds the allocation directory for workspace persistence
|
||||
# For raw_exec, we use static port since Nomad's dynamic port mapping doesn't
|
||||
# work the same as Docker - the process runs directly on the host.
|
||||
shell_cmd = (
|
||||
f'apptainer run '
|
||||
f'--bind "$NOMAD_ALLOC_DIR/data:/data" '
|
||||
f'--pwd /app '
|
||||
f'--env PYTHONUNBUFFERED=1 '
|
||||
f'{singularity_image} '
|
||||
f'python sandbox_server.py '
|
||||
f'--port {port} '
|
||||
f'--slots {slots_per_container} '
|
||||
f'--data-dir /data'
|
||||
)
|
||||
task_config = {
|
||||
"command": "/bin/sh",
|
||||
"args": ["-c", shell_cmd],
|
||||
}
|
||||
task_driver = "raw_exec"
|
||||
else:
|
||||
# Docker driver (default)
|
||||
task_config = {
|
||||
"image": image,
|
||||
"force_pull": False, # Use local image, don't try to pull
|
||||
"ports": ["http"],
|
||||
"privileged": privileged,
|
||||
"command": "python",
|
||||
"args": [
|
||||
"sandbox_server.py",
|
||||
"--port", str(port),
|
||||
"--slots", str(slots_per_container),
|
||||
"--data-dir", "/data",
|
||||
],
|
||||
# Note: On Linux, you can mount persistent storage:
|
||||
# "volumes": ["${NOMAD_ALLOC_DIR}/data:/data"],
|
||||
# On macOS/Docker Desktop, skip volumes for PoC
|
||||
# (container /data is ephemeral but works for testing)
|
||||
}
|
||||
task_driver = "docker"
|
||||
|
||||
# For Singularity/raw_exec, use static ports since the process runs directly on host.
|
||||
# For Docker, use dynamic ports with port mapping.
|
||||
if driver == "singularity":
|
||||
network_config = {
|
||||
"Mode": "host",
|
||||
"ReservedPorts": [
|
||||
{
|
||||
"Label": "http",
|
||||
"Value": port,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
network_config = {
|
||||
"Mode": "host",
|
||||
"DynamicPorts": [
|
||||
{
|
||||
"Label": "http",
|
||||
"To": port,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
return {
|
||||
"ID": job_id,
|
||||
"Name": job_id,
|
||||
"Type": "service",
|
||||
"Datacenters": [datacenter],
|
||||
"TaskGroups": [
|
||||
{
|
||||
"Name": "sandbox",
|
||||
"Count": count,
|
||||
# Speed up deployments and avoid Consul checks. Without this, Nomad may
|
||||
# keep an "active deployment" around for the default MinHealthyTime,
|
||||
# which blocks immediate scaling under load.
|
||||
"Update": {
|
||||
"HealthCheck": "task_states",
|
||||
"MinHealthyTime": 0,
|
||||
},
|
||||
"Networks": [network_config],
|
||||
"Tasks": [
|
||||
{
|
||||
"Name": "sandbox-server",
|
||||
"Driver": task_driver,
|
||||
"Config": task_config,
|
||||
"Env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"NOMAD_ALLOC_DIR": "${NOMAD_ALLOC_DIR}",
|
||||
},
|
||||
"Resources": {
|
||||
"CPU": cpu,
|
||||
"MemoryMB": memory,
|
||||
},
|
||||
# Note: Services with Checks require Consul, which we skip for the PoC
|
||||
}
|
||||
],
|
||||
"RestartPolicy": {
|
||||
"Attempts": 3,
|
||||
"Interval": 300_000_000_000, # 5 minutes
|
||||
"Delay": 10_000_000_000, # 10 seconds
|
||||
"Mode": "delay",
|
||||
},
|
||||
"ReschedulePolicy": {
|
||||
"Attempts": 5,
|
||||
"Interval": 3600_000_000_000, # 1 hour
|
||||
"Delay": 30_000_000_000, # 30 seconds
|
||||
"DelayFunction": "exponential",
|
||||
"MaxDelay": 300_000_000_000, # 5 minutes
|
||||
"Unlimited": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
1912
atropos/sandbox_server.py
Normal file
1912
atropos/sandbox_server.py
Normal file
File diff suppressed because it is too large
Load Diff
20
atropos/slots/__init__.py
Normal file
20
atropos/slots/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Slot-based multiplexing for atropos-agent.
|
||||
|
||||
Provides:
|
||||
- Slot: Isolated workspace for a single trajectory
|
||||
- SlotPool: Manages slots across Nomad allocations
|
||||
- SandboxExecutor: Executes tools in sandbox containers
|
||||
"""
|
||||
|
||||
from .executor import SandboxExecutor
|
||||
from .pool import SlotPool, SlotPoolConfig
|
||||
from .slot import Slot, SlotState
|
||||
|
||||
__all__ = [
|
||||
"Slot",
|
||||
"SlotState",
|
||||
"SlotPool",
|
||||
"SlotPoolConfig",
|
||||
"SandboxExecutor",
|
||||
]
|
||||
457
atropos/slots/executor.py
Normal file
457
atropos/slots/executor.py
Normal file
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
SandboxExecutor - HTTP client for sandbox container communication.
|
||||
|
||||
Sends tool execution requests to sandbox_server.py running inside Nomad containers.
|
||||
Supports single and batch execution for efficiency.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .slot import Slot, SlotState
|
||||
from ..tools.base import ToolCall, ToolResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionRequest:
|
||||
"""Request to execute a tool in a slot."""
|
||||
slot: Slot
|
||||
tool_name: str
|
||||
args: Dict[str, Any]
|
||||
execution_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
timeout: float = 30.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result from sandbox execution."""
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
execution_id: str = ""
|
||||
slot_id: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_tool_result(self) -> ToolResult:
|
||||
"""Convert to ToolResult for agent consumption."""
|
||||
return ToolResult(
|
||||
success=self.success,
|
||||
output=self.output,
|
||||
error=self.error,
|
||||
metadata=self.metadata,
|
||||
uniq_id=self.execution_id,
|
||||
)
|
||||
|
||||
|
||||
class SandboxExecutor:
|
||||
"""
|
||||
HTTP client for executing tools in sandbox containers.
|
||||
|
||||
Communicates with sandbox_server.py running inside Nomad allocations.
|
||||
Supports both single execution and batched parallel execution.
|
||||
|
||||
Usage:
|
||||
executor = SandboxExecutor()
|
||||
|
||||
# Single execution
|
||||
result = await executor.execute(slot, "bash", {"command": "ls"})
|
||||
|
||||
# Batch execution
|
||||
results = await executor.execute_batch([
|
||||
(slot1, "bash", {"command": "ls"}),
|
||||
(slot2, "write_file", {"path": "test.txt", "content": "hello"}),
|
||||
])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
):
|
||||
self.timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create HTTP session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession(timeout=self.timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
slot: Slot,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute a tool in a slot's workspace.
|
||||
|
||||
Args:
|
||||
slot: Slot to execute in
|
||||
tool_name: Name of tool (bash, read_file, write_file)
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
ExecutionResult with output or error
|
||||
"""
|
||||
execution_id = str(uuid.uuid4())
|
||||
exec_timeout = timeout or self.timeout.total or 30.0
|
||||
|
||||
# Mark slot as executing
|
||||
original_state = slot.state
|
||||
try:
|
||||
if slot.state == SlotState.ACQUIRED:
|
||||
slot.start_execution(execution_id)
|
||||
|
||||
result = await self._send_execute_request(
|
||||
container_addr=slot.container_addr,
|
||||
slot_id=slot.slot_id,
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
execution_id=execution_id,
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
result.slot_id = slot.slot_id
|
||||
return result
|
||||
|
||||
finally:
|
||||
# Restore slot state
|
||||
if slot.state == SlotState.EXECUTING:
|
||||
slot.end_execution()
|
||||
|
||||
async def _send_execute_request(
|
||||
self,
|
||||
container_addr: str,
|
||||
slot_id: str,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
execution_id: str,
|
||||
timeout: float,
|
||||
) -> ExecutionResult:
|
||||
"""Send execution request to sandbox server with retry logic."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/execute"
|
||||
|
||||
payload = {
|
||||
"slot_id": slot_id,
|
||||
"tool": tool_name,
|
||||
"args": args,
|
||||
"execution_id": execution_id,
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
data = await response.json()
|
||||
|
||||
return ExecutionResult(
|
||||
success=data.get("success", False),
|
||||
output=data.get("output", ""),
|
||||
error=data.get("error", ""),
|
||||
execution_id=data.get("execution_id", execution_id),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
last_error = str(e)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1))
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"Request timed out after {timeout}s"
|
||||
break
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
break
|
||||
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"Failed after {self.max_retries} attempts: {last_error}",
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
timeout: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""
|
||||
Execute multiple tools in parallel across slots.
|
||||
|
||||
This is the key optimization - we batch tool calls to maximize
|
||||
container utilization while agents are waiting for LLM responses.
|
||||
|
||||
Args:
|
||||
requests: List of (slot, tool_name, args) tuples
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
List of ExecutionResults in same order as requests
|
||||
"""
|
||||
if not requests:
|
||||
return []
|
||||
|
||||
# Group requests by container address for batch API
|
||||
by_container: Dict[str, List[Tuple[int, Slot, str, Dict[str, Any], str]]] = {}
|
||||
|
||||
for idx, (slot, tool_name, args) in enumerate(requests):
|
||||
execution_id = str(uuid.uuid4())
|
||||
container = slot.container_addr
|
||||
|
||||
if container not in by_container:
|
||||
by_container[container] = []
|
||||
by_container[container].append((idx, slot, tool_name, args, execution_id))
|
||||
|
||||
# Mark slots as executing
|
||||
if slot.state == SlotState.ACQUIRED:
|
||||
slot.start_execution(execution_id)
|
||||
|
||||
# Execute batches in parallel
|
||||
exec_timeout = timeout or self.timeout.total or 30.0
|
||||
batch_tasks = []
|
||||
|
||||
for container_addr, batch_requests in by_container.items():
|
||||
task = self._send_batch_request(
|
||||
container_addr=container_addr,
|
||||
batch_requests=batch_requests,
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
batch_tasks.append(task)
|
||||
|
||||
# Gather all batch results
|
||||
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
# Collect results in original order
|
||||
results: List[Optional[ExecutionResult]] = [None] * len(requests)
|
||||
|
||||
for batch_result in batch_results:
|
||||
if isinstance(batch_result, Exception):
|
||||
# Mark all in this batch as failed
|
||||
continue
|
||||
|
||||
for idx, result in batch_result:
|
||||
results[idx] = result
|
||||
|
||||
# Fill in any missing results
|
||||
for idx, result in enumerate(results):
|
||||
if result is None:
|
||||
slot, tool_name, args = requests[idx]
|
||||
results[idx] = ExecutionResult(
|
||||
success=False,
|
||||
error="Batch execution failed",
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
|
||||
# End execution on all slots
|
||||
for slot, _, _ in requests:
|
||||
if slot.state == SlotState.EXECUTING:
|
||||
slot.end_execution()
|
||||
|
||||
return results # type: ignore
|
||||
|
||||
async def _send_batch_request(
|
||||
self,
|
||||
container_addr: str,
|
||||
batch_requests: List[Tuple[int, Slot, str, Dict[str, Any], str]],
|
||||
timeout: float,
|
||||
) -> List[Tuple[int, ExecutionResult]]:
|
||||
"""Send batch execution request to a single container."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/batch"
|
||||
|
||||
# Build batch payload
|
||||
payload = [
|
||||
{
|
||||
"slot_id": slot.slot_id,
|
||||
"tool": tool_name,
|
||||
"args": args,
|
||||
"execution_id": execution_id,
|
||||
"timeout": timeout,
|
||||
}
|
||||
for _, slot, tool_name, args, execution_id in batch_requests
|
||||
]
|
||||
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
data = await response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValueError(f"Expected list response, got {type(data)}")
|
||||
|
||||
results = []
|
||||
for i, (idx, slot, _, _, execution_id) in enumerate(batch_requests):
|
||||
if i < len(data):
|
||||
item = data[i]
|
||||
result = ExecutionResult(
|
||||
success=item.get("success", False),
|
||||
output=item.get("output", ""),
|
||||
error=item.get("error", ""),
|
||||
execution_id=item.get("execution_id", execution_id),
|
||||
slot_id=slot.slot_id,
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
else:
|
||||
result = ExecutionResult(
|
||||
success=False,
|
||||
error="Missing result in batch response",
|
||||
execution_id=execution_id,
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
results.append((idx, result))
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
# Return error for all requests in batch
|
||||
return [
|
||||
(idx, ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
execution_id=execution_id,
|
||||
slot_id=slot.slot_id,
|
||||
))
|
||||
for idx, slot, _, _, execution_id in batch_requests
|
||||
]
|
||||
|
||||
async def reset_slot(self, slot: Slot) -> ExecutionResult:
|
||||
"""
|
||||
Reset a slot's workspace (delete all files).
|
||||
|
||||
Useful when reusing a slot for a new trajectory.
|
||||
"""
|
||||
session = await self._get_session()
|
||||
url = f"{slot.container_addr}/reset"
|
||||
|
||||
try:
|
||||
async with session.post(url, json={"slot_id": slot.slot_id}) as response:
|
||||
data = await response.json()
|
||||
return ExecutionResult(
|
||||
success=data.get("success", False),
|
||||
output=data.get("output", ""),
|
||||
error=data.get("error", ""),
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
|
||||
async def health_check(self, container_addr: str) -> bool:
|
||||
"""Check if a sandbox container is healthy."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/health"
|
||||
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
data = await response.json()
|
||||
return data.get("status") == "ok"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_container_status(
|
||||
self,
|
||||
container_addr: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get status info from a sandbox container."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/health"
|
||||
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
return await response.json()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Artifact helpers (optional)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _post_json(
|
||||
self,
|
||||
url: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
session = await self._get_session()
|
||||
try:
|
||||
async with session.post(url, json=payload, timeout=timeout) as response:
|
||||
data = await response.json()
|
||||
if isinstance(data, dict):
|
||||
data.setdefault("http_status", response.status)
|
||||
return data
|
||||
return {"success": False, "error": f"Unexpected response type: {type(data)}", "http_status": response.status}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/read"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "encoding": encoding, "include_sha256": include_sha256}
|
||||
if max_bytes is not None:
|
||||
payload["max_bytes"] = max_bytes
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/list"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "recursive": recursive}
|
||||
if max_entries is not None:
|
||||
payload["max_entries"] = max_entries
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/archive"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "format": archive_format}
|
||||
if max_bytes is not None:
|
||||
payload["max_bytes"] = max_bytes
|
||||
if max_entries is not None:
|
||||
payload["max_entries"] = max_entries
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
659
atropos/slots/pool.py
Normal file
659
atropos/slots/pool.py
Normal file
@@ -0,0 +1,659 @@
|
||||
"""
|
||||
SlotPool - Manages slots across Nomad allocations.
|
||||
|
||||
The SlotPool is the core abstraction for slot-based multiplexing:
|
||||
- Tracks available/acquired slots across containers
|
||||
- Handles slot acquisition and release
|
||||
- Auto-scales Nomad job count based on demand
|
||||
- Provides batched tool execution
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..nomad.client import (
|
||||
Allocation,
|
||||
AllocationStatus,
|
||||
NomadClient,
|
||||
create_sandbox_job,
|
||||
)
|
||||
from .executor import ExecutionResult, SandboxExecutor
|
||||
from .slot import Slot, SlotState, create_slots_for_allocation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlotPoolConfig:
|
||||
"""Configuration for SlotPool."""
|
||||
|
||||
# Nomad settings
|
||||
nomad_address: str = "http://localhost:4646"
|
||||
job_id: str = "atropos-sandbox"
|
||||
datacenter: str = "dc1"
|
||||
|
||||
# Container settings
|
||||
image: str = "atropos-sandbox:local" # Use :local tag to avoid registry pull
|
||||
slots_per_container: int = 10
|
||||
privileged: bool = False
|
||||
cpu: int = 500 # MHz
|
||||
memory: int = 512 # MB
|
||||
|
||||
# Driver selection: "docker" or "singularity"
|
||||
driver: str = "docker"
|
||||
# Path to .sif file for singularity driver (required if driver="singularity")
|
||||
singularity_image: Optional[str] = None
|
||||
|
||||
# Scaling settings
|
||||
min_containers: int = 1
|
||||
max_containers: int = 10
|
||||
|
||||
# Timeouts
|
||||
acquire_timeout: float = 30.0 # Seconds between acquire polls (also triggers scale-up attempts)
|
||||
health_check_interval: float = 30.0 # Seconds between health checks
|
||||
scale_cooldown: float = 60.0 # Seconds between scale operations
|
||||
|
||||
# Job lifecycle
|
||||
purge_job_on_start: bool = False # Purge any pre-existing job before starting (local dev/training friendly)
|
||||
|
||||
# Local Docker image convenience (macOS/Nomad dev mode)
|
||||
auto_build_local_image: bool = True # If image endswith :local and is missing, build it from the bundled Dockerfile.
|
||||
dockerfile_path: Optional[str] = None # Override Dockerfile path (default: Hermes-Agent/atropos/Dockerfile).
|
||||
docker_build_context: Optional[str] = None # Override build context (default: Hermes-Agent/atropos).
|
||||
|
||||
|
||||
class SlotPool:
|
||||
"""
|
||||
Manages a pool of slots across Nomad allocations.
|
||||
|
||||
The SlotPool:
|
||||
- Deploys sandbox containers to Nomad
|
||||
- Tracks slots across all running containers
|
||||
- Handles slot acquisition/release
|
||||
- Auto-scales based on demand
|
||||
- Provides batched execution via SandboxExecutor
|
||||
|
||||
Usage:
|
||||
config = SlotPoolConfig(
|
||||
nomad_address="http://localhost:4646",
|
||||
job_id="my-sandbox",
|
||||
slots_per_container=10,
|
||||
)
|
||||
|
||||
pool = SlotPool(config)
|
||||
await pool.start()
|
||||
|
||||
# Acquire a slot
|
||||
slot = await pool.acquire()
|
||||
|
||||
# Execute tool
|
||||
result = await pool.execute(slot, "bash", {"command": "ls"})
|
||||
|
||||
# Release slot
|
||||
await pool.release(slot)
|
||||
|
||||
# Shutdown
|
||||
await pool.stop()
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SlotPoolConfig] = None):
|
||||
self.config = config or SlotPoolConfig()
|
||||
|
||||
# Nomad client
|
||||
self.nomad = NomadClient(address=self.config.nomad_address)
|
||||
|
||||
# Sandbox executor for tool execution
|
||||
self.executor = SandboxExecutor()
|
||||
|
||||
# Slot tracking
|
||||
self._slots: Dict[str, Slot] = {} # slot_key -> Slot
|
||||
self._available_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._lock = asyncio.Lock()
|
||||
self._scale_lock = asyncio.Lock()
|
||||
|
||||
# State
|
||||
self._started = False
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
self._scale_task: Optional[asyncio.Task] = None
|
||||
self._last_scale_time = 0.0
|
||||
|
||||
def _default_dockerfile_path(self) -> Path:
|
||||
# Hermes-Agent/atropos/Dockerfile lives next to this module in source checkouts.
|
||||
return Path(__file__).resolve().parents[1] / "Dockerfile"
|
||||
|
||||
def _default_build_context(self) -> Path:
|
||||
return Path(__file__).resolve().parents[1]
|
||||
|
||||
def _docker_image_exists(self, image: str) -> bool:
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
["docker", "image", "inspect", image],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
check=False,
|
||||
env={**os.environ, "DOCKER_CLI_HINTS": "false"},
|
||||
)
|
||||
return proc.returncode == 0
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
def _try_build_local_image(self, image: str) -> None:
|
||||
dockerfile = Path(self.config.dockerfile_path) if self.config.dockerfile_path else self._default_dockerfile_path()
|
||||
context = Path(self.config.docker_build_context) if self.config.docker_build_context else self._default_build_context()
|
||||
|
||||
if not dockerfile.exists():
|
||||
raise RuntimeError(
|
||||
f"Sandbox Dockerfile not found at {dockerfile}. "
|
||||
"Build the sandbox image manually or set --env.purge_job_on_start false and provide a non-local image."
|
||||
)
|
||||
if not context.exists():
|
||||
raise RuntimeError(f"Docker build context not found at {context}")
|
||||
|
||||
# Prefer buildx+--load to ensure the image ends up in the local daemon (required by Nomad's docker driver).
|
||||
buildx_cmd = [
|
||||
"docker",
|
||||
"buildx",
|
||||
"build",
|
||||
"--load",
|
||||
"-t",
|
||||
image,
|
||||
"-f",
|
||||
str(dockerfile),
|
||||
str(context),
|
||||
]
|
||||
proc = subprocess.run(buildx_cmd, check=False, env={**os.environ, "DOCKER_CLI_HINTS": "false"})
|
||||
if proc.returncode == 0:
|
||||
return
|
||||
|
||||
# Fallback to classic docker build if buildx isn't available.
|
||||
build_cmd = ["docker", "build", "-t", image, "-f", str(dockerfile), str(context)]
|
||||
proc2 = subprocess.run(build_cmd, check=False, env={**os.environ, "DOCKER_CLI_HINTS": "false"})
|
||||
if proc2.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Failed to build local sandbox image {image}. "
|
||||
f"Tried: {' '.join(buildx_cmd)} and {' '.join(build_cmd)}"
|
||||
)
|
||||
|
||||
def _ensure_local_image(self) -> None:
|
||||
image = (self.config.image or "").strip()
|
||||
if not image.endswith(":local"):
|
||||
return
|
||||
if not self.config.auto_build_local_image:
|
||||
return
|
||||
|
||||
if self._docker_image_exists(image):
|
||||
return
|
||||
|
||||
logger.info(f"Local sandbox image {image} not found; building it now...")
|
||||
self._try_build_local_image(image)
|
||||
|
||||
def _slot_key(self, alloc_id: str, slot_id: str) -> str:
|
||||
"""Generate unique key for a slot."""
|
||||
return f"{alloc_id}:{slot_id}"
|
||||
|
||||
@property
|
||||
def total_slots(self) -> int:
|
||||
"""Total number of slots in pool."""
|
||||
return len(self._slots)
|
||||
|
||||
@property
|
||||
def available_slots(self) -> int:
|
||||
"""Number of available slots."""
|
||||
return sum(1 for s in self._slots.values() if s.is_available)
|
||||
|
||||
@property
|
||||
def acquired_slots(self) -> int:
|
||||
"""Number of acquired slots."""
|
||||
return sum(1 for s in self._slots.values() if s.is_acquired)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the slot pool.
|
||||
|
||||
- Checks if Nomad is healthy
|
||||
- Deploys sandbox job if not running
|
||||
- Discovers existing allocations
|
||||
- Starts health check background task
|
||||
"""
|
||||
if self._started:
|
||||
return
|
||||
|
||||
logger.info(f"Starting SlotPool (job_id={self.config.job_id})")
|
||||
|
||||
try:
|
||||
# Make sure local sandbox images exist before Nomad tries to pull them.
|
||||
# This is a common footgun in macOS dev mode with :local tags.
|
||||
self._ensure_local_image()
|
||||
|
||||
# Check Nomad health
|
||||
if not await self.nomad.is_healthy():
|
||||
raise RuntimeError(f"Nomad is not reachable at {self.config.nomad_address}")
|
||||
|
||||
if self.config.purge_job_on_start:
|
||||
logger.info(f"Purging any existing Nomad job: {self.config.job_id}")
|
||||
await self.nomad.stop_job(self.config.job_id, purge=True)
|
||||
|
||||
# Check if job exists (after optional purge)
|
||||
job = await self.nomad.get_job(self.config.job_id)
|
||||
|
||||
if job is None:
|
||||
# Deploy new job
|
||||
logger.info(f"Deploying sandbox job: {self.config.job_id} (driver={self.config.driver})")
|
||||
job_spec = create_sandbox_job(
|
||||
job_id=self.config.job_id,
|
||||
image=self.config.image,
|
||||
count=self.config.min_containers,
|
||||
slots_per_container=self.config.slots_per_container,
|
||||
privileged=self.config.privileged,
|
||||
cpu=self.config.cpu,
|
||||
memory=self.config.memory,
|
||||
datacenter=self.config.datacenter,
|
||||
driver=self.config.driver,
|
||||
singularity_image=self.config.singularity_image,
|
||||
)
|
||||
result = await self.nomad.submit_job(job_spec)
|
||||
if "error" in result:
|
||||
raise RuntimeError(f"Failed to submit job: {result}")
|
||||
|
||||
# Wait for allocations to be running (even if the job already existed).
|
||||
await self._wait_for_healthy_allocations(self.config.min_containers)
|
||||
|
||||
# Discover existing allocations and slots
|
||||
await self._refresh_slots()
|
||||
|
||||
# Start health check task
|
||||
self._health_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
self._started = True
|
||||
logger.info(f"SlotPool started: {self.total_slots} slots available")
|
||||
except Exception:
|
||||
# Ensure aiohttp sessions are not leaked if we fail to start.
|
||||
await self.stop(purge_job=False)
|
||||
raise
|
||||
|
||||
async def stop(self, purge_job: bool = False) -> None:
|
||||
"""
|
||||
Stop the slot pool.
|
||||
|
||||
Args:
|
||||
purge_job: If True, also stop the Nomad job
|
||||
"""
|
||||
logger.info("Stopping SlotPool")
|
||||
|
||||
# Cancel health check task
|
||||
if self._health_task:
|
||||
self._health_task.cancel()
|
||||
try:
|
||||
await self._health_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._health_task = None
|
||||
|
||||
if self._scale_task:
|
||||
self._scale_task.cancel()
|
||||
try:
|
||||
await self._scale_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._scale_task = None
|
||||
|
||||
# Optionally stop the job (do this even if start() never completed).
|
||||
if purge_job:
|
||||
logger.info(f"Stopping Nomad job: {self.config.job_id}")
|
||||
await self.nomad.stop_job(self.config.job_id, purge=True)
|
||||
|
||||
# Close connections
|
||||
await self.executor.close()
|
||||
await self.nomad.close()
|
||||
|
||||
self._started = False
|
||||
self._slots.clear()
|
||||
|
||||
# Clear the queue
|
||||
while not self._available_queue.empty():
|
||||
try:
|
||||
self._available_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
"""
|
||||
Acquire an available slot.
|
||||
|
||||
If no slots are available, waits up to acquire_timeout seconds.
|
||||
If still no slots, attempts to scale up.
|
||||
|
||||
Args:
|
||||
trajectory_id: Optional ID of trajectory acquiring the slot
|
||||
|
||||
Returns:
|
||||
Acquired Slot
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If no slot becomes available
|
||||
"""
|
||||
if not self._started:
|
||||
raise RuntimeError("SlotPool not started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Try to get an available slot
|
||||
slot_key = await asyncio.wait_for(
|
||||
self._available_queue.get(),
|
||||
timeout=self.config.acquire_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Try to scale up, but keep waiting even if scaling isn't possible.
|
||||
# In practice, slots may become available shortly (e.g. contention),
|
||||
# and scaling may be temporarily blocked by Nomad deployments.
|
||||
await self._try_scale_up()
|
||||
continue
|
||||
|
||||
slot = self._slots.get(slot_key)
|
||||
if slot is None:
|
||||
# Slot was removed; discard stale queue entry and retry.
|
||||
continue
|
||||
|
||||
try:
|
||||
slot.acquire(trajectory_id)
|
||||
except RuntimeError:
|
||||
# Slot isn't actually available (e.g. duplicate queue entry); retry.
|
||||
continue
|
||||
|
||||
logger.debug(f"Acquired slot {slot.slot_id} (alloc={slot.alloc_id[:8]})")
|
||||
return slot
|
||||
|
||||
async def release(self, slot: Slot, reset_workspace: bool = False) -> None:
|
||||
"""
|
||||
Release a slot back to the pool.
|
||||
|
||||
Args:
|
||||
slot: Slot to release
|
||||
reset_workspace: If True, clear the workspace files
|
||||
"""
|
||||
slot_key = self._slot_key(slot.alloc_id, slot.slot_id)
|
||||
|
||||
if slot_key not in self._slots:
|
||||
logger.warning(f"Releasing unknown slot: {slot_key}")
|
||||
return
|
||||
|
||||
# Optionally reset workspace
|
||||
if reset_workspace:
|
||||
await self.executor.reset_slot(slot)
|
||||
|
||||
slot.release()
|
||||
await self._available_queue.put(slot_key)
|
||||
|
||||
logger.debug(f"Released slot {slot.slot_id}")
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
slot: Slot,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute a tool in a slot's workspace.
|
||||
|
||||
Args:
|
||||
slot: Slot to execute in
|
||||
tool_name: Name of tool (bash, read_file, write_file)
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
ExecutionResult
|
||||
"""
|
||||
return await self.executor.execute(slot, tool_name, args, timeout)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
timeout: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""
|
||||
Execute multiple tools in parallel.
|
||||
|
||||
This is the key optimization - batch execution across multiple slots
|
||||
maximizes container utilization.
|
||||
|
||||
Args:
|
||||
requests: List of (slot, tool_name, args) tuples
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
List of ExecutionResults in same order
|
||||
"""
|
||||
return await self.executor.execute_batch(requests, timeout)
|
||||
|
||||
async def _refresh_slots(self) -> None:
|
||||
"""Refresh slot inventory from Nomad allocations."""
|
||||
async with self._lock:
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
|
||||
# Track which slots we've seen
|
||||
seen_keys = set()
|
||||
|
||||
for alloc in allocs:
|
||||
if alloc.status != AllocationStatus.RUNNING:
|
||||
continue
|
||||
|
||||
if not alloc.http_address:
|
||||
continue
|
||||
|
||||
# Check container health
|
||||
healthy = await self.executor.health_check(alloc.http_address)
|
||||
if not healthy:
|
||||
continue
|
||||
|
||||
# Create slots for this allocation
|
||||
for i in range(self.config.slots_per_container):
|
||||
slot_id = f"slot_{i}"
|
||||
slot_key = self._slot_key(alloc.id, slot_id)
|
||||
seen_keys.add(slot_key)
|
||||
|
||||
if slot_key not in self._slots:
|
||||
# New slot
|
||||
slot = Slot(
|
||||
slot_id=slot_id,
|
||||
alloc_id=alloc.id,
|
||||
container_addr=alloc.http_address,
|
||||
)
|
||||
self._slots[slot_key] = slot
|
||||
await self._available_queue.put(slot_key)
|
||||
logger.debug(f"Added slot: {slot_key}")
|
||||
|
||||
# Remove slots from dead allocations
|
||||
for slot_key in list(self._slots.keys()):
|
||||
if slot_key not in seen_keys:
|
||||
slot = self._slots.pop(slot_key)
|
||||
logger.debug(f"Removed slot: {slot_key}")
|
||||
|
||||
async def _wait_for_healthy_allocations(
|
||||
self,
|
||||
min_count: int,
|
||||
timeout: float = 120.0
|
||||
) -> None:
|
||||
"""Wait for allocations to become healthy."""
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
def _summarize_alloc_detail(detail: Dict[str, Any]) -> str:
|
||||
task_states = detail.get("TaskStates") or {}
|
||||
parts: List[str] = []
|
||||
if isinstance(task_states, dict):
|
||||
for task_name, st in task_states.items():
|
||||
events = (st or {}).get("Events") or []
|
||||
if isinstance(events, list) and events:
|
||||
# Include a few recent events; the latest can be a generic restart message
|
||||
# while the true root cause is slightly earlier (e.g. image pull failure).
|
||||
recent = events[-3:]
|
||||
msgs: List[str] = []
|
||||
for ev in recent:
|
||||
desc = ev.get("DisplayMessage") or ev.get("Message") or ev.get("Type") or ""
|
||||
if desc:
|
||||
msgs.append(desc)
|
||||
if msgs:
|
||||
parts.append(f"{task_name}: " + " | ".join(msgs))
|
||||
return "; ".join(parts)
|
||||
|
||||
def _alloc_events_lower(detail: Dict[str, Any]) -> str:
|
||||
task_states = detail.get("TaskStates") or {}
|
||||
texts: List[str] = []
|
||||
if isinstance(task_states, dict):
|
||||
for _task_name, st in task_states.items():
|
||||
events = (st or {}).get("Events") or []
|
||||
if isinstance(events, list):
|
||||
for ev in events[-10:]:
|
||||
desc = ev.get("DisplayMessage") or ev.get("Message") or ev.get("Type") or ""
|
||||
if desc:
|
||||
texts.append(desc)
|
||||
return " ".join(texts).lower()
|
||||
|
||||
while time.time() - start < timeout:
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
|
||||
healthy_count = 0
|
||||
for alloc in allocs:
|
||||
if alloc.status == AllocationStatus.RUNNING and alloc.http_address:
|
||||
if await self.executor.health_check(alloc.http_address):
|
||||
healthy_count += 1
|
||||
|
||||
# Fast-fail on obvious driver/image errors to avoid waiting out the full timeout.
|
||||
if alloc.id:
|
||||
detail = await self.nomad.get_allocation(alloc.id)
|
||||
if isinstance(detail, dict):
|
||||
summary = _summarize_alloc_detail(detail)
|
||||
lowered = _alloc_events_lower(detail) or summary.lower()
|
||||
if "failed to pull" in lowered or "pull access denied" in lowered:
|
||||
raise RuntimeError(
|
||||
"Nomad allocation failed to start due to a Docker image pull error. "
|
||||
f"Allocation {alloc.id[:8]}: {summary}\n"
|
||||
"If you're using a local image tag (e.g. `atropos-sandbox:local`) on macOS, "
|
||||
"make sure the image is loaded into Docker, e.g.:\n"
|
||||
" docker buildx build --load -t atropos-sandbox:local -f Hermes-Agent/atropos/Dockerfile Hermes-Agent/atropos"
|
||||
)
|
||||
if "exceeded allowed attempts" in lowered:
|
||||
raise RuntimeError(
|
||||
"Nomad allocation is crash-looping and has entered restart backoff. "
|
||||
f"Allocation {alloc.id[:8]}: {summary}\n"
|
||||
"Inspect logs with:\n"
|
||||
f" nomad alloc logs -stderr -task sandbox-server {alloc.id}\n"
|
||||
"Common causes include: missing local Docker image tag, container entrypoint error, "
|
||||
"or sandbox-server startup failure."
|
||||
)
|
||||
|
||||
if healthy_count >= min_count:
|
||||
return
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# Timed out: include allocation status detail to help debugging.
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
alloc_lines: List[str] = []
|
||||
for alloc in allocs[:10]:
|
||||
addr = alloc.http_address or "-"
|
||||
line = f"{alloc.id[:8]} status={alloc.status.value} http={addr}"
|
||||
detail = await self.nomad.get_allocation(alloc.id)
|
||||
if isinstance(detail, dict):
|
||||
summary = _summarize_alloc_detail(detail)
|
||||
if summary:
|
||||
line += f" detail={summary}"
|
||||
alloc_lines.append(line)
|
||||
|
||||
hint = (
|
||||
"Timed out waiting for healthy sandbox allocations.\n"
|
||||
f"Job: {self.config.job_id}, desired_healthy: {min_count}\n"
|
||||
"Allocations:\n - " + "\n - ".join(alloc_lines)
|
||||
)
|
||||
raise RuntimeError(hint)
|
||||
|
||||
async def _try_scale_up(self) -> bool:
|
||||
"""Attempt to scale up the job."""
|
||||
import time
|
||||
|
||||
async with self._scale_lock:
|
||||
# Check cooldown
|
||||
if time.time() - self._last_scale_time < self.config.scale_cooldown:
|
||||
return False
|
||||
|
||||
# Check max containers
|
||||
status = await self.nomad.get_job_status(self.config.job_id)
|
||||
if status is None:
|
||||
return False
|
||||
|
||||
current_count = status.count
|
||||
if current_count >= self.config.max_containers:
|
||||
logger.warning(f"Cannot scale up: already at max ({self.config.max_containers})")
|
||||
return False
|
||||
|
||||
# Scale up
|
||||
new_count = min(current_count + 1, self.config.max_containers)
|
||||
logger.info(f"Scaling up from {current_count} to {new_count} containers")
|
||||
|
||||
scale_resp = await self.nomad.scale_job(
|
||||
self.config.job_id,
|
||||
count=new_count,
|
||||
task_group="sandbox",
|
||||
)
|
||||
|
||||
# Nomad may return non-JSON errors (e.g. plain text) with a status field.
|
||||
if isinstance(scale_resp, dict) and scale_resp.get("status", 200) >= 400:
|
||||
logger.warning(f"Scale request rejected: {scale_resp}")
|
||||
self._last_scale_time = time.time()
|
||||
return False
|
||||
|
||||
self._last_scale_time = time.time()
|
||||
|
||||
# Wait for new allocation in the background so contended acquires can still
|
||||
# make progress (e.g. by grabbing slots released by other trajectories).
|
||||
if self._scale_task is None or self._scale_task.done():
|
||||
self._scale_task = asyncio.create_task(self._wait_for_scale(new_count))
|
||||
|
||||
return True
|
||||
|
||||
async def _wait_for_scale(self, desired_count: int) -> None:
|
||||
try:
|
||||
await self._wait_for_healthy_allocations(desired_count, timeout=60.0)
|
||||
await self._refresh_slots()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to scale up: {e}")
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""Background task to monitor container health."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.config.health_check_interval)
|
||||
await self._refresh_slots()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get pool statistics."""
|
||||
slots_by_state = {}
|
||||
for slot in self._slots.values():
|
||||
state = slot.state.value
|
||||
slots_by_state[state] = slots_by_state.get(state, 0) + 1
|
||||
|
||||
container_count = len({s.alloc_id for s in self._slots.values()}) if self._slots else 0
|
||||
|
||||
return {
|
||||
"total_slots": self.total_slots,
|
||||
"available_slots": self.available_slots,
|
||||
"acquired_slots": self.acquired_slots,
|
||||
"containers": container_count,
|
||||
"slots_by_state": slots_by_state,
|
||||
"started": self._started,
|
||||
}
|
||||
159
atropos/slots/slot.py
Normal file
159
atropos/slots/slot.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Slot abstraction for atropos-agent.
|
||||
|
||||
A Slot represents an isolated workspace for a single agent trajectory.
|
||||
Slots are hosted on Nomad allocations and provide workspace isolation
|
||||
via filesystem directories.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
class SlotState(Enum):
|
||||
"""State of a slot in the pool."""
|
||||
AVAILABLE = "available" # Ready to be acquired
|
||||
ACQUIRED = "acquired" # Assigned to a trajectory
|
||||
EXECUTING = "executing" # Currently executing a tool
|
||||
RELEASING = "releasing" # Being released back to pool
|
||||
ERROR = "error" # In error state
|
||||
|
||||
|
||||
@dataclass
|
||||
class Slot:
|
||||
"""
|
||||
An isolated workspace for a single agent trajectory.
|
||||
|
||||
Slots are the unit of scheduling - each trajectory runs in its own slot,
|
||||
with an isolated workspace directory. Multiple slots share a container.
|
||||
|
||||
Attributes:
|
||||
slot_id: Unique identifier for this slot (e.g., "slot_0")
|
||||
alloc_id: Nomad allocation ID hosting this slot
|
||||
container_addr: HTTP address of the sandbox server (e.g., "http://10.0.0.1:8080")
|
||||
workspace_dir: Path to workspace in container (e.g., "/data/slot_0")
|
||||
state: Current state of the slot
|
||||
trajectory_id: ID of trajectory currently using this slot (if acquired)
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
slot_id: str
|
||||
alloc_id: str
|
||||
container_addr: str
|
||||
workspace_dir: str = ""
|
||||
state: SlotState = SlotState.AVAILABLE
|
||||
trajectory_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set default workspace_dir if not provided."""
|
||||
if not self.workspace_dir:
|
||||
self.workspace_dir = f"/data/{self.slot_id}"
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if slot is available for acquisition."""
|
||||
return self.state == SlotState.AVAILABLE
|
||||
|
||||
@property
|
||||
def is_acquired(self) -> bool:
|
||||
"""Check if slot is currently acquired."""
|
||||
return self.state in (SlotState.ACQUIRED, SlotState.EXECUTING)
|
||||
|
||||
def acquire(self, trajectory_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Mark slot as acquired by a trajectory.
|
||||
|
||||
Args:
|
||||
trajectory_id: Optional ID of acquiring trajectory
|
||||
"""
|
||||
if not self.is_available:
|
||||
raise RuntimeError(f"Cannot acquire slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.ACQUIRED
|
||||
self.trajectory_id = trajectory_id or str(uuid.uuid4())
|
||||
|
||||
def start_execution(self, execution_id: Optional[str] = None) -> None:
|
||||
"""Mark slot as executing."""
|
||||
if self.state != SlotState.ACQUIRED:
|
||||
raise RuntimeError(f"Cannot start execution on slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.EXECUTING
|
||||
if execution_id:
|
||||
self.metadata["current_execution_id"] = execution_id
|
||||
|
||||
def end_execution(self) -> None:
|
||||
"""Mark execution as complete, return to acquired state."""
|
||||
if self.state != SlotState.EXECUTING:
|
||||
raise RuntimeError(f"Cannot end execution on slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.ACQUIRED
|
||||
self.metadata.pop("current_execution_id", None)
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release slot back to available state."""
|
||||
self.state = SlotState.AVAILABLE
|
||||
self.trajectory_id = None
|
||||
self.metadata.pop("current_execution_id", None)
|
||||
|
||||
def mark_error(self, error: str) -> None:
|
||||
"""Mark slot as in error state."""
|
||||
self.state = SlotState.ERROR
|
||||
self.metadata["error"] = error
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"slot_id": self.slot_id,
|
||||
"alloc_id": self.alloc_id,
|
||||
"container_addr": self.container_addr,
|
||||
"workspace_dir": self.workspace_dir,
|
||||
"state": self.state.value,
|
||||
"trajectory_id": self.trajectory_id,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Slot":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
slot_id=data["slot_id"],
|
||||
alloc_id=data["alloc_id"],
|
||||
container_addr=data["container_addr"],
|
||||
workspace_dir=data.get("workspace_dir", ""),
|
||||
state=SlotState(data.get("state", "available")),
|
||||
trajectory_id=data.get("trajectory_id"),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Slot({self.slot_id}, state={self.state.value}, alloc={self.alloc_id[:8]}...)"
|
||||
|
||||
|
||||
def create_slots_for_allocation(
|
||||
alloc_id: str,
|
||||
container_addr: str,
|
||||
num_slots: int = 10,
|
||||
) -> list["Slot"]:
|
||||
"""
|
||||
Create slots for a Nomad allocation.
|
||||
|
||||
Args:
|
||||
alloc_id: Nomad allocation ID
|
||||
container_addr: HTTP address of sandbox server
|
||||
num_slots: Number of slots to create
|
||||
|
||||
Returns:
|
||||
List of Slot objects
|
||||
"""
|
||||
slots = []
|
||||
for i in range(num_slots):
|
||||
slot_id = f"slot_{i}"
|
||||
slots.append(Slot(
|
||||
slot_id=slot_id,
|
||||
alloc_id=alloc_id,
|
||||
container_addr=container_addr,
|
||||
workspace_dir=f"/data/{slot_id}",
|
||||
))
|
||||
return slots
|
||||
2
atropos/terminal/__init__.py
Normal file
2
atropos/terminal/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""Terminal helpers for stateful sandbox interactions."""
|
||||
|
||||
115
atropos/terminal/asciinema_stream.py
Normal file
115
atropos/terminal/asciinema_stream.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pyte
|
||||
|
||||
|
||||
class AsciinemaStreamDecoder:
|
||||
def __init__(self, *, default_width: int = 80, default_height: int = 24) -> None:
|
||||
self._default_width = max(1, int(default_width))
|
||||
self._default_height = max(1, int(default_height))
|
||||
self._buffer = ""
|
||||
self._has_header = False
|
||||
self.width = self._default_width
|
||||
self.height = self._default_height
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._buffer = ""
|
||||
self._has_header = False
|
||||
self.width = self._default_width
|
||||
self.height = self._default_height
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
|
||||
def feed(self, chunk: str | bytes) -> None:
|
||||
if not chunk:
|
||||
return
|
||||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode("utf-8", errors="replace")
|
||||
self._buffer += chunk
|
||||
while True:
|
||||
line, sep, rest = self._buffer.partition("\n")
|
||||
if not sep:
|
||||
break
|
||||
self._buffer = rest
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parsed = self._parse_json_line(line)
|
||||
if parsed is None:
|
||||
continue
|
||||
if not self._has_header:
|
||||
if isinstance(parsed, dict):
|
||||
self._init_from_header(parsed)
|
||||
continue
|
||||
if isinstance(parsed, list):
|
||||
self._has_header = True
|
||||
self._apply_event(parsed)
|
||||
continue
|
||||
continue
|
||||
if isinstance(parsed, list):
|
||||
self._apply_event(parsed)
|
||||
|
||||
def render(self) -> str:
|
||||
return "\n".join(self._screen.display)
|
||||
|
||||
def _parse_json_line(self, line: str) -> Any | None:
|
||||
try:
|
||||
return json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _init_from_header(self, header: dict[str, Any]) -> None:
|
||||
width = _coerce_int(
|
||||
header.get("width") or header.get("columns") or header.get("cols"),
|
||||
self._default_width,
|
||||
)
|
||||
height = _coerce_int(
|
||||
header.get("height") or header.get("rows") or header.get("lines"),
|
||||
self._default_height,
|
||||
)
|
||||
self.width = max(1, width)
|
||||
self.height = max(1, height)
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
self._has_header = True
|
||||
|
||||
def _apply_event(self, event: list[Any]) -> None:
|
||||
if len(event) < 2:
|
||||
return
|
||||
event_type = event[1]
|
||||
payload = event[2] if len(event) > 2 else ""
|
||||
if event_type == "o":
|
||||
if isinstance(payload, str):
|
||||
self._stream.feed(payload)
|
||||
elif event_type == "r":
|
||||
width, height = _parse_resize(payload)
|
||||
if width and height:
|
||||
self.width = width
|
||||
self.height = height
|
||||
self._screen.resize(width, height)
|
||||
|
||||
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return int(default)
|
||||
|
||||
|
||||
def _parse_resize(payload: Any) -> tuple[int, int]:
|
||||
if isinstance(payload, str) and "x" in payload:
|
||||
left, right = payload.lower().split("x", 1)
|
||||
return _coerce_int(left, 0), _coerce_int(right, 0)
|
||||
if isinstance(payload, dict):
|
||||
width = _coerce_int(payload.get("width") or payload.get("columns") or payload.get("cols"), 0)
|
||||
height = _coerce_int(payload.get("height") or payload.get("rows") or payload.get("lines"), 0)
|
||||
return width, height
|
||||
if isinstance(payload, list) and len(payload) >= 2:
|
||||
return _coerce_int(payload[0], 0), _coerce_int(payload[1], 0)
|
||||
return 0, 0
|
||||
|
||||
31
atropos/tools/__init__.py
Normal file
31
atropos/tools/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Tool abstractions for atropos-agent.
|
||||
|
||||
Provides base Tool class, ToolCall/ToolResult types, and specialized tools.
|
||||
|
||||
Kept modules:
|
||||
- base.py: ToolSchema, ToolCall, ToolResult, Tool ABC, ToolRegistry
|
||||
- tool_executor.py: Batched execution queue with slot routing
|
||||
- terminal_stateful_tool.py: Persistent terminal sessions
|
||||
- tmux_tool.py: Tmux-based streaming terminal
|
||||
|
||||
Removed (replaced by hermes-agent equivalents):
|
||||
- build_registry.py → model_tools.py + toolsets.py
|
||||
- sandbox_stubs.py → atropos/backends/ execute() methods
|
||||
- hermes_external_tools.py → environments/agent_loop.py handle_function_call()
|
||||
- toolset_resolver.py → toolsets.py
|
||||
"""
|
||||
|
||||
from .base import Tool, ToolCall, ToolRegistry, ToolResult, ToolSchema
|
||||
from .terminal_stateful_tool import TerminalStatefulTool
|
||||
from .tmux_tool import TmuxTool
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"ToolCall",
|
||||
"ToolRegistry",
|
||||
"ToolResult",
|
||||
"ToolSchema",
|
||||
"TerminalStatefulTool",
|
||||
"TmuxTool",
|
||||
]
|
||||
423
atropos/tools/base.py
Normal file
423
atropos/tools/base.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Base Tool abstraction for atropos-agent.
|
||||
|
||||
Tools follow a simple pattern:
|
||||
1. Define schema (name, description, parameters)
|
||||
2. Implement execute() method
|
||||
3. Return ToolResult with output/error
|
||||
|
||||
Tool calls use Hermes-style XML tags:
|
||||
<tool_call>{"name": "bash", "arguments": {"command": "ls"}}</tool_call>
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSchema:
|
||||
"""JSON Schema for a tool's parameters."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any] = field(default_factory=dict)
|
||||
required: List[str] = field(default_factory=list)
|
||||
external: bool = False # Whether the tool must be executed via an external ToolServer (secret proxy) and not inside the sandbox.
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to OpenAI-compatible function schema."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": self.parameters,
|
||||
"required": self.required,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def to_prompt_description(self) -> str:
|
||||
"""Convert to human-readable description for system prompt."""
|
||||
params_desc = []
|
||||
for name, spec in self.parameters.items():
|
||||
req = "(required)" if name in self.required else "(optional)"
|
||||
desc = spec.get("description", "")
|
||||
param_type = spec.get("type", "string")
|
||||
params_desc.append(f" - {name} ({param_type}) {req}: {desc}")
|
||||
|
||||
params_str = "\n".join(params_desc) if params_desc else " (no parameters)"
|
||||
return f"**{self.name}**: {self.description}\nParameters:\n{params_str}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""A parsed tool call from model output."""
|
||||
|
||||
name: str
|
||||
arguments: Dict[str, Any]
|
||||
raw_text: str = "" # Original XML/JSON text
|
||||
uniq_id: str = field(default_factory=lambda: str(uuid.uuid4())) # Unique tool-call id for traceability/reconstruction.
|
||||
|
||||
@classmethod
|
||||
def parse_from_text(cls, text: str) -> List["ToolCall"]:
|
||||
"""
|
||||
Extract tool calls from text using Hermes-style XML tags.
|
||||
|
||||
Supported formats (STRICT: requires well-formed closing tags):
|
||||
- Hermes JSON wrapper:
|
||||
<tool_call>{"name": "...", "arguments": {...}}</tool_call>
|
||||
- GLM/llama.cpp style:
|
||||
<tool_call>terminal{"command":"ls -la"}</tool_call>
|
||||
"""
|
||||
calls: List["ToolCall"] = []
|
||||
|
||||
if not text:
|
||||
return calls
|
||||
|
||||
def _append_from_payload(*, name: str, arguments: Dict[str, Any], raw: str, uniq_id: Optional[str] = None) -> None:
|
||||
if not isinstance(name, str) or not name:
|
||||
return
|
||||
if not isinstance(arguments, dict):
|
||||
return
|
||||
calls.append(
|
||||
cls(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
raw_text=raw,
|
||||
uniq_id=uniq_id or str(uuid.uuid4()),
|
||||
)
|
||||
)
|
||||
|
||||
# STRICT parsing: only accept well-formed <tool_call>...</tool_call> blocks.
|
||||
pattern = r"<tool_call>\s*(.*?)\s*</tool_call>"
|
||||
for inner in re.findall(pattern, text, re.DOTALL):
|
||||
cleaned = (inner or "").strip()
|
||||
if not cleaned:
|
||||
continue
|
||||
|
||||
# Hermes JSON wrapper.
|
||||
if cleaned.startswith("{"):
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
uniq_id = data.get("uniq_id") or data.get("id") or None
|
||||
_append_from_payload(
|
||||
name=data.get("name", ""),
|
||||
arguments=data.get("arguments", {}),
|
||||
raw=inner,
|
||||
uniq_id=uniq_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# GLM/llama.cpp style: terminal{...}
|
||||
m = re.match(r"^\s*([A-Za-z0-9_.:\\-]+)\s*(\{.*\})\s*$", cleaned, re.DOTALL)
|
||||
if not m:
|
||||
continue
|
||||
name = m.group(1)
|
||||
args_text = m.group(2)
|
||||
try:
|
||||
args = json.loads(args_text)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
_append_from_payload(name=name, arguments=args, raw=inner)
|
||||
|
||||
return calls
|
||||
|
||||
@classmethod
|
||||
def has_tool_call(cls, text: str) -> bool:
|
||||
"""Check if text contains any tool calls."""
|
||||
return bool(re.search(r"<tool_call>", text))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result from executing a tool."""
|
||||
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
uniq_id: Optional[str] = None # Should match ToolCall.uniq_id for async execution tracking.
|
||||
|
||||
def to_xml(self) -> str:
|
||||
"""Format as XML for including in conversation."""
|
||||
data = {
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
}
|
||||
if self.uniq_id:
|
||||
data["uniq_id"] = self.uniq_id
|
||||
if self.error:
|
||||
data["error"] = self.error
|
||||
if self.metadata:
|
||||
data["metadata"] = self.metadata
|
||||
return f"<tool_response>{json.dumps(data)}</tool_response>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
"error": self.error,
|
||||
"metadata": self.metadata,
|
||||
"uniq_id": self.uniq_id,
|
||||
}
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
Abstract base class for tools.
|
||||
|
||||
Subclasses must implement:
|
||||
- schema: ToolSchema describing the tool
|
||||
- execute(): async method that performs the tool action
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def schema(self) -> ToolSchema:
|
||||
"""Return the tool's schema."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Tool name (from schema)."""
|
||||
return self.schema.name
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""
|
||||
Execute the tool with given arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: Tool-specific arguments
|
||||
|
||||
Returns:
|
||||
ToolResult with success/failure and output
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Return whether this tool should be exposed/executable in the current process.
|
||||
|
||||
Tools that depend on optional binaries/services/env vars can override this
|
||||
to avoid advertising a tool that will fail at runtime.
|
||||
"""
|
||||
return True, None
|
||||
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Allow calling tool instance directly."""
|
||||
return await self.execute(**kwargs)
|
||||
|
||||
# Note: This is only wrapping declarations for the external ToolServer (for execution on external process tools), and tools preinstalled in envs
|
||||
class ToolRegistry:
|
||||
"""Registry of available tools."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, Tool] = {}
|
||||
|
||||
def register(self, tool: Tool) -> None:
|
||||
"""Register a tool."""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def get(self, name: str) -> Optional[Tool]:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(self) -> List[Tool]:
|
||||
"""List all registered tools."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def get_schemas(self) -> List[ToolSchema]:
|
||||
"""Get schemas for all registered tools."""
|
||||
return [tool.schema for tool in self._tools.values()]
|
||||
|
||||
def get_prompt_description(self) -> str:
|
||||
"""Generate tool descriptions for system prompt."""
|
||||
descriptions = [tool.schema.to_prompt_description() for tool in self._tools.values()]
|
||||
return "\n\n".join(descriptions)
|
||||
|
||||
def get_prompt_tool_definitions_json(self) -> str:
|
||||
"""
|
||||
Return a Hermes-style JSON list of tool definitions for use inside a `<tools>...</tools>` block.
|
||||
|
||||
Hermes trajectories historically use a simplified schema list:
|
||||
[{"name": ..., "description": ..., "parameters": {...}, "required": null}, ...]
|
||||
"""
|
||||
formatted: List[Dict[str, Any]] = []
|
||||
for tool in self._tools.values():
|
||||
fn = tool.schema.to_dict().get("function", {})
|
||||
formatted.append(
|
||||
{
|
||||
"name": fn.get("name", tool.name),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
# Keep parity with Hermes saved trajectories (required is typically null there).
|
||||
"required": None,
|
||||
}
|
||||
)
|
||||
return json.dumps(formatted, ensure_ascii=False)
|
||||
|
||||
async def execute(self, call: ToolCall) -> ToolResult:
|
||||
"""Execute a tool call."""
|
||||
tool = self.get(call.name)
|
||||
if tool is None:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Unknown tool: {call.name}",
|
||||
uniq_id=call.uniq_id,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await tool.execute(**call.arguments)
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = call.uniq_id
|
||||
return result
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Tool execution error: {str(e)}",
|
||||
uniq_id=call.uniq_id,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FastAPI / transport models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ToolCallPayload(BaseModel):
|
||||
name: str
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict)
|
||||
uniq_id: str
|
||||
|
||||
@classmethod
|
||||
def from_tool_call(cls, call: ToolCall) -> "ToolCallPayload":
|
||||
return cls(name=call.name, arguments=call.arguments, uniq_id=call.uniq_id)
|
||||
|
||||
def to_tool_call(self) -> ToolCall:
|
||||
return ToolCall(name=self.name, arguments=self.arguments, uniq_id=self.uniq_id)
|
||||
|
||||
|
||||
class ToolResultPayload(BaseModel):
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
uniq_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_tool_result(cls, result: ToolResult) -> "ToolResultPayload":
|
||||
return cls(
|
||||
success=result.success,
|
||||
output=result.output,
|
||||
error=result.error,
|
||||
metadata=result.metadata,
|
||||
uniq_id=result.uniq_id,
|
||||
)
|
||||
|
||||
def to_tool_result(self) -> ToolResult:
|
||||
return ToolResult(
|
||||
success=self.success,
|
||||
output=self.output,
|
||||
error=self.error,
|
||||
metadata=self.metadata,
|
||||
uniq_id=self.uniq_id,
|
||||
)
|
||||
|
||||
|
||||
class ToolExecutorExecuteRequest(BaseModel):
|
||||
trajectory_id: str
|
||||
tool: ToolCallPayload
|
||||
timeout_s: Optional[float] = None
|
||||
|
||||
|
||||
class ToolExecutorReleaseRequest(BaseModel):
|
||||
trajectory_id: str
|
||||
reset_workspace: bool = False
|
||||
|
||||
|
||||
class ToolServerExecuteRequest(BaseModel):
|
||||
trajectory_id: Optional[str] = None
|
||||
tool: ToolCallPayload
|
||||
timeout_s: Optional[float] = None
|
||||
# Optional sandbox context for tools that need workspace artifacts.
|
||||
# This is set by ToolExecutor and is NOT model-controlled.
|
||||
slot_id: Optional[str] = None
|
||||
container_addr: Optional[str] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Artifact transport models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ArtifactReadRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str
|
||||
encoding: Literal["text", "base64"] = "text"
|
||||
max_bytes: Optional[int] = None
|
||||
include_sha256: bool = False
|
||||
|
||||
|
||||
class ArtifactReadResponsePayload(BaseModel):
|
||||
success: bool
|
||||
content: str = ""
|
||||
error: str = ""
|
||||
encoding: str = "text"
|
||||
truncated: bool = False
|
||||
bytes: int = 0
|
||||
file_size: Optional[int] = None
|
||||
path: str = ""
|
||||
mime: Optional[str] = None
|
||||
sha256: Optional[str] = None
|
||||
|
||||
|
||||
class ArtifactListRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str = "."
|
||||
recursive: bool = False
|
||||
max_entries: Optional[int] = None
|
||||
|
||||
|
||||
class ArtifactListEntryPayload(BaseModel):
|
||||
path: str
|
||||
is_dir: bool
|
||||
size: int
|
||||
mtime: float
|
||||
|
||||
|
||||
class ArtifactListResponsePayload(BaseModel):
|
||||
success: bool
|
||||
entries: List[ArtifactListEntryPayload] = Field(default_factory=list)
|
||||
truncated: bool = False
|
||||
error: str = ""
|
||||
|
||||
|
||||
class ArtifactArchiveRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str = "."
|
||||
format: Literal["tar.gz", "tgz"] = "tar.gz"
|
||||
max_bytes: Optional[int] = None
|
||||
max_entries: Optional[int] = None
|
||||
|
||||
|
||||
class ArtifactArchiveResponsePayload(BaseModel):
|
||||
success: bool
|
||||
content: str = ""
|
||||
error: str = ""
|
||||
encoding: str = "base64"
|
||||
format: str = "tar.gz"
|
||||
bytes: int = 0
|
||||
entry_count: int = 0
|
||||
45
atropos/tools/terminal_stateful_tool.py
Normal file
45
atropos/tools/terminal_stateful_tool.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Stateful terminal tool schema.
|
||||
|
||||
This is a sandbox tool that routes to the sandbox server as `bash_stateful`
|
||||
via ToolExecutor mapping. It exists to expose an explicit, opt-in terminal
|
||||
primitive suitable for stateful workflows (e.g. tmux sessions / TUIs).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
class TerminalStatefulTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="terminal_stateful",
|
||||
description=(
|
||||
"Execute a command in the sandbox, allowing stateful/background processes to persist "
|
||||
"across tool calls within the same trajectory slot (e.g. tmux sessions). "
|
||||
"Use sparingly; output is still non-interactive."
|
||||
),
|
||||
parameters={
|
||||
"command": {"type": "string", "description": "The command to execute"},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (optional).",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
required=["command"],
|
||||
)
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
async def execute(self, command: str, timeout: Optional[int] = None) -> ToolResult:
|
||||
_ = (command, timeout)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="terminal_stateful must be executed via ToolExecutor inside the sandbox",
|
||||
)
|
||||
89
atropos/tools/tmux_tool.py
Normal file
89
atropos/tools/tmux_tool.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
tmux tool schema (sandbox).
|
||||
|
||||
This is a sandbox tool that provides basic tmux session control suitable for
|
||||
TUI-style terminal interactions:
|
||||
- send keys (arrow keys, enter, etc.)
|
||||
- capture the current screen buffer
|
||||
|
||||
Execution is routed by ToolExecutor to the sandbox server's `tmux` backend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
class TmuxTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="tmux",
|
||||
description=(
|
||||
"Control a per-trajectory tmux session inside the sandbox (stateful terminal). "
|
||||
"Use this for TUI-style interactions: send keys and capture the current screen."
|
||||
),
|
||||
parameters={
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform: start | send_keys | stream | stop.",
|
||||
"enum": ["start", "send_keys", "stream", "stop", "capture"],
|
||||
},
|
||||
"keys": {
|
||||
"description": "Keys to send (string or list of strings) when action=send_keys.",
|
||||
},
|
||||
"block": {
|
||||
"type": "boolean",
|
||||
"description": "If true, wait for shell command completion (only valid at a shell prompt).",
|
||||
"default": False,
|
||||
},
|
||||
"min_wait_s": {
|
||||
"type": "number",
|
||||
"description": "For non-blocking send_keys, sleep this long after sending keys (seconds).",
|
||||
"default": 0.0,
|
||||
},
|
||||
"max_wait_s": {
|
||||
"type": "number",
|
||||
"description": "For blocking send_keys, max time to wait for completion (seconds).",
|
||||
},
|
||||
"capture_entire": {
|
||||
"type": "boolean",
|
||||
"description": "Deprecated. Streaming is preferred.",
|
||||
"default": False,
|
||||
},
|
||||
"max_bytes": {
|
||||
"type": "integer",
|
||||
"description": "Max bytes to return per stream call.",
|
||||
},
|
||||
"reset": {
|
||||
"type": "boolean",
|
||||
"description": "If true, reset stream offset to the beginning of the asciinema recording.",
|
||||
"default": False,
|
||||
},
|
||||
"pane_width": {
|
||||
"type": "integer",
|
||||
"description": "Pane width for action=start (columns).",
|
||||
"minimum": 20,
|
||||
},
|
||||
"pane_height": {
|
||||
"type": "integer",
|
||||
"description": "Pane height for action=start (rows).",
|
||||
"minimum": 10,
|
||||
},
|
||||
},
|
||||
required=["action"],
|
||||
)
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
async def execute(self, **kwargs: Dict[str, Any]) -> ToolResult:
|
||||
# This tool is intended to be executed via ToolExecutor -> sandbox server.
|
||||
# We keep a safe fallback for non-sandbox contexts.
|
||||
action = str(kwargs.get("action") or "").strip()
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"tmux tool must be executed in the sandbox (got action={action!r})",
|
||||
)
|
||||
500
atropos/tools/tool_executor.py
Normal file
500
atropos/tools/tool_executor.py
Normal file
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
ToolExecutor - queued, batched tool dispatch for multiplexed agent trajectories.
|
||||
|
||||
This component is responsible for:
|
||||
- Maintaining trajectory -> Slot affinity (workspace continuity)
|
||||
- Batching sandbox tool calls across trajectories to maximize container utilization
|
||||
- Routing external tools (ToolSchema.external=True) to a ToolServer (Phase 4.5)
|
||||
|
||||
For now, only sandbox tools are executed:
|
||||
- bash
|
||||
- read_file
|
||||
- write_file
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import (
|
||||
ArtifactArchiveRequestPayload,
|
||||
ArtifactArchiveResponsePayload,
|
||||
ArtifactListRequestPayload,
|
||||
ArtifactListResponsePayload,
|
||||
ArtifactReadRequestPayload,
|
||||
ArtifactReadResponsePayload,
|
||||
ToolCall,
|
||||
ToolCallPayload,
|
||||
ToolRegistry,
|
||||
ToolResult,
|
||||
ToolResultPayload,
|
||||
ToolServerExecuteRequest,
|
||||
)
|
||||
from ..backends.base import ToolBackend
|
||||
from ..slots import Slot
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecutorConfig:
|
||||
batch_window_ms: int = 20
|
||||
max_batch_size: int = 200
|
||||
allow_network: bool = True
|
||||
require_sandbox: bool = False
|
||||
require_stateful_sandbox: bool = False
|
||||
tool_server_url: Optional[str] = None
|
||||
tool_server_token: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _QueuedToolRequest:
|
||||
trajectory_id: str
|
||||
call: ToolCall
|
||||
timeout_s: Optional[float]
|
||||
future: asyncio.Future
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
backend: ToolBackend,
|
||||
tools: ToolRegistry,
|
||||
config: Optional[ToolExecutorConfig] = None,
|
||||
) -> None:
|
||||
self.backend = backend
|
||||
self.tools = tools
|
||||
self.config = config or ToolExecutorConfig()
|
||||
|
||||
self._queue: asyncio.Queue[Optional[_QueuedToolRequest]] = asyncio.Queue()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._stopping = asyncio.Event()
|
||||
|
||||
self._slots_lock = asyncio.Lock()
|
||||
self._slot_by_trajectory: Dict[str, Slot] = {}
|
||||
|
||||
self._tool_server_client: Optional[httpx.AsyncClient] = None
|
||||
self._tool_server_lock = asyncio.Lock()
|
||||
|
||||
# lightweight stats for status endpoints
|
||||
self.total_requests: int = 0
|
||||
self.total_errors: int = 0
|
||||
self.latencies_s: List[float] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
|
||||
def queue_size(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def close(self) -> None:
|
||||
self._stopping.set()
|
||||
await self._queue.put(None)
|
||||
if self._task:
|
||||
await self._task
|
||||
self._task = None
|
||||
|
||||
client = self._tool_server_client
|
||||
self._tool_server_client = None
|
||||
if client is not None:
|
||||
await client.aclose()
|
||||
|
||||
# Best-effort release any remaining slots.
|
||||
async with self._slots_lock:
|
||||
slots = list(self._slot_by_trajectory.items())
|
||||
self._slot_by_trajectory.clear()
|
||||
|
||||
for _, slot in slots:
|
||||
try:
|
||||
await self.backend.release(slot, reset_workspace=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
trajectory_id: str,
|
||||
call: ToolCall,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> ToolResult:
|
||||
if self._task is None:
|
||||
raise RuntimeError("ToolExecutor not started (call start() first)")
|
||||
|
||||
# Allow tool args to suggest a timeout (Hermes-compatible terminal tool),
|
||||
# but never let the model choose "infinite" timeouts.
|
||||
if timeout_s is None:
|
||||
raw_timeout = call.arguments.get("timeout")
|
||||
if isinstance(raw_timeout, (int, float)):
|
||||
timeout_s = float(raw_timeout)
|
||||
if timeout_s is not None:
|
||||
timeout_s = max(1.0, min(float(timeout_s), 600.0))
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
fut: asyncio.Future = loop.create_future()
|
||||
started = time.perf_counter()
|
||||
await self._queue.put(_QueuedToolRequest(trajectory_id=trajectory_id, call=call, timeout_s=timeout_s, future=fut))
|
||||
try:
|
||||
result: ToolResult = await fut
|
||||
return result
|
||||
finally:
|
||||
self.latencies_s.append(time.perf_counter() - started)
|
||||
|
||||
async def release_trajectory(self, trajectory_id: str, reset_workspace: bool = False) -> None:
|
||||
async with self._slots_lock:
|
||||
slot = self._slot_by_trajectory.pop(trajectory_id, None)
|
||||
|
||||
if slot is not None:
|
||||
await self.backend.release(slot, reset_workspace=reset_workspace)
|
||||
|
||||
async def _get_slot_if_present(self, trajectory_id: str) -> Optional[Slot]:
|
||||
async with self._slots_lock:
|
||||
return self._slot_by_trajectory.get(trajectory_id)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Artifact helpers (optional)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def read_artifact(self, req: ArtifactReadRequestPayload) -> ArtifactReadResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactReadResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.read_artifact(
|
||||
slot,
|
||||
req.path,
|
||||
encoding=req.encoding,
|
||||
max_bytes=req.max_bytes,
|
||||
include_sha256=req.include_sha256,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactReadResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactReadResponsePayload(success=False, error=f"Invalid artifact read response: {e}")
|
||||
|
||||
async def list_artifacts(self, req: ArtifactListRequestPayload) -> ArtifactListResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactListResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.list_artifacts(
|
||||
slot,
|
||||
req.path,
|
||||
recursive=req.recursive,
|
||||
max_entries=req.max_entries,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactListResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactListResponsePayload(success=False, error=f"Invalid artifact list response: {e}")
|
||||
|
||||
async def archive_artifacts(self, req: ArtifactArchiveRequestPayload) -> ArtifactArchiveResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactArchiveResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.archive_artifacts(
|
||||
slot,
|
||||
req.path,
|
||||
archive_format=req.format,
|
||||
max_bytes=req.max_bytes,
|
||||
max_entries=req.max_entries,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactArchiveResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactArchiveResponsePayload(success=False, error=f"Invalid artifact archive response: {e}")
|
||||
|
||||
async def _get_or_acquire_slot(self, trajectory_id: str) -> Slot:
|
||||
async with self._slots_lock:
|
||||
existing = self._slot_by_trajectory.get(trajectory_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
slot = await self.backend.acquire(trajectory_id)
|
||||
|
||||
async with self._slots_lock:
|
||||
existing = self._slot_by_trajectory.get(trajectory_id)
|
||||
if existing is not None:
|
||||
# Another coroutine won the race; return its slot.
|
||||
await self.backend.release(slot, reset_workspace=False)
|
||||
return existing
|
||||
self._slot_by_trajectory[trajectory_id] = slot
|
||||
return slot
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
pending: List[_QueuedToolRequest] = []
|
||||
deadline: Optional[float] = None
|
||||
|
||||
batch_window_s = max(0.0, self.config.batch_window_ms / 1000.0)
|
||||
max_batch = max(1, self.config.max_batch_size)
|
||||
|
||||
while True:
|
||||
if self._stopping.is_set() and self._queue.empty() and not pending:
|
||||
break
|
||||
|
||||
timeout = None
|
||||
if pending and deadline is not None:
|
||||
timeout = max(0.0, deadline - time.perf_counter())
|
||||
|
||||
try:
|
||||
item = await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
||||
if item is None:
|
||||
continue
|
||||
pending.append(item)
|
||||
if len(pending) == 1:
|
||||
deadline = time.perf_counter() + batch_window_s
|
||||
if len(pending) < max_batch:
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
# batch window elapsed
|
||||
pass
|
||||
|
||||
if not pending:
|
||||
deadline = None
|
||||
continue
|
||||
|
||||
batch = pending
|
||||
pending = []
|
||||
deadline = None
|
||||
|
||||
await self._execute_batch(batch)
|
||||
|
||||
async def _get_tool_server_client(self) -> httpx.AsyncClient:
|
||||
url = self.config.tool_server_url
|
||||
if not url:
|
||||
raise RuntimeError("ToolServer not configured")
|
||||
|
||||
if self._tool_server_client is not None:
|
||||
return self._tool_server_client
|
||||
|
||||
async with self._tool_server_lock:
|
||||
if self._tool_server_client is None:
|
||||
self._tool_server_client = httpx.AsyncClient(base_url=url.rstrip("/"))
|
||||
return self._tool_server_client
|
||||
|
||||
def _tool_server_headers(self) -> Dict[str, str]:
|
||||
token = self.config.tool_server_token
|
||||
if not token:
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async def _execute_external(self, req: _QueuedToolRequest) -> ToolResult:
|
||||
client = await self._get_tool_server_client()
|
||||
slot_id: Optional[str] = None
|
||||
container_addr: Optional[str] = None
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is not None:
|
||||
slot_id = slot.slot_id
|
||||
container_addr = slot.container_addr
|
||||
|
||||
payload = ToolServerExecuteRequest(
|
||||
trajectory_id=req.trajectory_id,
|
||||
tool=ToolCallPayload.from_tool_call(req.call),
|
||||
timeout_s=req.timeout_s,
|
||||
slot_id=slot_id,
|
||||
container_addr=container_addr,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/execute",
|
||||
json=payload.model_dump(),
|
||||
headers=self._tool_server_headers(),
|
||||
timeout=req.timeout_s,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
parsed = ToolResultPayload(**data)
|
||||
result = parsed.to_tool_result()
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = req.call.uniq_id
|
||||
return result
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"External tool failed: {e}",
|
||||
uniq_id=req.call.uniq_id,
|
||||
)
|
||||
|
||||
async def _execute_batch(self, batch: List[_QueuedToolRequest]) -> None:
|
||||
# Resolve tool schemas once per request and separate sandbox/external/unknown.
|
||||
sandbox_items: List[_QueuedToolRequest] = []
|
||||
external_items: List[_QueuedToolRequest] = []
|
||||
unknown_items: List[_QueuedToolRequest] = []
|
||||
|
||||
for it in batch:
|
||||
tool = self.tools.get(it.call.name)
|
||||
if tool is None:
|
||||
unknown_items.append(it)
|
||||
continue
|
||||
|
||||
schema = tool.schema
|
||||
if not schema.external:
|
||||
sandbox_items.append(it)
|
||||
else:
|
||||
external_items.append(it)
|
||||
|
||||
for it in unknown_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Unknown tool: {it.call.name}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
|
||||
if external_items:
|
||||
if not self.config.tool_server_url:
|
||||
for it in external_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"External tool not available (ToolServer not configured): {it.call.name}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results = await asyncio.gather(*[self._execute_external(it) for it in external_items])
|
||||
for it, res in zip(external_items, results):
|
||||
self.total_requests += 1
|
||||
if not getattr(res, "success", False):
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(res)
|
||||
|
||||
if not sandbox_items:
|
||||
return
|
||||
|
||||
# Acquire slots for the distinct trajectories in this batch.
|
||||
try:
|
||||
traj_ids = list({it.trajectory_id for it in sandbox_items})
|
||||
slots = await asyncio.gather(*[self._get_or_acquire_slot(tid) for tid in traj_ids])
|
||||
slot_by_traj = dict(zip(traj_ids, slots))
|
||||
except Exception as e:
|
||||
for it in sandbox_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Failed to acquire slot: {e}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Group by timeout so we don't accidentally make short timeouts wait on long ones.
|
||||
by_timeout: Dict[float, List[_QueuedToolRequest]] = {}
|
||||
default_timeout = self.backend.default_timeout_s
|
||||
|
||||
for it in sandbox_items:
|
||||
t = it.timeout_s
|
||||
if t is None:
|
||||
t = default_timeout
|
||||
if t is None:
|
||||
t = 30.0
|
||||
by_timeout.setdefault(float(t), []).append(it)
|
||||
|
||||
for timeout_s, items in by_timeout.items():
|
||||
requests = []
|
||||
dispatched: List[_QueuedToolRequest] = []
|
||||
for it in items:
|
||||
slot = slot_by_traj[it.trajectory_id]
|
||||
tool_name = it.call.name
|
||||
args = dict(it.call.arguments)
|
||||
|
||||
# Hermes compatibility: treat `terminal` as an alias of sandbox `bash`.
|
||||
if tool_name == "terminal":
|
||||
if args.get("background"):
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error="terminal background execution is not supported in sandbox",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
tool_name = "bash"
|
||||
# `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args.
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "terminal_stateful":
|
||||
tool_name = "bash_stateful"
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "tmux":
|
||||
# `tmux` is a sandbox tool backed by the stateful session manager.
|
||||
# Network policy is env-controlled.
|
||||
args.pop("allow_network", None)
|
||||
|
||||
if tool_name == "bash":
|
||||
# Network policy is set by the environment/executor, not by the model.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_sandbox"] = bool(self.config.require_sandbox)
|
||||
# `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args.
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "bash_stateful":
|
||||
# Network policy is set by the environment/executor, not by the model.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args.pop("require_stateful_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox)
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "tmux":
|
||||
# Network policy applies to the underlying stateful session.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args.pop("require_stateful_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox)
|
||||
|
||||
requests.append((slot, tool_name, args))
|
||||
dispatched.append(it)
|
||||
|
||||
results = None
|
||||
try:
|
||||
if not dispatched:
|
||||
continue
|
||||
results = await self.backend.execute_batch(requests, timeout_s=timeout_s)
|
||||
except Exception as e:
|
||||
for it in items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Batch execution failed: {e}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
for it, res in zip(dispatched, results):
|
||||
self.total_requests += 1
|
||||
if not getattr(res, "success", False):
|
||||
self.total_errors += 1
|
||||
tool_result = res.to_tool_result()
|
||||
tool_result.uniq_id = it.call.uniq_id
|
||||
if not it.future.done():
|
||||
it.future.set_result(tool_result)
|
||||
415
atropos_compatible_agent.py
Normal file
415
atropos_compatible_agent.py
Normal file
@@ -0,0 +1,415 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Atropos-compatible Hermes agent runner.
|
||||
|
||||
This is a minimal subclass of Hermes-Agent's `AIAgent` that swaps the OpenAI
|
||||
function-calling backend for Atroposlib's `ManagedServer`/`ServerManager` backend
|
||||
and uses Hermes-style XML tool tags:
|
||||
|
||||
- <tool_call>{"name": "...", "arguments": {...}}</tool_call>
|
||||
- <tool_response>{...}</tool_response>
|
||||
|
||||
Tool observations are appended as `role="user"` messages containing one or more
|
||||
`<tool_response>` blocks so they survive common chat templates during tokenization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from model_tools import cleanup_vm, handle_function_call
|
||||
from run_agent import AIAgent
|
||||
|
||||
_TOOL_CALL_RE = re.compile(r"<tool_call>\\s*(.*?)\\s*</tool_call>", re.DOTALL)
|
||||
|
||||
|
||||
ATROPOS_TOOL_SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools.
|
||||
|
||||
## Available Tools
|
||||
<tools>
|
||||
{tool_descriptions}
|
||||
</tools>
|
||||
|
||||
## How to Use Tools
|
||||
To call a tool, output:
|
||||
<tool_call>{{"name": "tool_name", "arguments": {{"arg1": "value1"}}}}</tool_call>
|
||||
|
||||
You may include optional reasoning in <think>...</think> before tool calls.
|
||||
|
||||
After each tool call, you will receive tool results as:
|
||||
<tool_response>{{...}}</tool_response>
|
||||
|
||||
Continue until finished, then provide a final response with no <tool_call> blocks.
|
||||
"""
|
||||
|
||||
|
||||
class AtroposAIAgent(AIAgent):
|
||||
"""
|
||||
Hermes `AIAgent` variant that uses Atroposlib ServerManager/ManagedServer.
|
||||
|
||||
Notes:
|
||||
- The default Hermes `AIAgent` remains unchanged; this class is opt-in.
|
||||
- The underlying server must expose `managed_server(tokenizer=...)` OR be a single
|
||||
APIServer-compatible object usable by Atroposlib's `ManagedServer`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server: Any,
|
||||
tokenizer: Any = None,
|
||||
model: str = "local",
|
||||
max_iterations: int = 10,
|
||||
tool_delay: float = 0.0,
|
||||
enabled_toolsets: Optional[List[str]] = None,
|
||||
disabled_toolsets: Optional[List[str]] = None,
|
||||
save_trajectories: bool = False,
|
||||
verbose_logging: bool = False,
|
||||
quiet_mode: bool = False,
|
||||
ephemeral_system_prompt: Optional[str] = None,
|
||||
log_prefix_chars: int = 100,
|
||||
log_prefix: str = "",
|
||||
session_id: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
# Call parent init mainly to reuse tool selection + trajectory saving utilities.
|
||||
super().__init__(
|
||||
base_url="http://unused",
|
||||
api_key="dummy-key",
|
||||
model=model,
|
||||
max_iterations=max_iterations,
|
||||
tool_delay=tool_delay,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
disabled_toolsets=disabled_toolsets,
|
||||
save_trajectories=save_trajectories,
|
||||
verbose_logging=verbose_logging,
|
||||
quiet_mode=quiet_mode,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||
log_prefix_chars=log_prefix_chars,
|
||||
log_prefix=log_prefix,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@asynccontextmanager
|
||||
async def _managed(self) -> AsyncGenerator[Any, None]:
|
||||
if hasattr(self.server, "managed_server"):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"Using OpenAIServer with managed_server does not allow for state tracking",
|
||||
category=UserWarning,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
yield managed
|
||||
return
|
||||
|
||||
# Fall back to directly wrapping a single server object.
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
managed = ManagedServer(server=self.server, tokenizer=self.tokenizer)
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
managed.reset()
|
||||
|
||||
def _tool_descriptions_text(self) -> str:
|
||||
if not self.tools:
|
||||
return "(no tools available)"
|
||||
|
||||
parts: List[str] = []
|
||||
for tool in self.tools:
|
||||
fn = (tool or {}).get("function", {})
|
||||
name = fn.get("name", "")
|
||||
desc = (fn.get("description") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
if desc:
|
||||
parts.append(f"- {name}: {desc}")
|
||||
else:
|
||||
parts.append(f"- {name}")
|
||||
return "\n".join(parts) if parts else "(no tools available)"
|
||||
|
||||
def _build_system_prompt(self, system_message: Optional[str]) -> Optional[str]:
|
||||
tool_prompt = ATROPOS_TOOL_SYSTEM_PROMPT.format(
|
||||
tool_descriptions=self._tool_descriptions_text()
|
||||
)
|
||||
|
||||
parts: List[str] = []
|
||||
if system_message:
|
||||
parts.append(system_message)
|
||||
if self.ephemeral_system_prompt:
|
||||
parts.append(self.ephemeral_system_prompt)
|
||||
parts.append(tool_prompt)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _parse_tool_calls(self, content: str) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[str]]:
|
||||
"""
|
||||
Returns:
|
||||
(calls, errors)
|
||||
"""
|
||||
calls: List[Tuple[str, Dict[str, Any]]] = []
|
||||
errors: List[str] = []
|
||||
|
||||
for raw in _TOOL_CALL_RE.findall(content or ""):
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
errors.append(f"Invalid JSON inside <tool_call>: {exc}")
|
||||
continue
|
||||
|
||||
name = payload.get("name")
|
||||
args = payload.get("arguments", {})
|
||||
if not isinstance(name, str) or not name:
|
||||
errors.append("Tool call missing 'name' string")
|
||||
continue
|
||||
if not isinstance(args, dict):
|
||||
errors.append("Tool call 'arguments' must be an object")
|
||||
continue
|
||||
|
||||
calls.append((name, args))
|
||||
|
||||
return calls, errors
|
||||
|
||||
async def run_conversation_async(
|
||||
self,
|
||||
user_message: str,
|
||||
system_message: Optional[str] = None,
|
||||
conversation_history: Optional[List[Dict[str, Any]]] = None,
|
||||
task_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
import uuid
|
||||
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
messages: List[Dict[str, Any]] = conversation_history.copy() if conversation_history else []
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
active_system_prompt = self._build_system_prompt(system_message)
|
||||
|
||||
api_call_count = 0
|
||||
final_response: Optional[str] = None
|
||||
managed_state: Optional[Dict[str, Any]] = None
|
||||
completed = False
|
||||
|
||||
try:
|
||||
async with self._managed() as managed:
|
||||
while api_call_count < self.max_iterations:
|
||||
api_call_count += 1
|
||||
|
||||
api_messages = messages.copy()
|
||||
if active_system_prompt:
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
chat_kwargs: Dict[str, Any] = {"messages": api_messages, "n": 1}
|
||||
if self.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.max_tokens
|
||||
if self.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.temperature
|
||||
|
||||
# Prefer OpenAI tool calling when supported by the backend:
|
||||
# - Many providers normalize Hermes-style <tool_call> tags into tool_calls when `tools` is provided.
|
||||
# - ManagedServer (atroposlib) does prompt->completion conversion and does not support `tools`.
|
||||
# Only pass `tools` when we're calling an OpenAI-compatible chat endpoint directly.
|
||||
tool_schemas = self.tools if self.tools else None
|
||||
managed_cls = type(managed).__name__
|
||||
if tool_schemas and managed_cls != "ManagedServer":
|
||||
chat_kwargs["tools"] = tool_schemas
|
||||
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_REQUEST") == "1":
|
||||
meta = {
|
||||
"managed_type": managed_cls,
|
||||
"model": getattr(getattr(managed, "config", None), "model_name", self.model),
|
||||
"base_url": getattr(getattr(managed, "config", None), "base_url", None),
|
||||
"kwargs": chat_kwargs,
|
||||
}
|
||||
# Avoid dumping megabytes of data accidentally.
|
||||
# (Messages can be large; this is still "full" but bounded.)
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_REQUEST ===", flush=True)
|
||||
print(json.dumps(meta, ensure_ascii=False, indent=2)[:200_000], flush=True)
|
||||
|
||||
response = await managed.chat_completion(**chat_kwargs)
|
||||
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_RESPONSE") == "1":
|
||||
try:
|
||||
dumped = response.model_dump() # openai pydantic model
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_RESPONSE: ChatCompletion (raw) ===", flush=True)
|
||||
print(json.dumps(dumped, ensure_ascii=False, indent=2), flush=True)
|
||||
|
||||
if hasattr(managed, "get_state"):
|
||||
managed_state = managed.get_state()
|
||||
|
||||
msg = response.choices[0].message
|
||||
assistant_content = (msg.content or "")
|
||||
msg_reasoning = getattr(msg, "reasoning", None)
|
||||
|
||||
# Use tool_calls if the backend provides them (preferred).
|
||||
structured_tool_calls = getattr(msg, "tool_calls", None)
|
||||
|
||||
# If the backend emits content="" but includes useful text in reasoning,
|
||||
# use it for parsing *only if needed* (e.g. tool tags).
|
||||
if assistant_content == "" and isinstance(msg_reasoning, str) and msg_reasoning:
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_RESPONSE") == "1":
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_RESPONSE: message.reasoning present (content empty) ===", flush=True)
|
||||
print(msg_reasoning, flush=True)
|
||||
|
||||
assistant_msg: Dict[str, Any] = {"role": "assistant", "content": assistant_content}
|
||||
if structured_tool_calls:
|
||||
# Preserve tool_calls so the next request is consistent with OpenAI protocol.
|
||||
try:
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
||||
}
|
||||
for tc in structured_tool_calls
|
||||
]
|
||||
except Exception:
|
||||
# Best-effort; keep conversation moving.
|
||||
pass
|
||||
messages.append(assistant_msg)
|
||||
|
||||
# Mode A: OpenAI tool calling (preferred when supported)
|
||||
if structured_tool_calls:
|
||||
for tc in structured_tool_calls:
|
||||
tool_start = time.time()
|
||||
try:
|
||||
tool_args = json.loads(tc.function.arguments or "{}")
|
||||
except Exception:
|
||||
tool_args = {}
|
||||
tool_result = handle_function_call(tc.function.name, tool_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start
|
||||
|
||||
# Keep the raw tool result as tool content (OpenAI protocol expects role=tool).
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
if self.tool_delay and self.tool_delay > 0:
|
||||
await asyncio.sleep(self.tool_delay)
|
||||
|
||||
# Continue loop after tool execution.
|
||||
continue
|
||||
|
||||
# Mode B: Hermes XML tool tags in assistant text (fallback).
|
||||
parse_source = assistant_content or (msg_reasoning or "")
|
||||
tool_calls, parse_errors = self._parse_tool_calls(parse_source)
|
||||
|
||||
if parse_errors and not tool_calls:
|
||||
# Ask the model to retry with valid tool JSON.
|
||||
err_text = "; ".join(parse_errors[:3])
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"<tool_response>{json.dumps({'error': err_text}, ensure_ascii=False)}</tool_response>\n"
|
||||
"The previous <tool_call> blocks were invalid. Please output valid JSON inside <tool_call>."
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls: treat as final answer.
|
||||
final_response = (assistant_content or "").strip()
|
||||
completed = True
|
||||
break
|
||||
|
||||
tool_responses: List[str] = []
|
||||
for tool_name, tool_args in tool_calls:
|
||||
tool_start = time.time()
|
||||
tool_result = handle_function_call(tool_name, tool_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start
|
||||
|
||||
try:
|
||||
parsed = json.loads(tool_result)
|
||||
payload: Any = parsed
|
||||
except Exception:
|
||||
payload = tool_result
|
||||
|
||||
tool_payload = {
|
||||
"name": tool_name,
|
||||
"duration_s": round(tool_duration, 3),
|
||||
"result": payload,
|
||||
}
|
||||
tool_responses.append(
|
||||
f"<tool_response>{json.dumps(tool_payload, ensure_ascii=False)}</tool_response>"
|
||||
)
|
||||
|
||||
if self.tool_delay and self.tool_delay > 0:
|
||||
await asyncio.sleep(self.tool_delay)
|
||||
|
||||
messages.append({"role": "user", "content": "\n".join(tool_responses)})
|
||||
|
||||
if final_response is None:
|
||||
final_response = "I've reached the maximum number of iterations."
|
||||
|
||||
finally:
|
||||
try:
|
||||
cleanup_vm(effective_task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save trajectory using Hermes formatting (optional).
|
||||
self._save_trajectory(messages, user_message, completed=completed)
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": completed,
|
||||
"managed_state": managed_state,
|
||||
"system_prompt": active_system_prompt,
|
||||
"task_id": effective_task_id,
|
||||
}
|
||||
|
||||
def run_conversation(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Sync wrapper for convenience.
|
||||
|
||||
If called from within a running event loop (e.g. prompt_toolkit), this
|
||||
runs the async conversation in a dedicated thread to avoid nested loops.
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(self.run_conversation_async(*args, **kwargs))
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
out: "queue.Queue[object]" = queue.Queue(maxsize=1)
|
||||
|
||||
def runner() -> None:
|
||||
try:
|
||||
out.put(asyncio.run(self.run_conversation_async(*args, **kwargs)))
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
out.put(exc)
|
||||
|
||||
thread = threading.Thread(target=runner, daemon=True)
|
||||
thread.start()
|
||||
|
||||
result = out.get()
|
||||
if isinstance(result, BaseException):
|
||||
raise result
|
||||
return result # type: ignore[return-value]
|
||||
189
batch_runner.py
189
batch_runner.py
@@ -41,24 +41,17 @@ from toolset_distributions import (
|
||||
sample_toolsets_from_distribution,
|
||||
validate_distribution
|
||||
)
|
||||
from model_tools import TOOL_TO_TOOLSET_MAP
|
||||
|
||||
|
||||
# Global configuration for worker processes
|
||||
_WORKER_CONFIG = {}
|
||||
|
||||
# All possible tools - used to ensure consistent schema across all trajectory entries
|
||||
# This is required because Arrow/Parquet (used by HuggingFace datasets) needs identical schemas
|
||||
ALL_POSSIBLE_TOOLS = {
|
||||
'terminal', 'web_search', 'web_extract',
|
||||
'vision_analyze', 'image_generate', 'mixture_of_agents',
|
||||
# Skills tools
|
||||
'skills_categories', 'skills_list', 'skill_view',
|
||||
# Browser automation tools
|
||||
'browser_navigate', 'browser_snapshot', 'browser_click',
|
||||
'browser_type', 'browser_scroll', 'browser_back',
|
||||
'browser_press', 'browser_close', 'browser_get_images',
|
||||
'browser_vision'
|
||||
}
|
||||
# All possible tools - auto-derived from the master mapping in model_tools.py.
|
||||
# This stays in sync automatically when new tools are added to TOOL_TO_TOOLSET_MAP.
|
||||
# Used for consistent schema in Arrow/Parquet (HuggingFace datasets) and for
|
||||
# filtering corrupted entries during trajectory combination.
|
||||
ALL_POSSIBLE_TOOLS = set(TOOL_TO_TOOLSET_MAP.keys())
|
||||
|
||||
# Default stats for tools that weren't used
|
||||
DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0}
|
||||
@@ -200,6 +193,42 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
|
||||
return tool_stats
|
||||
|
||||
|
||||
def _extract_reasoning_stats(messages: List[Dict[str, Any]]) -> Dict[str, int]:
|
||||
"""
|
||||
Count how many assistant turns have reasoning vs no reasoning.
|
||||
|
||||
Checks for <REASONING_SCRATCHPAD> in content or a non-empty 'reasoning' field
|
||||
(native thinking tokens). Returns counts for tracking reasoning coverage.
|
||||
|
||||
Args:
|
||||
messages: Message history
|
||||
|
||||
Returns:
|
||||
Dict with 'total_assistant_turns', 'turns_with_reasoning', 'turns_without_reasoning'
|
||||
"""
|
||||
total = 0
|
||||
with_reasoning = 0
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("role") != "assistant":
|
||||
continue
|
||||
total += 1
|
||||
|
||||
content = msg.get("content", "") or ""
|
||||
has_scratchpad = "<REASONING_SCRATCHPAD>" in content
|
||||
has_native_reasoning = bool(msg.get("reasoning", "").strip()) if msg.get("reasoning") else False
|
||||
|
||||
if has_scratchpad or has_native_reasoning:
|
||||
with_reasoning += 1
|
||||
|
||||
return {
|
||||
"total_assistant_turns": total,
|
||||
"turns_with_reasoning": with_reasoning,
|
||||
"turns_without_reasoning": total - with_reasoning,
|
||||
"has_any_reasoning": with_reasoning > 0,
|
||||
}
|
||||
|
||||
|
||||
def _process_single_prompt(
|
||||
prompt_index: int,
|
||||
prompt_data: Dict[str, Any],
|
||||
@@ -244,6 +273,9 @@ def _process_single_prompt(
|
||||
providers_ignored=config.get("providers_ignored"),
|
||||
providers_order=config.get("providers_order"),
|
||||
provider_sort=config.get("provider_sort"),
|
||||
max_tokens=config.get("max_tokens"),
|
||||
reasoning_config=config.get("reasoning_config"),
|
||||
prefill_messages=config.get("prefill_messages"),
|
||||
)
|
||||
|
||||
# Run the agent with task_id to ensure each task gets its own isolated VM
|
||||
@@ -252,6 +284,9 @@ def _process_single_prompt(
|
||||
# Extract tool usage statistics
|
||||
tool_stats = _extract_tool_stats(result["messages"])
|
||||
|
||||
# Extract reasoning coverage stats
|
||||
reasoning_stats = _extract_reasoning_stats(result["messages"])
|
||||
|
||||
# Convert to trajectory format (using existing method)
|
||||
trajectory = agent._convert_to_trajectory_format(
|
||||
result["messages"],
|
||||
@@ -264,6 +299,7 @@ def _process_single_prompt(
|
||||
"prompt_index": prompt_index,
|
||||
"trajectory": trajectory,
|
||||
"tool_stats": tool_stats,
|
||||
"reasoning_stats": reasoning_stats,
|
||||
"completed": result["completed"],
|
||||
"partial": result.get("partial", False),
|
||||
"api_calls": result["api_calls"],
|
||||
@@ -332,7 +368,9 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
|
||||
# Initialize aggregated stats for this batch
|
||||
batch_tool_stats = {}
|
||||
batch_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0}
|
||||
completed_in_batch = []
|
||||
discarded_no_reasoning = 0
|
||||
|
||||
# Process each prompt sequentially in this batch
|
||||
for prompt_index, prompt_data in prompts_to_process:
|
||||
@@ -346,6 +384,13 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
|
||||
# Save trajectory if successful
|
||||
if result["success"] and result["trajectory"]:
|
||||
# Discard samples with zero reasoning across all turns
|
||||
reasoning = result.get("reasoning_stats", {})
|
||||
if not reasoning.get("has_any_reasoning", True):
|
||||
print(f" 🚫 Prompt {prompt_index} discarded (no reasoning in any turn)")
|
||||
discarded_no_reasoning += 1
|
||||
continue
|
||||
|
||||
# Get and normalize tool stats for consistent schema across all entries
|
||||
raw_tool_stats = result.get("tool_stats", {})
|
||||
tool_stats = _normalize_tool_stats(raw_tool_stats)
|
||||
@@ -386,6 +431,10 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
batch_tool_stats[tool_name]["success"] += stats["success"]
|
||||
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
# Aggregate reasoning stats
|
||||
for key in batch_reasoning_stats:
|
||||
batch_reasoning_stats[key] += result.get("reasoning_stats", {}).get(key, 0)
|
||||
|
||||
# Only mark as completed if successfully saved (failed prompts can be retried on resume)
|
||||
if result["success"] and result["trajectory"]:
|
||||
completed_in_batch.append(prompt_index)
|
||||
@@ -401,6 +450,8 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
"processed": len(prompts_to_process),
|
||||
"skipped": len(batch_data) - len(prompts_to_process),
|
||||
"tool_stats": batch_tool_stats,
|
||||
"reasoning_stats": batch_reasoning_stats,
|
||||
"discarded_no_reasoning": discarded_no_reasoning,
|
||||
"completed_prompts": completed_in_batch
|
||||
}
|
||||
|
||||
@@ -428,6 +479,10 @@ class BatchRunner:
|
||||
providers_ignored: List[str] = None,
|
||||
providers_order: List[str] = None,
|
||||
provider_sort: str = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
max_samples: int = None,
|
||||
):
|
||||
"""
|
||||
Initialize the batch runner.
|
||||
@@ -449,6 +504,10 @@ class BatchRunner:
|
||||
providers_ignored (List[str]): OpenRouter providers to ignore (optional)
|
||||
providers_order (List[str]): OpenRouter providers to try in order (optional)
|
||||
provider_sort (str): Sort providers by price/throughput/latency (optional)
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_config (Dict): OpenRouter reasoning config override (e.g. {"effort": "none"} to disable thinking)
|
||||
prefill_messages (List[Dict]): Messages to prepend as prefilled conversation context (few-shot priming)
|
||||
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
|
||||
"""
|
||||
self.dataset_file = Path(dataset_file)
|
||||
self.batch_size = batch_size
|
||||
@@ -466,6 +525,10 @@ class BatchRunner:
|
||||
self.providers_ignored = providers_ignored
|
||||
self.providers_order = providers_order
|
||||
self.provider_sort = provider_sort
|
||||
self.max_tokens = max_tokens
|
||||
self.reasoning_config = reasoning_config
|
||||
self.prefill_messages = prefill_messages
|
||||
self.max_samples = max_samples
|
||||
|
||||
# Validate distribution
|
||||
if not validate_distribution(distribution):
|
||||
@@ -481,8 +544,12 @@ class BatchRunner:
|
||||
# Statistics file
|
||||
self.stats_file = self.output_dir / "statistics.json"
|
||||
|
||||
# Load dataset
|
||||
# Load dataset (and optionally truncate to max_samples)
|
||||
self.dataset = self._load_dataset()
|
||||
if self.max_samples and self.max_samples < len(self.dataset):
|
||||
full_count = len(self.dataset)
|
||||
self.dataset = self.dataset[:self.max_samples]
|
||||
print(f"✂️ Truncated dataset from {full_count} to {self.max_samples} samples (--max_samples)")
|
||||
|
||||
# Create batches
|
||||
self.batches = self._create_batches()
|
||||
@@ -735,6 +802,9 @@ class BatchRunner:
|
||||
"providers_ignored": self.providers_ignored,
|
||||
"providers_order": self.providers_order,
|
||||
"provider_sort": self.provider_sort,
|
||||
"max_tokens": self.max_tokens,
|
||||
"reasoning_config": self.reasoning_config,
|
||||
"prefill_messages": self.prefill_messages,
|
||||
}
|
||||
|
||||
# For backward compatibility, still track by index (but this is secondary to content matching)
|
||||
@@ -797,6 +867,8 @@ class BatchRunner:
|
||||
|
||||
# Aggregate all batch statistics and update checkpoint
|
||||
all_completed_prompts = list(completed_prompts_set)
|
||||
total_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0}
|
||||
|
||||
for batch_result in results:
|
||||
# Add newly completed prompts
|
||||
all_completed_prompts.extend(batch_result.get("completed_prompts", []))
|
||||
@@ -813,6 +885,10 @@ class BatchRunner:
|
||||
total_tool_stats[tool_name]["count"] += stats["count"]
|
||||
total_tool_stats[tool_name]["success"] += stats["success"]
|
||||
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
# Aggregate reasoning stats
|
||||
for key in total_reasoning_stats:
|
||||
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
|
||||
|
||||
# Save final checkpoint
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
@@ -835,15 +911,8 @@ class BatchRunner:
|
||||
combined_file = self.output_dir / "trajectories.jsonl"
|
||||
print(f"\n📦 Combining ALL batch files into {combined_file.name}...")
|
||||
|
||||
VALID_TOOLS = {'web_search', 'web_extract', 'terminal', 'vision_analyze',
|
||||
'image_generate', 'mixture_of_agents',
|
||||
# Skills tools
|
||||
'skills_categories', 'skills_list', 'skill_view',
|
||||
# Browser automation tools
|
||||
'browser_navigate', 'browser_snapshot', 'browser_click',
|
||||
'browser_type', 'browser_scroll', 'browser_back',
|
||||
'browser_press', 'browser_close', 'browser_get_images',
|
||||
'browser_vision'}
|
||||
# Valid tools auto-derived from model_tools.py — no manual updates needed
|
||||
VALID_TOOLS = ALL_POSSIBLE_TOOLS
|
||||
|
||||
total_entries = 0
|
||||
filtered_entries = 0
|
||||
@@ -892,7 +961,8 @@ class BatchRunner:
|
||||
"model": self.model,
|
||||
"completed_at": datetime.now().isoformat(),
|
||||
"duration_seconds": round(time.time() - start_time, 2),
|
||||
"tool_statistics": total_tool_stats
|
||||
"tool_statistics": total_tool_stats,
|
||||
"reasoning_statistics": total_reasoning_stats,
|
||||
}
|
||||
|
||||
with open(self.stats_file, 'w', encoding='utf-8') as f:
|
||||
@@ -930,6 +1000,25 @@ class BatchRunner:
|
||||
else:
|
||||
print("No tool calls were made during this run.")
|
||||
|
||||
# Print reasoning coverage stats
|
||||
total_discarded = sum(r.get("discarded_no_reasoning", 0) for r in results)
|
||||
|
||||
print(f"\n🧠 Reasoning Coverage:")
|
||||
print("-" * 70)
|
||||
total_turns = total_reasoning_stats["total_assistant_turns"]
|
||||
with_reasoning = total_reasoning_stats["turns_with_reasoning"]
|
||||
without_reasoning = total_reasoning_stats["turns_without_reasoning"]
|
||||
if total_turns > 0:
|
||||
pct_with = round(with_reasoning / total_turns * 100, 1)
|
||||
pct_without = round(without_reasoning / total_turns * 100, 1)
|
||||
print(f" Total assistant turns: {total_turns:,}")
|
||||
print(f" With reasoning: {with_reasoning:,} ({pct_with}%)")
|
||||
print(f" Without reasoning: {without_reasoning:,} ({pct_without}%)")
|
||||
else:
|
||||
print(" No assistant turns recorded.")
|
||||
if total_discarded > 0:
|
||||
print(f" 🚫 Samples discarded (zero reasoning): {total_discarded:,}")
|
||||
|
||||
print(f"\n💾 Results saved to: {self.output_dir}")
|
||||
print(f" - Trajectories: trajectories.jsonl (combined)")
|
||||
print(f" - Individual batches: batch_*.jsonl (for debugging)")
|
||||
@@ -956,6 +1045,11 @@ def main(
|
||||
providers_ignored: str = None,
|
||||
providers_order: str = None,
|
||||
provider_sort: str = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_effort: str = None,
|
||||
reasoning_disabled: bool = False,
|
||||
prefill_messages_file: str = None,
|
||||
max_samples: int = None,
|
||||
):
|
||||
"""
|
||||
Run batch processing of agent prompts from a dataset.
|
||||
@@ -979,6 +1073,11 @@ def main(
|
||||
providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra")
|
||||
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
|
||||
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "xhigh")
|
||||
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
|
||||
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
|
||||
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
|
||||
|
||||
Examples:
|
||||
# Basic usage
|
||||
@@ -990,9 +1089,13 @@ def main(
|
||||
# Use specific distribution
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen
|
||||
|
||||
# With ephemeral system prompt (not saved to dataset)
|
||||
# With disabled reasoning and max tokens
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
--reasoning_disabled --max_tokens=128000
|
||||
|
||||
# With prefill messages from file
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||
--prefill_messages_file=configs/prefill_opus.json
|
||||
|
||||
# List available distributions
|
||||
python batch_runner.py --list_distributions
|
||||
@@ -1031,6 +1134,36 @@ def main(
|
||||
providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None
|
||||
providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
|
||||
|
||||
# Build reasoning_config from CLI flags
|
||||
# --reasoning_disabled takes priority, then --reasoning_effort, then default (xhigh)
|
||||
reasoning_config = None
|
||||
if reasoning_disabled:
|
||||
# Completely disable reasoning/thinking tokens
|
||||
reasoning_config = {"effort": "none"}
|
||||
print("🧠 Reasoning: DISABLED (effort=none)")
|
||||
elif reasoning_effort:
|
||||
# Use specified effort level
|
||||
valid_efforts = ["xhigh", "high", "medium", "low", "minimal", "none"]
|
||||
if reasoning_effort not in valid_efforts:
|
||||
print(f"❌ Error: --reasoning_effort must be one of: {', '.join(valid_efforts)}")
|
||||
return
|
||||
reasoning_config = {"enabled": True, "effort": reasoning_effort}
|
||||
print(f"🧠 Reasoning effort: {reasoning_effort}")
|
||||
|
||||
# Load prefill messages from JSON file if provided
|
||||
prefill_messages = None
|
||||
if prefill_messages_file:
|
||||
try:
|
||||
with open(prefill_messages_file, 'r', encoding='utf-8') as f:
|
||||
prefill_messages = json.load(f)
|
||||
if not isinstance(prefill_messages, list):
|
||||
print(f"❌ Error: prefill_messages_file must contain a JSON array of messages")
|
||||
return
|
||||
print(f"💬 Loaded {len(prefill_messages)} prefill messages from {prefill_messages_file}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading prefill messages: {e}")
|
||||
return
|
||||
|
||||
# Initialize and run batch runner
|
||||
try:
|
||||
runner = BatchRunner(
|
||||
@@ -1050,6 +1183,10 @@ def main(
|
||||
providers_ignored=providers_ignored_list,
|
||||
providers_order=providers_order_list,
|
||||
provider_sort=provider_sort,
|
||||
max_tokens=max_tokens,
|
||||
reasoning_config=reasoning_config,
|
||||
prefill_messages=prefill_messages,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
|
||||
runner.run(resume=resume)
|
||||
|
||||
1228
batch_runner_threaded.py
Normal file
1228
batch_runner_threaded.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@
|
||||
# =============================================================================
|
||||
model:
|
||||
# Default model to use (can be overridden with --model flag)
|
||||
default: "anthropic/claude-sonnet-4"
|
||||
default: "anthropic/claude-opus-4.6"
|
||||
|
||||
# API configuration (falls back to OPENROUTER_API_KEY env var)
|
||||
# api_key: "your-key-here" # Uncomment to set here instead of .env
|
||||
@@ -140,7 +140,7 @@ compression:
|
||||
|
||||
# Model to use for generating summaries (fast/cheap recommended)
|
||||
# This model compresses the middle turns into a concise summary
|
||||
summary_model: "google/gemini-2.0-flash-001"
|
||||
summary_model: "google/gemini-3-flash-preview"
|
||||
|
||||
# =============================================================================
|
||||
# Agent Behavior
|
||||
|
||||
154
cli.py
154
cli.py
@@ -83,12 +83,12 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
# Default configuration
|
||||
defaults = {
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4-20250514",
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
"terminal": {
|
||||
"env_type": "local",
|
||||
"cwd": "/tmp",
|
||||
"cwd": ".", # "." is resolved to os.getcwd() at runtime
|
||||
"timeout": 60,
|
||||
"lifetime_seconds": 300,
|
||||
"docker_image": "python:3.11",
|
||||
@@ -101,7 +101,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"compression": {
|
||||
"enabled": True, # Auto-compress when approaching context limit
|
||||
"threshold": 0.85, # Compress at 85% of model's context limit
|
||||
"summary_model": "google/gemini-2.0-flash-001", # Fast/cheap model for summaries
|
||||
"summary_model": "google/gemini-3-flash-preview", # Fast/cheap model for summaries
|
||||
},
|
||||
"agent": {
|
||||
"max_turns": 60, # Default max tool-calling iterations
|
||||
@@ -238,6 +238,10 @@ from toolsets import get_all_toolsets, get_toolset_info, resolve_toolset, valida
|
||||
# Cron job system for scheduled tasks
|
||||
from cron import create_job, list_jobs, remove_job, get_job, run_daemon as run_cron_daemon, tick as cron_tick
|
||||
|
||||
# Resource cleanup imports for safe shutdown (terminal VMs, browser sessions)
|
||||
from tools.terminal_tool import cleanup_all_environments as _cleanup_all_terminals
|
||||
from tools.browser_tool import _emergency_cleanup_all_sessions as _cleanup_all_browsers
|
||||
|
||||
# ============================================================================
|
||||
# ASCII Art & Branding
|
||||
# ============================================================================
|
||||
@@ -839,7 +843,7 @@ class HermesCLI:
|
||||
"""Display current configuration with kawaii ASCII art."""
|
||||
# Get terminal config from environment (which was set from cli-config.yaml)
|
||||
terminal_env = os.getenv("TERMINAL_ENV", "local")
|
||||
terminal_cwd = os.getenv("TERMINAL_CWD", "/tmp")
|
||||
terminal_cwd = os.getenv("TERMINAL_CWD", os.getcwd())
|
||||
terminal_timeout = os.getenv("TERMINAL_TIMEOUT", "60")
|
||||
|
||||
config_path = Path(__file__).parent / 'cli-config.yaml'
|
||||
@@ -1217,33 +1221,35 @@ class HermesCLI:
|
||||
Returns:
|
||||
bool: True to continue, False to exit
|
||||
"""
|
||||
cmd = command.lower().strip()
|
||||
# Lowercase only for dispatch matching; preserve original case for arguments
|
||||
cmd_lower = command.lower().strip()
|
||||
cmd_original = command.strip()
|
||||
|
||||
if cmd in ("/quit", "/exit", "/q"):
|
||||
if cmd_lower in ("/quit", "/exit", "/q"):
|
||||
return False
|
||||
elif cmd == "/help":
|
||||
elif cmd_lower == "/help":
|
||||
self.show_help()
|
||||
elif cmd == "/tools":
|
||||
elif cmd_lower == "/tools":
|
||||
self.show_tools()
|
||||
elif cmd == "/toolsets":
|
||||
elif cmd_lower == "/toolsets":
|
||||
self.show_toolsets()
|
||||
elif cmd == "/config":
|
||||
elif cmd_lower == "/config":
|
||||
self.show_config()
|
||||
elif cmd == "/clear":
|
||||
# Clear terminal screen
|
||||
import os as _os
|
||||
_os.system('clear' if _os.name != 'nt' else 'cls')
|
||||
elif cmd_lower == "/clear":
|
||||
# Clear terminal screen using Rich (portable, no shell needed)
|
||||
self.console.clear()
|
||||
# Reset conversation
|
||||
self.conversation_history = []
|
||||
# Show fresh banner
|
||||
self.show_banner()
|
||||
print(" ✨ (◕‿◕)✨ Fresh start! Screen cleared and conversation reset.\n")
|
||||
elif cmd == "/history":
|
||||
elif cmd_lower == "/history":
|
||||
self.show_history()
|
||||
elif cmd == "/reset":
|
||||
elif cmd_lower == "/reset":
|
||||
self.reset_conversation()
|
||||
elif cmd.startswith("/model"):
|
||||
parts = cmd.split(maxsplit=1)
|
||||
elif cmd_lower.startswith("/model"):
|
||||
# Use original case so model names like "Anthropic/Claude-Opus-4" are preserved
|
||||
parts = cmd_original.split(maxsplit=1)
|
||||
if len(parts) > 1:
|
||||
new_model = parts[1]
|
||||
self.model = new_model
|
||||
@@ -1256,18 +1262,20 @@ class HermesCLI:
|
||||
else:
|
||||
print(f"Current model: {self.model}")
|
||||
print(" Usage: /model <model-name> to change")
|
||||
elif cmd.startswith("/prompt"):
|
||||
self._handle_prompt_command(cmd)
|
||||
elif cmd.startswith("/personality"):
|
||||
self._handle_personality_command(cmd)
|
||||
elif cmd == "/save":
|
||||
elif cmd_lower.startswith("/prompt"):
|
||||
# Use original case so prompt text isn't lowercased
|
||||
self._handle_prompt_command(cmd_original)
|
||||
elif cmd_lower.startswith("/personality"):
|
||||
# Use original case (handler lowercases the personality name itself)
|
||||
self._handle_personality_command(cmd_original)
|
||||
elif cmd_lower == "/save":
|
||||
self.save_conversation()
|
||||
elif cmd.startswith("/cron"):
|
||||
self._handle_cron_command(command) # Use original command for proper parsing
|
||||
elif cmd == "/platforms" or cmd == "/gateway":
|
||||
elif cmd_lower.startswith("/cron"):
|
||||
self._handle_cron_command(cmd_original)
|
||||
elif cmd_lower == "/platforms" or cmd_lower == "/gateway":
|
||||
self._show_gateway_status()
|
||||
else:
|
||||
self.console.print(f"[bold red]Unknown command: {cmd}[/]")
|
||||
self.console.print(f"[bold red]Unknown command: {cmd_lower}[/]")
|
||||
self.console.print("[dim #B8860B]Type /help for available commands[/]")
|
||||
|
||||
return True
|
||||
@@ -1276,6 +1284,11 @@ class HermesCLI:
|
||||
"""
|
||||
Send a message to the agent and get a response.
|
||||
|
||||
Uses a dedicated _interrupt_queue (separate from _pending_input) to avoid
|
||||
race conditions between the process_loop and interrupt monitoring. Messages
|
||||
typed while the agent is running go to _interrupt_queue; messages typed while
|
||||
idle go to _pending_input.
|
||||
|
||||
Args:
|
||||
message: The user's message
|
||||
|
||||
@@ -1307,21 +1320,22 @@ class HermesCLI:
|
||||
agent_thread = threading.Thread(target=run_agent)
|
||||
agent_thread.start()
|
||||
|
||||
# Monitor for new input in the pending queue while agent runs
|
||||
# Monitor the dedicated interrupt queue while the agent runs.
|
||||
# _interrupt_queue is separate from _pending_input, so process_loop
|
||||
# and chat() never compete for the same queue.
|
||||
interrupt_msg = None
|
||||
while agent_thread.is_alive():
|
||||
# Check if there's new input in the queue (from the persistent input area)
|
||||
if hasattr(self, '_pending_input'):
|
||||
if hasattr(self, '_interrupt_queue'):
|
||||
try:
|
||||
interrupt_msg = self._pending_input.get(timeout=0.1)
|
||||
interrupt_msg = self._interrupt_queue.get(timeout=0.1)
|
||||
if interrupt_msg:
|
||||
print(f"\n⚡ New message detected, interrupting...")
|
||||
self.agent.interrupt(interrupt_msg)
|
||||
break
|
||||
except:
|
||||
except queue.Empty:
|
||||
pass # Queue empty or timeout, continue waiting
|
||||
else:
|
||||
# Fallback if no queue (shouldn't happen)
|
||||
# Fallback for non-interactive mode (e.g., single-query)
|
||||
agent_thread.join(0.1)
|
||||
|
||||
agent_thread.join() # Ensure agent thread completes
|
||||
@@ -1332,6 +1346,11 @@ class HermesCLI:
|
||||
# Get the final response
|
||||
response = result.get("final_response", "") if result else ""
|
||||
|
||||
# Handle failed results (e.g., non-retryable errors like invalid model)
|
||||
if result and result.get("failed") and not response:
|
||||
error_detail = result.get("error", "Unknown error")
|
||||
response = f"Error: {error_detail}"
|
||||
|
||||
# Handle interrupt - check if we were interrupted
|
||||
pending_message = None
|
||||
if result and result.get("interrupted"):
|
||||
@@ -1351,10 +1370,11 @@ class HermesCLI:
|
||||
print()
|
||||
print("─" * 60)
|
||||
|
||||
# If we have a pending message from interrupt, process it immediately
|
||||
if pending_message:
|
||||
print(f"\n📨 Processing: '{pending_message[:50]}{'...' if len(pending_message) > 50 else ''}'")
|
||||
return self.chat(pending_message) # Recursive call to handle the new message
|
||||
# If we have a pending message from interrupt, re-queue it for process_loop
|
||||
# instead of recursing (avoids unbounded recursion from rapid interrupts)
|
||||
if pending_message and hasattr(self, '_pending_input'):
|
||||
print(f"\n📨 Queued: '{pending_message[:50]}{'...' if len(pending_message) > 50 else ''}'")
|
||||
self._pending_input.put(pending_message)
|
||||
|
||||
return response
|
||||
|
||||
@@ -1401,8 +1421,10 @@ class HermesCLI:
|
||||
|
||||
# State for async operation
|
||||
self._agent_running = False
|
||||
self._pending_input = queue.Queue()
|
||||
self._pending_input = queue.Queue() # For normal input (commands + new queries)
|
||||
self._interrupt_queue = queue.Queue() # For messages typed while agent is running
|
||||
self._should_exit = False
|
||||
self._last_ctrl_c_time = 0 # Track double Ctrl+C for force exit
|
||||
|
||||
# Create a persistent input area using prompt_toolkit Application
|
||||
input_buffer = Buffer()
|
||||
@@ -1412,21 +1434,49 @@ class HermesCLI:
|
||||
|
||||
@kb.add('enter')
|
||||
def handle_enter(event):
|
||||
"""Handle Enter key - submit input."""
|
||||
"""Handle Enter key - submit input.
|
||||
|
||||
Routes to the correct queue based on agent state:
|
||||
- Agent running: goes to _interrupt_queue (chat() monitors this)
|
||||
- Agent idle: goes to _pending_input (process_loop monitors this)
|
||||
Commands (starting with /) always go to _pending_input so they're
|
||||
handled as commands, not sent as interrupt text to the agent.
|
||||
"""
|
||||
text = event.app.current_buffer.text.strip()
|
||||
if text:
|
||||
# Store the input
|
||||
self._pending_input.put(text)
|
||||
if self._agent_running and not text.startswith("/"):
|
||||
# Agent is working - route to interrupt queue for chat() to pick up
|
||||
self._interrupt_queue.put(text)
|
||||
else:
|
||||
# Agent idle, or it's a command - route to normal input queue
|
||||
self._pending_input.put(text)
|
||||
# Clear the buffer
|
||||
event.app.current_buffer.reset()
|
||||
|
||||
@kb.add('c-c')
|
||||
def handle_ctrl_c(event):
|
||||
"""Handle Ctrl+C - interrupt or exit."""
|
||||
"""Handle Ctrl+C - interrupt agent or force exit on double press.
|
||||
|
||||
First Ctrl+C: interrupt the running agent gracefully.
|
||||
Second Ctrl+C within 2 seconds (or when agent is idle): force exit.
|
||||
"""
|
||||
import time as _time
|
||||
now = _time.time()
|
||||
|
||||
if self._agent_running and self.agent:
|
||||
print("\n⚡ Interrupting agent...")
|
||||
# Check for double Ctrl+C (second press within 2 seconds)
|
||||
if now - self._last_ctrl_c_time < 2.0:
|
||||
print("\n⚡ Force exiting...")
|
||||
self._should_exit = True
|
||||
event.app.exit()
|
||||
return
|
||||
|
||||
# First Ctrl+C: try graceful interrupt
|
||||
self._last_ctrl_c_time = now
|
||||
print("\n⚡ Interrupting agent... (press Ctrl+C again to force exit)")
|
||||
self.agent.interrupt()
|
||||
else:
|
||||
# Agent not running, exit immediately
|
||||
self._should_exit = True
|
||||
event.app.exit()
|
||||
|
||||
@@ -1519,6 +1569,11 @@ class HermesCLI:
|
||||
process_thread = threading.Thread(target=process_loop, daemon=True)
|
||||
process_thread.start()
|
||||
|
||||
# Register atexit cleanup so resources are freed even on unexpected exit
|
||||
# (terminal VMs, browser sessions, etc.)
|
||||
atexit.register(_cleanup_all_browsers)
|
||||
atexit.register(_cleanup_all_terminals)
|
||||
|
||||
# Run the application with patch_stdout for proper output handling
|
||||
try:
|
||||
with patch_stdout():
|
||||
@@ -1527,6 +1582,15 @@ class HermesCLI:
|
||||
pass
|
||||
finally:
|
||||
self._should_exit = True
|
||||
# Explicitly clean up resources before exit
|
||||
try:
|
||||
_cleanup_all_terminals()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
_cleanup_all_browsers()
|
||||
except Exception:
|
||||
pass
|
||||
print("\nGoodbye! ⚕")
|
||||
|
||||
|
||||
@@ -1646,6 +1710,10 @@ def main(
|
||||
cli.show_toolsets()
|
||||
sys.exit(0)
|
||||
|
||||
# Register cleanup for single-query mode (interactive mode registers in run())
|
||||
atexit.register(_cleanup_all_browsers)
|
||||
atexit.register(_cleanup_all_terminals)
|
||||
|
||||
# Handle single query mode
|
||||
if query:
|
||||
cli.show_banner()
|
||||
|
||||
@@ -40,7 +40,7 @@ def run_job(job: dict) -> tuple[bool, str, Optional[str]]:
|
||||
# Create agent with default settings
|
||||
# Jobs run in isolated sessions (no prior context)
|
||||
agent = AIAgent(
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-sonnet-4"),
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-opus-4.6"),
|
||||
quiet_mode=True,
|
||||
session_id=f"cron_{job_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
224
docs/MODAL_BACKEND.md
Normal file
224
docs/MODAL_BACKEND.md
Normal file
@@ -0,0 +1,224 @@
|
||||
# Modal Backend
|
||||
|
||||
Hermes Agent uses [Modal](https://modal.com) for scalable, isolated cloud execution environments. There are two Modal integrations:
|
||||
|
||||
1. **Terminal Tool** (`tools/terminal_tool.py`) - For CLI/agent command execution
|
||||
2. **Atropos Backend** (`atropos/backends/modal_backend.py`) - For batch RL training workloads
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Terminal Tool (CLI/Agent)
|
||||
|
||||
The terminal tool provides a simple interface for executing commands in Modal sandboxes.
|
||||
|
||||
### Configuration
|
||||
|
||||
Set environment variables:
|
||||
|
||||
```bash
|
||||
export TERMINAL_ENV=modal
|
||||
export TERMINAL_MODAL_IMAGE=python:3.11
|
||||
export TERMINAL_MODAL_APP_NAME=hermes-sandbox
|
||||
```
|
||||
|
||||
Or use a YAML config file (`modal_profiles.yaml`):
|
||||
|
||||
```yaml
|
||||
profiles:
|
||||
default:
|
||||
image: python:3.11
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 1
|
||||
max_pool: 5
|
||||
idle_timeout: 120
|
||||
|
||||
gpu:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: T4
|
||||
memory: 16384
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
```
|
||||
|
||||
### Features
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| **Sandbox Pool** | Pre-warmed sandboxes for low latency |
|
||||
| **Auto-scaling** | Grows/shrinks pool based on demand |
|
||||
| **Idle Timeout** | Sandboxes auto-terminate when unused |
|
||||
| **Profile Selection** | Different configs for different workloads |
|
||||
| **Credential Injection** | `modal.Secret` integration |
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from tools.terminal_tool import terminal_tool
|
||||
|
||||
# Simple command
|
||||
output = terminal_tool("echo hello", task_id="my-task")
|
||||
|
||||
# With profile selection
|
||||
output = terminal_tool("python train.py", task_id="training", profile="gpu")
|
||||
|
||||
# Cleanup when done
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
cleanup_vm("my-task")
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
_ModalPoolManager (singleton)
|
||||
├── "default" pool → [sandbox-0, sandbox-1, ...]
|
||||
└── "gpu" pool → [sandbox-0, ...]
|
||||
|
||||
Each pool:
|
||||
- Maintains min_pool warm sandboxes
|
||||
- Scales up to max_pool on demand
|
||||
- Background thread scales down idle sandboxes
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Atropos Backend (RL Training)
|
||||
|
||||
The Atropos backend is designed for high-throughput batch execution during reinforcement learning training.
|
||||
|
||||
### Key Concept: Slot-based Multiplexing
|
||||
|
||||
Instead of one sandbox per trajectory, multiple trajectories share sandboxes via **slots**:
|
||||
|
||||
```
|
||||
Sandbox (1 container)
|
||||
├── Slot 0 → Trajectory A (workspace: /data/slot_0)
|
||||
├── Slot 1 → Trajectory B (workspace: /data/slot_1)
|
||||
└── Slot 2 → Trajectory C (workspace: /data/slot_2)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Fewer containers = lower cost
|
||||
- Shared warm-up time
|
||||
- Better GPU utilization
|
||||
|
||||
### Configuration
|
||||
|
||||
```python
|
||||
from atropos.backends.modal_backend import ModalSandboxConfig, ModalToolBackend
|
||||
|
||||
config = ModalSandboxConfig(
|
||||
name="default",
|
||||
image="python:3.11",
|
||||
cpu=1.0,
|
||||
memory=2048,
|
||||
slots_per_sandbox=10, # 10 trajectories per container
|
||||
min_sandboxes=1,
|
||||
max_sandboxes=5,
|
||||
)
|
||||
|
||||
backend = ModalToolBackend(config.with_app_name("my-training"))
|
||||
```
|
||||
|
||||
### Multi-Profile Support
|
||||
|
||||
Different trajectory types can request different resources:
|
||||
|
||||
```python
|
||||
backend = ModalToolBackend.with_profiles(
|
||||
app_name="rl-training",
|
||||
profiles={
|
||||
"default": ModalSandboxConfig(
|
||||
name="default",
|
||||
cpu=1.0,
|
||||
memory=2048,
|
||||
),
|
||||
"pytorch-gpu": ModalSandboxConfig(
|
||||
name="pytorch-gpu",
|
||||
image="pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime",
|
||||
gpu="T4",
|
||||
memory=16384,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# CPU task
|
||||
slot1 = await backend.acquire("traj-1", profile="default")
|
||||
|
||||
# GPU task
|
||||
slot2 = await backend.acquire("traj-2", profile="pytorch-gpu")
|
||||
```
|
||||
|
||||
### Batched Execution
|
||||
|
||||
The key optimization - execute many commands in parallel:
|
||||
|
||||
```python
|
||||
# Acquire slots for multiple trajectories
|
||||
slots = [await backend.acquire(f"traj-{i}") for i in range(50)]
|
||||
|
||||
# Execute batch across all slots in parallel
|
||||
results = await backend.execute_batch([
|
||||
(slot, "bash", {"command": "python step.py"})
|
||||
for slot in slots
|
||||
])
|
||||
|
||||
# Release slots
|
||||
for slot in slots:
|
||||
await backend.release(slot)
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
ModalToolBackend
|
||||
└── _ModalMultiProfileManager
|
||||
├── "default" → _ModalSandboxPool
|
||||
│ ├── Sandbox 0 (slots 0-9)
|
||||
│ └── Sandbox 1 (slots 0-9)
|
||||
│
|
||||
└── "pytorch-gpu" → _ModalSandboxPool
|
||||
└── Sandbox 0 (slots 0-9)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Credentials
|
||||
|
||||
Inject secrets securely using Modal's secret management:
|
||||
|
||||
```bash
|
||||
# Create secret in Modal dashboard or CLI
|
||||
modal secret create my-api-key API_KEY=sk-xxx
|
||||
```
|
||||
|
||||
```python
|
||||
# Reference in config
|
||||
config = ModalSandboxConfig(
|
||||
secrets=["my-api-key"], # Modal secret names
|
||||
env_vars={"DEBUG": "1"}, # Additional env vars
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Modal package not installed"
|
||||
```bash
|
||||
pip install modal
|
||||
modal token new # Authenticate
|
||||
```
|
||||
|
||||
### "Sandbox creation failed"
|
||||
- Check Modal dashboard for quota limits
|
||||
- Verify image exists and is accessible
|
||||
- Check secret names are correct
|
||||
|
||||
### Shutdown errors
|
||||
These are harmless warnings during Python interpreter shutdown:
|
||||
```
|
||||
[Modal] Error terminating ...: cannot schedule new futures after interpreter shutdown
|
||||
```
|
||||
|
||||
The sandboxes will auto-terminate via Modal's idle_timeout anyway.
|
||||
28
environments/__init__.py
Normal file
28
environments/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Hermes-Agent Atropos Environments
|
||||
|
||||
Provides a layered integration between hermes-agent's tool-calling capabilities
|
||||
and the Atropos RL training framework.
|
||||
|
||||
Layers:
|
||||
- agent_loop: Reusable multi-turn agent loop with standard OpenAI-spec tool calling
|
||||
- tool_context: Per-rollout tool access handle for reward/verification functions
|
||||
- hermes_base_env: Abstract base environment (BaseEnv subclass) for Atropos
|
||||
- tool_call_parsers: Client-side tool call parser registry for Phase 2 (VLLM /generate)
|
||||
|
||||
Concrete environments:
|
||||
- terminal_test_env: Simple file-creation tasks for testing the stack
|
||||
- hermes_swe_env: SWE-bench style tasks with Modal sandboxes
|
||||
"""
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
__all__ = [
|
||||
"AgentResult",
|
||||
"HermesAgentLoop",
|
||||
"ToolContext",
|
||||
"HermesAgentBaseEnv",
|
||||
"HermesAgentEnvConfig",
|
||||
]
|
||||
574
environments/agent_loop.py
Normal file
574
environments/agent_loop.py
Normal file
@@ -0,0 +1,574 @@
|
||||
"""
|
||||
HermesAgentLoop -- Reusable Multi-Turn Agent Engine
|
||||
|
||||
Runs the hermes-agent tool-calling loop using standard OpenAI-spec tool calling.
|
||||
Works with any server that returns ChatCompletion objects with tool_calls:
|
||||
- Phase 1: OpenAI server type (VLLM, SGLang, OpenRouter, OpenAI API)
|
||||
- Phase 2: ManagedServer with client-side tool call parser
|
||||
|
||||
The loop passes tools= and checks response.choices[0].message.tool_calls,
|
||||
identical to hermes-agent's run_agent.py. Tool execution is dispatched via
|
||||
handle_function_call() from model_tools.py.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from model_tools import handle_function_call
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
# (e.g., mini-swe-agent's modal/docker backends). Running them in a separate
|
||||
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolError:
|
||||
"""Record of a tool execution error during the agent loop."""
|
||||
|
||||
turn: int # Which turn the error occurred on
|
||||
tool_name: str # Which tool was called
|
||||
arguments: str # The arguments passed (truncated)
|
||||
error: str # The error message
|
||||
tool_result: str # The raw result returned to the model
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""Result of running the agent loop."""
|
||||
|
||||
# Full conversation history in OpenAI message format
|
||||
messages: List[Dict[str, Any]]
|
||||
# ManagedServer.get_state() if available (Phase 2), None otherwise
|
||||
managed_state: Optional[Dict[str, Any]] = None
|
||||
# How many LLM calls were made
|
||||
turns_used: int = 0
|
||||
# True if model stopped calling tools naturally (vs hitting max_turns)
|
||||
finished_naturally: bool = False
|
||||
# Extracted reasoning content per turn (from PR #297 helpers)
|
||||
reasoning_per_turn: List[Optional[str]] = field(default_factory=list)
|
||||
# Tool errors encountered during the loop
|
||||
tool_errors: List[ToolError] = field(default_factory=list)
|
||||
|
||||
# Tool-call metrics (for reward shaping + debugging)
|
||||
tool_calls_attempted: int = 0 # Valid tool name + attempted dispatch
|
||||
tool_calls_schema_valid: int = 0 # Arguments matched schema (no coercion)
|
||||
tool_calls_executed_ok: int = 0 # Tool ran and returned no error
|
||||
tool_calls_exec_error: int = 0 # Unknown tool / exception / tool returned error
|
||||
|
||||
|
||||
def _extract_reasoning_from_message(message) -> Optional[str]:
|
||||
"""
|
||||
Extract reasoning content from a ChatCompletion message.
|
||||
|
||||
Handles multiple provider formats:
|
||||
1. message.reasoning_content field (some providers)
|
||||
2. message.reasoning field (some providers)
|
||||
3. message.reasoning_details[].text (OpenRouter style)
|
||||
|
||||
Note: <think> block extraction from content is NOT done here -- that's
|
||||
handled by the response already in Phase 1 (server does it) or by
|
||||
ManagedServer's patch in Phase 2.
|
||||
|
||||
Args:
|
||||
message: The assistant message from ChatCompletion response
|
||||
|
||||
Returns:
|
||||
Extracted reasoning text, or None if not found
|
||||
"""
|
||||
# Check reasoning_content field (common across providers)
|
||||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
||||
return message.reasoning_content
|
||||
|
||||
# Check reasoning field
|
||||
if hasattr(message, "reasoning") and message.reasoning:
|
||||
return message.reasoning
|
||||
|
||||
# Check reasoning_details (OpenRouter style)
|
||||
if hasattr(message, "reasoning_details") and message.reasoning_details:
|
||||
for detail in message.reasoning_details:
|
||||
if hasattr(detail, "text") and detail.text:
|
||||
return detail.text
|
||||
if isinstance(detail, dict) and detail.get("text"):
|
||||
return detail["text"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class HermesAgentLoop:
|
||||
"""
|
||||
Runs hermes-agent's tool-calling loop using standard OpenAI-spec tool calling.
|
||||
|
||||
Same pattern as run_agent.py:
|
||||
- Pass tools= to the API
|
||||
- Check response.choices[0].message.tool_calls
|
||||
- Dispatch via handle_function_call()
|
||||
|
||||
Works identically with any server type -- OpenAI, VLLM, SGLang, OpenRouter,
|
||||
or ManagedServer with a parser. The server determines how tool_calls get
|
||||
populated on the response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server,
|
||||
tool_schemas: List[Dict[str, Any]],
|
||||
valid_tool_names: Set[str],
|
||||
max_turns: int = 30,
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
tool_handler=None,
|
||||
max_context_tokens: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
|
||||
Args:
|
||||
server: Server object with chat_completion() method (OpenAIServer,
|
||||
ManagedServer, ServerManager, etc.)
|
||||
tool_schemas: OpenAI-format tool definitions from get_tool_definitions()
|
||||
valid_tool_names: Set of tool names the model is allowed to call
|
||||
max_turns: Maximum number of LLM calls before stopping
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
tool_handler: Optional async callable(tool_name, args, task_id) -> str.
|
||||
When provided, used INSTEAD of handle_function_call() for
|
||||
tool dispatch. This allows sandbox backends (Modal, Nomad)
|
||||
to route tool calls through their slot-based execution.
|
||||
max_context_tokens: Maximum prompt tokens before truncation.
|
||||
If None, no truncation is applied.
|
||||
Recommended: set to max_model_len - max_tokens - 512 (safety margin).
|
||||
"""
|
||||
self.server = server
|
||||
self.tool_schemas = tool_schemas
|
||||
self.valid_tool_names = valid_tool_names
|
||||
self.max_turns = max_turns
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.tool_handler = tool_handler
|
||||
self.max_context_tokens = max_context_tokens
|
||||
|
||||
|
||||
def _truncate_context(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Truncate conversation history to fit within max_context_tokens.
|
||||
|
||||
Strategy:
|
||||
- Keep system message (index 0) and initial user message (index 1) always
|
||||
- Keep last 6 messages (recent context) always
|
||||
- For everything in between, progressively truncate tool result content
|
||||
- If still too long, drop oldest middle messages entirely
|
||||
|
||||
Uses rough char/4 token estimate (fast, no tokenizer needed).
|
||||
"""
|
||||
if self.max_context_tokens is None:
|
||||
return messages
|
||||
|
||||
def estimate_tokens(msgs):
|
||||
total = 0
|
||||
for m in msgs:
|
||||
content = m.get("content", "") or ""
|
||||
total += len(content) // 4 + 10 # ~4 chars per token + overhead
|
||||
if "tool_calls" in m:
|
||||
total += 50 * len(m["tool_calls"]) # tool call overhead
|
||||
return total
|
||||
|
||||
est = estimate_tokens(messages)
|
||||
if est <= self.max_context_tokens:
|
||||
return messages
|
||||
|
||||
# Phase 1: Truncate tool result content in middle messages
|
||||
# Keep first 2 and last 6 messages untouched
|
||||
protect_head = 2
|
||||
protect_tail = max(0, min(6, len(messages) - protect_head))
|
||||
middle_start = protect_head
|
||||
middle_end = len(messages) - protect_tail
|
||||
|
||||
if middle_start < middle_end:
|
||||
# Truncate tool results from oldest first
|
||||
for i in range(middle_start, middle_end):
|
||||
if messages[i].get("role") == "tool":
|
||||
content = messages[i].get("content", "") or ""
|
||||
if len(content) > 200:
|
||||
messages[i] = dict(messages[i]) # copy
|
||||
messages[i]["content"] = content[:100] + "\n...[truncated]...\n" + content[-50:]
|
||||
|
||||
est = estimate_tokens(messages)
|
||||
if est <= self.max_context_tokens:
|
||||
logger.debug("Context truncated (phase 1: tool results): %d tokens", est)
|
||||
return messages
|
||||
|
||||
# Phase 2: Drop oldest middle messages entirely
|
||||
while middle_start < middle_end and estimate_tokens(messages) > self.max_context_tokens:
|
||||
# Remove the oldest middle message
|
||||
# But keep assistant+tool pairs together
|
||||
msg = messages[middle_start]
|
||||
messages.pop(middle_start)
|
||||
middle_end -= 1
|
||||
# If we removed an assistant with tool_calls, also remove matching tool responses
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_ids = {tc.get("id") or tc.get("tool_call_id", "") for tc in msg.get("tool_calls", []) if isinstance(tc, dict)}
|
||||
# Remove tool responses for those IDs
|
||||
i = middle_start
|
||||
while i < middle_end:
|
||||
if messages[i].get("role") == "tool" and messages[i].get("tool_call_id", "") in tool_ids:
|
||||
messages.pop(i)
|
||||
middle_end -= 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
est = estimate_tokens(messages)
|
||||
logger.info("Context truncated (phase 2: dropped messages): %d estimated tokens, %d messages remaining", est, len(messages))
|
||||
return messages
|
||||
|
||||
def _normalize_tool_args(self, tool_name: str, tool_args_raw: str) -> (Dict[str, Any], bool):
|
||||
"""Normalize tool arguments into a dict.
|
||||
|
||||
Returns:
|
||||
(args_dict, schema_valid)
|
||||
|
||||
schema_valid is True only when the arguments decode directly into a dict
|
||||
(i.e. no double-decoding and no coercion/wrapping was needed).
|
||||
|
||||
This lets us keep the environment robust (never crash due to args format)
|
||||
while still scoring down malformed tool-call argument formats.
|
||||
"""
|
||||
try:
|
||||
decoded = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON at all. Be robust: treat it as a plain string.
|
||||
# (Some parsers/providers may pass through non-JSON strings.)
|
||||
if tool_name == "terminal":
|
||||
return {"command": tool_args_raw}, False
|
||||
return {"input": tool_args_raw}, False
|
||||
|
||||
# Canonical case: decoded is already a dict
|
||||
if isinstance(decoded, dict):
|
||||
# For terminal tool, require a command key
|
||||
if tool_name == "terminal":
|
||||
cmd = decoded.get("command")
|
||||
if isinstance(cmd, str) and cmd.strip():
|
||||
return decoded, True
|
||||
# Common alternate key
|
||||
if isinstance(decoded.get("input"), str):
|
||||
return {"command": decoded.get("input")}, False
|
||||
return decoded, False
|
||||
return decoded, True
|
||||
|
||||
# Common drift case: decoded is a JSON string of an object
|
||||
if isinstance(decoded, str):
|
||||
s = decoded.strip()
|
||||
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
|
||||
try:
|
||||
decoded2 = json.loads(s)
|
||||
except json.JSONDecodeError:
|
||||
decoded2 = None
|
||||
if isinstance(decoded2, dict):
|
||||
# Terminal tool: ensure command
|
||||
if tool_name == "terminal" and isinstance(decoded2.get("command"), str):
|
||||
return decoded2, False
|
||||
return decoded2, False
|
||||
|
||||
# Plain string (not JSON) — coerce to expected shape
|
||||
if tool_name == "terminal":
|
||||
return {"command": decoded}, False
|
||||
return {"input": decoded}, False
|
||||
|
||||
# Other JSON types (list/number/etc.) — wrap
|
||||
if tool_name == "terminal":
|
||||
return {"command": str(decoded)}, False
|
||||
return {"input": decoded}, False
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
||||
"""
|
||||
Execute the full agent loop using standard OpenAI tool calling.
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages (system + user).
|
||||
This list is treated as the FULL trajectory and is
|
||||
appended to as the conversation progresses.
|
||||
|
||||
Prompt truncation (to avoid context overflow) is applied
|
||||
on a copy of this list per turn, so we do not lose
|
||||
earlier messages for reward computation/debugging.
|
||||
|
||||
Returns:
|
||||
AgentResult with full conversation history, managed state, and metadata
|
||||
"""
|
||||
reasoning_per_turn = []
|
||||
tool_errors: List[ToolError] = []
|
||||
|
||||
# Metrics to separate "attempted tool use" from "schema-valid tool use"
|
||||
tool_calls_attempted = 0
|
||||
tool_calls_schema_valid = 0
|
||||
tool_calls_executed_ok = 0
|
||||
tool_calls_exec_error = 0
|
||||
|
||||
for turn in range(self.max_turns):
|
||||
# Truncate context if approaching limit.
|
||||
# IMPORTANT: do this on a copy so we keep the full trajectory in `messages`
|
||||
# for reward computation + debugging, while only trimming the prompt view.
|
||||
prompt_messages = self._truncate_context(list(messages))
|
||||
|
||||
# Build the chat_completion kwargs
|
||||
chat_kwargs = {
|
||||
"messages": prompt_messages,
|
||||
"n": 1,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
# Only pass tools if we have them
|
||||
if self.tool_schemas:
|
||||
chat_kwargs["tools"] = self.tool_schemas
|
||||
|
||||
# Only pass max_tokens if explicitly set
|
||||
if self.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.max_tokens
|
||||
|
||||
# Make the API call -- standard OpenAI spec
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
except Exception as e:
|
||||
logger.error("API call failed on turn %d: %s", turn + 1, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
if not response or not response.choices:
|
||||
logger.warning("Empty response on turn %d", turn + 1)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
assistant_msg = response.choices[0].message
|
||||
|
||||
# Extract reasoning content from the response (all provider formats)
|
||||
reasoning = _extract_reasoning_from_message(assistant_msg)
|
||||
reasoning_per_turn.append(reasoning)
|
||||
|
||||
# Check for tool calls -- standard OpenAI spec
|
||||
if assistant_msg.tool_calls:
|
||||
# Build the assistant message dict for conversation history
|
||||
msg_dict: Dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": assistant_msg.content or "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments,
|
||||
},
|
||||
}
|
||||
for tc in assistant_msg.tool_calls
|
||||
],
|
||||
}
|
||||
|
||||
# Preserve reasoning_content for multi-turn chat template handling
|
||||
# (e.g., Kimi-K2's template renders <think> blocks differently
|
||||
# for history vs. the latest turn based on this field)
|
||||
if reasoning:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
|
||||
messages.append(msg_dict)
|
||||
|
||||
# Execute each tool call via hermes-agent's dispatch
|
||||
for tc in assistant_msg.tool_calls:
|
||||
tool_name = tc.function.name
|
||||
tool_args_raw = tc.function.arguments
|
||||
|
||||
# Validate tool name
|
||||
if tool_name not in self.valid_tool_names:
|
||||
tool_result = json.dumps(
|
||||
{
|
||||
"error": f"Unknown tool '{tool_name}'. "
|
||||
f"Available tools: {sorted(self.valid_tool_names)}"
|
||||
}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"Unknown tool '{tool_name}'",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.warning(
|
||||
"Model called unknown tool '%s' on turn %d",
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
tool_calls_exec_error += 1
|
||||
else:
|
||||
tool_calls_attempted += 1
|
||||
|
||||
# Normalize args into a dict so we never crash due to formatting.
|
||||
# Track schema_valid separately so reward shaping can penalize
|
||||
# non-canonical formats (e.g. stringified JSON).
|
||||
args, schema_valid = self._normalize_tool_args(tool_name, tool_args_raw)
|
||||
if schema_valid:
|
||||
tool_calls_schema_valid += 1
|
||||
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
import os
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
if self.tool_handler:
|
||||
backend = "sandbox"
|
||||
cmd_preview = str(args.get("command", ""))[:80]
|
||||
print(f" 🖥️ [{backend}] $ {cmd_preview}")
|
||||
|
||||
if self.tool_handler:
|
||||
# Use custom tool handler (sandbox backend routing)
|
||||
tool_result = await self.tool_handler(
|
||||
tool_name, args, self.task_id
|
||||
)
|
||||
else:
|
||||
# Default: run via hermes-agent's handle_function_call
|
||||
# in a thread pool so backends that use asyncio.run()
|
||||
# internally (modal, docker) get a clean event loop
|
||||
# instead of deadlocking inside Atropos's loop.
|
||||
loop = asyncio.get_event_loop()
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
tool_name, args, task_id=self.task_id
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
tool_calls_exec_error += 1
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"{type(e).__name__}: {str(e)}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.error(
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
else:
|
||||
# Count tool result errors (if tool returns structured JSON error)
|
||||
tool_err = False
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
if isinstance(result_data, dict):
|
||||
err = result_data.get("error")
|
||||
if err:
|
||||
tool_err = True
|
||||
|
||||
# Keep existing behavior: treat negative exit_code as tool error
|
||||
exit_code = result_data.get("exit_code")
|
||||
if exit_code is not None and isinstance(exit_code, int) and exit_code < 0:
|
||||
tool_err = True
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=str(err) if err else "nonzero exit_code",
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Non-JSON tool output — assume ok
|
||||
pass
|
||||
|
||||
if tool_err:
|
||||
tool_calls_exec_error += 1
|
||||
else:
|
||||
tool_calls_executed_ok += 1
|
||||
|
||||
# Add tool response to conversation
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Turn %d: %d tool calls executed",
|
||||
turn + 1,
|
||||
len(assistant_msg.tool_calls),
|
||||
)
|
||||
|
||||
else:
|
||||
# No tool calls -- model is done
|
||||
msg_dict = {
|
||||
"role": "assistant",
|
||||
"content": assistant_msg.content or "",
|
||||
}
|
||||
if reasoning:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
messages.append(msg_dict)
|
||||
|
||||
logger.debug(
|
||||
"Turn %d: model finished naturally (no tool calls)", turn + 1
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=True,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
# Hit max turns without the model stopping
|
||||
logger.info("Agent hit max_turns (%d) without finishing", self.max_turns)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=self.max_turns,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
def _get_managed_state(self) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get ManagedServer state if the server supports it.
|
||||
|
||||
Returns state dict with SequenceNodes containing tokens/logprobs/masks,
|
||||
or None if the server doesn't support get_state() (e.g., regular OpenAI server).
|
||||
"""
|
||||
if hasattr(self.server, "get_state"):
|
||||
return self.server.get_state()
|
||||
return None
|
||||
33
environments/configs/swe_default.yaml
Normal file
33
environments/configs/swe_default.yaml
Normal file
@@ -0,0 +1,33 @@
|
||||
# SWE Environment -- Default Configuration
|
||||
#
|
||||
# SWE-bench style tasks with Modal sandboxes for cloud isolation.
|
||||
# Uses terminal + file + web toolsets.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/hermes_swe_env.py serve --config environments/configs/swe_default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file", "web"]
|
||||
max_agent_turns: 30
|
||||
max_token_length: 4096
|
||||
group_size: 4
|
||||
terminal_backend: "modal"
|
||||
tool_call_parser: "hermes"
|
||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
dataset_name: "bigcode/humanevalpack"
|
||||
dataset_split: "test"
|
||||
prompt_field: "prompt"
|
||||
steps_per_eval: 50
|
||||
total_steps: 500
|
||||
use_wandb: true
|
||||
wandb_name: "hermes-swe"
|
||||
system_prompt: >
|
||||
You are a skilled software engineer. You have access to a terminal,
|
||||
file tools, and web search. Use these tools to complete the coding task.
|
||||
Write clean, working code and verify it runs correctly before finishing.
|
||||
|
||||
openai:
|
||||
base_url: "http://localhost:8000/v1"
|
||||
model_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
server_type: "openai"
|
||||
api_key: ""
|
||||
35
environments/configs/terminal_test_default.yaml
Normal file
35
environments/configs/terminal_test_default.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
# Terminal Test Environment -- Default Configuration
|
||||
#
|
||||
# Simple file-creation tasks for validating the full Atropos + hermes-agent stack.
|
||||
# Uses Modal terminal backend and OpenRouter (Claude) for inference.
|
||||
# API keys loaded from ~/hermes-agent/.env
|
||||
#
|
||||
# Usage:
|
||||
# run-api
|
||||
# python environments/terminal_test_env.py serve
|
||||
# # Or with config file:
|
||||
# python environments/terminal_test_env.py serve --config environments/configs/terminal_test_default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 10
|
||||
max_token_length: 2048
|
||||
group_size: 3
|
||||
total_steps: 3
|
||||
steps_per_eval: 3
|
||||
terminal_backend: "modal"
|
||||
tool_call_parser: "hermes"
|
||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
||||
ensure_scores_are_not_same: false
|
||||
use_wandb: false
|
||||
system_prompt: >
|
||||
You are a helpful assistant with access to a terminal and file tools.
|
||||
Complete the user's request by using the available tools.
|
||||
Be precise and follow instructions exactly.
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
350
environments/gsm8k_agent_env.py
Normal file
350
environments/gsm8k_agent_env.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
GSM8kAgentEnv -- Math Reasoning with Tool Use (Python REPL)
|
||||
|
||||
An agentic RL environment where models solve GSM8k math problems using
|
||||
a Python interpreter tool. Uses proper OpenAI-spec tool calling via
|
||||
HermesAgentBaseEnv (not ICL).
|
||||
|
||||
The model:
|
||||
1. Receives a math problem
|
||||
2. Can call the `terminal` tool to run Python code (`python3 -c "..."`)
|
||||
3. Provides a final answer in \\boxed{} format
|
||||
4. Gets reward: 1.0 if correct, 0.0 if wrong
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenRouter, no training):
|
||||
python environments/gsm8k_agent_env.py process \\
|
||||
--env.data_path_to_save_groups gsm8k_agent_output.jsonl
|
||||
|
||||
# Phase 2 (VLLM + Tinker training):
|
||||
run-api
|
||||
python launch_training.py --config configs/gsm8k_agent.yaml
|
||||
python environments/gsm8k_agent_env.py serve
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Math verification helpers
|
||||
# =============================================================================
|
||||
|
||||
def _verify_math_answer(model_response: str, gold_answer: str) -> bool:
|
||||
"""
|
||||
Verify if the model's response contains the correct answer.
|
||||
Uses math_verify for robust LaTeX comparison, falls back to string matching.
|
||||
"""
|
||||
try:
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
gold_parsed = parse(
|
||||
f"\\boxed{{{gold_answer}}}",
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
|
||||
# Strip <think> blocks if present
|
||||
answer_text = model_response
|
||||
if "</think>" in answer_text:
|
||||
answer_text = answer_text.split("</think>")[-1]
|
||||
|
||||
answer_parsed = parse(
|
||||
answer_text,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
|
||||
return bool(verify(answer_parsed, gold_parsed))
|
||||
|
||||
except ImportError:
|
||||
# Fallback: simple string matching for \\boxed{answer}
|
||||
import re
|
||||
pattern = r'\\boxed\{([^}]+)\}'
|
||||
matches = re.findall(pattern, model_response)
|
||||
if matches:
|
||||
model_answer = matches[-1].strip().replace(",", "")
|
||||
gold_clean = gold_answer.strip().replace(",", "")
|
||||
return model_answer == gold_clean
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment Config
|
||||
# =============================================================================
|
||||
|
||||
class GSM8kAgentEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults for GSM8k agent environment."""
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment
|
||||
# =============================================================================
|
||||
|
||||
class GSM8kAgentEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
GSM8k math environment with Python REPL tool calling.
|
||||
|
||||
Models solve grade-school math problems by reasoning step by step
|
||||
and using Python (via the terminal tool) for calculations.
|
||||
|
||||
Exercises the full agentic RL training loop:
|
||||
- Model receives math problem
|
||||
- Makes tool calls to compute (python3 -c "...")
|
||||
- Provides final answer in \\boxed{}
|
||||
- Reward: binary (1.0 correct, 0.0 wrong)
|
||||
"""
|
||||
|
||||
name = "gsm8k-agent"
|
||||
env_config_cls = GSM8kAgentEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[GSM8kAgentEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default config using terminal tool.
|
||||
|
||||
Reads from environment variables (set in .env):
|
||||
ATROPOS_SERVER_BASE_URL - Inference server URL
|
||||
ATROPOS_SERVER_MODEL - Model name on the server
|
||||
ATROPOS_TOKENIZER_NAME - HuggingFace tokenizer name
|
||||
ATROPOS_SERVER_API_KEY - API key for the server
|
||||
"""
|
||||
# Resolve inference server settings from env
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "https://openrouter.ai/api/v1"
|
||||
)
|
||||
if not base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/") + "/v1"
|
||||
|
||||
model = (
|
||||
os.getenv("ATROPOS_SERVER_MODEL")
|
||||
or os.getenv("LLM_MODEL")
|
||||
or "Hermes-4.3-36B"
|
||||
)
|
||||
|
||||
api_key = (
|
||||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("NOUS_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
|
||||
tokenizer = (
|
||||
os.getenv("ATROPOS_TOKENIZER_NAME")
|
||||
or os.getenv("ATROPOS_TOKENIZER")
|
||||
or "NousResearch/Hermes-4.3-36B"
|
||||
)
|
||||
|
||||
env_config = GSM8kAgentEnvConfig(
|
||||
# Terminal + file toolsets (same as terminal_test_env.py)
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings
|
||||
max_agent_turns=5, # Math problems don't need many turns
|
||||
max_token_length=2048, # Room for reasoning + code
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a helpful math assistant. You have access to a terminal "
|
||||
"where you can run Python code to help solve problems.\n\n"
|
||||
"When you need to calculate something, use the terminal tool with "
|
||||
"a command like: python3 -c \"print(2 + 2)\"\n\n"
|
||||
"When you have the final answer, write it inside \\boxed{} like: \\boxed{42}\n\n"
|
||||
"Work step by step. Use Python to verify your reasoning."
|
||||
),
|
||||
# Terminal backend (local for testing, modal for production)
|
||||
terminal_backend=os.getenv("TERMINAL_ENV", "local"),
|
||||
# Parser -- hermes format for Hermes models
|
||||
tool_call_parser="hermes",
|
||||
# Atropos settings
|
||||
group_size=4,
|
||||
tokenizer_name=tokenizer,
|
||||
steps_per_eval=5,
|
||||
total_steps=10,
|
||||
use_wandb=bool(os.getenv("WANDB_API_KEY")),
|
||||
wandb_name="gsm8k-agent",
|
||||
ensure_scores_are_not_same=False,
|
||||
# No external dataset (we load GSM8k ourselves)
|
||||
dataset_name=None,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url=base_url,
|
||||
model_name=model,
|
||||
server_type="openai",
|
||||
api_key=api_key,
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Load GSM8k dataset."""
|
||||
from datasets import load_dataset
|
||||
|
||||
self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
|
||||
test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
|
||||
self.test = [
|
||||
{
|
||||
"question": item["question"],
|
||||
"gold_answer": item["answer"].split("#")[-1].strip().replace(",", ""),
|
||||
}
|
||||
for item in test_data
|
||||
]
|
||||
self.iter = 0
|
||||
self.reward_buffer: List[float] = []
|
||||
self.tool_use_buffer: List[int] = []
|
||||
print(f"[GSM8kAgentEnv] Loaded {len(self.train)} train, {len(self.test)} test examples")
|
||||
|
||||
async def get_next_item(self) -> Dict[str, str]:
|
||||
"""Cycle through training problems."""
|
||||
item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return {
|
||||
"question": item["question"],
|
||||
"gold_answer": item["answer"].split("#")[-1].strip().replace(",", ""),
|
||||
}
|
||||
|
||||
def format_prompt(self, item: Dict[str, str]) -> str:
|
||||
"""Format the math problem as a user message."""
|
||||
return item["question"]
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, str], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score: verify the model's \\boxed{} answer against the gold answer.
|
||||
|
||||
The agent has full access to terminal via ctx, but for GSM8k we just
|
||||
check the final answer from the conversation.
|
||||
"""
|
||||
# Get the last assistant message content
|
||||
final_text = ""
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content"):
|
||||
final_text = msg["content"]
|
||||
break
|
||||
|
||||
correct = _verify_math_answer(final_text, item["gold_answer"])
|
||||
reward = 1.0 if correct else 0.0
|
||||
|
||||
self.reward_buffer.append(reward)
|
||||
# Count tool calls in this trajectory
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
self.tool_use_buffer.append(tool_call_count)
|
||||
|
||||
return reward
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Evaluate on a subset of the test set (greedy, no tools for speed)."""
|
||||
start_time = time.time()
|
||||
correct = 0
|
||||
total = 0
|
||||
samples = []
|
||||
|
||||
eval_subset = self.test[:30] # Small subset for quick eval
|
||||
|
||||
for item in eval_subset:
|
||||
try:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": item["question"]},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response = completion.choices[0].message.content or ""
|
||||
is_correct = _verify_math_answer(response, item["gold_answer"])
|
||||
|
||||
if is_correct:
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
samples.append({
|
||||
"question": item["question"],
|
||||
"gold_answer": item["gold_answer"],
|
||||
"response": response[:500],
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Eval failed: %s", e)
|
||||
total += 1
|
||||
|
||||
percent_correct = correct / total if total > 0 else 0
|
||||
end_time = time.time()
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics={"eval/percent_correct": percent_correct, "eval/total": total},
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log training metrics."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
wandb_metrics["train/percent_correct"] = sum(self.reward_buffer) / len(self.reward_buffer)
|
||||
wandb_metrics["train/total_rollouts"] = len(self.reward_buffer)
|
||||
self.reward_buffer = []
|
||||
|
||||
if self.tool_use_buffer:
|
||||
wandb_metrics["train/avg_tool_calls"] = sum(self.tool_use_buffer) / len(self.tool_use_buffer)
|
||||
wandb_metrics["train/tool_use_rate"] = sum(1 for t in self.tool_use_buffer if t > 0) / len(self.tool_use_buffer)
|
||||
self.tool_use_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kAgentEnv.cli()
|
||||
1005
environments/hermes_base_env.py
Normal file
1005
environments/hermes_base_env.py
Normal file
File diff suppressed because it is too large
Load Diff
229
environments/hermes_swe_env.py
Normal file
229
environments/hermes_swe_env.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
HermesSweEnv -- SWE-Bench Style Environment with Modal Sandboxes
|
||||
|
||||
A concrete environment for software engineering tasks where the model writes code
|
||||
and the reward function runs tests to verify correctness. Uses Modal terminal
|
||||
backend for cloud-isolated sandboxes per rollout.
|
||||
|
||||
The reward function uses ToolContext.terminal() to run test commands in the same
|
||||
Modal sandbox the model used during its agentic loop. All filesystem state from
|
||||
the model's tool calls is preserved for verification.
|
||||
|
||||
Usage:
|
||||
# Phase 1: OpenAI server type
|
||||
vllm serve YourModel --tool-parser hermes
|
||||
run-api
|
||||
python environments/hermes_swe_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type openai \\
|
||||
--env.dataset_name bigcode/humanevalpack \\
|
||||
--env.terminal_backend modal
|
||||
|
||||
# Phase 2: VLLM server type (full RL training)
|
||||
python environments/hermes_swe_env.py serve \\
|
||||
--openai.base_url http://localhost:8000/v1 \\
|
||||
--openai.model_name YourModel \\
|
||||
--openai.server_type vllm \\
|
||||
--env.tool_call_parser hermes \\
|
||||
--env.terminal_backend modal
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HermesSweEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults for SWE-bench style tasks."""
|
||||
|
||||
pass # Inherits all fields, overrides defaults in config_init
|
||||
|
||||
|
||||
class HermesSweEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
SWE-bench style environment using Modal terminal backend.
|
||||
|
||||
The model gets a coding task, uses terminal + file + web tools to solve it,
|
||||
and the reward function runs tests in the same Modal sandbox to verify.
|
||||
|
||||
Subclass this for specific SWE datasets (HumanEval, SWE-bench, etc.)
|
||||
and customize format_prompt() and compute_reward() as needed.
|
||||
"""
|
||||
|
||||
name = "hermes-swe"
|
||||
env_config_cls = HermesSweEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[HermesSweEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for the SWE environment.
|
||||
|
||||
Uses Modal terminal backend for cloud isolation and terminal + file + web toolsets.
|
||||
"""
|
||||
env_config = HermesSweEnvConfig(
|
||||
# Toolsets: terminal for running code, file for reading/writing, web for docs
|
||||
enabled_toolsets=["terminal", "file", "web"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings -- SWE tasks need more turns
|
||||
max_agent_turns=30,
|
||||
max_token_length=4096,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a skilled software engineer. You have access to a terminal, "
|
||||
"file tools, and web search. Use these tools to complete the coding task. "
|
||||
"Write clean, working code and verify it runs correctly before finishing."
|
||||
),
|
||||
# Modal backend for cloud-isolated sandboxes
|
||||
terminal_backend="modal",
|
||||
# Dataset -- override via CLI for your specific SWE dataset
|
||||
dataset_name="bigcode/humanevalpack",
|
||||
dataset_split="test",
|
||||
prompt_field="prompt",
|
||||
# Atropos settings
|
||||
group_size=4,
|
||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
tool_call_parser="hermes",
|
||||
steps_per_eval=50,
|
||||
total_steps=500,
|
||||
use_wandb=True,
|
||||
wandb_name="hermes-swe",
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="http://localhost:8000/v1",
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
||||
server_type="openai", # Phase 1; switch to "vllm" for Phase 2
|
||||
api_key="",
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Load the SWE dataset."""
|
||||
if self.config.dataset_name:
|
||||
self.dataset = load_dataset(
|
||||
self.config.dataset_name, split=self.config.dataset_split
|
||||
)
|
||||
else:
|
||||
# Placeholder if no dataset specified
|
||||
self.dataset = []
|
||||
self.iter = 0
|
||||
self.reward_buffer: List[float] = []
|
||||
|
||||
async def get_next_item(self) -> Dict[str, Any]:
|
||||
"""Cycle through the SWE dataset."""
|
||||
if not self.dataset:
|
||||
raise ValueError("No dataset loaded. Set dataset_name in config.")
|
||||
item = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Format the SWE task prompt.
|
||||
|
||||
Override this in subclasses for different dataset formats.
|
||||
Default assumes the dataset has a 'prompt' field and optionally a 'test' field.
|
||||
"""
|
||||
prompt = item.get(self.config.prompt_field, "")
|
||||
|
||||
# If the dataset has test information, include it in the prompt
|
||||
test_info = item.get("test", item.get("test_code", item.get("tests", "")))
|
||||
if test_info:
|
||||
prompt += f"\n\nTests to pass:\n{test_info}"
|
||||
|
||||
return prompt
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, Any], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score by running tests in the model's Modal sandbox.
|
||||
|
||||
Default implementation:
|
||||
- If the dataset item has a 'test' or 'test_code' field, run it
|
||||
- Check exit code: 0 = pass, non-zero = fail
|
||||
- Partial credit for file creation
|
||||
|
||||
Override this in subclasses for more sophisticated reward logic.
|
||||
"""
|
||||
# Find the test command from the dataset item
|
||||
test_code = item.get("test", item.get("test_code", item.get("tests", "")))
|
||||
|
||||
if test_code:
|
||||
# Run the test in the model's sandbox
|
||||
test_result = ctx.terminal(
|
||||
f'cd /workspace && python3 -c "{test_code}"', timeout=60
|
||||
)
|
||||
|
||||
if test_result["exit_code"] == 0:
|
||||
self.reward_buffer.append(1.0)
|
||||
return 1.0
|
||||
|
||||
# Partial credit: check if the model created any Python files
|
||||
file_check = ctx.terminal("find /workspace -name '*.py' -newer /tmp/.start_marker 2>/dev/null | head -5")
|
||||
if file_check["exit_code"] == 0 and file_check.get("output", "").strip():
|
||||
self.reward_buffer.append(0.1)
|
||||
return 0.1
|
||||
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Run evaluation on a held-out set.
|
||||
|
||||
Override for dataset-specific evaluation logic.
|
||||
"""
|
||||
start_time = time.time()
|
||||
end_time = time.time()
|
||||
|
||||
eval_metrics = {"eval/placeholder": 0.0}
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log SWE-specific metrics."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / len(
|
||||
self.reward_buffer
|
||||
)
|
||||
wandb_metrics["train/pass_rate"] = sum(
|
||||
1 for r in self.reward_buffer if r == 1.0
|
||||
) / len(self.reward_buffer)
|
||||
self.reward_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HermesSweEnv.cli()
|
||||
309
environments/patches.py
Normal file
309
environments/patches.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Monkey patches for making hermes-agent tools work inside async frameworks (Atropos).
|
||||
|
||||
Problem:
|
||||
Some tools use asyncio.run() internally (e.g., mini-swe-agent's Modal backend,
|
||||
web_extract). This crashes when called from inside Atropos's event loop because
|
||||
asyncio.run() can't be nested.
|
||||
|
||||
Solution:
|
||||
Replace the problematic methods with versions that use a dedicated background
|
||||
thread with its own event loop. The calling code sees the same sync interface --
|
||||
call a function, get a result -- but internally the async work happens on a
|
||||
separate thread that doesn't conflict with Atropos's loop.
|
||||
|
||||
These patches are safe for normal CLI use too: when there's no running event
|
||||
loop, the behavior is identical (the background thread approach works regardless).
|
||||
|
||||
What gets patched:
|
||||
- SwerexModalEnvironment.__init__ -- creates Modal deployment on a background thread
|
||||
- SwerexModalEnvironment.execute -- runs commands on the same background thread
|
||||
- SwerexModalEnvironment.stop -- stops deployment on the background thread
|
||||
|
||||
Usage:
|
||||
Call apply_patches() once at import time (done automatically by hermes_base_env.py).
|
||||
This is idempotent -- calling it multiple times is safe.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_patches_applied = False
|
||||
|
||||
|
||||
class _AsyncWorker:
|
||||
"""
|
||||
A dedicated background thread with its own event loop.
|
||||
|
||||
Allows sync code to submit async coroutines and block for results,
|
||||
even when called from inside another running event loop. Used to
|
||||
bridge sync tool interfaces with async backends (Modal, SWE-ReX).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._loop: asyncio.AbstractEventLoop = None
|
||||
self._thread: threading.Thread = None
|
||||
self._started = threading.Event()
|
||||
|
||||
def start(self):
|
||||
"""Start the background event loop thread."""
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._started.wait(timeout=30)
|
||||
|
||||
def _run_loop(self):
|
||||
"""Background thread entry point -- runs the event loop forever."""
|
||||
self._loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._loop)
|
||||
self._started.set()
|
||||
self._loop.run_forever()
|
||||
|
||||
def run_coroutine(self, coro, timeout=600):
|
||||
"""
|
||||
Submit a coroutine to the background loop and block until it completes.
|
||||
|
||||
Safe to call from any thread, including threads that already have
|
||||
a running event loop.
|
||||
"""
|
||||
if self._loop is None or self._loop.is_closed():
|
||||
raise RuntimeError("AsyncWorker loop is not running")
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result(timeout=timeout)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the background event loop and join the thread."""
|
||||
if self._loop and self._loop.is_running():
|
||||
self._loop.call_soon_threadsafe(self._loop.stop)
|
||||
if self._thread:
|
||||
self._thread.join(timeout=10)
|
||||
|
||||
|
||||
def _patch_swerex_modal():
|
||||
"""
|
||||
Monkey patch SwerexModalEnvironment to use a background thread event loop
|
||||
instead of asyncio.run(). This makes it safe to call from inside Atropos's
|
||||
async event loop.
|
||||
|
||||
The patched methods have the exact same interface and behavior -- the only
|
||||
difference is HOW the async work is executed internally.
|
||||
"""
|
||||
try:
|
||||
from minisweagent.environments.extra.swerex_modal import (
|
||||
SwerexModalEnvironment,
|
||||
SwerexModalEnvironmentConfig,
|
||||
)
|
||||
from swerex.deployment.modal import ModalDeployment
|
||||
from swerex.runtime.abstract import Command as RexCommand
|
||||
except ImportError:
|
||||
# mini-swe-agent or swe-rex not installed -- nothing to patch
|
||||
logger.debug("mini-swe-agent Modal backend not available, skipping patch")
|
||||
return
|
||||
|
||||
# Save original methods so we can refer to config handling
|
||||
_original_init = SwerexModalEnvironment.__init__
|
||||
|
||||
def _patched_init(self, **kwargs):
|
||||
"""Patched __init__: creates Modal deployment on a background thread."""
|
||||
self.config = SwerexModalEnvironmentConfig(**kwargs)
|
||||
|
||||
# Start a dedicated event loop thread for all Modal async operations
|
||||
self._worker = _AsyncWorker()
|
||||
self._worker.start()
|
||||
|
||||
# Create AND start the deployment entirely on the worker's loop/thread
|
||||
# so all gRPC channels and async state are bound to that loop
|
||||
async def _create_and_start():
|
||||
deployment = ModalDeployment(
|
||||
image=self.config.image,
|
||||
startup_timeout=self.config.startup_timeout,
|
||||
runtime_timeout=self.config.runtime_timeout,
|
||||
deployment_timeout=self.config.deployment_timeout,
|
||||
install_pipx=self.config.install_pipx,
|
||||
modal_sandbox_kwargs=self.config.modal_sandbox_kwargs,
|
||||
)
|
||||
await deployment.start()
|
||||
return deployment
|
||||
|
||||
self.deployment = self._worker.run_coroutine(_create_and_start())
|
||||
|
||||
def _patched_execute(self, command: str, cwd: str = "", *, timeout: int | None = None) -> dict[str, Any]:
|
||||
"""Patched execute: runs commands on the background thread's loop."""
|
||||
async def _do_execute():
|
||||
return await self.deployment.runtime.execute(
|
||||
RexCommand(
|
||||
command=command,
|
||||
shell=True,
|
||||
check=False,
|
||||
cwd=cwd or self.config.cwd,
|
||||
timeout=timeout or self.config.timeout,
|
||||
merge_output_streams=True,
|
||||
env=self.config.env if self.config.env else None,
|
||||
)
|
||||
)
|
||||
|
||||
output = self._worker.run_coroutine(_do_execute())
|
||||
return {
|
||||
"output": output.stdout,
|
||||
"returncode": output.exit_code,
|
||||
}
|
||||
|
||||
def _patched_stop(self):
|
||||
"""Patched stop: stops deployment on the background thread, then stops the thread."""
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
asyncio.wait_for(self.deployment.stop(), timeout=10),
|
||||
timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._worker.stop()
|
||||
|
||||
# Apply the patches
|
||||
SwerexModalEnvironment.__init__ = _patched_init
|
||||
SwerexModalEnvironment.execute = _patched_execute
|
||||
SwerexModalEnvironment.stop = _patched_stop
|
||||
|
||||
logger.debug("Patched SwerexModalEnvironment for async-safe operation")
|
||||
|
||||
|
||||
def _patch_vllm_server_for_sglang():
|
||||
"""
|
||||
(Mainly for Runpod serverless compat)
|
||||
|
||||
Monkey patch VLLMServer._tokens_and_logprobs_completion_wrapper to handle
|
||||
SGLang's /generate response format.
|
||||
|
||||
VLLMServer expects:
|
||||
Request: {"prompt": {"prompt_token_ids": [...]}, "logprobs": 0}
|
||||
Response: {"logprobs": [[{token_id: logprob}]], "finish_reasons": [...]}
|
||||
|
||||
SGLang returns:
|
||||
Request: {"input_ids": [...], "sampling_params": {...}, "return_logprob": true}
|
||||
Response: {"text": "...", "meta_info": {"output_token_logprobs": [[logprob, token_id, text], ...]}}
|
||||
|
||||
This patch makes VLLMServer work with SGLang endpoints (e.g., RunPod SGLang workers).
|
||||
"""
|
||||
try:
|
||||
import aiohttp
|
||||
from atroposlib.envs.server_handling.vllm_server import VLLMServer
|
||||
except ImportError:
|
||||
logger.debug("atroposlib VLLMServer not available, skipping SGLang patch")
|
||||
return
|
||||
|
||||
# Save the original method
|
||||
_original_wrapper = VLLMServer._tokens_and_logprobs_completion_wrapper
|
||||
|
||||
async def _sglang_compatible_wrapper(self, **kwargs):
|
||||
"""
|
||||
Patched wrapper that tries the original VLLMServer format first,
|
||||
then falls back to SGLang format if that fails.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("prompt") is not None or kwargs.get("input_ids") is not None, "Prompt or input_ids required!"
|
||||
|
||||
# Get prompt tokens
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Check for double BOS
|
||||
if (len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
|
||||
# Normalize kwargs
|
||||
max_tokens = kwargs.pop("max_new_tokens", kwargs.pop("max_completion_tokens", kwargs.pop("max_tokens", 2048)))
|
||||
n = kwargs.pop("n", 1)
|
||||
temperature = kwargs.pop("temperature", 1.0)
|
||||
kwargs.pop("model", None)
|
||||
|
||||
# Build SGLang-compatible request
|
||||
request_data = {
|
||||
"input_ids": prompt_tokens,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"n": n,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"top_logprobs_num": 0,
|
||||
}
|
||||
|
||||
generate_url = f"{self.config.base_url.replace('/v1', '')}/generate"
|
||||
|
||||
headers = {}
|
||||
if self.config.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
generate_url,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
raw_text = await response.text()
|
||||
|
||||
# RunPod wraps JSON responses in quotes — may need double-parse
|
||||
import json
|
||||
results = json.loads(raw_text)
|
||||
if isinstance(results, str):
|
||||
results = json.loads(results)
|
||||
|
||||
# Parse SGLang response format
|
||||
meta = results.get("meta_info", {})
|
||||
output_token_logprobs_raw = meta.get("output_token_logprobs", [])
|
||||
|
||||
# SGLang format: [[logprob, token_id, token_text], ...]
|
||||
output_tokens = []
|
||||
output_logprobs = []
|
||||
for entry in output_token_logprobs_raw:
|
||||
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
||||
logprob, token_id = entry[0], entry[1]
|
||||
output_tokens.append(int(token_id))
|
||||
output_logprobs.append(float(logprob))
|
||||
|
||||
# Get finish reason
|
||||
finish_reason_raw = meta.get("finish_reason", "stop")
|
||||
if isinstance(finish_reason_raw, dict):
|
||||
finish_reason = finish_reason_raw.get("type", "stop")
|
||||
else:
|
||||
finish_reason = str(finish_reason_raw)
|
||||
|
||||
return (
|
||||
prompt_tokens,
|
||||
[output_tokens],
|
||||
[output_logprobs],
|
||||
[finish_reason],
|
||||
)
|
||||
|
||||
# Apply the patch
|
||||
VLLMServer._tokens_and_logprobs_completion_wrapper = _sglang_compatible_wrapper
|
||||
logger.info("Patched VLLMServer for SGLang /generate compatibility")
|
||||
|
||||
|
||||
def apply_patches():
|
||||
"""
|
||||
Apply all monkey patches needed for Atropos compatibility.
|
||||
|
||||
Safe to call multiple times -- patches are only applied once.
|
||||
Safe for normal CLI use -- patched code works identically when
|
||||
there is no running event loop.
|
||||
"""
|
||||
global _patches_applied
|
||||
if _patches_applied:
|
||||
return
|
||||
|
||||
_patch_swerex_modal()
|
||||
# _patch_vllm_server_for_sglang()
|
||||
|
||||
_patches_applied = True
|
||||
620
environments/swe_smith_oracle_env.py
Normal file
620
environments/swe_smith_oracle_env.py
Normal file
@@ -0,0 +1,620 @@
|
||||
"""
|
||||
SWE-smith-oracle environment (ported to HermesAgentBaseEnv).
|
||||
|
||||
Trains models to fix real GitHub repositories:
|
||||
- Clones a public GitHub repo at a specific commit
|
||||
- Runs an agent loop with terminal tool to apply a fix
|
||||
- Verifies by running pytest with nodeids from the dataset
|
||||
- Reward: 1.0 if all tests pass, 0.0 otherwise
|
||||
|
||||
Dataset: NousResearch/SWE-smith-oracle (train split; does NOT use SWE-bench eval set).
|
||||
|
||||
Usage:
|
||||
# Process mode (OpenAI server, no training):
|
||||
python environments/swe_smith_oracle_env.py process \\
|
||||
--env.data_path_to_save_groups data/swe_oracle_output.jsonl
|
||||
|
||||
# With Modal sandbox backend:
|
||||
python environments/swe_smith_oracle_env.py process \\
|
||||
--env.tool_pool_mode modal \\
|
||||
--env.modal_image python:3.11
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config
|
||||
# =============================================================================
|
||||
|
||||
class SweSmithOracleEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config for SWE-smith-oracle environment."""
|
||||
|
||||
dataset_name: str = Field(default="NousResearch/SWE-smith-oracle")
|
||||
dataset_split: str = Field(default="train")
|
||||
max_items: int = Field(default=0, description="0 = no limit")
|
||||
shuffle: bool = Field(default=True)
|
||||
seed: int = Field(default=0)
|
||||
|
||||
python_only: bool = Field(default=True, description="Filter to Python-evaluable rows")
|
||||
score_include_fail_to_pass: bool = Field(
|
||||
default=True,
|
||||
description="Score tests on PASS_TO_PASS ∪ FAIL_TO_PASS. "
|
||||
"Disable to only run PASS_TO_PASS (faster but weaker signal).",
|
||||
)
|
||||
|
||||
prompt_mode: str = Field(
|
||||
default="problem_statement",
|
||||
description="'problem_statement' (fast) or 'problem_statement+text' (includes dataset 'text').",
|
||||
)
|
||||
|
||||
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
||||
install_timeout_s: float = Field(default=600.0)
|
||||
test_timeout_s: float = Field(default=600.0)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment
|
||||
# =============================================================================
|
||||
|
||||
class SweSmithOracleEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
SWE-smith-oracle environment for training models to fix real GitHub repos.
|
||||
|
||||
Uses proper OpenAI-spec tool calling via HermesAgentBaseEnv.
|
||||
The model gets terminal access to inspect, edit, and test the repository.
|
||||
"""
|
||||
|
||||
name = "swe-smith-oracle"
|
||||
env_config_cls = SweSmithOracleEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SweSmithOracleEnvConfig,
|
||||
server_configs,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._dataset = None
|
||||
self._indices: List[int] = []
|
||||
self._cursor = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
||||
"""Default config — reads from ATROPOS_SERVER_* env vars."""
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
if not base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/") + "/v1"
|
||||
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "Hermes-4.3-36B"
|
||||
api_key = (
|
||||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("NOUS_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or "local"
|
||||
)
|
||||
|
||||
env_config = SweSmithOracleEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=1,
|
||||
max_token_length=8192,
|
||||
wandb_name="swe_smith_oracle",
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
terminal_backend=os.getenv("TERMINAL_ENV", "local"),
|
||||
# Longer agent turns for SWE tasks
|
||||
max_agent_turns=50,
|
||||
agent_temperature=0.7,
|
||||
system_prompt=(
|
||||
"You are a senior software engineer. You have access to a terminal "
|
||||
"to inspect and fix repositories. Use non-interactive commands only. "
|
||||
"Each terminal command runs in a fresh shell."
|
||||
),
|
||||
tool_call_parser="hermes",
|
||||
# Sandbox settings (used when tool_pool_mode != "default")
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
server_type="vllm",
|
||||
health_check=False,
|
||||
timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"),
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Dataset loading
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Load SWE-smith-oracle dataset."""
|
||||
from datasets import load_dataset
|
||||
|
||||
t0 = time.perf_counter()
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})",
|
||||
flush=True,
|
||||
)
|
||||
ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self._dataset = ds
|
||||
|
||||
indices: List[int] = []
|
||||
for idx in range(len(ds)):
|
||||
row = ds[idx]
|
||||
if self.config.python_only and not self._is_python_row(row):
|
||||
continue
|
||||
indices.append(idx)
|
||||
|
||||
if self.config.shuffle:
|
||||
rnd = random.Random(self.config.seed)
|
||||
rnd.shuffle(indices)
|
||||
|
||||
if self.config.max_items and self.config.max_items > 0:
|
||||
indices = indices[: self.config.max_items]
|
||||
|
||||
self._indices = indices
|
||||
self._cursor = 0
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loaded {len(self._indices)} items "
|
||||
f"in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _is_python_row(self, row: Dict[str, Any]) -> bool:
|
||||
nodeids = row.get("PASS_TO_PASS")
|
||||
if not isinstance(nodeids, list) or not nodeids:
|
||||
return False
|
||||
return all(isinstance(nid, str) and ".py::" in nid for nid in nodeids)
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
if not self._dataset or not self._indices:
|
||||
raise RuntimeError("Dataset not initialized")
|
||||
if self._cursor >= len(self._indices):
|
||||
self._cursor = 0
|
||||
idx = self._indices[self._cursor]
|
||||
self._cursor += 1
|
||||
return dict(self._dataset[idx])
|
||||
|
||||
# =========================================================================
|
||||
# Prompt formatting
|
||||
# =========================================================================
|
||||
|
||||
def _repo_name(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
if isinstance(repo, str) and "/" in repo:
|
||||
return repo.split("/")[-1]
|
||||
return "repo"
|
||||
|
||||
def format_prompt(self, item: Item) -> str:
|
||||
"""Build the SWE task prompt."""
|
||||
repo = item.get("repo") or ""
|
||||
base_commit = item.get("base_commit") or ""
|
||||
problem = str(item.get("problem_statement") or "")
|
||||
context = str(item.get("text") or "")
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
tests_list = "\n".join(f"- {t}" for t in nodeids)
|
||||
|
||||
context_block = ""
|
||||
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
|
||||
if prompt_mode == "problem_statement+text" and context:
|
||||
context_block = f"\nAdditional context:\n{context}\n"
|
||||
|
||||
return (
|
||||
f"Fix the repository so the specified tests pass.\n\n"
|
||||
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
||||
f"Workspace path: ./{repo_dir}\n\n"
|
||||
"Constraints:\n"
|
||||
"- Use the terminal tool to inspect, edit, and verify the repository.\n"
|
||||
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
|
||||
"- Use a workspace-local virtualenv (.venv) to avoid cross-run contamination.\n"
|
||||
"- Use non-interactive commands only.\n"
|
||||
"- Prefer `. .venv/bin/activate` or `.venv/bin/python ...` (POSIX compatible).\n\n"
|
||||
f"Problem statement:\n{problem}\n\n"
|
||||
f"{context_block}"
|
||||
f"Run these tests to verify:\n{tests_list}\n\n"
|
||||
"When done, briefly describe what you changed and confirm tests pass."
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Test helpers
|
||||
# =========================================================================
|
||||
|
||||
def _tests_for_item(self, item: Item) -> List[str]:
|
||||
tests: List[str] = []
|
||||
if self.config.score_include_fail_to_pass:
|
||||
for key in ("PASS_TO_PASS", "FAIL_TO_PASS"):
|
||||
nodeids = item.get(key)
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
else:
|
||||
nodeids = item.get("PASS_TO_PASS")
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
return sorted(dict.fromkeys(tests))
|
||||
|
||||
def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]:
|
||||
return [nodeids[i : i + max_per_chunk] for i in range(0, len(nodeids), max_per_chunk)]
|
||||
|
||||
# =========================================================================
|
||||
# Sandbox hooks: setup_trajectory_workspace + verify_and_score_trajectory
|
||||
# =========================================================================
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self, item: Item, *, trajectory_id: str, exec_tool
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare a sandbox workspace: bare repo cache + git worktree.
|
||||
|
||||
Uses flock-serialized bare repo cache under /data/repo_cache so
|
||||
multiple trajectories sharing a sandbox don't clone the same repo
|
||||
in parallel. Each trajectory gets an isolated worktree at the
|
||||
specified base_commit.
|
||||
|
||||
Args:
|
||||
item: Dataset row with repo, base_commit, etc.
|
||||
trajectory_id: Unique trajectory ID
|
||||
exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult
|
||||
|
||||
Returns:
|
||||
Dict with repo_dir, base_commit metadata
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
t0 = _time.perf_counter()
|
||||
repo = item.get("repo")
|
||||
base_commit = item.get("base_commit")
|
||||
instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id")
|
||||
if not isinstance(repo, str) or not isinstance(base_commit, str):
|
||||
raise RuntimeError("Invalid dataset row: missing repo/base_commit")
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git"
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Bare repo cache + worktree strategy (same as atropos/envs/swe_smith_oracle_env.py)
|
||||
repo_slug = repo.replace("/", "__")
|
||||
cache_root = "/data/repo_cache"
|
||||
bare_repo = f"{cache_root}/{repo_slug}.git"
|
||||
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
|
||||
|
||||
worktree_cmd = (
|
||||
"set -e; "
|
||||
f"rm -rf {repo_dir}; "
|
||||
f"mkdir -p {cache_root}/.locks; "
|
||||
f": > {lock_file}; "
|
||||
f"flock -x {lock_file} sh -lc '"
|
||||
f"set -e; "
|
||||
"export GIT_TERMINAL_PROMPT=0; "
|
||||
"export GIT_LFS_SKIP_SMUDGE=1; "
|
||||
f"if [ ! -d \"{bare_repo}\" ]; then "
|
||||
f" git init --bare \"{bare_repo}\"; "
|
||||
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
|
||||
"fi; "
|
||||
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
|
||||
f"git -C \"{bare_repo}\" worktree prune || true; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
|
||||
"fi; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --prune origin; "
|
||||
"fi; "
|
||||
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
|
||||
"'"
|
||||
)
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
|
||||
res = await exec_tool(
|
||||
"bash",
|
||||
{"command": worktree_cmd},
|
||||
timeout=self.config.install_timeout_s,
|
||||
)
|
||||
if not res.success:
|
||||
raise RuntimeError(
|
||||
f"git worktree setup failed "
|
||||
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): "
|
||||
f"{res.error}\n{res.output}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"worktree ready in {_time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
In-sandbox verification: install deps + run pytest with dataset nodeids.
|
||||
|
||||
Args:
|
||||
item: Dataset row
|
||||
result: Agent's rollout result
|
||||
trajectory_id: Unique trajectory ID
|
||||
exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult
|
||||
workspace_meta: From setup_trajectory_workspace (has repo_dir)
|
||||
|
||||
Returns:
|
||||
(reward, metadata) tuple
|
||||
"""
|
||||
repo_dir = (workspace_meta or {}).get("repo_dir") or self._repo_name(item)
|
||||
|
||||
# Don't reward trajectories that never used tools
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
if tool_call_count == 0:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify: no tool calls; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {"error": "No tool calls were made by the agent"}
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
return 0.0, {"error": "No tests provided"}
|
||||
|
||||
# Install dependencies
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify: installing deps + running tests",
|
||||
flush=True,
|
||||
)
|
||||
setup_cmd = (
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest"
|
||||
)
|
||||
setup_res = await exec_tool(
|
||||
"bash", {"command": setup_cmd}, timeout=self.config.install_timeout_s
|
||||
)
|
||||
if not setup_res.success:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} install failed; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"phase": "install",
|
||||
"error": setup_res.error,
|
||||
"output": setup_res.output,
|
||||
}
|
||||
|
||||
# Run test chunks
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}"
|
||||
res = await exec_tool(
|
||||
"bash", {"command": cmd}, timeout=self.config.test_timeout_s
|
||||
)
|
||||
if not res.success:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} tests failed (chunk {chunk_idx}); score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"phase": "pytest",
|
||||
"failed_chunk": chunk_idx,
|
||||
"error": res.error,
|
||||
"output": res.output,
|
||||
}
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} all tests passed; score=1.0",
|
||||
flush=True,
|
||||
)
|
||||
return 1.0, {"passed": True}
|
||||
|
||||
# =========================================================================
|
||||
# Reward: run pytest in the terminal (local / non-sandbox path)
|
||||
# =========================================================================
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Item, result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Verify by running pytest with the dataset's nodeids.
|
||||
|
||||
Reward structure (shaped to give training signal even when model can't solve tasks):
|
||||
- 0.0: No tool calls at all
|
||||
- 0.05: Per valid tool call (up to 0.3 max for tool-call shaping)
|
||||
- 0.4: Successfully installed deps
|
||||
- 1.0: All tests pass
|
||||
|
||||
The partial rewards for tool calls help the model learn to USE tools
|
||||
before it can learn to use them CORRECTLY. This is critical for cold-start
|
||||
training where the base model barely makes any tool calls.
|
||||
"""
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
# Count tool calls (assistant messages that have tool_calls).
|
||||
# NOTE: we keep scoring policy here intentionally simple and env-specific.
|
||||
# The agent loop exposes additional tool-call metrics (attempted/schema_valid/
|
||||
# executed_ok/exec_error) that other environments may choose to use for
|
||||
# reward shaping, but we don't hard-require any particular calling format here.
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
|
||||
if tool_call_count == 0:
|
||||
print(f"[SweSmithOracleEnv] No tool calls made; score=0.0", flush=True)
|
||||
return 0.0
|
||||
|
||||
# Partial reward: 0.05 per tool call, capped at 0.3
|
||||
tool_call_reward = min(tool_call_count * 0.05, 0.3)
|
||||
|
||||
# Debug: log tool-call quality metrics if present
|
||||
attempted = getattr(result, "tool_calls_attempted", None)
|
||||
schema_valid = getattr(result, "tool_calls_schema_valid", None)
|
||||
executed_ok = getattr(result, "tool_calls_executed_ok", None)
|
||||
exec_error = getattr(result, "tool_calls_exec_error", None)
|
||||
if attempted is not None:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] Tool calls: total={tool_call_count}, attempted={attempted}, schema_valid={schema_valid}, ok={executed_ok}, err={exec_error}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
# No tests defined — just reward tool usage
|
||||
print(f"[SweSmithOracleEnv] No tests defined; score={tool_call_reward:.2f} (tool calls)", flush=True)
|
||||
return tool_call_reward
|
||||
|
||||
# Install deps + run tests
|
||||
print(f"[SweSmithOracleEnv] Verifying: installing deps + running tests", flush=True)
|
||||
setup_result = ctx.terminal(
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest",
|
||||
timeout=int(self.config.install_timeout_s),
|
||||
)
|
||||
if setup_result.get("exit_code", 1) != 0:
|
||||
print(f"[SweSmithOracleEnv] Install failed; score={tool_call_reward:.2f} (tool calls only)", flush=True)
|
||||
return tool_call_reward
|
||||
|
||||
# Partial reward for successful install
|
||||
install_reward = 0.4
|
||||
|
||||
# Run test chunks
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
test_result = ctx.terminal(
|
||||
f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}",
|
||||
timeout=int(self.config.test_timeout_s),
|
||||
)
|
||||
if test_result.get("exit_code", 1) != 0:
|
||||
print(f"[SweSmithOracleEnv] Tests failed (chunk {chunk_idx}); score={install_reward:.2f} (install ok)", flush=True)
|
||||
return install_reward
|
||||
|
||||
print(f"[SweSmithOracleEnv] All tests passed; score=1.0", flush=True)
|
||||
return 1.0
|
||||
|
||||
# =========================================================================
|
||||
# Token truncation — keep start of trajectory, truncate from end
|
||||
# =========================================================================
|
||||
|
||||
def _build_scored_item(self, item, result, reward):
|
||||
"""
|
||||
Override to truncate tokens/masks from the END to fit within max_token_len.
|
||||
|
||||
Intuition (from NeurIPS finding): the start of the trajectory is most important
|
||||
for shifting the model distribution. Truncating from the end only costs ~2-3%
|
||||
vs handling the full sequence, but avoids the "Token length is too long" discard
|
||||
that throws away entire groups including valid training signal.
|
||||
"""
|
||||
scored_item, remaining = super()._build_scored_item(item, result, reward)
|
||||
if scored_item is None:
|
||||
return scored_item, remaining
|
||||
|
||||
# Use config.max_token_length as the truncation limit.
|
||||
# self.max_token_len comes from the trainer via /info, but may be -1
|
||||
# if the trainer hasn't registered yet (race condition).
|
||||
max_len = self.max_token_len
|
||||
if max_len <= 0:
|
||||
# Fallback to config value
|
||||
max_len = getattr(self.config, 'max_token_length', 0)
|
||||
if max_len <= 0:
|
||||
return scored_item, remaining
|
||||
|
||||
# Leave some margin (64 tokens) to avoid edge cases with padding alignment
|
||||
truncate_to = max_len - 64
|
||||
|
||||
tokens = scored_item.get("tokens")
|
||||
masks = scored_item.get("masks")
|
||||
|
||||
if tokens is not None and len(tokens) >= max_len:
|
||||
orig_len = len(tokens)
|
||||
scored_item["tokens"] = tokens[:truncate_to]
|
||||
if masks is not None and len(masks) >= max_len:
|
||||
scored_item["masks"] = masks[:truncate_to]
|
||||
logger.info(
|
||||
"Truncated trajectory from %d to %d tokens (max_token_len=%d)",
|
||||
orig_len, truncate_to, max_len,
|
||||
)
|
||||
|
||||
return scored_item, remaining
|
||||
|
||||
# =========================================================================
|
||||
# Evaluation (minimal for now)
|
||||
# =========================================================================
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Placeholder evaluation — SWE tasks are too expensive for frequent eval."""
|
||||
start_time = time.time()
|
||||
await self.evaluate_log(
|
||||
metrics={"eval/placeholder": 0.0},
|
||||
samples=[],
|
||||
start_time=start_time,
|
||||
end_time=time.time(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SweSmithOracleEnv.cli()
|
||||
292
environments/terminal_test_env.py
Normal file
292
environments/terminal_test_env.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""
|
||||
TerminalTestEnv -- Simple Test Environment for Validating the Stack
|
||||
|
||||
A self-contained environment with inline tasks (no external dataset needed).
|
||||
Each task asks the model to create a file at a known path with specific content.
|
||||
The reward verifier cats the file and checks if the content matches.
|
||||
|
||||
Enables only terminal + file toolsets. Uses Modal terminal backend with
|
||||
OpenRouter (Claude) by default.
|
||||
|
||||
Training tasks (3):
|
||||
1. Create ~/greeting.txt with "Hello from Hermes Agent"
|
||||
2. Create ~/count.txt with numbers 1-5, one per line
|
||||
3. Create ~/answer.txt with the result of 123 + 456
|
||||
|
||||
Eval task (1):
|
||||
1. Create ~/result.txt with the result of 6 * 7
|
||||
|
||||
Usage:
|
||||
# Start Atropos API server
|
||||
run-api
|
||||
|
||||
# Run environment (uses OpenRouter + Modal by default)
|
||||
python environments/terminal_test_env.py serve
|
||||
|
||||
# Process mode (no run-api needed, saves to JSONL)
|
||||
python environments/terminal_test_env.py process \\
|
||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Inline task definitions -- no external dataset needed
|
||||
# =============================================================================
|
||||
|
||||
TRAIN_TASKS = [
|
||||
{
|
||||
"prompt": "Create a file at ~/greeting.txt containing exactly the text: Hello from Hermes Agent",
|
||||
"verify_path": "~/greeting.txt",
|
||||
"expected_content": "Hello from Hermes Agent",
|
||||
},
|
||||
{
|
||||
"prompt": "Create a file at ~/count.txt containing the numbers 1 through 5, one per line",
|
||||
"verify_path": "~/count.txt",
|
||||
"expected_content": "1\n2\n3\n4\n5",
|
||||
},
|
||||
{
|
||||
"prompt": "Create a file at ~/answer.txt containing the result of 123 + 456",
|
||||
"verify_path": "~/answer.txt",
|
||||
"expected_content": "579",
|
||||
},
|
||||
]
|
||||
|
||||
EVAL_TASKS = [
|
||||
{
|
||||
"prompt": "Create a file at ~/result.txt containing the result of 6 * 7",
|
||||
"verify_path": "~/result.txt",
|
||||
"expected_content": "42",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
class TerminalTestEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults suitable for terminal testing."""
|
||||
|
||||
pass # Inherits all fields, overrides defaults in config_init
|
||||
|
||||
|
||||
class TerminalTestEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
Simple test environment with inline file-creation tasks.
|
||||
|
||||
All tasks follow the same pattern: "create a file at ~/X.txt with content Y".
|
||||
The verifier runs `cat ~/X.txt` in the rollout's terminal and checks the output
|
||||
against the expected string. Same verifier logic for all tasks.
|
||||
|
||||
This environment is designed to validate the full stack end-to-end:
|
||||
- Agent loop executes tool calls (terminal/file)
|
||||
- ToolContext provides terminal access to the reward function
|
||||
- Reward function verifies file content via cat
|
||||
- Scored data flows through the Atropos pipeline
|
||||
"""
|
||||
|
||||
name = "terminal-test"
|
||||
env_config_cls = TerminalTestEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TerminalTestEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for the terminal test environment.
|
||||
|
||||
Uses Modal terminal backend for cloud isolation and OpenRouter with
|
||||
Claude for inference. API keys loaded from ~/hermes-agent/.env.
|
||||
"""
|
||||
env_config = TerminalTestEnvConfig(
|
||||
# Terminal + file tools only
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings
|
||||
max_agent_turns=10, # Simple tasks, don't need many turns
|
||||
max_token_length=16000,
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a helpful assistant with access to a terminal and file tools. "
|
||||
"Complete the user's request by using the available tools. "
|
||||
"Be precise and follow instructions exactly."
|
||||
),
|
||||
# Modal terminal backend for cloud-isolated sandboxes per rollout
|
||||
terminal_backend="modal",
|
||||
# Atropos settings
|
||||
group_size=3, # 3 rollouts per group
|
||||
tokenizer_name="NousResearch/q-30b-t-h45-e1",
|
||||
tool_call_parser="hermes",
|
||||
steps_per_eval=3, # Eval after all 3 steps
|
||||
total_steps=3, # 3 groups total (1 group per step)
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-test",
|
||||
ensure_scores_are_not_same=False, # Allow all-same scores for simple tasks
|
||||
# No external dataset
|
||||
dataset_name=None,
|
||||
)
|
||||
|
||||
# OpenRouter with Claude -- API key loaded from .env (OPENROUTER_API_KEY)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-opus-4.6",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False, # OpenRouter doesn't have a /health endpoint
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Initialize inline task lists."""
|
||||
self.train_tasks = list(TRAIN_TASKS)
|
||||
self.eval_tasks = list(EVAL_TASKS)
|
||||
self.iter = 0
|
||||
# Track reward stats for wandb logging
|
||||
self.reward_buffer: List[float] = []
|
||||
|
||||
async def get_next_item(self) -> Dict[str, str]:
|
||||
"""Cycle through training tasks."""
|
||||
item = self.train_tasks[self.iter % len(self.train_tasks)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, str]) -> str:
|
||||
"""The prompt is directly in the task item."""
|
||||
return item["prompt"]
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, str], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Verify by cat-ing the expected file path and checking content matches.
|
||||
Same verifier for all tasks -- they all write a file at a known path.
|
||||
|
||||
Scoring:
|
||||
1.0 = exact match
|
||||
0.5 = expected content is present but has extra stuff
|
||||
0.0 = file doesn't exist or content doesn't match
|
||||
"""
|
||||
verify_result = ctx.terminal(f"cat {item['verify_path']}")
|
||||
|
||||
# File doesn't exist or can't be read
|
||||
if verify_result["exit_code"] != 0:
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
actual = verify_result.get("output", "").strip()
|
||||
expected = item["expected_content"].strip()
|
||||
|
||||
# Exact match
|
||||
if actual == expected:
|
||||
self.reward_buffer.append(1.0)
|
||||
return 1.0
|
||||
|
||||
# Partial credit: expected content is present but has extra stuff
|
||||
if expected in actual:
|
||||
self.reward_buffer.append(0.5)
|
||||
return 0.5
|
||||
|
||||
self.reward_buffer.append(0.0)
|
||||
return 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Run eval tasks using the agent loop and verify results.
|
||||
Logs accuracy metrics.
|
||||
"""
|
||||
start_time = time.time()
|
||||
correct = 0
|
||||
total = len(self.eval_tasks)
|
||||
samples = []
|
||||
|
||||
for eval_item in self.eval_tasks:
|
||||
try:
|
||||
# For eval, we do a simple single-turn completion (not full agent loop)
|
||||
# to keep eval fast. The agent loop is tested via training.
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": eval_item["prompt"]},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_content = (
|
||||
completion.choices[0].message.content if completion.choices else ""
|
||||
)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"prompt": eval_item["prompt"],
|
||||
"response": response_content,
|
||||
"expected": eval_item["expected_content"],
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Eval failed for item: %s", e)
|
||||
samples.append(
|
||||
{
|
||||
"prompt": eval_item["prompt"],
|
||||
"response": f"ERROR: {e}",
|
||||
"expected": eval_item["expected_content"],
|
||||
}
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
eval_metrics = {
|
||||
"eval/num_samples": total,
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log training metrics including reward stats and accuracy."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
total = len(self.reward_buffer)
|
||||
correct = sum(1 for r in self.reward_buffer if r == 1.0)
|
||||
partial = sum(1 for r in self.reward_buffer if r == 0.5)
|
||||
|
||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / total
|
||||
wandb_metrics["train/accuracy"] = correct / total
|
||||
wandb_metrics["train/partial_match_rate"] = partial / total
|
||||
wandb_metrics["train/total_rollouts"] = total
|
||||
self.reward_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TerminalTestEnv.cli()
|
||||
120
environments/tool_call_parsers/__init__.py
Normal file
120
environments/tool_call_parsers/__init__.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Tool Call Parser Registry
|
||||
|
||||
Client-side parsers that extract structured tool_calls from raw model output text.
|
||||
Used in Phase 2 (VLLM server type) where ManagedServer's /generate endpoint returns
|
||||
raw text without tool call parsing.
|
||||
|
||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's
|
||||
non-streaming extract_tool_calls() logic. No VLLM dependency -- only standard library
|
||||
(re, json, uuid) and openai types.
|
||||
|
||||
Usage:
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse(raw_model_output)
|
||||
# content = text with tool call markup stripped
|
||||
# tool_calls = list of ChatCompletionMessageToolCall objects, or None
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for parser return value
|
||||
ParseResult = Tuple[Optional[str], Optional[List[ChatCompletionMessageToolCall]]]
|
||||
|
||||
|
||||
class ToolCallParser(ABC):
|
||||
"""
|
||||
Base class for tool call parsers.
|
||||
|
||||
Each parser knows how to extract structured tool_calls from a specific
|
||||
model family's raw output text format.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
"""
|
||||
Parse raw model output text for tool calls.
|
||||
|
||||
Args:
|
||||
text: Raw decoded text from the model's completion
|
||||
|
||||
Returns:
|
||||
Tuple of (content, tool_calls) where:
|
||||
- content: text with tool call markup stripped (the message 'content' field),
|
||||
or None if the entire output was tool calls
|
||||
- tool_calls: list of ChatCompletionMessageToolCall objects,
|
||||
or None if no tool calls were found
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Global parser registry: name -> parser class
|
||||
PARSER_REGISTRY: Dict[str, Type[ToolCallParser]] = {}
|
||||
|
||||
|
||||
def register_parser(name: str):
|
||||
"""
|
||||
Decorator to register a parser class under a given name.
|
||||
|
||||
Usage:
|
||||
@register_parser("hermes")
|
||||
class HermesToolCallParser(ToolCallParser):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(cls: Type[ToolCallParser]) -> Type[ToolCallParser]:
|
||||
PARSER_REGISTRY[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_parser(name: str) -> ToolCallParser:
|
||||
"""
|
||||
Get a parser instance by name.
|
||||
|
||||
Args:
|
||||
name: Parser name (e.g., "hermes", "mistral", "llama3_json")
|
||||
|
||||
Returns:
|
||||
Instantiated parser
|
||||
|
||||
Raises:
|
||||
KeyError: If parser name is not found in registry
|
||||
"""
|
||||
if name not in PARSER_REGISTRY:
|
||||
available = sorted(PARSER_REGISTRY.keys())
|
||||
raise KeyError(
|
||||
f"Tool call parser '{name}' not found. Available parsers: {available}"
|
||||
)
|
||||
return PARSER_REGISTRY[name]()
|
||||
|
||||
|
||||
def list_parsers() -> List[str]:
|
||||
"""Return sorted list of registered parser names."""
|
||||
return sorted(PARSER_REGISTRY.keys())
|
||||
|
||||
|
||||
# Import all parser modules to trigger registration via @register_parser decorators
|
||||
# Each module registers itself when imported
|
||||
from environments.tool_call_parsers.hermes_parser import HermesToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.longcat_parser import LongcatToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.mistral_parser import MistralToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.llama_parser import LlamaToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.qwen_parser import QwenToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.deepseek_v3_parser import DeepSeekV3ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.deepseek_v3_1_parser import DeepSeekV31ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.kimi_k2_parser import KimiK2ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.glm45_parser import Glm45ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.glm47_parser import Glm47ToolCallParser # noqa: E402, F401
|
||||
from environments.tool_call_parsers.qwen3_coder_parser import Qwen3CoderToolCallParser # noqa: E402, F401
|
||||
71
environments/tool_call_parsers/deepseek_v3_1_parser.py
Normal file
71
environments/tool_call_parsers/deepseek_v3_1_parser.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
DeepSeek V3.1 tool call parser.
|
||||
|
||||
Similar to V3 but with a slightly different format:
|
||||
<|tool▁call▁begin|>function_name<|tool▁sep|>arguments<|tool▁call▁end|>
|
||||
|
||||
Note: V3 has type+name before the separator, V3.1 has name before and args after.
|
||||
|
||||
Based on VLLM's DeepSeekV31ToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("deepseek_v3_1")
|
||||
@register_parser("deepseek_v31")
|
||||
class DeepSeekV31ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for DeepSeek V3.1 tool calls.
|
||||
|
||||
Slightly different regex than V3: function_name comes before the separator,
|
||||
arguments come after (no type field, no json code block wrapper).
|
||||
"""
|
||||
|
||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
||||
|
||||
# Regex captures: function_name, function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>"
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
func_name, func_args = match
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name.strip(),
|
||||
arguments=func_args.strip(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
75
environments/tool_call_parsers/deepseek_v3_parser.py
Normal file
75
environments/tool_call_parsers/deepseek_v3_parser.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
DeepSeek V3 tool call parser.
|
||||
|
||||
Format uses special unicode tokens:
|
||||
<|tool▁calls▁begin|>
|
||||
<|tool▁call▁begin|>type<|tool▁sep|>function_name
|
||||
```json
|
||||
{"arg": "value"}
|
||||
```
|
||||
<|tool▁call▁end|>
|
||||
<|tool▁calls▁end|>
|
||||
|
||||
Based on VLLM's DeepSeekV3ToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("deepseek_v3")
|
||||
class DeepSeekV3ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for DeepSeek V3 tool calls.
|
||||
|
||||
Uses special unicode tokens with fullwidth angle brackets and block elements.
|
||||
Extracts type, function name, and JSON arguments from the structured format.
|
||||
"""
|
||||
|
||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
||||
|
||||
# Regex captures: type, function_name, function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<|tool▁call▁begin|>(?P<type>.*)<|tool▁sep|>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<|tool▁call▁end|>"
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
tc_type, func_name, func_args = match
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name.strip(),
|
||||
arguments=func_args.strip(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the tool calls section
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
109
environments/tool_call_parsers/glm45_parser.py
Normal file
109
environments/tool_call_parsers/glm45_parser.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
GLM 4.5 (GLM-4-MoE) tool call parser.
|
||||
|
||||
Format uses custom arg_key/arg_value tags rather than standard JSON:
|
||||
<tool_call>function_name
|
||||
<arg_key>param1</arg_key><arg_value>value1</arg_value>
|
||||
<arg_key>param2</arg_key><arg_value>value2</arg_value>
|
||||
</tool_call>
|
||||
|
||||
Values are deserialized using json.loads -> ast.literal_eval -> raw string fallback.
|
||||
|
||||
Based on VLLM's Glm4MoeModelToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _deserialize_value(value: str) -> Any:
|
||||
"""
|
||||
Try to deserialize a string value to its native Python type.
|
||||
Attempts json.loads, then ast.literal_eval, then returns raw string.
|
||||
"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
try:
|
||||
return ast.literal_eval(value)
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@register_parser("glm45")
|
||||
class Glm45ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for GLM 4.5 (GLM-4-MoE) tool calls.
|
||||
|
||||
Uses <tool_call>...</tool_call> tags with <arg_key>/<arg_value> pairs
|
||||
instead of standard JSON arguments.
|
||||
"""
|
||||
|
||||
FUNC_CALL_REGEX = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL)
|
||||
FUNC_DETAIL_REGEX = re.compile(r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL)
|
||||
FUNC_ARG_REGEX = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL
|
||||
)
|
||||
|
||||
START_TOKEN = "<tool_call>"
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.START_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matched_calls = self.FUNC_CALL_REGEX.findall(text)
|
||||
if not matched_calls:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
for match in matched_calls:
|
||||
detail = self.FUNC_DETAIL_REGEX.search(match)
|
||||
if not detail:
|
||||
continue
|
||||
|
||||
func_name = detail.group(1).strip()
|
||||
func_args_raw = detail.group(2)
|
||||
|
||||
# Parse arg_key/arg_value pairs
|
||||
pairs = self.FUNC_ARG_REGEX.findall(func_args_raw) if func_args_raw else []
|
||||
arg_dict: Dict[str, Any] = {}
|
||||
for key, value in pairs:
|
||||
arg_key = key.strip()
|
||||
arg_val = _deserialize_value(value.strip())
|
||||
arg_dict[arg_key] = arg_val
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name,
|
||||
arguments=json.dumps(arg_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find(self.START_TOKEN)].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
35
environments/tool_call_parsers/glm47_parser.py
Normal file
35
environments/tool_call_parsers/glm47_parser.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
GLM 4.7 tool call parser.
|
||||
|
||||
Same as GLM 4.5 but with slightly different regex patterns.
|
||||
The tool_call tags may wrap differently and arg parsing handles
|
||||
newlines between key/value pairs.
|
||||
|
||||
Based on VLLM's Glm47MoeModelToolParser (extends Glm4MoeModelToolParser).
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, register_parser
|
||||
from environments.tool_call_parsers.glm45_parser import Glm45ToolCallParser
|
||||
|
||||
|
||||
@register_parser("glm47")
|
||||
class Glm47ToolCallParser(Glm45ToolCallParser):
|
||||
"""
|
||||
Parser for GLM 4.7 tool calls.
|
||||
Extends GLM 4.5 with updated regex patterns.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# GLM 4.7 uses a slightly different detail regex that includes
|
||||
# the <tool_call> wrapper and optional arg_key content
|
||||
self.FUNC_DETAIL_REGEX = re.compile(
|
||||
r"<tool_call>(.*?)(<arg_key>.*?)?</tool_call>", re.DOTALL
|
||||
)
|
||||
# GLM 4.7 handles newlines between arg_key and arg_value tags
|
||||
self.FUNC_ARG_REGEX = re.compile(
|
||||
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
|
||||
re.DOTALL,
|
||||
)
|
||||
80
environments/tool_call_parsers/hermes_parser.py
Normal file
80
environments/tool_call_parsers/hermes_parser.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Hermes tool call parser.
|
||||
|
||||
Format: <tool_call>{"name": "func", "arguments": {...}}</tool_call>
|
||||
Based on VLLM's Hermes2ProToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("hermes")
|
||||
class HermesToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Hermes-format tool calls.
|
||||
|
||||
Matches <tool_call>...</tool_call> tags containing JSON with "name" and "arguments".
|
||||
Also handles unclosed <tool_call> at end-of-string (truncated generation).
|
||||
"""
|
||||
|
||||
# Matches both closed and unclosed tool_call tags
|
||||
PATTERN = re.compile(
|
||||
r"<tool_call>\s*(.*?)\s*</tool_call>|<tool_call>\s*(.*)", re.DOTALL
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if "<tool_call>" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
# match is a tuple: (closed_content, unclosed_content)
|
||||
raw_json = match[0] if match[0] else match[1]
|
||||
if not raw_json.strip():
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
# Handle arguments: could be dict or already a JSON string
|
||||
raw_args = tc_data.get("arguments", {})
|
||||
if isinstance(raw_args, str):
|
||||
# Already a string — pass through as-is.
|
||||
# It may be a JSON string ("{...}") or a plain string ("ls").
|
||||
args_str = raw_args
|
||||
else:
|
||||
# Dict — serialize to JSON
|
||||
args_str = json.dumps(raw_args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=args_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the first <tool_call> tag
|
||||
content = text[: text.find("<tool_call>")].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
93
environments/tool_call_parsers/kimi_k2_parser.py
Normal file
93
environments/tool_call_parsers/kimi_k2_parser.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Kimi K2 tool call parser.
|
||||
|
||||
Format:
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>function_id:0<|tool_call_argument_begin|>{"arg": "val"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
|
||||
The function_id format is typically "functions.func_name:index" or "func_name:index".
|
||||
|
||||
Based on VLLM's KimiK2ToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("kimi_k2")
|
||||
class KimiK2ToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Kimi K2 tool calls.
|
||||
|
||||
Uses section begin/end tokens wrapping individual tool call begin/end tokens.
|
||||
The tool_call_id contains the function name (after last dot, before colon).
|
||||
"""
|
||||
|
||||
# Support both singular and plural variants
|
||||
START_TOKENS = [
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_call_section_begin|>",
|
||||
]
|
||||
|
||||
# Regex captures: tool_call_id (e.g., "functions.get_weather:0"), function_arguments
|
||||
PATTERN = re.compile(
|
||||
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[^<]+:\d+)\s*"
|
||||
r"<\|tool_call_argument_begin\|>\s*"
|
||||
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
|
||||
r"<\|tool_call_end\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
# Check for any variant of the start token
|
||||
has_start = any(token in text for token in self.START_TOKENS)
|
||||
if not has_start:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
function_id, function_args = match
|
||||
|
||||
# Extract function name from ID format: "functions.get_weather:0" -> "get_weather"
|
||||
function_name = function_id.split(":")[0].split(".")[-1]
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=function_id, # Preserve the original ID format
|
||||
type="function",
|
||||
function=Function(
|
||||
name=function_name,
|
||||
arguments=function_args.strip(),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the tool calls section
|
||||
earliest_start = len(text)
|
||||
for token in self.START_TOKENS:
|
||||
idx = text.find(token)
|
||||
if idx >= 0 and idx < earliest_start:
|
||||
earliest_start = idx
|
||||
|
||||
content = text[:earliest_start].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
96
environments/tool_call_parsers/llama_parser.py
Normal file
96
environments/tool_call_parsers/llama_parser.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Llama 3.x / 4 tool call parser.
|
||||
|
||||
Format: The model outputs JSON objects with "name" and "arguments" (or "parameters") keys.
|
||||
May be preceded by <|python_tag|> token. Supports multiple JSON objects separated
|
||||
by content or semicolons.
|
||||
|
||||
Based on VLLM's Llama3JsonToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("llama3_json")
|
||||
@register_parser("llama4_json")
|
||||
class LlamaToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Llama 3.x and 4 JSON-format tool calls.
|
||||
|
||||
Finds JSON objects containing "name" + ("arguments" or "parameters") keys.
|
||||
Uses Python's json.JSONDecoder.raw_decode for robust extraction of
|
||||
JSON objects from mixed text.
|
||||
"""
|
||||
|
||||
BOT_TOKEN = "<|python_tag|>"
|
||||
|
||||
# Regex to find the start of potential JSON objects
|
||||
JSON_START = re.compile(r"\{")
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
# Quick check: need either the bot token or a JSON brace
|
||||
if self.BOT_TOKEN not in text and "{" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
decoder = json.JSONDecoder()
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
end_index = -1 # Track where the last parsed JSON ended
|
||||
|
||||
for match in self.JSON_START.finditer(text):
|
||||
start = match.start()
|
||||
# Skip if this brace is inside a previously parsed JSON object
|
||||
if start <= end_index:
|
||||
continue
|
||||
|
||||
try:
|
||||
obj, json_end = decoder.raw_decode(text[start:])
|
||||
end_index = start + json_end
|
||||
|
||||
# Must have "name" and either "arguments" or "parameters"
|
||||
name = obj.get("name")
|
||||
args = obj.get("arguments", obj.get("parameters"))
|
||||
|
||||
if not name or args is None:
|
||||
continue
|
||||
|
||||
# Normalize arguments to JSON string
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
elif not isinstance(args, str):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(name=name, arguments=args),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, ValueError):
|
||||
continue
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content is everything before the first tool call JSON
|
||||
# Find where the first tool call starts in the text
|
||||
first_tc_start = text.find("{")
|
||||
if self.BOT_TOKEN in text:
|
||||
first_tc_start = text.find(self.BOT_TOKEN)
|
||||
content = text[:first_tc_start].strip() if first_tc_start > 0 else None
|
||||
|
||||
return content, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
69
environments/tool_call_parsers/longcat_parser.py
Normal file
69
environments/tool_call_parsers/longcat_parser.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Longcat Flash Chat tool call parser.
|
||||
|
||||
Same as Hermes but uses <longcat_tool_call> tags instead of <tool_call>.
|
||||
Based on VLLM's LongcatFlashToolParser (extends Hermes2ProToolParser).
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
@register_parser("longcat")
|
||||
class LongcatToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Longcat Flash Chat tool calls.
|
||||
Identical logic to Hermes, just different tag names.
|
||||
"""
|
||||
|
||||
PATTERN = re.compile(
|
||||
r"<longcat_tool_call>\s*(.*?)\s*</longcat_tool_call>|<longcat_tool_call>\s*(.*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if "<longcat_tool_call>" not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
matches = self.PATTERN.findall(text)
|
||||
if not matches:
|
||||
return text, None
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for match in matches:
|
||||
raw_json = match[0] if match[0] else match[1]
|
||||
if not raw_json.strip():
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=json.dumps(
|
||||
tc_data.get("arguments", {}), ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
content = text[: text.find("<longcat_tool_call>")].strip()
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
130
environments/tool_call_parsers/mistral_parser.py
Normal file
130
environments/tool_call_parsers/mistral_parser.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Mistral tool call parser.
|
||||
|
||||
Supports two formats depending on tokenizer version:
|
||||
- Pre-v11: content[TOOL_CALLS] [{"name": ..., "arguments": {...}}, ...]
|
||||
- v11+: content[TOOL_CALLS]tool_name1{"arg": "val"}[TOOL_CALLS]tool_name2{"arg": "val"}
|
||||
|
||||
Based on VLLM's MistralToolParser.extract_tool_calls()
|
||||
The [TOOL_CALLS] token is the bot_token used by Mistral models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _generate_mistral_id() -> str:
|
||||
"""Mistral tool call IDs are 9-char alphanumeric strings."""
|
||||
import random
|
||||
import string
|
||||
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=9))
|
||||
|
||||
|
||||
@register_parser("mistral")
|
||||
class MistralToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Mistral-format tool calls.
|
||||
|
||||
Detects format by checking if the content after [TOOL_CALLS] starts with '['
|
||||
(pre-v11 JSON array) or with a tool name (v11+ format).
|
||||
"""
|
||||
|
||||
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
||||
BOT_TOKEN = "[TOOL_CALLS]"
|
||||
|
||||
# Fallback regex for pre-v11 format when JSON parsing fails
|
||||
TOOL_CALL_REGEX = re.compile(r"\[?\s*(\{.*?\})\s*\]?", re.DOTALL)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.BOT_TOKEN not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
parts = text.split(self.BOT_TOKEN)
|
||||
content = parts[0].strip()
|
||||
raw_tool_calls = parts[1:]
|
||||
|
||||
# Detect format: if the first raw part starts with '[', it's pre-v11
|
||||
first_raw = raw_tool_calls[0].strip() if raw_tool_calls else ""
|
||||
is_pre_v11 = first_raw.startswith("[") or first_raw.startswith("{")
|
||||
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
|
||||
if not is_pre_v11:
|
||||
# v11+ format: [TOOL_CALLS]tool_name{args}[TOOL_CALLS]tool_name2{args2}
|
||||
for raw in raw_tool_calls:
|
||||
raw = raw.strip()
|
||||
if not raw or "{" not in raw:
|
||||
continue
|
||||
|
||||
brace_idx = raw.find("{")
|
||||
tool_name = raw[:brace_idx].strip()
|
||||
args_str = raw[brace_idx:]
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(name=tool_name, arguments=args_str),
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Pre-v11 format: [TOOL_CALLS] [{"name": ..., "arguments": {...}}]
|
||||
try:
|
||||
parsed = json.loads(first_raw)
|
||||
if isinstance(parsed, dict):
|
||||
parsed = [parsed]
|
||||
|
||||
for tc in parsed:
|
||||
args = tc.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback regex extraction
|
||||
match = self.TOOL_CALL_REGEX.findall(first_raw)
|
||||
if match:
|
||||
for raw_json in match:
|
||||
try:
|
||||
tc = json.loads(raw_json)
|
||||
args = tc.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
return content if content else None, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
163
environments/tool_call_parsers/qwen3_coder_parser.py
Normal file
163
environments/tool_call_parsers/qwen3_coder_parser.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Qwen3-Coder tool call parser.
|
||||
|
||||
Format uses XML-style nested tags:
|
||||
<tool_call>
|
||||
<function=function_name>
|
||||
<parameter=param_name>value</parameter>
|
||||
<parameter=param_name2>value2</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
Parameters are extracted from <parameter=name>value</parameter> tags and
|
||||
type-converted using the schema if available, otherwise treated as strings.
|
||||
|
||||
Based on VLLM's Qwen3CoderToolParser.extract_tool_calls()
|
||||
"""
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
|
||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
||||
|
||||
|
||||
def _try_convert_value(value: str) -> Any:
|
||||
"""
|
||||
Try to convert a parameter value string to a native Python type.
|
||||
Handles null, numbers, booleans, JSON objects/arrays, and falls back to string.
|
||||
"""
|
||||
stripped = value.strip()
|
||||
|
||||
# Handle null
|
||||
if stripped.lower() == "null":
|
||||
return None
|
||||
|
||||
# Try JSON first (handles objects, arrays, strings, numbers, booleans)
|
||||
try:
|
||||
return json.loads(stripped)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Try Python literal eval (handles tuples, etc.)
|
||||
try:
|
||||
return ast.literal_eval(stripped)
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
|
||||
# Return as string
|
||||
return stripped
|
||||
|
||||
|
||||
@register_parser("qwen3_coder")
|
||||
class Qwen3CoderToolCallParser(ToolCallParser):
|
||||
"""
|
||||
Parser for Qwen3-Coder XML-format tool calls.
|
||||
|
||||
Uses nested XML tags: <tool_call><function=name><parameter=key>val</parameter></function></tool_call>
|
||||
"""
|
||||
|
||||
START_TOKEN = "<tool_call>"
|
||||
FUNCTION_PREFIX = "<function="
|
||||
|
||||
# Find complete tool_call blocks (or unclosed at end)
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
|
||||
)
|
||||
|
||||
# Find function blocks within a tool_call
|
||||
FUNCTION_REGEX = re.compile(
|
||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
|
||||
)
|
||||
|
||||
# Find parameter blocks within a function
|
||||
PARAMETER_REGEX = re.compile(
|
||||
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
def _parse_function_call(self, function_str: str) -> Optional[ChatCompletionMessageToolCall]:
|
||||
"""Parse a single <function=name>...</function> block into a ToolCall."""
|
||||
try:
|
||||
# Extract function name: everything before the first '>'
|
||||
gt_idx = function_str.index(">")
|
||||
func_name = function_str[:gt_idx].strip()
|
||||
params_str = function_str[gt_idx + 1:]
|
||||
|
||||
# Extract parameters
|
||||
param_dict: Dict[str, Any] = {}
|
||||
for match_text in self.PARAMETER_REGEX.findall(params_str):
|
||||
if ">" not in match_text:
|
||||
continue
|
||||
eq_idx = match_text.index(">")
|
||||
param_name = match_text[:eq_idx].strip()
|
||||
param_value = match_text[eq_idx + 1:]
|
||||
|
||||
# Clean up whitespace
|
||||
if param_value.startswith("\n"):
|
||||
param_value = param_value[1:]
|
||||
if param_value.endswith("\n"):
|
||||
param_value = param_value[:-1]
|
||||
|
||||
param_dict[param_name] = _try_convert_value(param_value)
|
||||
|
||||
return ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=func_name,
|
||||
arguments=json.dumps(param_dict, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.FUNCTION_PREFIX not in text:
|
||||
return text, None
|
||||
|
||||
try:
|
||||
# Find all tool_call blocks
|
||||
tc_matches = self.TOOL_CALL_REGEX.findall(text)
|
||||
raw_blocks = [m[0] if m[0] else m[1] for m in tc_matches]
|
||||
|
||||
# Fallback: if no tool_call tags, try the whole text
|
||||
if not raw_blocks:
|
||||
raw_blocks = [text]
|
||||
|
||||
# Find function blocks within each tool_call
|
||||
function_strs: List[str] = []
|
||||
for block in raw_blocks:
|
||||
func_matches = self.FUNCTION_REGEX.findall(block)
|
||||
function_strs.extend(m[0] if m[0] else m[1] for m in func_matches)
|
||||
|
||||
if not function_strs:
|
||||
return text, None
|
||||
|
||||
# Parse each function call
|
||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
||||
for func_str in function_strs:
|
||||
tc = self._parse_function_call(func_str)
|
||||
if tc is not None:
|
||||
tool_calls.append(tc)
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
# Content before tool calls
|
||||
first_tc = text.find(self.START_TOKEN)
|
||||
if first_tc < 0:
|
||||
first_tc = text.find(self.FUNCTION_PREFIX)
|
||||
content = text[:first_tc].strip() if first_tc > 0 else None
|
||||
|
||||
return content, tool_calls
|
||||
|
||||
except Exception:
|
||||
return text, None
|
||||
19
environments/tool_call_parsers/qwen_parser.py
Normal file
19
environments/tool_call_parsers/qwen_parser.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Qwen 2.5 tool call parser.
|
||||
|
||||
Uses the same <tool_call> format as Hermes.
|
||||
Registered as a separate parser name for clarity when using --tool-parser=qwen.
|
||||
"""
|
||||
|
||||
from environments.tool_call_parsers import register_parser
|
||||
from environments.tool_call_parsers.hermes_parser import HermesToolCallParser
|
||||
|
||||
|
||||
@register_parser("qwen")
|
||||
class QwenToolCallParser(HermesToolCallParser):
|
||||
"""
|
||||
Parser for Qwen 2.5 tool calls.
|
||||
Same <tool_call>{"name": ..., "arguments": ...}</tool_call> format as Hermes.
|
||||
"""
|
||||
|
||||
pass # Identical format -- inherits everything from Hermes
|
||||
289
environments/tool_context.py
Normal file
289
environments/tool_context.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
ToolContext -- Unrestricted Tool Access for Reward Functions
|
||||
|
||||
A per-rollout handle that gives reward/verification functions direct access to
|
||||
ALL hermes-agent tools, scoped to the rollout's task_id. The same task_id means
|
||||
the terminal/browser session is the SAME one the model used during its rollout --
|
||||
all state (files, processes, browser tabs) is preserved.
|
||||
|
||||
The verifier author decides which tools to use. Nothing is hardcoded or gated.
|
||||
|
||||
Example usage in a compute_reward():
|
||||
async def compute_reward(self, item, result, ctx):
|
||||
# Run tests in the model's terminal sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
if test["exit_code"] == 0:
|
||||
return 1.0
|
||||
|
||||
# Check if a file was created
|
||||
content = ctx.read_file("/workspace/solution.py")
|
||||
if content.get("content"):
|
||||
return 0.5
|
||||
|
||||
return 0.0
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
from model_tools import handle_function_call
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.browser_tool import cleanup_browser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
|
||||
"""
|
||||
Run a tool call in a thread pool executor so backends that use asyncio.run()
|
||||
internally (modal, docker) get a clean event loop.
|
||||
|
||||
If we're already in an async context, uses run_in_executor.
|
||||
If not (e.g., called from sync code), runs directly.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# We're in an async context -- need to run in thread
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
future = pool.submit(
|
||||
handle_function_call, tool_name, arguments, task_id
|
||||
)
|
||||
return future.result(timeout=300)
|
||||
except RuntimeError:
|
||||
# No running event loop -- safe to call directly
|
||||
return handle_function_call(tool_name, arguments, task_id)
|
||||
|
||||
|
||||
class ToolContext:
|
||||
"""
|
||||
Open-ended access to all hermes-agent tools for a specific rollout.
|
||||
|
||||
Passed to compute_reward() so verifiers can use any tool they need:
|
||||
terminal commands, file reads/writes, web searches, browser automation, etc.
|
||||
All calls share the rollout's task_id for session isolation.
|
||||
"""
|
||||
|
||||
def __init__(self, task_id: str):
|
||||
self.task_id = task_id
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Terminal tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def terminal(self, command: str, timeout: int = 180) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a command in the rollout's terminal session.
|
||||
|
||||
Args:
|
||||
command: Shell command to execute
|
||||
timeout: Command timeout in seconds
|
||||
|
||||
Returns:
|
||||
Dict with 'exit_code' (int) and 'output' (str)
|
||||
"""
|
||||
import os
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
|
||||
|
||||
# Run in thread pool so modal/docker backends' asyncio.run() doesn't deadlock
|
||||
result = _run_tool_in_thread(
|
||||
"terminal",
|
||||
{"command": command, "timeout": timeout},
|
||||
self.task_id,
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"exit_code": -1, "output": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# File tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def read_file(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read a file from the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
path: File path to read
|
||||
|
||||
Returns:
|
||||
Dict with file content or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"read_file", {"path": path}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def write_file(self, path: str, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Write a file in the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
path: File path to write
|
||||
content: Content to write
|
||||
|
||||
Returns:
|
||||
Dict with success status or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"write_file", {"path": path, "content": content}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def search(self, query: str, path: str = ".") -> Dict[str, Any]:
|
||||
"""
|
||||
Search for text in the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
path: Directory to search in
|
||||
|
||||
Returns:
|
||||
Dict with search results
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"search", {"query": query, "path": path}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Web tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def web_search(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Search the web.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
|
||||
Returns:
|
||||
Dict with search results
|
||||
"""
|
||||
result = handle_function_call("web_search", {"query": query})
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def web_extract(self, urls: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract content from URLs.
|
||||
|
||||
Args:
|
||||
urls: List of URLs to extract content from
|
||||
|
||||
Returns:
|
||||
Dict with extracted content
|
||||
"""
|
||||
result = handle_function_call("web_extract", {"urls": urls})
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Browser tools
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def browser_navigate(self, url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Navigate the rollout's browser session to a URL.
|
||||
|
||||
Args:
|
||||
url: URL to navigate to
|
||||
|
||||
Returns:
|
||||
Dict with page snapshot or error
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"browser_navigate", {"url": url}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def browser_snapshot(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Take a snapshot of the current browser page.
|
||||
|
||||
Returns:
|
||||
Dict with page content/accessibility snapshot
|
||||
"""
|
||||
result = handle_function_call(
|
||||
"browser_snapshot", {}, task_id=self.task_id
|
||||
)
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Generic tool access
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Call any hermes-agent tool by name.
|
||||
|
||||
This is the generic escape hatch -- if a tool doesn't have a convenience
|
||||
wrapper above, you can call it directly here.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool (e.g., "vision_analyze", "skills_list")
|
||||
arguments: Dict of arguments for the tool
|
||||
|
||||
Returns:
|
||||
Raw JSON string result from the tool
|
||||
"""
|
||||
return _run_tool_in_thread(tool_name, arguments, self.task_id)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Cleanup
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Release all resources (terminal VMs, browser sessions) for this rollout.
|
||||
|
||||
Called automatically by the base environment via try/finally after
|
||||
compute_reward() completes. You generally don't need to call this yourself.
|
||||
"""
|
||||
try:
|
||||
cleanup_vm(self.task_id)
|
||||
except Exception as e:
|
||||
logger.debug("VM cleanup for task %s: %s", self.task_id, e)
|
||||
|
||||
# Suppress browser_tool's noisy debug prints during cleanup.
|
||||
# The cleanup still runs (safe), it just doesn't spam the console.
|
||||
_prev_quiet = os.environ.get("HERMES_QUIET")
|
||||
os.environ["HERMES_QUIET"] = "1"
|
||||
try:
|
||||
cleanup_browser(self.task_id)
|
||||
except Exception as e:
|
||||
logger.debug("Browser cleanup for task %s: %s", self.task_id, e)
|
||||
finally:
|
||||
if _prev_quiet is None:
|
||||
os.environ.pop("HERMES_QUIET", None)
|
||||
else:
|
||||
os.environ["HERMES_QUIET"] = _prev_quiet
|
||||
@@ -1,70 +0,0 @@
|
||||
---
|
||||
name: example-skill
|
||||
description: An example skill demonstrating the skill file format and structure
|
||||
---
|
||||
|
||||
# Example Skill
|
||||
|
||||
This is an example skill file that demonstrates how to create skills for the Hermes Agent.
|
||||
|
||||
## Skill File Format
|
||||
|
||||
Skills are markdown files with YAML frontmatter at the top:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: your-skill-name
|
||||
description: A brief one-line description of what this skill does
|
||||
---
|
||||
```
|
||||
|
||||
The frontmatter fields:
|
||||
- **name**: The identifier used to reference this skill (lowercase, hyphens for spaces)
|
||||
- **description**: A brief description shown when listing skills (keep under 200 chars)
|
||||
|
||||
## Writing Effective Skills
|
||||
|
||||
### 1. Be Specific and Actionable
|
||||
|
||||
Good skills provide clear, actionable instructions:
|
||||
|
||||
```
|
||||
When reviewing code:
|
||||
1. Check for security vulnerabilities first
|
||||
2. Verify error handling is comprehensive
|
||||
3. Ensure tests cover edge cases
|
||||
```
|
||||
|
||||
### 2. Include Examples
|
||||
|
||||
Show concrete examples of what you want:
|
||||
|
||||
```python
|
||||
# Good: Descriptive variable names
|
||||
user_authentication_token = get_token()
|
||||
|
||||
# Bad: Cryptic abbreviations
|
||||
uat = gt()
|
||||
```
|
||||
|
||||
### 3. Define When to Use
|
||||
|
||||
Help the agent understand when this skill applies:
|
||||
|
||||
> Use this skill when: reviewing pull requests, auditing security, or checking code quality.
|
||||
|
||||
## Skill Categories
|
||||
|
||||
Consider organizing skills by purpose:
|
||||
|
||||
- **Conventions**: Coding standards, API patterns, naming rules
|
||||
- **Workflows**: Step-by-step processes for deployments, reviews, releases
|
||||
- **Knowledge**: Domain-specific information, system architecture, gotchas
|
||||
- **Templates**: Boilerplate for common tasks, response formats
|
||||
|
||||
## Tips
|
||||
|
||||
1. Keep the description concise - it's shown in the skills list
|
||||
2. Use headers to organize longer skills
|
||||
3. Include code examples where helpful
|
||||
4. Reference other skills if they're related
|
||||
@@ -481,7 +481,7 @@ class GatewayRunner:
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
||||
|
||||
agent = AIAgent(
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-sonnet-4"),
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-opus-4.6"),
|
||||
max_iterations=max_iterations,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=[toolset],
|
||||
|
||||
34
hermes
34
hermes
@@ -7,6 +7,40 @@ Usage: ./hermes [options]
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Fire (google/python-fire) does not support POSIX-style short flags like `-p`.
|
||||
We translate the most common shorthands to their long equivalents so wrapper
|
||||
scripts can reliably use:
|
||||
- `-p "..."` -> `--prompt "..."` (no TUI/banner; print result and exit)
|
||||
- `-q "..."` -> `--query "..."` (single-shot with banner UX)
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
def _rewrite_short_flags(argv: list[str]) -> list[str]:
|
||||
rewritten: list[str] = []
|
||||
i = 0
|
||||
while i < len(argv):
|
||||
arg = argv[i]
|
||||
if arg == "-p":
|
||||
rewritten.append("--prompt")
|
||||
if i + 1 < len(argv):
|
||||
rewritten.append(argv[i + 1])
|
||||
i += 2
|
||||
continue
|
||||
if arg == "-q":
|
||||
rewritten.append("--query")
|
||||
if i + 1 < len(argv):
|
||||
rewritten.append(argv[i + 1])
|
||||
i += 2
|
||||
continue
|
||||
rewritten.append(arg)
|
||||
i += 1
|
||||
return rewritten
|
||||
|
||||
sys.argv = [sys.argv[0]] + _rewrite_short_flags(sys.argv[1:])
|
||||
|
||||
from cli import main
|
||||
import fire
|
||||
|
||||
fire.Fire(main)
|
||||
|
||||
@@ -13,6 +13,7 @@ Requires-Dist: httpx
|
||||
Requires-Dist: rich
|
||||
Requires-Dist: tenacity
|
||||
Requires-Dist: pyyaml
|
||||
Requires-Dist: prompt_toolkit
|
||||
Requires-Dist: requests
|
||||
Requires-Dist: jinja2
|
||||
Requires-Dist: pydantic>=2.0
|
||||
@@ -27,15 +28,12 @@ Requires-Dist: boto3; extra == "modal"
|
||||
Provides-Extra: dev
|
||||
Requires-Dist: pytest; extra == "dev"
|
||||
Requires-Dist: pytest-asyncio; extra == "dev"
|
||||
Provides-Extra: messaging
|
||||
Requires-Dist: python-telegram-bot>=20.0; extra == "messaging"
|
||||
Requires-Dist: discord.py>=2.0; extra == "messaging"
|
||||
Provides-Extra: cron
|
||||
Requires-Dist: croniter; extra == "cron"
|
||||
Provides-Extra: all
|
||||
Requires-Dist: croniter; extra == "all"
|
||||
Requires-Dist: python-telegram-bot>=20.0; extra == "all"
|
||||
Requires-Dist: discord.py>=2.0; extra == "all"
|
||||
Provides-Extra: atropos
|
||||
Requires-Dist: atroposlib @ git+https://github.com/NousResearch/atropos.git ; extra == "atropos"
|
||||
Requires-Dist: aiohttp; extra == "atropos"
|
||||
Requires-Dist: fastapi; extra == "atropos"
|
||||
Requires-Dist: uvicorn; extra == "atropos"
|
||||
Requires-Dist: pyte; extra == "atropos"
|
||||
|
||||
# Hermes Agent
|
||||
|
||||
@@ -44,7 +42,6 @@ An AI agent with advanced tool-calling capabilities, featuring a flexible toolse
|
||||
## Features
|
||||
|
||||
- **Interactive CLI**: Beautiful terminal interface with animated feedback, personalities, and session management
|
||||
- **Messaging Gateway**: Connect to Telegram, Discord, and WhatsApp for conversational AI anywhere
|
||||
- **Web Tools**: Search, extract content, and crawl websites
|
||||
- **Terminal Tools**: Execute commands via local, Docker, Singularity, Modal, or SSH backends
|
||||
- **Browser Tools**: Automate web browsers to navigate, click, type, and extract content
|
||||
@@ -53,85 +50,13 @@ An AI agent with advanced tool-calling capabilities, featuring a flexible toolse
|
||||
- **Creative Tools**: Generate images from text prompts
|
||||
- **Skills Tools**: On-demand knowledge documents with progressive disclosure
|
||||
- **Toolsets System**: Organize tools into logical groups for different scenarios
|
||||
- **Scheduled Tasks**: Cron jobs for automated agent tasks with delivery to platforms
|
||||
- **Context Compression**: Automatic summarization when approaching context limits
|
||||
- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking
|
||||
- **Ephemeral System Prompts**: Guide model behavior without polluting training datasets
|
||||
|
||||
## Installation
|
||||
|
||||
### Quick Install (Recommended)
|
||||
|
||||
**Linux/macOS:**
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
**Windows (PowerShell):**
|
||||
```powershell
|
||||
irm https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.ps1 | iex
|
||||
```
|
||||
|
||||
This installer will:
|
||||
- Clone the repository to `~/.hermes-agent`
|
||||
- Create a virtual environment and install dependencies
|
||||
- Set up the `hermes` command in your PATH
|
||||
- Run an interactive setup wizard to configure API keys
|
||||
|
||||
### Manual Installation
|
||||
|
||||
If you prefer to install manually:
|
||||
## Quick Start (CLI)
|
||||
|
||||
```bash
|
||||
# Clone with submodules
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Run the setup script
|
||||
./setup-hermes.sh
|
||||
```
|
||||
|
||||
Or step-by-step:
|
||||
|
||||
```bash
|
||||
# Create and activate virtual environment
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # Windows: venv\Scripts\activate
|
||||
|
||||
# Install in editable mode with all extras
|
||||
pip install -e ".[all]"
|
||||
|
||||
# Or install dependencies manually
|
||||
pip install -r requirements.txt
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Copy and configure environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your API keys
|
||||
|
||||
# Run the setup wizard
|
||||
hermes setup
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
Once installed, the `hermes` command is your main entry point:
|
||||
|
||||
```bash
|
||||
hermes # Interactive chat (default)
|
||||
hermes chat # Same as above
|
||||
hermes chat -q "Hello" # Single query, then exit
|
||||
hermes setup # Configure API keys and settings
|
||||
hermes status # Show configuration status
|
||||
hermes doctor # Diagnose issues
|
||||
hermes gateway # Start messaging gateway (Telegram/Discord/WhatsApp)
|
||||
hermes cron daemon # Run cron job scheduler
|
||||
hermes version # Show version info
|
||||
```
|
||||
|
||||
**Legacy `./hermes` script:**
|
||||
```bash
|
||||
# The old CLI script still works:
|
||||
# After setup (see below), just run:
|
||||
./hermes
|
||||
|
||||
# Or with options:
|
||||
@@ -145,9 +70,35 @@ The CLI provides:
|
||||
- Customizable personalities (`/personality kawaii`, `/personality pirate`, etc.)
|
||||
- Persistent configuration via `cli-config.yaml`
|
||||
|
||||
## Configuration
|
||||
## Setup
|
||||
|
||||
### Environment Variables
|
||||
### 1. Clone the Repository
|
||||
```bash
|
||||
# Clone with submodules (recommended)
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Or if already cloned without submodules:
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
### 2. Install Dependencies
|
||||
```bash
|
||||
# Create and activate virtual environment (recommended)
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
|
||||
# Install Python packages
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install mini-swe-agent for terminal tools
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Install Node.js dependencies for browser tools (requires Node.js)
|
||||
npm install
|
||||
```
|
||||
|
||||
### 3. Configure Environment Variables
|
||||
```bash
|
||||
# Copy the example environment file
|
||||
cp .env.example .env
|
||||
@@ -376,169 +327,6 @@ logs/
|
||||
- **Trajectory Format**: Uses the same format as batch processing for consistency
|
||||
- **Git Ignored**: `logs/` is in `.gitignore` so logs aren't committed
|
||||
|
||||
## Context Compression
|
||||
|
||||
Long conversations can exceed the model's context limit. Hermes Agent automatically compresses context when approaching the limit:
|
||||
|
||||
**How it works:**
|
||||
1. Tracks actual token usage from API responses (`usage.prompt_tokens`)
|
||||
2. When tokens reach 85% of model's context limit, triggers compression
|
||||
3. Protects first 3 turns (system prompt, initial request, first response)
|
||||
4. Protects last 4 turns (recent context is most relevant)
|
||||
5. Summarizes middle turns using a fast/cheap model (Gemini Flash)
|
||||
6. Inserts summary as a user message, conversation continues seamlessly
|
||||
|
||||
**Configuration (`cli-config.yaml`):**
|
||||
```yaml
|
||||
compression:
|
||||
enabled: true # Enable auto-compression (default)
|
||||
threshold: 0.85 # Compress at 85% of context limit
|
||||
summary_model: "google/gemini-2.0-flash-001"
|
||||
```
|
||||
|
||||
**Or via environment variables:**
|
||||
```bash
|
||||
CONTEXT_COMPRESSION_ENABLED=true
|
||||
CONTEXT_COMPRESSION_THRESHOLD=0.85
|
||||
CONTEXT_COMPRESSION_MODEL=google/gemini-2.0-flash-001
|
||||
```
|
||||
|
||||
**When compression triggers, you'll see:**
|
||||
```
|
||||
📦 Context compression triggered (170,000 tokens ≥ 170,000 threshold)
|
||||
📊 Model context limit: 200,000 tokens (85% = 170,000)
|
||||
🗜️ Summarizing turns 4-15 (12 turns)
|
||||
✅ Compressed: 20 → 9 messages (~45,000 tokens saved)
|
||||
```
|
||||
|
||||
## Scheduled Tasks (Cron Jobs)
|
||||
|
||||
Hermes Agent can schedule automated tasks to run in the future - either one-time reminders or recurring jobs.
|
||||
|
||||
### CLI Commands
|
||||
|
||||
```bash
|
||||
# List scheduled jobs
|
||||
/cron
|
||||
|
||||
# Add a one-shot reminder (runs once in 30 minutes)
|
||||
/cron add 30m Remind me to check the build status
|
||||
|
||||
# Add a recurring job (every 2 hours)
|
||||
/cron add "every 2h" Check server status at 192.168.1.100 and report any issues
|
||||
|
||||
# Add a cron expression (daily at 9am)
|
||||
/cron add "0 9 * * *" Generate a morning briefing summarizing GitHub notifications
|
||||
|
||||
# Remove a job
|
||||
/cron remove abc123def456
|
||||
```
|
||||
|
||||
### Agent Self-Scheduling
|
||||
|
||||
The agent can also schedule its own follow-up tasks using tools:
|
||||
|
||||
```python
|
||||
# Available when using hermes-cli toolset (default for CLI)
|
||||
schedule_cronjob(prompt="...", schedule="30m", repeat=1) # One-shot
|
||||
schedule_cronjob(prompt="...", schedule="every 2h") # Recurring
|
||||
list_cronjobs() # View all jobs
|
||||
remove_cronjob(job_id="...") # Cancel a job
|
||||
```
|
||||
|
||||
**⚠️ Important:** Cronjobs run in **isolated sessions with NO prior context**. The prompt must be completely self-contained with all necessary information (file paths, URLs, server addresses, etc.). The future agent will not remember anything from the current conversation.
|
||||
|
||||
### Schedule Formats
|
||||
|
||||
| Format | Example | Description |
|
||||
|--------|---------|-------------|
|
||||
| Duration | `30m`, `2h`, `1d` | One-shot delay from now |
|
||||
| Interval | `every 30m`, `every 2h` | Recurring at fixed intervals |
|
||||
| Cron | `0 9 * * *` | Cron expression (requires `croniter`) |
|
||||
| Timestamp | `2026-02-03T14:00` | One-shot at specific time |
|
||||
|
||||
### Repeat Options
|
||||
|
||||
| repeat | Behavior |
|
||||
|--------|----------|
|
||||
| (omitted) | One-shot schedules run once; intervals/cron run forever |
|
||||
| `1` | Run once then auto-delete |
|
||||
| `N` | Run N times then auto-delete |
|
||||
|
||||
### Running the Cron Daemon
|
||||
|
||||
Jobs are stored in `~/.hermes/cron/jobs.json` and executed by a scheduler:
|
||||
|
||||
```bash
|
||||
# Option 1: Built-in daemon (checks every 60 seconds)
|
||||
python cli.py --cron-daemon
|
||||
|
||||
# Option 2: System cron integration (run once per minute)
|
||||
# Add to crontab: crontab -e
|
||||
*/1 * * * * cd ~/hermes-agent && python cli.py --cron-tick-once >> ~/.hermes/cron/cron.log 2>&1
|
||||
```
|
||||
|
||||
### Job Output
|
||||
|
||||
Job outputs are saved to `~/.hermes/cron/output/{job_id}/{timestamp}.md` for review.
|
||||
|
||||
## Messaging Gateway (Telegram, Discord, WhatsApp)
|
||||
|
||||
Connect Hermes Agent to messaging platforms so you can chat from anywhere.
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Add your bot token to .env
|
||||
echo 'TELEGRAM_BOT_TOKEN="your_token"' >> .env
|
||||
|
||||
# 2. Test the gateway (foreground)
|
||||
./scripts/hermes-gateway run
|
||||
|
||||
# 3. Install as a background service
|
||||
./scripts/hermes-gateway install
|
||||
|
||||
# 4. Manage the service
|
||||
./scripts/hermes-gateway start # Start
|
||||
./scripts/hermes-gateway stop # Stop
|
||||
./scripts/hermes-gateway status # Check status
|
||||
```
|
||||
|
||||
### Supported Platforms
|
||||
|
||||
| Platform | Setup | Toolset |
|
||||
|----------|-------|---------|
|
||||
| Telegram | Bot via @BotFather | `hermes-telegram` |
|
||||
| Discord | Bot via Developer Portal | `hermes-discord` |
|
||||
| WhatsApp | Node.js bridge | `hermes-whatsapp` |
|
||||
|
||||
### Session Management
|
||||
|
||||
- Sessions persist across messages (agent remembers context)
|
||||
- Reset policies: daily (4am), idle (2 hours), or both
|
||||
- Manual reset: send `/new` or `/reset`
|
||||
|
||||
### Cron Job Delivery
|
||||
|
||||
Schedule tasks that deliver to specific platforms:
|
||||
|
||||
```python
|
||||
schedule_cronjob(
|
||||
prompt="Check server status...",
|
||||
schedule="every 1h",
|
||||
deliver="telegram" # or "origin", "discord", etc.
|
||||
)
|
||||
```
|
||||
|
||||
### CLI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/platforms` | Show gateway configuration status |
|
||||
| `--gateway` | Start the gateway (CLI flag) |
|
||||
|
||||
See [docs/messaging.md](docs/messaging.md) for full setup instructions.
|
||||
|
||||
## Interactive CLI
|
||||
|
||||
The CLI provides a rich interactive experience for working with the agent.
|
||||
@@ -571,8 +359,6 @@ The CLI provides a rich interactive experience for working with the agent.
|
||||
| `/history` | Show conversation history |
|
||||
| `/save` | Save current conversation to file |
|
||||
| `/config` | Show current configuration |
|
||||
| `/cron` | Manage scheduled tasks (list, add, remove) |
|
||||
| `/platforms` | Show gateway/messaging platform status |
|
||||
| `/quit` | Exit the CLI |
|
||||
|
||||
### Configuration
|
||||
@@ -830,11 +616,6 @@ All environment variables can be configured in the `.env` file (copy from `.env.
|
||||
- `TERMINAL_SSH_PORT`: SSH port (default: `22`)
|
||||
- `TERMINAL_SSH_KEY`: Path to SSH private key (optional, uses ssh-agent if not set)
|
||||
|
||||
**Context Compression (auto-shrinks long conversations):**
|
||||
- `CONTEXT_COMPRESSION_ENABLED`: Enable auto-compression (default: `true`)
|
||||
- `CONTEXT_COMPRESSION_THRESHOLD`: Compress at this % of context limit (default: `0.85`)
|
||||
- `CONTEXT_COMPRESSION_MODEL`: Model for generating summaries (default: `google/gemini-2.0-flash-001`)
|
||||
|
||||
**Browser Tool Configuration (agent-browser + Browserbase):**
|
||||
- `BROWSERBASE_API_KEY`: Browserbase API key for cloud browser execution
|
||||
- `BROWSERBASE_PROJECT_ID`: Browserbase project ID
|
||||
@@ -866,3 +647,13 @@ All environment variables can be configured in the `.env` file (copy from `.env.
|
||||
| `skills/` | On-demand knowledge documents |
|
||||
| `docs/` | Documentation |
|
||||
| `configs/` | Example batch run scripts |
|
||||
|
||||
# Atropos Integrations & RL Training
|
||||
|
||||
## Nomad Setup
|
||||
Follow this: https://developer.hashicorp.com/nomad/docs/deploy
|
||||
|
||||
## Atropos dependencies
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -e '.[atropos]'
|
||||
|
||||
@@ -1,43 +1,66 @@
|
||||
README.md
|
||||
atropos_compatible_agent.py
|
||||
batch_runner.py
|
||||
cli.py
|
||||
local_server.py
|
||||
model_tools.py
|
||||
pyproject.toml
|
||||
run_agent.py
|
||||
toolset_distributions.py
|
||||
toolsets.py
|
||||
trajectory_compressor.py
|
||||
cron/__init__.py
|
||||
cron/jobs.py
|
||||
cron/scheduler.py
|
||||
gateway/__init__.py
|
||||
gateway/config.py
|
||||
gateway/delivery.py
|
||||
gateway/run.py
|
||||
gateway/session.py
|
||||
atropos/__init__.py
|
||||
atropos/sandbox_server.py
|
||||
atropos/agent/__init__.py
|
||||
atropos/agent/atropos_agent.py
|
||||
atropos/api/__init__.py
|
||||
atropos/api/tool_executor_server.py
|
||||
atropos/api/tool_server.py
|
||||
atropos/backends/__init__.py
|
||||
atropos/backends/base.py
|
||||
atropos/backends/modal_backend.py
|
||||
atropos/backends/nomad_backend.py
|
||||
atropos/envs/__init__.py
|
||||
atropos/envs/agent_env.py
|
||||
atropos/envs/hermes_compat_test_env.py
|
||||
atropos/envs/sandbox_terminal_smoke_env.py
|
||||
atropos/envs/swe_smith_oracle_env.py
|
||||
atropos/envs/test_env.py
|
||||
atropos/envs/toolserver_smoke_env.py
|
||||
atropos/nomad/__init__.py
|
||||
atropos/nomad/client.py
|
||||
atropos/slots/__init__.py
|
||||
atropos/slots/executor.py
|
||||
atropos/slots/pool.py
|
||||
atropos/slots/slot.py
|
||||
atropos/terminal/__init__.py
|
||||
atropos/terminal/asciinema_stream.py
|
||||
atropos/tools/__init__.py
|
||||
atropos/tools/base.py
|
||||
atropos/tools/build_registry.py
|
||||
atropos/tools/hermes_external_tools.py
|
||||
atropos/tools/sandbox_stubs.py
|
||||
atropos/tools/terminal_stateful_tool.py
|
||||
atropos/tools/tmux_tool.py
|
||||
atropos/tools/tool_executor.py
|
||||
atropos/tools/toolset_resolver.py
|
||||
hermes_agent.egg-info/PKG-INFO
|
||||
hermes_agent.egg-info/SOURCES.txt
|
||||
hermes_agent.egg-info/dependency_links.txt
|
||||
hermes_agent.egg-info/entry_points.txt
|
||||
hermes_agent.egg-info/requires.txt
|
||||
hermes_agent.egg-info/top_level.txt
|
||||
hermes_cli/__init__.py
|
||||
hermes_cli/cron.py
|
||||
hermes_cli/doctor.py
|
||||
hermes_cli/gateway.py
|
||||
hermes_cli/main.py
|
||||
hermes_cli/setup.py
|
||||
hermes_cli/status.py
|
||||
tests/test_batch_runner.py
|
||||
tests/test_checkpoint_resumption.py
|
||||
tests/test_modal_integration.py
|
||||
tests/test_modal_stress.py
|
||||
tests/test_modal_terminal.py
|
||||
tests/test_nous_api_limits.py
|
||||
tests/test_nous_api_pattern.py
|
||||
tests/test_temperature_fix.py
|
||||
tests/test_tool_call_parsing.py
|
||||
tests/test_web_tools.py
|
||||
tools/__init__.py
|
||||
tools/browser_tool.py
|
||||
tools/cronjob_tools.py
|
||||
tools/image_generation_tool.py
|
||||
tools/mixture_of_agents_tool.py
|
||||
tools/skills_tool.py
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
[console_scripts]
|
||||
hermes = hermes_cli.main:main
|
||||
hermes-agent = run_agent:main
|
||||
hermes-atropos-sandbox-smoke = atropos.envs.sandbox_terminal_smoke_env:SandboxTerminalSmokeEnv.cli
|
||||
hermes-atropos-toolserver-smoke = atropos.envs.toolserver_smoke_env:ToolServerSmokeEnv.cli
|
||||
|
||||
@@ -5,6 +5,7 @@ httpx
|
||||
rich
|
||||
tenacity
|
||||
pyyaml
|
||||
prompt_toolkit
|
||||
requests
|
||||
jinja2
|
||||
pydantic>=2.0
|
||||
@@ -14,22 +15,17 @@ litellm>=1.75.5
|
||||
typer
|
||||
platformdirs
|
||||
|
||||
[all]
|
||||
croniter
|
||||
python-telegram-bot>=20.0
|
||||
discord.py>=2.0
|
||||
|
||||
[cron]
|
||||
croniter
|
||||
[atropos]
|
||||
atroposlib @ git+https://github.com/NousResearch/atropos.git
|
||||
aiohttp
|
||||
fastapi
|
||||
uvicorn
|
||||
pyte
|
||||
|
||||
[dev]
|
||||
pytest
|
||||
pytest-asyncio
|
||||
|
||||
[messaging]
|
||||
python-telegram-bot>=20.0
|
||||
discord.py>=2.0
|
||||
|
||||
[modal]
|
||||
modal
|
||||
boto3
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
atropos
|
||||
atropos_compatible_agent
|
||||
batch_runner
|
||||
cli
|
||||
cron
|
||||
gateway
|
||||
hermes_cli
|
||||
local_server
|
||||
model_tools
|
||||
run_agent
|
||||
tools
|
||||
|
||||
@@ -71,7 +71,7 @@ def ensure_hermes_home():
|
||||
# =============================================================================
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"model": "anthropic/claude-sonnet-4.5",
|
||||
"model": "anthropic/claude-opus-4.6",
|
||||
"toolsets": ["hermes-cli"],
|
||||
"max_turns": 100,
|
||||
|
||||
@@ -91,7 +91,7 @@ DEFAULT_CONFIG = {
|
||||
"compression": {
|
||||
"enabled": True,
|
||||
"threshold": 0.85,
|
||||
"summary_model": "google/gemini-2.0-flash-001",
|
||||
"summary_model": "google/gemini-3-flash-preview",
|
||||
},
|
||||
|
||||
"display": {
|
||||
@@ -555,7 +555,7 @@ def show_config():
|
||||
print(f" Enabled: {'yes' if enabled else 'no'}")
|
||||
if enabled:
|
||||
print(f" Threshold: {compression.get('threshold', 0.85) * 100:.0f}%")
|
||||
print(f" Model: {compression.get('summary_model', 'google/gemini-2.0-flash-001')}")
|
||||
print(f" Model: {compression.get('summary_model', 'google/gemini-3-flash-preview')}")
|
||||
|
||||
# Messaging
|
||||
print()
|
||||
|
||||
@@ -58,8 +58,11 @@ def run_doctor(args):
|
||||
print(color("◆ Python Environment", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
py_version = sys.version_info
|
||||
if py_version >= (3, 10):
|
||||
if py_version >= (3, 11):
|
||||
check_ok(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}")
|
||||
elif py_version >= (3, 10):
|
||||
check_ok(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}")
|
||||
check_warn("Python 3.11+ recommended for RL Training tools (tinker requires >= 3.11)")
|
||||
elif py_version >= (3, 8):
|
||||
check_warn(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}", "(3.10+ recommended)")
|
||||
else:
|
||||
@@ -100,7 +103,7 @@ def run_doctor(args):
|
||||
check_ok(name)
|
||||
except ImportError:
|
||||
check_fail(name, "(missing)")
|
||||
issues.append(f"Install {name}: pip install {module}")
|
||||
issues.append(f"Install {name}: uv pip install {module}")
|
||||
|
||||
for module, name in optional_packages:
|
||||
try:
|
||||
@@ -263,6 +266,39 @@ def run_doctor(args):
|
||||
except Exception as e:
|
||||
check_warn("Anthropic API", f"({e})")
|
||||
|
||||
# =========================================================================
|
||||
# Check: Submodules
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Submodules", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
# mini-swe-agent (terminal tool backend)
|
||||
mini_swe_dir = PROJECT_ROOT / "mini-swe-agent"
|
||||
if mini_swe_dir.exists() and (mini_swe_dir / "pyproject.toml").exists():
|
||||
try:
|
||||
__import__("minisweagent")
|
||||
check_ok("mini-swe-agent", "(terminal backend)")
|
||||
except ImportError:
|
||||
check_warn("mini-swe-agent found but not installed", "(run: uv pip install -e ./mini-swe-agent)")
|
||||
issues.append("Install mini-swe-agent: uv pip install -e ./mini-swe-agent")
|
||||
else:
|
||||
check_warn("mini-swe-agent not found", "(run: git submodule update --init --recursive)")
|
||||
|
||||
# tinker-atropos (RL training backend)
|
||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
if py_version >= (3, 11):
|
||||
try:
|
||||
__import__("tinker_atropos")
|
||||
check_ok("tinker-atropos", "(RL training backend)")
|
||||
except ImportError:
|
||||
check_warn("tinker-atropos found but not installed", "(run: uv pip install -e ./tinker-atropos)")
|
||||
issues.append("Install tinker-atropos: uv pip install -e ./tinker-atropos")
|
||||
else:
|
||||
check_warn("tinker-atropos requires Python 3.11+", f"(current: {py_version.major}.{py_version.minor})")
|
||||
else:
|
||||
check_warn("tinker-atropos not found", "(run: git submodule update --init --recursive)")
|
||||
|
||||
# =========================================================================
|
||||
# Check: Tool Availability
|
||||
# =========================================================================
|
||||
|
||||
@@ -119,6 +119,7 @@ def cmd_uninstall(args):
|
||||
def cmd_update(args):
|
||||
"""Update Hermes Agent to the latest version."""
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
print("🦋 Updating Hermes Agent...")
|
||||
print()
|
||||
@@ -163,13 +164,21 @@ def cmd_update(args):
|
||||
print("→ Pulling updates...")
|
||||
subprocess.run(["git", "pull", "origin", branch], cwd=PROJECT_ROOT, check=True)
|
||||
|
||||
# Reinstall Python dependencies
|
||||
# Reinstall Python dependencies (prefer uv for speed, fall back to pip)
|
||||
print("→ Updating Python dependencies...")
|
||||
venv_pip = PROJECT_ROOT / "venv" / "bin" / "pip"
|
||||
if venv_pip.exists():
|
||||
subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin:
|
||||
subprocess.run(
|
||||
[uv_bin, "pip", "install", "-e", ".", "--quiet"],
|
||||
cwd=PROJECT_ROOT, check=True,
|
||||
env={**os.environ, "VIRTUAL_ENV": str(PROJECT_ROOT / "venv")}
|
||||
)
|
||||
else:
|
||||
subprocess.run(["pip", "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
venv_pip = PROJECT_ROOT / "venv" / "bin" / "pip"
|
||||
if venv_pip.exists():
|
||||
subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
else:
|
||||
subprocess.run(["pip", "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
|
||||
# Check for Node.js deps
|
||||
if (PROJECT_ROOT / "package.json").exists():
|
||||
|
||||
@@ -501,11 +501,12 @@ def run_setup_wizard(args):
|
||||
# =========================================================================
|
||||
print_header("Default Model")
|
||||
|
||||
current_model = config.get('model', 'anthropic/claude-sonnet-4')
|
||||
current_model = config.get('model', 'anthropic/claude-opus-4.6')
|
||||
print_info(f"Current: {current_model}")
|
||||
|
||||
model_choices = [
|
||||
"anthropic/claude-sonnet-4.5 (recommended)",
|
||||
"anthropic/claude-opus-4.6 (recommended)",
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"anthropic/claude-opus-4.5",
|
||||
"openai/gpt-5.2",
|
||||
"openai/gpt-5.2-codex",
|
||||
@@ -518,27 +519,31 @@ def run_setup_wizard(args):
|
||||
f"Keep current ({current_model})"
|
||||
]
|
||||
|
||||
model_idx = prompt_choice("Select default model:", model_choices, 10) # Default: keep current
|
||||
model_idx = prompt_choice("Select default model:", model_choices, 11) # Default: keep current
|
||||
|
||||
model_map = {
|
||||
0: "anthropic/claude-sonnet-4.5",
|
||||
1: "anthropic/claude-opus-4.5",
|
||||
2: "openai/gpt-5.2",
|
||||
3: "openai/gpt-5.2-codex",
|
||||
4: "google/gemini-3-pro-preview",
|
||||
5: "google/gemini-3-flash-preview",
|
||||
6: "z-ai/glm-4.7",
|
||||
7: "moonshotai/kimi-k2.5",
|
||||
8: "minimax/minimax-m2.1",
|
||||
0: "anthropic/claude-opus-4.6",
|
||||
1: "anthropic/claude-sonnet-4.5",
|
||||
2: "anthropic/claude-opus-4.5",
|
||||
3: "openai/gpt-5.2",
|
||||
4: "openai/gpt-5.2-codex",
|
||||
5: "google/gemini-3-pro-preview",
|
||||
6: "google/gemini-3-flash-preview",
|
||||
7: "z-ai/glm-4.7",
|
||||
8: "moonshotai/kimi-k2.5",
|
||||
9: "minimax/minimax-m2.1",
|
||||
}
|
||||
|
||||
if model_idx in model_map:
|
||||
config['model'] = model_map[model_idx]
|
||||
elif model_idx == 9: # Custom
|
||||
custom = prompt("Enter model name (e.g., anthropic/claude-sonnet-4.5)")
|
||||
# Also update LLM_MODEL in .env so it stays in sync (cli.py reads .env first)
|
||||
save_env_value("LLM_MODEL", model_map[model_idx])
|
||||
elif model_idx == 10: # Custom
|
||||
custom = prompt("Enter model name (e.g., anthropic/claude-opus-4.6)")
|
||||
if custom:
|
||||
config['model'] = custom
|
||||
# else: Keep current (model_idx == 10)
|
||||
save_env_value("LLM_MODEL", custom)
|
||||
# else: Keep current (model_idx == 11)
|
||||
|
||||
# =========================================================================
|
||||
# Step 4: Terminal Backend
|
||||
@@ -652,6 +657,32 @@ def run_setup_wizard(args):
|
||||
print_info("Modal Cloud Configuration:")
|
||||
print_info("Get credentials at: https://modal.com/settings")
|
||||
|
||||
# Check if swe-rex[modal] is installed, install if missing
|
||||
try:
|
||||
from swerex.deployment.modal import ModalDeployment
|
||||
print_info("swe-rex[modal] package: installed ✓")
|
||||
except ImportError:
|
||||
print_info("Installing required package: swe-rex[modal]...")
|
||||
import subprocess
|
||||
import shutil
|
||||
# Prefer uv for speed, fall back to pip
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin:
|
||||
result = subprocess.run(
|
||||
[uv_bin, "pip", "install", "swe-rex[modal]>=1.4.0"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "swe-rex[modal]>=1.4.0"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
print_success("swe-rex[modal] installed (includes modal + boto3)")
|
||||
else:
|
||||
print_warning("Failed to install swe-rex[modal] — install manually:")
|
||||
print_info(' uv pip install "swe-rex[modal]>=1.4.0"')
|
||||
|
||||
# Always show current status and allow reconfiguration
|
||||
current_token = get_env_value('MODAL_TOKEN_ID')
|
||||
if current_token:
|
||||
@@ -917,6 +948,24 @@ def run_setup_wizard(args):
|
||||
save_env_value("BROWSERBASE_API_KEY", api_key)
|
||||
if project_id:
|
||||
save_env_value("BROWSERBASE_PROJECT_ID", project_id)
|
||||
|
||||
# Check if Node.js dependencies are installed (required for browser tools)
|
||||
import shutil
|
||||
node_modules = PROJECT_ROOT / "node_modules" / "agent-browser"
|
||||
if not node_modules.exists() and shutil.which("npm"):
|
||||
print_info(" Installing Node.js dependencies for browser tools...")
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--silent"],
|
||||
capture_output=True, text=True, cwd=str(PROJECT_ROOT)
|
||||
)
|
||||
if result.returncode == 0:
|
||||
print_success(" Node.js dependencies installed")
|
||||
else:
|
||||
print_warning(" npm install failed — run manually: cd ~/.hermes/hermes-agent && npm install")
|
||||
elif not node_modules.exists():
|
||||
print_warning(" Node.js not found — browser tools require: npm install (in the hermes-agent directory)")
|
||||
|
||||
print_success(" Configured ✓")
|
||||
print()
|
||||
|
||||
@@ -950,6 +999,11 @@ def run_setup_wizard(args):
|
||||
tinker_configured = get_env_value('TINKER_API_KEY')
|
||||
wandb_configured = get_env_value('WANDB_API_KEY')
|
||||
|
||||
# Check Python version requirement upfront
|
||||
rl_python_ok = sys.version_info >= (3, 11)
|
||||
if not rl_python_ok:
|
||||
print_warning(f" Requires Python 3.11+ (current: {sys.version_info.major}.{sys.version_info.minor})")
|
||||
|
||||
if tinker_configured and wandb_configured:
|
||||
print_success(" Status: Configured ✓")
|
||||
if prompt_yes_no(" Update RL training credentials?", False):
|
||||
@@ -969,18 +1023,55 @@ def run_setup_wizard(args):
|
||||
print_warning(" Status: Not configured (tools will be disabled)")
|
||||
|
||||
if prompt_yes_no(" Set up RL Training?", False):
|
||||
print_info(" Get Tinker key at: https://tinker-console.thinkingmachines.ai/keys")
|
||||
print_info(" Get WandB key at: https://wandb.ai/authorize")
|
||||
api_key = prompt(" Tinker API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("TINKER_API_KEY", api_key)
|
||||
wandb_key = prompt(" WandB API key", password=True)
|
||||
if wandb_key:
|
||||
save_env_value("WANDB_API_KEY", wandb_key)
|
||||
if api_key and wandb_key:
|
||||
print_success(" Configured ✓")
|
||||
# Check Python version before proceeding
|
||||
if not rl_python_ok:
|
||||
print_error(f" Python 3.11+ required (current: {sys.version_info.major}.{sys.version_info.minor})")
|
||||
print_info(" Upgrade Python and reinstall to enable RL training tools")
|
||||
else:
|
||||
print_warning(" Partially configured (both keys required)")
|
||||
print_info(" Get Tinker key at: https://tinker-console.thinkingmachines.ai/keys")
|
||||
print_info(" Get WandB key at: https://wandb.ai/authorize")
|
||||
api_key = prompt(" Tinker API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("TINKER_API_KEY", api_key)
|
||||
wandb_key = prompt(" WandB API key", password=True)
|
||||
if wandb_key:
|
||||
save_env_value("WANDB_API_KEY", wandb_key)
|
||||
|
||||
# Check if tinker-atropos submodule is installed
|
||||
try:
|
||||
__import__("tinker_atropos")
|
||||
except ImportError:
|
||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
print_info(" Installing tinker-atropos submodule...")
|
||||
import subprocess
|
||||
import shutil
|
||||
# Prefer uv for speed, fall back to pip
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin:
|
||||
result = subprocess.run(
|
||||
[uv_bin, "pip", "install", "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
print_success(" tinker-atropos installed")
|
||||
else:
|
||||
print_warning(" tinker-atropos install failed — run manually:")
|
||||
print_info(' uv pip install -e "./tinker-atropos"')
|
||||
else:
|
||||
print_warning(" tinker-atropos submodule not found — run:")
|
||||
print_info(" git submodule update --init --recursive")
|
||||
print_info(' uv pip install -e "./tinker-atropos"')
|
||||
|
||||
if api_key and wandb_key:
|
||||
print_success(" Configured ✓")
|
||||
else:
|
||||
print_warning(" Partially configured (both keys required)")
|
||||
|
||||
# =========================================================================
|
||||
# Save config and show summary
|
||||
|
||||
353
local_server.py
Normal file
353
local_server.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Local OpenAI-compatible server implementation for Hermes-Agent (Atropos integration).
|
||||
|
||||
Extends the Atropos APIServer to work with local OpenAI-compatible APIs (e.g. vLLM, SGLang),
|
||||
providing tokens_and_logprobs_completion support via client-side tokenization.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ReasoningConfig,
|
||||
)
|
||||
|
||||
|
||||
class LocalServer(APIServer):
|
||||
"""
|
||||
OpenAI-compatible local server with tokens_and_logprobs support.
|
||||
|
||||
Uses an OpenAI-compatible API (typically at a /v1 endpoint) and handles
|
||||
token extraction via client-side tokenization.
|
||||
|
||||
Note: Many local servers don't return per-token logprobs in the standard API,
|
||||
so this implementation uses placeholder logprobs (0.0) for PoC purposes.
|
||||
For production training, use vLLM/SGLang servers that return real logprobs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: APIServerConfig,
|
||||
tokenizer: Optional[Any] = None,
|
||||
tokenizer_name: str = "gpt2",
|
||||
reasoning_config: Optional[ReasoningConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the local server.
|
||||
|
||||
Args:
|
||||
config: Server configuration
|
||||
tokenizer: Pre-initialized tokenizer (optional)
|
||||
tokenizer_name: Name of tokenizer to load if tokenizer not provided
|
||||
reasoning_config: Optional reasoning configuration
|
||||
"""
|
||||
# Build the OpenAI client pointing to the server's /v1 endpoint
|
||||
base_url = config.base_url
|
||||
if base_url and not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url.rstrip('/')}/v1"
|
||||
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key or "local", # Local servers often ignore auth
|
||||
base_url=base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
if tokenizer is not None:
|
||||
self.tokenizer = tokenizer
|
||||
else:
|
||||
try:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"Missing optional dependency 'transformers'. Pass a tokenizer instance to LocalServer, "
|
||||
"or install transformers to enable `tokenizer_name` auto-loading."
|
||||
) from exc
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
|
||||
# Add a simple chat template if the tokenizer doesn't have one
|
||||
# This is needed for ManagedServer's chat_completion to work
|
||||
if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None:
|
||||
# Simple ChatML-style template
|
||||
self.tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
)
|
||||
|
||||
super().__init__(config, reasoning_config=reasoning_config)
|
||||
# Local servers are treated as always-healthy unless a status task is enabled.
|
||||
self.server_healthy = True
|
||||
|
||||
@classmethod
|
||||
def from_env(
|
||||
cls,
|
||||
base_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
tokenizer_name: str = "gpt2",
|
||||
**kwargs,
|
||||
) -> "LocalServer":
|
||||
"""
|
||||
Create a LocalServer from environment variables (or explicit overrides).
|
||||
|
||||
Env vars (checked in order):
|
||||
- base URL: ATROPOS_SERVER_BASE_URL, OPENAI_BASE_URL, LOCAL_LLM_BASE_URL, LLM_BASE_URL
|
||||
- model: ATROPOS_SERVER_MODEL, LLM_MODEL, LOCAL_LLM_MODEL
|
||||
- api key: ATROPOS_SERVER_API_KEY, OPENAI_API_KEY, LOCAL_LLM_API_KEY, LLM_API_KEY
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
base_url = (
|
||||
base_url
|
||||
or os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LOCAL_LLM_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://localhost:11434"
|
||||
)
|
||||
model = (
|
||||
model
|
||||
or os.getenv("ATROPOS_SERVER_MODEL")
|
||||
or os.getenv("LLM_MODEL")
|
||||
or os.getenv("LOCAL_LLM_MODEL")
|
||||
or "hermes3:8b"
|
||||
)
|
||||
api_key = (
|
||||
api_key
|
||||
or os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("LOCAL_LLM_API_KEY")
|
||||
or os.getenv("LLM_API_KEY")
|
||||
)
|
||||
|
||||
config = APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key or "local",
|
||||
timeout=kwargs.get("timeout", 120),
|
||||
num_max_requests_at_once=kwargs.get("num_max_requests_at_once", 4),
|
||||
num_requests_for_eval=kwargs.get("num_requests_for_eval", 4),
|
||||
health_check=False, # Local dev servers often lack /health
|
||||
)
|
||||
|
||||
return cls(config, tokenizer_name=tokenizer_name)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
"""
|
||||
Check if the server is healthy.
|
||||
|
||||
For local development, we generally assume the server is healthy.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Simple health check via a minimal completion
|
||||
if chat_completion:
|
||||
await self.openai.chat.completions.create(
|
||||
model=self.config.model_name,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
else:
|
||||
await self.openai.completions.create(
|
||||
model=self.config.model_name,
|
||||
prompt="hi",
|
||||
max_tokens=1,
|
||||
)
|
||||
self.server_healthy = True
|
||||
except Exception:
|
||||
self.server_healthy = False
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for chat completion using an OpenAI-compatible API.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("messages") is not None, "Messages are required!"
|
||||
|
||||
n = kwargs.get("n", 1)
|
||||
|
||||
# Some OpenAI-compatible servers don't support n > 1, so we make multiple requests.
|
||||
if n > 1:
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.chat.completions.create(**{**kwargs, "n": 1}) for _ in range(n)]
|
||||
)
|
||||
# Merge completions
|
||||
completions = completion_list[0]
|
||||
for c in completion_list[1:]:
|
||||
for choice in c.choices:
|
||||
choice.index = len(completions.choices)
|
||||
completions.choices.append(choice)
|
||||
return completions
|
||||
else:
|
||||
return await self.openai.chat.completions.create(**kwargs)
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for completion using an OpenAI-compatible API.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("prompt") is not None, "Prompt is required!"
|
||||
|
||||
n = kwargs.get("n", 1)
|
||||
|
||||
# Some OpenAI-compatible servers don't support n > 1.
|
||||
if n > 1:
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**{**kwargs, "n": 1}) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
for c in completion_list[1:]:
|
||||
for choice in c.choices:
|
||||
choice.index = len(completions.choices)
|
||||
completions.choices.append(choice)
|
||||
return completions
|
||||
else:
|
||||
return await self.openai.completions.create(**kwargs)
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[List[int], List[List[int]], List[List[float]], List[str]]:
|
||||
"""
|
||||
Wrapper for tokens and logprobs completion.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons)
|
||||
|
||||
Note: Many OpenAI-compatible local servers don't return per-token logprobs,
|
||||
so we use placeholder logprobs (0.0). For real training, use vLLM/SGLang.
|
||||
"""
|
||||
model = kwargs.get("model")
|
||||
assert model is not None, "Model is required!"
|
||||
|
||||
# Handle input_ids (from ManagedServer) or prompt
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
prompt = self.tokenizer.decode(prompt_tokens)
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt = kwargs.pop("prompt", "")
|
||||
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
|
||||
|
||||
n = kwargs.pop("n", 1)
|
||||
max_tokens = kwargs.pop("max_tokens", 256)
|
||||
temperature = kwargs.pop("temperature", 0.7)
|
||||
stop = kwargs.pop("stop", None)
|
||||
|
||||
# Make completion requests
|
||||
completions = []
|
||||
for _ in range(n):
|
||||
try:
|
||||
response = await self.openai.completions.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
)
|
||||
completions.append(response)
|
||||
except Exception as e:
|
||||
# Fallback to chat completion if completion endpoint not supported
|
||||
warnings.warn(f"Completion API failed, trying chat: {e}")
|
||||
response = await self.openai.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
)
|
||||
# Convert to completion-like response
|
||||
completions.append(response)
|
||||
|
||||
output_tokens_list = []
|
||||
output_logprobs_list = []
|
||||
finish_reasons = []
|
||||
|
||||
for completion in completions:
|
||||
# Extract text from response
|
||||
if hasattr(completion.choices[0], "text"):
|
||||
# Completion API response
|
||||
text = completion.choices[0].text
|
||||
finish_reason = completion.choices[0].finish_reason or "stop"
|
||||
else:
|
||||
# Chat completion API response
|
||||
text = completion.choices[0].message.content or ""
|
||||
finish_reason = completion.choices[0].finish_reason or "stop"
|
||||
|
||||
# Tokenize output
|
||||
output_tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Placeholder logprobs (many local servers don't provide per-token logprobs).
|
||||
# In production, use vLLM/SGLang which return real logprobs
|
||||
output_logprobs = [0.0] * len(output_tokens)
|
||||
|
||||
output_tokens_list.append(output_tokens)
|
||||
output_logprobs_list.append(output_logprobs)
|
||||
finish_reasons.append(finish_reason)
|
||||
|
||||
return prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons
|
||||
|
||||
def managed_server(self, tokenizer=None, track_tree: bool = False):
|
||||
"""
|
||||
Create a ManagedServer context manager for this server.
|
||||
|
||||
Args:
|
||||
tokenizer: Optional tokenizer override
|
||||
track_tree: Whether to maintain tree structure for multi-turn
|
||||
|
||||
Returns:
|
||||
ManagedServer context manager
|
||||
"""
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
return ManagedServerContext(
|
||||
self,
|
||||
tokenizer=tokenizer or self.tokenizer,
|
||||
track_tree=track_tree,
|
||||
)
|
||||
|
||||
|
||||
class ManagedServerContext:
|
||||
"""
|
||||
Context manager wrapper for ManagedServer.
|
||||
|
||||
Usage:
|
||||
async with server.managed_server(tokenizer=tokenizer) as managed:
|
||||
response = await managed.chat_completion(...)
|
||||
state = managed.get_state()
|
||||
"""
|
||||
|
||||
def __init__(self, server: LocalServer, tokenizer, track_tree: bool = False):
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.track_tree = track_tree
|
||||
self.managed = None
|
||||
|
||||
async def __aenter__(self):
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
self.managed = ManagedServer(
|
||||
self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
track_tree=self.track_tree,
|
||||
)
|
||||
return self.managed
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.managed:
|
||||
self.managed.reset()
|
||||
return False
|
||||
136
memory-bank/activeContext.md
Normal file
136
memory-bank/activeContext.md
Normal file
@@ -0,0 +1,136 @@
|
||||
# Active Context
|
||||
|
||||
## Current Task: SWE Smith Oracle Env with Modal Backend
|
||||
|
||||
### Goal
|
||||
Run this command:
|
||||
```bash
|
||||
python environments/swe_smith_oracle_env.py process \
|
||||
--env.use_wandb false \
|
||||
--env.total_steps 2 \
|
||||
--env.group_size 1 \
|
||||
--env.max_items 2 \
|
||||
--env.tool_pool_mode modal \
|
||||
--env.modal_image python:3.11 \
|
||||
--env.modal_slots_per_sandbox 10 \
|
||||
--env.modal_min_sandboxes 1
|
||||
```
|
||||
|
||||
### What's Done
|
||||
1. ✅ **agent_loop.py** - Added `tool_handler` parameter
|
||||
- New param: `tool_handler=None` in `__init__`
|
||||
- When `self.tool_handler` is set, it's called INSTEAD of `handle_function_call()`
|
||||
- Signature: `async tool_handler(tool_name, args, task_id) -> str`
|
||||
- Shows `[sandbox]` instead of backend name in terminal preview
|
||||
|
||||
2. ✅ **Phase 2 ManagedServer + SGLang** - Fully working (previous session)
|
||||
|
||||
3. ✅ **hermes_base_env.py** - Sandbox routing in collect_trajectory() (THIS SESSION)
|
||||
- Refactored `collect_trajectory()` into:
|
||||
- `_use_sandbox_backend()` - checks if sandbox should be used
|
||||
- `_collect_trajectory_local()` - existing path (ToolContext + handle_function_call)
|
||||
- `_collect_trajectory_sandbox()` - NEW sandbox path with slot lifecycle
|
||||
- `_run_agent_loop()` - shared agent loop for Phase 1/2, accepts tool_handler
|
||||
- `_build_scored_item()` - shared scored item construction
|
||||
- Sandbox path:
|
||||
1. `backend.acquire(task_id)` → Slot
|
||||
2. `exec_tool` callable wrapping `backend.execute_batch([(slot, tool_name, args)])`
|
||||
3. `setup_trajectory_workspace(item, exec_tool=exec_tool)` → workspace_meta
|
||||
4. `sandbox_tool_handler` routes terminal→sandbox, other→local
|
||||
5. `_run_agent_loop(tool_handler=sandbox_tool_handler)`
|
||||
6. `verify_and_score_trajectory(item, result, exec_tool=exec_tool)`
|
||||
7. `backend.release(slot, reset_workspace=True)` in finally
|
||||
- Added `handle_function_call` import for non-terminal tool fallback
|
||||
|
||||
4. ✅ **swe_smith_oracle_env.py** - Sandbox hooks (THIS SESSION)
|
||||
- `setup_trajectory_workspace()` - bare repo cache + git worktree (ported from atropos/envs/swe_smith_oracle_env.py)
|
||||
- `verify_and_score_trajectory()` - install deps + run pytest in sandbox
|
||||
- `compute_reward()` retained for local (non-sandbox) path
|
||||
- Uses `exec_tool("bash", {"command": cmd}, timeout=600)` → `ExecutionResult`
|
||||
|
||||
5. ✅ **All tests pass**:
|
||||
- Syntax checks (ast.parse) on both files
|
||||
- Import checks (both modules import cleanly)
|
||||
- Method existence checks (all new methods present)
|
||||
- Signature checks (exec_tool, trajectory_id, workspace_meta params)
|
||||
- Backend integration (ModalSandboxConfig.from_agent_env_config, create_tool_backend)
|
||||
- `_use_sandbox_backend()` logic (True when modal+backend set, False otherwise)
|
||||
|
||||
6. ✅ **End-to-end test with Qwen 3 8B + Modal sandbox** (THIS SESSION)
|
||||
- RunPod endpoint: `0tx0ruuuo4f10c` (Qwen/Qwen3-8B via SGLang)
|
||||
- 5 terminal tool calls executed IN sandbox: `ls`, `git status`, `git log`, `cat parse.py`, `cat tests/`
|
||||
- In-sandbox verification: install deps + pytest → score=0.0 (model inspected but didn't fix)
|
||||
- Full token tracking with logprobs via Phase 2 ManagedServer
|
||||
- Key finding: Llama-3-8B template silently drops `tools=` param, Qwen 3 has full Hermes format support
|
||||
|
||||
### Current Task: Integrate Slot Pool Backend into tools/terminal_tool.py
|
||||
|
||||
#### Step 1: Add `_SlotPoolEnvironment` to `tools/terminal_tool.py`
|
||||
- New class alongside existing `_LocalEnvironment`, `_DockerEnvironment`, etc.
|
||||
- Routes through `atropos/backends/` (ModalToolBackend or NomadToolBackend)
|
||||
- N:M slot multiplexing: 5-10 sandboxes × 10 slots each = 50-100 concurrent
|
||||
- Singleton `_SlotPoolManager` (like `_ModalPoolManager`) manages backend lifecycle
|
||||
- `execute()` acquires slot → `backend.execute_batch([(slot, "bash", ...)])` → returns `{"output": ..., "returncode": ...}`
|
||||
- `cleanup()` releases slot back to pool
|
||||
|
||||
#### Step 2: Wire into `_create_environment()`
|
||||
- `TERMINAL_ENV=slot_pool` → `_SlotPoolEnvironment(...)`
|
||||
- Sub-config: `TERMINAL_SLOT_BACKEND=modal` or `TERMINAL_SLOT_BACKEND=nomad`
|
||||
- Reuse existing `TERMINAL_MODAL_*` and Nomad env vars for configuration
|
||||
|
||||
#### Step 3: Remove redundant `atropos/tools/` files
|
||||
- DELETE: `hermes_external_tools.py`, `build_registry.py`, `sandbox_stubs.py`, `toolset_resolver.py`
|
||||
- KEEP: `base.py` (ToolCall/ToolResult types), `tool_executor.py` (batched queue), `terminal_stateful_tool.py`, `tmux_tool.py`
|
||||
|
||||
#### Step 4: Clean up `atropos/envs/` and `atropos/agent/` (defer)
|
||||
- Remove `atropos/envs/agent_env.py` → replaced by `environments/hermes_base_env.py`
|
||||
- Remove `atropos/agent/atropos_agent.py` → replaced by `environments/agent_loop.py`
|
||||
|
||||
#### Later
|
||||
- Test with Tinker trainer (blocked on billing)
|
||||
- Add more environments (endless-terminals, terminalbench 2)
|
||||
|
||||
### Key Architecture Insight
|
||||
Two separate sandbox integration points:
|
||||
1. **`tools/terminal_tool.py` with `TERMINAL_ENV=slot_pool`** — for hermes CLI, batch_runner, any code using `handle_function_call("terminal", ...)`. Uses `_SlotPoolEnvironment` which wraps `atropos/backends/`.
|
||||
2. **`environments/hermes_base_env.py` with `tool_pool_mode=modal/nomad`** — for RL environments. Uses `_collect_trajectory_sandbox()` which directly acquires slots and creates `sandbox_tool_handler`.
|
||||
|
||||
Both use the same underlying `atropos/backends/` (ModalToolBackend, NomadToolBackend) with the same slot pool.
|
||||
|
||||
### Architecture Summary
|
||||
|
||||
```
|
||||
environments/hermes_base_env.py (HermesAgentBaseEnv)
|
||||
│
|
||||
├── tool_pool_mode="default" (existing path)
|
||||
│ └── collect_trajectory() → HermesAgentLoop(tool_handler=None)
|
||||
│ → handle_function_call() → hermes terminal tool (local)
|
||||
│
|
||||
└── tool_pool_mode="modal" or "nomad" (new path)
|
||||
└── collect_trajectory():
|
||||
1. slot = backend.acquire(task_id)
|
||||
2. exec_tool = lambda routing through backend.execute_batch
|
||||
3. setup_trajectory_workspace(item, exec_tool=exec_tool) [subclass hook]
|
||||
4. HermesAgentLoop(tool_handler=sandbox_tool_handler)
|
||||
→ terminal calls → backend.execute_batch(slot, "bash", ...)
|
||||
5. verify_and_score_trajectory(item, result, exec_tool=exec_tool) [subclass hook]
|
||||
6. backend.release(slot, reset_workspace=True)
|
||||
|
||||
atropos/backends/modal_backend.py (ModalToolBackend)
|
||||
└── acquire(trajectory_id) → Slot
|
||||
└── execute_batch([(slot, "bash", {"command": "..."})]) → [ExecutionResult]
|
||||
└── release(slot, reset_workspace=True)
|
||||
```
|
||||
|
||||
### Key Files to Modify
|
||||
1. `environments/hermes_base_env.py` - Add sandbox path in `collect_trajectory()`
|
||||
2. `environments/swe_smith_oracle_env.py` - Override `setup_trajectory_workspace()` and `verify_and_score_trajectory()` to use exec_tool
|
||||
|
||||
### Important Notes
|
||||
- `exec_tool` returns `ExecutionResult` (from `atropos/slots/executor.py`) with `.success`, `.output`, `.error`, `.metadata`
|
||||
- `tool_handler` returns JSON string (for agent loop message format)
|
||||
- These are DIFFERENT interfaces for different purposes:
|
||||
- `exec_tool`: used by env hooks (setup/verify) - returns structured result
|
||||
- `tool_handler`: used by agent loop - returns JSON string like hermes tools do
|
||||
- The ModalToolBackend.execute_batch calls _ModalSandboxWithSlots.execute which runs `sandbox.exec("bash", "-c", command)` on Modal
|
||||
- For the SWE env, the worktree setup pattern from `atropos/envs/swe_smith_oracle_env.py` should be reused (bare repo cache + worktree add)
|
||||
55
memory-bank/productContext.md
Normal file
55
memory-bank/productContext.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Product Context: Hermes-Agent
|
||||
|
||||
## Why This Project Exists
|
||||
|
||||
Hermes-Agent addresses several key challenges in the AI agent space:
|
||||
|
||||
1. **Unified Tool Interface** - Provides a clean, consistent interface for LLMs to use various tools (web, terminal, browser, vision, etc.) without requiring custom integration for each model provider.
|
||||
|
||||
2. **Training Data Generation** - Enables efficient generation of high-quality tool-calling trajectories for fine-tuning LLMs, with features like batch processing, checkpointing, and trajectory compression.
|
||||
|
||||
3. **Flexible Deployment** - Supports multiple execution environments (local, Docker, Singularity, Modal, SSH) to accommodate different security and isolation requirements.
|
||||
|
||||
4. **Developer Experience** - Offers a beautiful, interactive CLI with kawaii-style feedback that makes working with AI agents enjoyable.
|
||||
|
||||
## Problems It Solves
|
||||
|
||||
### For AI Researchers
|
||||
- **Data Generation at Scale**: Parallel batch processing with content-based checkpointing for fault tolerance
|
||||
- **Clean Trajectories**: Trajectory compression to fit token budgets while preserving important information
|
||||
- **Toolset Distributions**: Probability-based tool selection for varied training data
|
||||
|
||||
### For Developers
|
||||
- **Tool Orchestration**: Logical grouping of tools into toolsets (research, development, debugging, etc.)
|
||||
- **Session Persistence**: Conversation history and session logging for debugging
|
||||
- **Multi-Model Support**: Works with any OpenAI-compatible API (OpenRouter, local models, etc.)
|
||||
|
||||
### For MLOps
|
||||
- **Skills System**: On-demand knowledge documents for specific tools/frameworks (Axolotl, vLLM, TRL, etc.)
|
||||
- **Sandboxed Execution**: Terminal commands can run in isolated environments (Docker, Singularity, Modal)
|
||||
- **Configurable Backends**: Easy switching between local and cloud execution
|
||||
|
||||
## How It Should Work
|
||||
|
||||
### User Flow (CLI)
|
||||
1. User launches `./hermes`
|
||||
2. Beautiful welcome banner displays with caduceus logo, model info, and available tools
|
||||
3. User types a natural language request
|
||||
4. Agent processes request, potentially calling tools with animated feedback
|
||||
5. Agent responds with results, conversation continues
|
||||
6. Session is automatically logged for debugging
|
||||
|
||||
### User Flow (Batch Processing)
|
||||
1. User prepares JSONL file with prompts
|
||||
2. Runs `batch_runner.py` with distribution and worker count
|
||||
3. System processes prompts in parallel, saves checkpoints
|
||||
4. Completed trajectories saved to `data/<run_name>/trajectories.jsonl`
|
||||
5. Optional: compress trajectories with `trajectory_compressor.py`
|
||||
|
||||
## User Experience Goals
|
||||
|
||||
- **Delightful Interaction**: Kawaii ASCII faces, animated spinners, cute messages
|
||||
- **Informative Feedback**: Clear progress indication during tool execution
|
||||
- **Configurable Personalities**: From "helpful" to "pirate" to "Shakespeare"
|
||||
- **Easy Configuration**: YAML config file + environment variables + CLI flags
|
||||
- **Graceful Degradation**: Missing tools/APIs don't break the system, just disable features
|
||||
134
memory-bank/progress.md
Normal file
134
memory-bank/progress.md
Normal file
@@ -0,0 +1,134 @@
|
||||
# Progress
|
||||
|
||||
## Current Sprint: Phase 2 ManagedServer + SGLang Working (Feb 10, 2026)
|
||||
|
||||
### ✅ Phase 2 End-to-End Pipeline VERIFIED
|
||||
Full pipeline working: GSM8k env → collect_trajectory → ManagedServer → VLLMServer (SGLang patched) → tokens + logprobs + masks.
|
||||
|
||||
Test results:
|
||||
- 212 tokens with logprobs and masks from single trajectory
|
||||
- Reward: 1.0 (correct answer)
|
||||
- ScoredDataItem has all required fields: tokens, masks, scores, advantages, ref_logprobs, messages
|
||||
- RunPod SGLang endpoint (b9zmuyn1carwya) with Llama-3-8B-Instruct
|
||||
|
||||
### Consolidation Checklist
|
||||
- [x] Install atropos `tool_call_support` branch (PR #366)
|
||||
- [x] Create `environments/gsm8k_agent_env.py` using `HermesAgentBaseEnv`
|
||||
- [x] Create `environments/agent_loop.py` with proper OpenAI-spec tool calling
|
||||
- [x] Create `environments/tool_call_parsers/` with 13 parsers
|
||||
- [x] Create `environments/patches.py` for SGLang compatibility
|
||||
- [x] Add sandbox pool support to `HermesAgentBaseEnv`
|
||||
- [x] Test Phase 1 (OpenAI server type) with Nous API — WORKS
|
||||
- [x] Test Phase 2 (ManagedServer) with RunPod SGLang — WORKS
|
||||
- [x] Port SWE env to `HermesAgentBaseEnv` with multiplexed sandboxing
|
||||
- [x] End-to-end test: Qwen 3 8B + Modal sandbox + tool calls in sandbox + pytest verification
|
||||
- [x] Add `_SlotPoolEnvironment` to `tools/terminal_tool.py` (TERMINAL_ENV=slot_pool)
|
||||
- [x] Remove redundant `atropos/tools/` files (4 of 8)
|
||||
- [ ] Remove redundant `atropos/agent/` and `atropos/envs/agent_env.py` (deferred)
|
||||
- [ ] Test end-to-end with Tinker trainer (blocked on billing)
|
||||
|
||||
### ✅ End-to-End SWE + Modal Sandbox Verified (Feb 10, 2026)
|
||||
- Qwen 3 8B on RunPod SGLang (endpoint `0tx0ruuuo4f10c`)
|
||||
- Phase 2 ManagedServer with hermes tool call parser
|
||||
- 5 terminal commands executed in Modal sandbox: ls, git status, git log, cat parse.py, cat tests/
|
||||
- In-sandbox verification: install deps + pytest → score 0.0 (model inspected but didn't fix)
|
||||
- Full token tracking with logprobs via /generate endpoint
|
||||
- Key finding: Llama-3-8B template drops tools= silently; Qwen 3 has full Hermes tool format
|
||||
|
||||
## Completed Features
|
||||
|
||||
### ✅ Phase 2 ManagedServer + SGLang (Feb 10, 2026)
|
||||
- SGLang patch in `environments/patches.py` monkey-patches VLLMServer
|
||||
- Handles SGLang's different request/response format vs VLLM
|
||||
- Handles RunPod's double-JSON wrapping
|
||||
- Full chain verified: ManagedServer → VLLMServer → _tokens_and_logprobs_comp (retry) → patched wrapper → /generate endpoint
|
||||
- SequenceNode tracking: tokens, logprobs, masked_tokens all populated
|
||||
- **Key discovery**: The AttributeError from earlier was NOT in our current code — likely from a prior code state
|
||||
|
||||
### ✅ Phase 1 OpenAI Server Mode (Feb 9-10, 2026)
|
||||
- GSM8k env works with Nous API (OpenRouter-style endpoint)
|
||||
- Terminal tool calls properly dispatched
|
||||
- Tool call parsing handled natively by server (VLLM/SGLang /v1/chat/completions)
|
||||
- Reward computation verified (math_verify for robust LaTeX comparison)
|
||||
|
||||
### ✅ Sandbox Pool Integration (Feb 10, 2026)
|
||||
- Config fields added to `HermesAgentEnvConfig` for Nomad and Modal
|
||||
- `_start_sandbox_backend()` / `_stop_sandbox_backend()` lifecycle methods
|
||||
- Optional hooks: `setup_trajectory_workspace()`, `verify_and_score_trajectory()`
|
||||
- Integrated into `env_manager()` and `process_manager()` cleanup
|
||||
|
||||
### ✅ Tool Call Parsers (Feb 9-10, 2026)
|
||||
- 13 parsers: hermes, llama3_json, llama4_json, qwen, qwen3_coder, deepseek_v3, deepseek_v31, glm45, glm47, mistral, kimi_k2, longcat
|
||||
- Registry pattern: `get_parser("hermes")` returns parser instance
|
||||
- Each parser: `.parse(text) → (content, tool_calls)`
|
||||
- Used by ManagedServer in Phase 2 to extract structured tool_calls from raw completion
|
||||
|
||||
### ✅ Modal Backend Integration (Feb 8, 2026)
|
||||
- `ModalToolBackend` with slot-based multiplexing
|
||||
- Multi-profile support (CPU, GPU, high-memory)
|
||||
- Auto-scaling sandbox pool via Modal Sandboxes
|
||||
|
||||
### ✅ Main Branch Merge (Feb 9, 2026)
|
||||
- Merged 22,560 lines, 79 files, 5 conflicts resolved
|
||||
- New: hermes_cli/, file_operations, RL training tools, gateway, cron
|
||||
|
||||
### ✅ Tinker RL Training Setup (Feb 9, 2026)
|
||||
- tinker 0.12.0 + tinker-atropos installed
|
||||
- GSM8k agent config created
|
||||
- Pipeline verified: Tinker API connection works, all imports pass
|
||||
- **Blocked on billing** (Tinker 402 error)
|
||||
|
||||
### ✅ Singularity/Apptainer Sandbox (Feb 6, 2026)
|
||||
- Nomad raw_exec driver for HPC clusters
|
||||
- All sandbox operations tested and working
|
||||
|
||||
### ✅ Memory Bank (Feb 5, 2026)
|
||||
- Project documentation structure initialized
|
||||
|
||||
## What to KEEP vs REMOVE
|
||||
|
||||
### KEEP (valuable infrastructure):
|
||||
| Component | Location | Purpose |
|
||||
|-----------|----------|---------|
|
||||
| Modal backend | `atropos/backends/modal_backend.py` | Cloud sandbox pool |
|
||||
| Nomad backend | `atropos/backends/nomad_backend.py` | Docker/Singularity sandboxes |
|
||||
| Slot pool | `atropos/slots/` | Container multiplexing |
|
||||
| Nomad client | `atropos/nomad/` | Nomad API |
|
||||
| Sandbox server | `atropos/sandbox_server.py` | HTTP server in containers |
|
||||
| Dockerfile | `atropos/Dockerfile` | Container image |
|
||||
| Agent loop | `environments/agent_loop.py` | Proper OpenAI-spec tool calling |
|
||||
| Base env | `environments/hermes_base_env.py` | Phase 1/2 with parsers |
|
||||
| Tool parsers | `environments/tool_call_parsers/` | 13 model parsers |
|
||||
| SGLang patch | `environments/patches.py` | SGLang compatibility |
|
||||
|
||||
### REMOVE (redundant with environments/):
|
||||
| Component | Location | Replaced By |
|
||||
|-----------|----------|-------------|
|
||||
| ICL agent | `atropos/agent/atropos_agent.py` | `environments/agent_loop.py` |
|
||||
| AgentEnv | `atropos/envs/agent_env.py` | `environments/hermes_base_env.py` |
|
||||
| Tool registry | `atropos/tools/` | `model_tools.py` + `tools/` |
|
||||
| GSM8k ICL env | `tinker-atropos/.../gsm8k_agent.py` | `environments/gsm8k_agent_env.py` |
|
||||
|
||||
## Known Issues
|
||||
- Tinker billing (402 error) - user's payment didn't process
|
||||
- `bwrap_available: false` in Singularity containers
|
||||
- Llama-3-8B-Instruct doesn't reliably produce tool calls via Phase 2 (needs Hermes-format model)
|
||||
- Model answered GSM8k correctly but didn't actually USE the terminal tool (computed mentally)
|
||||
|
||||
## Evolution of Decisions
|
||||
|
||||
### Agent Architecture
|
||||
- **v1 (our branch)**: ICL-based agent with `<tool_call>` XML tags in system prompt
|
||||
- **v2 (Teknium's)**: Proper OpenAI-spec tool calling with `tools=` parameter
|
||||
- **Decision**: Adopt v2, consolidate into `environments/`, keep sandbox backends from v1
|
||||
|
||||
### Environment Organization
|
||||
- **Before**: Two parallel systems (`atropos/envs/` and `environments/`)
|
||||
- **After**: Single system in `environments/`, using `HermesAgentBaseEnv` as base class
|
||||
- Sandbox backends remain in `atropos/backends/` but integrate via terminal backend config
|
||||
|
||||
### Phase 2 SGLang Support
|
||||
- **Problem**: VLLMServer hardcoded for VLLM's /generate format, SGLang is different
|
||||
- **Solution**: Monkey-patch `_tokens_and_logprobs_completion_wrapper` in `environments/patches.py`
|
||||
- **Applied**: Automatically at import time via `apply_patches()` in `hermes_base_env.py`
|
||||
- **Handles**: SGLang format differences AND RunPod's double-JSON wrapping
|
||||
44
memory-bank/projectbrief.md
Normal file
44
memory-bank/projectbrief.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# Project Brief: Hermes-Agent
|
||||
|
||||
## Overview
|
||||
Hermes-Agent is an AI agent harness for LLMs with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools. Named after Hermes, the Greek messenger god, it serves as a bridge between human intent and AI-powered task execution.
|
||||
|
||||
## Core Requirements
|
||||
|
||||
### Primary Goals
|
||||
1. **Interactive CLI Experience** - Beautiful terminal interface with animated feedback, personalities, and session management
|
||||
2. **Flexible Tool System** - Modular tools organized into logical toolsets for different use cases
|
||||
3. **Batch Processing** - Process multiple prompts in parallel with checkpointing and statistics
|
||||
4. **Multi-Backend Support** - Support for local, Docker, Singularity, Modal, and SSH terminal backends
|
||||
5. **Training Data Generation** - Save conversation trajectories in formats suitable for LLM fine-tuning
|
||||
|
||||
### Target Users
|
||||
- AI researchers generating training data
|
||||
- Developers needing an AI assistant with tool access
|
||||
- MLOps practitioners automating workflows
|
||||
- Anyone needing a powerful CLI-based AI agent
|
||||
|
||||
## Scope
|
||||
|
||||
### In Scope
|
||||
- Interactive CLI with rich formatting and kawaii-style feedback
|
||||
- Web tools (search, extract, crawl via Firecrawl)
|
||||
- Terminal tools (command execution across multiple backends)
|
||||
- Browser automation (via agent-browser + Browserbase)
|
||||
- Vision tools (image analysis)
|
||||
- Image generation (FLUX via FAL.ai)
|
||||
- Mixture-of-Agents reasoning
|
||||
- Skills system for on-demand knowledge
|
||||
- Batch processing with parallel workers
|
||||
- Trajectory compression for training
|
||||
|
||||
### Out of Scope (Current)
|
||||
- Proactive suggestions (agent only runs on request)
|
||||
- Clipboard integration (no local system access)
|
||||
- Real-time streaming of thinking/reasoning (deferred)
|
||||
|
||||
## Success Metrics
|
||||
- Clean, maintainable tool architecture
|
||||
- Reliable tool execution with proper error handling
|
||||
- Efficient context management for long conversations
|
||||
- High-quality trajectory data for training
|
||||
267
memory-bank/systemPatterns.md
Normal file
267
memory-bank/systemPatterns.md
Normal file
@@ -0,0 +1,267 @@
|
||||
# System Patterns: Hermes-Agent
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ CLI (cli.py) │
|
||||
│ - Rich welcome banner with caduceus │
|
||||
│ - prompt_toolkit for input with history │
|
||||
│ - Kawaii-style feedback and personalities │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ AIAgent (run_agent.py) │
|
||||
│ - Conversation loop with tool calling │
|
||||
│ - KawaiiSpinner for animated feedback │
|
||||
│ - Retry logic with exponential backoff │
|
||||
│ - Session logging to logs/ directory │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Tool Routing (model_tools.py) │
|
||||
│ - get_tool_definitions() - returns tools for API calls │
|
||||
│ - handle_function_call() - dispatches to tool handlers │
|
||||
│ - Toolset filtering (enabled/disabled) │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
┌─────────────────┼─────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌───────────┐ ┌───────────┐ ┌───────────┐
|
||||
│ Web Tools │ │ Terminal │ │ Browser │
|
||||
│ (Firecrawl)│ │ (mini-swe)│ │(agent-brw)│
|
||||
└───────────┘ └───────────┘ └───────────┘
|
||||
│ │ │
|
||||
└─────────────────┼─────────────────┘
|
||||
▼
|
||||
┌───────────────┐
|
||||
│ Toolsets │
|
||||
│ (toolsets.py)│
|
||||
│ Composition │
|
||||
└───────────────┘
|
||||
```
|
||||
|
||||
## Key Design Patterns
|
||||
|
||||
### 1. Toolset Composition Pattern
|
||||
Toolsets can include other toolsets, allowing flexible composition:
|
||||
|
||||
```python
|
||||
TOOLSETS = {
|
||||
"web": {"tools": ["web_search", "web_extract"], "includes": []},
|
||||
"debugging": {"tools": ["terminal"], "includes": ["web"]},
|
||||
"full_stack": {"tools": [], "includes": ["web", "terminal", "vision", "browser"]}
|
||||
}
|
||||
```
|
||||
|
||||
Resolution is recursive with cycle detection.
|
||||
|
||||
### 2. Graceful Degradation Pattern
|
||||
Each tool module has a `check_*_requirements()` function:
|
||||
- Tools are only loaded if requirements are met
|
||||
- Missing API keys disable tools, not crash the system
|
||||
- Import errors are caught and tools marked unavailable
|
||||
|
||||
```python
|
||||
try:
|
||||
from tools.web_tools import web_search_tool, check_firecrawl_api_key
|
||||
except ModuleNotFoundError:
|
||||
web_search_tool = None
|
||||
def check_firecrawl_api_key(): return False
|
||||
```
|
||||
|
||||
### 3. Session Isolation Pattern (task_id)
|
||||
Stateful tools (terminal, browser) use `task_id` to isolate concurrent sessions:
|
||||
- Each batch worker gets unique task_id
|
||||
- VMs and browser sessions are tracked per task_id
|
||||
- Cleanup functions release resources: `cleanup_vm(task_id)`, `cleanup_browser(task_id)`
|
||||
|
||||
### 4. Trajectory Format Pattern
|
||||
Conversations are saved in ShareGPT format for training:
|
||||
|
||||
```json
|
||||
{"from": "system", "value": "System prompt with <tools>...</tools>"}
|
||||
{"from": "human", "value": "User message"}
|
||||
{"from": "gpt", "value": "<think>reasoning</think>\n<tool_call>{...}</tool_call>"}
|
||||
{"from": "tool", "value": "<tool_response>{...}</tool_response>"}
|
||||
{"from": "gpt", "value": "Final response"}
|
||||
```
|
||||
|
||||
### 5. Ephemeral System Prompt Pattern
|
||||
Guide model behavior during data collection without saving to trajectories:
|
||||
- `ephemeral_system_prompt` influences execution
|
||||
- Only standard tool-calling system prompt saved to trajectories
|
||||
- Keeps training data clean
|
||||
|
||||
### 6. Retry with Validation Pattern
|
||||
The agent validates responses before accepting:
|
||||
- Check tool names against `valid_tool_names` set
|
||||
- Validate JSON arguments can be parsed
|
||||
- Check for content after `<think>` blocks
|
||||
- Roll back to last valid state on persistent failures
|
||||
|
||||
## Component Relationships
|
||||
|
||||
### AIAgent Class
|
||||
- Central orchestrator for conversations
|
||||
- Manages conversation history
|
||||
- Calls OpenAI-compatible API
|
||||
- Routes tool calls to handlers
|
||||
- Provides animated feedback (KawaiiSpinner)
|
||||
|
||||
### Tool Modules (tools/*.py)
|
||||
- Self-contained tool implementations
|
||||
- Export: handler function + check function + schema
|
||||
- Return JSON strings (never raw dicts)
|
||||
- Accept optional `task_id` for stateful tools
|
||||
|
||||
### Toolsets System (toolsets.py)
|
||||
- Defines logical groupings of tools
|
||||
- Supports composition via `includes`
|
||||
- `resolve_toolset()` recursively resolves all tools
|
||||
- `validate_toolset()` checks if name is valid
|
||||
|
||||
### Model Tools (model_tools.py)
|
||||
- Aggregates all tool definitions
|
||||
- Routes function calls to correct handlers
|
||||
- Filters tools based on enabled/disabled toolsets
|
||||
- Bridge between agent and tool implementations
|
||||
|
||||
## Critical Implementation Paths
|
||||
|
||||
### Tool Execution Flow
|
||||
1. AIAgent receives tool_calls from API response
|
||||
2. Validates tool names against `valid_tool_names`
|
||||
3. Validates JSON arguments can be parsed
|
||||
4. Calls `handle_function_call()` with tool name, args, task_id
|
||||
5. `handle_function_call()` routes to appropriate handler
|
||||
6. Tool executes, returns JSON string
|
||||
7. Result added to conversation as tool message
|
||||
8. Loop continues until natural language response
|
||||
|
||||
### Configuration Loading Flow
|
||||
1. `cli.py` calls `load_cli_config()`
|
||||
2. Loads `cli-config.yaml`, merges with defaults
|
||||
3. Sets environment variables for terminal config
|
||||
4. `AIAgent` reads env vars when initializing terminal tool
|
||||
5. Terminal tool creates appropriate backend based on `TERMINAL_ENV`
|
||||
|
||||
## RL Training Architecture (Consolidated)
|
||||
|
||||
### Environment System (`environments/`)
|
||||
|
||||
The canonical way to build agentic RL environments in Hermes-Agent:
|
||||
|
||||
```
|
||||
environments/
|
||||
├── agent_loop.py ← HermesAgentLoop: OpenAI-spec tool calling
|
||||
├── hermes_base_env.py ← HermesAgentBaseEnv: base class for all envs
|
||||
├── tool_context.py ← ToolContext: reward function tool access
|
||||
├── tool_call_parsers/ ← 11+ model parsers (hermes, qwen, deepseek, etc.)
|
||||
├── terminal_test_env.py ← Example: file creation tasks
|
||||
├── hermes_swe_env.py ← SWE environment
|
||||
└── gsm8k_agent_env.py ← GSM8k with Python REPL (TODO)
|
||||
```
|
||||
|
||||
### Two-Phase Operation
|
||||
- **Phase 1 (OpenAI server)**: Native tool_calls from VLLM/SGLang/OpenRouter
|
||||
- Good for: SFT data gen, testing, evaluation
|
||||
- Server handles tool call parsing via `/v1/chat/completions`
|
||||
- **Phase 2 (ManagedServer)**: Client-side tool call parser + logprob tracking
|
||||
- Required for: RL training (exact token IDs + logprobs for GRPO/PPO)
|
||||
- Uses `/generate` endpoint for raw token output
|
||||
- Parser registry selects per-model parser (hermes, qwen, llama, etc.)
|
||||
- **Verified working** with RunPod SGLang endpoint (Feb 10, 2026)
|
||||
|
||||
### Phase 2 Call Chain (Verified)
|
||||
```
|
||||
collect_trajectory()
|
||||
→ ServerManager.managed_server(tokenizer, tool_call_parser)
|
||||
→ ManagedServer(server=VLLMServer)
|
||||
→ ManagedServer.chat_completion(messages, tools, n, max_tokens, temp)
|
||||
→ _convert_messages_to_prompt(messages, tools=tools) [apply_chat_template]
|
||||
→ _compute_input_ids(prompt, extending_node)
|
||||
→ VLLMServer.tokens_and_logprobs_completion(**kwargs) [public method]
|
||||
→ _tokens_and_logprobs_comp(stat_dict, **kwargs) [retry decorator, semaphore]
|
||||
→ _tokens_and_logprobs_completion_wrapper(**kwargs) [patched for SGLang]
|
||||
→ aiohttp POST to /generate
|
||||
→ Returns (prompt_tokens, [output_tokens], [output_logprobs], [finish_reasons])
|
||||
→ _create_sequence_node(...) [stores in current_nodes]
|
||||
→ tool_call_parser.parse(completion_text) [if parser configured]
|
||||
→ Returns ChatCompletion with tool_calls
|
||||
```
|
||||
|
||||
### SGLang Compatibility Patch (`environments/patches.py`)
|
||||
VLLMServer's `_tokens_and_logprobs_completion_wrapper` is monkey-patched to handle SGLang's
|
||||
different request/response format. Applied automatically at import time via `apply_patches()`.
|
||||
|
||||
```
|
||||
SGLang request: {"input_ids": [...], "sampling_params": {...}, "return_logprob": true}
|
||||
SGLang response: {"meta_info": {"output_token_logprobs": [[logprob, token_id, text], ...]}}
|
||||
|
||||
VLLM request: {"prompt": {"prompt_token_ids": [...]}, "logprobs": 0}
|
||||
VLLM response: {"logprobs": [[{token_id: logprob}]], "finish_reasons": [...]}
|
||||
```
|
||||
|
||||
Also handles RunPod serverless double-JSON wrapping (response body wrapped in quotes).
|
||||
|
||||
### Key Design: Proper Tool Calling (NOT ICL)
|
||||
```python
|
||||
# CORRECT: pass tools= to chat_completion()
|
||||
response = await server.chat_completion(
|
||||
messages=messages,
|
||||
tools=tool_schemas, # ← tokenizer.apply_chat_template(tools=...) formats these
|
||||
temperature=1.0,
|
||||
)
|
||||
# Response has response.choices[0].message.tool_calls (structured objects)
|
||||
|
||||
# WRONG (old approach): embed tools in system prompt as XML
|
||||
system_prompt = f"<tools>{json.dumps(tools)}</tools>" # ← ICL, not proper training format
|
||||
```
|
||||
|
||||
### Sandbox Backends (`atropos/backends/`)
|
||||
|
||||
Infrastructure for scaled sandbox execution, integrated into HermesAgentBaseEnv:
|
||||
|
||||
```
|
||||
ToolBackend (Protocol)
|
||||
├── NomadToolBackend → SlotPool → NomadClient + SandboxExecutor (HTTP)
|
||||
│ ├── Docker driver (default)
|
||||
│ └── Singularity driver (HPC)
|
||||
└── ModalToolBackend → _ModalSandboxPool → modal.Sandbox.exec() (direct)
|
||||
└── _ModalMultiProfileManager (multi-profile support)
|
||||
```
|
||||
|
||||
Two execution modes in HermesAgentBaseEnv (controlled by `tool_pool_mode` config):
|
||||
- `default` - Local tool execution via handle_function_call() + ToolContext
|
||||
- `modal` / `nomad` - Sandbox routing: slot acquire → setup workspace → agent loop → verify → release
|
||||
|
||||
Sandbox routing architecture:
|
||||
```
|
||||
collect_trajectory()
|
||||
├── tool_pool_mode="default" → _collect_trajectory_local()
|
||||
│ └── _run_agent_loop(tool_handler=None) → compute_reward(ctx)
|
||||
│
|
||||
└── tool_pool_mode="modal"/"nomad" → _collect_trajectory_sandbox()
|
||||
├── backend.acquire(task_id) → Slot
|
||||
├── exec_tool = backend.execute_batch wrapper → ExecutionResult
|
||||
├── setup_trajectory_workspace(item, exec_tool) [subclass hook]
|
||||
├── _run_agent_loop(tool_handler=sandbox_tool_handler)
|
||||
│ └── terminal → backend.execute_batch → JSON string
|
||||
│ └── other tools → handle_function_call (local)
|
||||
├── verify_and_score_trajectory(item, result, exec_tool) [subclass hook]
|
||||
└── backend.release(slot, reset_workspace=True) [finally]
|
||||
```
|
||||
|
||||
Key interfaces:
|
||||
- `exec_tool(tool_name, args, timeout)` → `ExecutionResult` (for env hooks)
|
||||
- `tool_handler(tool_name, args, task_id)` → JSON string (for agent loop)
|
||||
|
||||
### Training Pipeline (Tinker + Atropos)
|
||||
```
|
||||
Terminal 1: run-api (port 8000) ← Atropos Rollout API
|
||||
Terminal 2: launch_training.py (port 8001) ← Tinker Trainer + inference
|
||||
Terminal 3: environment.py serve ← Environment (rollouts)
|
||||
```
|
||||
113
memory-bank/techContext.md
Normal file
113
memory-bank/techContext.md
Normal file
@@ -0,0 +1,113 @@
|
||||
# Technical Context: Hermes-Agent
|
||||
|
||||
## Technologies Used
|
||||
|
||||
### Core Stack
|
||||
- **Python 3.11+** - Primary language
|
||||
- **OpenAI SDK** - For LLM API interactions (OpenAI-compatible)
|
||||
- **OpenRouter** - Default LLM provider (supports multiple models)
|
||||
- **Rich** - Terminal formatting and panels
|
||||
- **prompt_toolkit** - Interactive input with history
|
||||
- **Fire** - CLI argument parsing
|
||||
- **PyYAML** - Configuration files
|
||||
- **python-dotenv** - Environment variable management
|
||||
|
||||
### Tool Dependencies
|
||||
- **Firecrawl** - Web search and extraction (`FIRECRAWL_API_KEY`)
|
||||
- **mini-swe-agent** - Terminal tool backend (local/docker/singularity/modal/ssh)
|
||||
- **agent-browser** - Browser automation (npm package)
|
||||
- **Browserbase** - Cloud browser execution (`BROWSERBASE_API_KEY`)
|
||||
- **FAL.ai** - Image generation with FLUX (`FAL_KEY`)
|
||||
- **Nous API** - Vision and MoA tools (`NOUS_API_KEY`)
|
||||
|
||||
### Optional Dependencies
|
||||
- **Modal** - Cloud compute for sandboxed environments
|
||||
- **Singularity/Apptainer** - Rootless containers (HPC environments)
|
||||
- **Docker** - Container isolation
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Quick Start
|
||||
```bash
|
||||
# Clone with submodules
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Create virtual environment
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Install browser tools (optional)
|
||||
npm install
|
||||
|
||||
# Configure environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your API keys
|
||||
```
|
||||
|
||||
### Key Configuration Files
|
||||
- `.env` - API keys and secrets
|
||||
- `cli-config.yaml` - CLI configuration (model, terminal, toolsets, personalities)
|
||||
- `configs/` - Batch run scripts and configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
**Required for Full Functionality:**
|
||||
- `OPENROUTER_API_KEY` - Primary LLM access
|
||||
- `FIRECRAWL_API_KEY` - Web tools
|
||||
- `NOUS_API_KEY` - Vision and reasoning tools
|
||||
- `FAL_KEY` - Image generation
|
||||
|
||||
**Terminal Backend:**
|
||||
- `TERMINAL_ENV` - Backend type: `local`, `docker`, `singularity`, `modal`, `ssh`
|
||||
- `TERMINAL_CWD` - Working directory
|
||||
- `TERMINAL_DOCKER_IMAGE` / `TERMINAL_SINGULARITY_IMAGE` - Container images
|
||||
- `TERMINAL_SSH_HOST/USER/KEY` - SSH backend config
|
||||
- `SUDO_PASSWORD` - Optional sudo support
|
||||
|
||||
**Browser:**
|
||||
- `BROWSERBASE_API_KEY` - Browser automation
|
||||
- `BROWSERBASE_PROJECT_ID` - Browserbase project
|
||||
|
||||
## Technical Constraints
|
||||
|
||||
1. **Context Window Limits** - Long tool outputs can exhaust context; trajectory compression helps
|
||||
2. **API Rate Limits** - OpenRouter and tool APIs have rate limits; exponential backoff implemented
|
||||
3. **Tool Availability** - Tools gracefully degrade if dependencies/keys missing
|
||||
4. **Async Compatibility** - Some tools are async, handled via `asyncio.run()` in sync context
|
||||
|
||||
## Dependency Graph
|
||||
|
||||
```
|
||||
tools/*.py → tools/__init__.py → model_tools.py → toolsets.py → toolset_distributions.py
|
||||
↑
|
||||
run_agent.py ──────────────────────────┘
|
||||
cli.py → run_agent.py (uses AIAgent with quiet_mode=True)
|
||||
batch_runner.py → run_agent.py + toolset_distributions.py
|
||||
```
|
||||
|
||||
## Tool Usage Patterns
|
||||
|
||||
### Adding a New Tool
|
||||
1. Create `tools/your_tool.py` with handler + requirements check
|
||||
2. Export in `tools/__init__.py`
|
||||
3. Register in `model_tools.py` (definitions + handler routing)
|
||||
4. Add to toolset in `toolsets.py`
|
||||
5. Optionally add to `toolset_distributions.py` for batch processing
|
||||
|
||||
### Tool Handler Pattern
|
||||
```python
|
||||
def your_tool(param: str, task_id: str = None) -> str:
|
||||
"""Execute tool and return JSON string result."""
|
||||
try:
|
||||
result = {"success": True, "data": "..."}
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
```
|
||||
|
||||
All tool handlers MUST return a JSON string, never raw dicts.
|
||||
Submodule mini-swe-agent updated: 07aa6a7385...ee36b3d4e5
134
modal_profiles.yaml.example
Normal file
134
modal_profiles.yaml.example
Normal file
@@ -0,0 +1,134 @@
|
||||
# Modal Sandbox Profiles Configuration
|
||||
# =====================================
|
||||
# This file defines different sandbox profiles for heterogeneous workloads.
|
||||
# Copy to modal_profiles.yaml and customize as needed.
|
||||
#
|
||||
# Usage:
|
||||
# terminal_tool("python train.py", profile="pytorch-gpu")
|
||||
# terminal_tool("npm test", profile="node")
|
||||
#
|
||||
# Each profile can specify:
|
||||
# - image: Docker image to use
|
||||
# - gpu: GPU type (null, "T4", "A10G", "A100", "H100")
|
||||
# - cpu: CPU cores (float)
|
||||
# - memory: Memory in MB
|
||||
# - min_pool: Minimum warm sandboxes (cost vs latency tradeoff)
|
||||
# - max_pool: Maximum sandboxes (hard cost cap)
|
||||
# - idle_timeout: Server-side auto-cleanup in seconds
|
||||
# - max_lifetime: Maximum sandbox lifetime in seconds
|
||||
# - scale_down_idle: Client-side scale-down threshold in seconds
|
||||
# - workdir: Working directory inside container
|
||||
# - secrets: List of Modal Secret names to inject (created via dashboard/CLI)
|
||||
# - env_vars: Dict of environment variables to pass directly
|
||||
# - use_dotenv: If true, loads local .env file into sandbox
|
||||
#
|
||||
# SECRETS SETUP:
|
||||
# Create secrets via Modal dashboard or CLI:
|
||||
# modal secret create huggingface-token HF_TOKEN=hf_xxx
|
||||
# modal secret create openai-key OPENAI_API_KEY=sk-xxx
|
||||
# Then reference by name in profile's secrets list.
|
||||
|
||||
# Default profile used when no profile specified
|
||||
default_profile: default
|
||||
|
||||
profiles:
|
||||
# Default Python environment - good for most tasks
|
||||
default:
|
||||
image: python:3.11
|
||||
gpu: null
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 1 # Keep 1 warm for fast response
|
||||
max_pool: 5
|
||||
idle_timeout: 120 # Modal terminates if idle 2 min
|
||||
max_lifetime: 3600 # Max 1 hour
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
secrets: [] # Add secret names here: ["my-api-keys"]
|
||||
env_vars: {} # Add env vars here: {DEBUG: "1"}
|
||||
use_dotenv: false # Set to true to load local .env
|
||||
|
||||
# PyTorch with GPU for ML training/inference
|
||||
pytorch-gpu:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: T4 # Options: T4, A10G, A100, H100
|
||||
cpu: 4.0
|
||||
memory: 16384 # 16GB
|
||||
min_pool: 0 # Don't keep GPU sandboxes warm (expensive!)
|
||||
max_pool: 2
|
||||
idle_timeout: 60 # Shorter idle timeout for GPU (cost)
|
||||
max_lifetime: 1800 # 30 min max for GPU tasks
|
||||
scale_down_idle: 60
|
||||
workdir: /workspace
|
||||
# ML-specific secrets
|
||||
secrets:
|
||||
- huggingface-token # HF_TOKEN env var
|
||||
- wandb-key # WANDB_API_KEY env var
|
||||
env_vars:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
|
||||
|
||||
# High-end GPU for large models
|
||||
pytorch-a100:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: A100
|
||||
cpu: 8.0
|
||||
memory: 65536 # 64GB
|
||||
min_pool: 0
|
||||
max_pool: 1 # Only 1 at a time (very expensive)
|
||||
idle_timeout: 30
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 30
|
||||
workdir: /workspace
|
||||
|
||||
# Node.js for JavaScript/TypeScript tasks
|
||||
node:
|
||||
image: node:18
|
||||
gpu: null
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 0 # Create on-demand
|
||||
max_pool: 3
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# High memory for data processing
|
||||
high-memory:
|
||||
image: python:3.11
|
||||
gpu: null
|
||||
cpu: 4.0
|
||||
memory: 32768 # 32GB
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# Rust development environment
|
||||
rust:
|
||||
image: rust:1.75
|
||||
gpu: null
|
||||
cpu: 2.0
|
||||
memory: 4096
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# Go development environment
|
||||
golang:
|
||||
image: golang:1.21
|
||||
gpu: null
|
||||
cpu: 2.0
|
||||
memory: 4096
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
180
model_tools.py
180
model_tools.py
@@ -700,13 +700,21 @@ def get_file_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read a file with pagination support. Returns content with line numbers in 'LINE_NUM|CONTENT' format. For binary files (images), returns base64-encoded data. If file not found, suggests similar filenames.",
|
||||
"description": (
|
||||
"Read a file with pagination support. Preferred over 'cat' in the terminal because it "
|
||||
"provides line numbers, handles binary/image files, and suggests similar filenames if "
|
||||
"the file is not found.\n\n"
|
||||
"**Output format:** Each line is returned as 'LINE_NUM|CONTENT' for easy reference.\n"
|
||||
"**Binary files:** Detected automatically; images (png/jpg/gif/webp) are returned as base64 with MIME type and dimensions.\n"
|
||||
"**Large files:** Use offset and limit to paginate. The response includes total line count and a hint for the next page.\n"
|
||||
"**Paths:** Supports absolute paths, relative paths (from working directory), and ~ expansion."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to read (absolute or relative)"
|
||||
"description": "Path to the file to read (absolute, relative, or ~/path)"
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
@@ -729,17 +737,25 @@ def get_file_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file. Creates parent directories automatically. Returns bytes written and lint check results for supported languages.",
|
||||
"description": (
|
||||
"Write content to a file, completely replacing any existing content. Creates parent "
|
||||
"directories automatically if they don't exist. Preferred over 'echo' or heredoc in the "
|
||||
"terminal because it safely handles special characters, newlines, and shell metacharacters "
|
||||
"without escaping issues.\n\n"
|
||||
"**Important:** This OVERWRITES the entire file. To make targeted edits to an existing file, "
|
||||
"use the 'patch' tool instead.\n"
|
||||
"**Paths:** Supports absolute paths, relative paths, and ~ expansion."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to write (will be created if doesn't exist)"
|
||||
"description": "Path to the file to write (will be created if it doesn't exist, overwritten if it does)"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
"description": "Complete content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
@@ -750,36 +766,48 @@ def get_file_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "patch",
|
||||
"description": "Modify files using either simple string replacement or V4A patch format. Mode 'replace' does find-and-replace with fuzzy matching. Mode 'patch' applies multi-file changes using V4A format (*** Begin/End Patch). Auto-runs syntax checks on modified files.",
|
||||
"description": (
|
||||
"Modify existing files using targeted edits. Preferred over 'sed' or manual rewriting because "
|
||||
"it uses intelligent fuzzy matching that tolerates minor whitespace and indentation differences, "
|
||||
"and auto-runs syntax checks (Python, JS, TS, Go, Rust) after editing.\n\n"
|
||||
"**Replace mode (recommended):** Find a unique string in the file and replace it. Uses a "
|
||||
"9-strategy fuzzy matching chain (exact → line-trimmed → whitespace-normalized → "
|
||||
"indentation-flexible → context-aware) so small formatting differences won't cause failures. "
|
||||
"Returns a unified diff showing exactly what changed.\n\n"
|
||||
"**Patch mode:** Apply multi-file changes using V4A patch format for large-scale edits across "
|
||||
"multiple files in one call.\n\n"
|
||||
"**Auto-lint:** After every edit, automatically runs syntax checks and reports errors so you "
|
||||
"can fix them immediately."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["replace", "patch"],
|
||||
"description": "Edit mode: 'replace' for string replacement, 'patch' for V4A patch format",
|
||||
"description": "Edit mode: 'replace' for targeted find-and-replace, 'patch' for V4A multi-file patches",
|
||||
"default": "replace"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (required for 'replace' mode)"
|
||||
"description": "File path to edit (required for 'replace' mode)"
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "Text to find and replace (required for 'replace' mode). Must be unique in file unless replace_all=true"
|
||||
"description": "Text to find in the file (required for 'replace' mode). Must be unique in the file unless replace_all=true. Include enough surrounding context to ensure uniqueness."
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "Replacement text (required for 'replace' mode)"
|
||||
"description": "Replacement text (required for 'replace' mode). Can be empty string to delete the matched text."
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences instead of requiring unique match (default: false)",
|
||||
"description": "Replace all occurrences instead of requiring a unique match (default: false)",
|
||||
"default": False
|
||||
},
|
||||
"patch": {
|
||||
"type": "string",
|
||||
"description": "V4A format patch content (required for 'patch' mode). Format: *** Begin Patch / *** Update File: path / @@ context @@ / -removed / +added / *** End Patch"
|
||||
"description": "V4A format patch content (required for 'patch' mode). Format:\n*** Begin Patch\n*** Update File: path/to/file\n@@ context hint @@\n context line\n-removed line\n+added line\n*** End Patch"
|
||||
}
|
||||
},
|
||||
"required": ["mode"]
|
||||
@@ -790,7 +818,16 @@ def get_file_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search for content in files or search for files by name. Use target='content' to search inside files (like grep), or target='files' to find files by name pattern (like glob/find). Results sorted by modification time (newest first).",
|
||||
"description": (
|
||||
"Search for content inside files or find files by name. Preferred over 'grep' or 'find' "
|
||||
"in the terminal because it uses ripgrep (fast) with automatic fallback to grep, handles "
|
||||
"pagination, and returns structured results sorted by modification time (newest first).\n\n"
|
||||
"**Content search (target='content'):** Regex-powered search inside files with optional "
|
||||
"file type filtering and context lines. Three output modes: full matches with line numbers, "
|
||||
"file paths only, or match counts per file.\n\n"
|
||||
"**File search (target='files'):** Find files by glob pattern (e.g., '*.py', '*config*'). "
|
||||
"Results sorted by modification time so recently changed files appear first."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -801,12 +838,12 @@ def get_file_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"target": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files"],
|
||||
"description": "Search mode: 'content' searches inside files, 'files' searches for files by name",
|
||||
"description": "Search mode: 'content' searches inside files (like grep/rg), 'files' searches for files by name (like find/glob)",
|
||||
"default": "content"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search in (default: current directory)",
|
||||
"description": "Directory or file to search in (default: current working directory)",
|
||||
"default": "."
|
||||
},
|
||||
"file_glob": {
|
||||
@@ -815,7 +852,7 @@ def get_file_tool_definitions() -> List[Dict[str, Any]]:
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results (default: 50)",
|
||||
"description": "Maximum number of results to return (default: 50)",
|
||||
"default": 50
|
||||
},
|
||||
"offset": {
|
||||
@@ -826,12 +863,12 @@ def get_file_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"output_mode": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files_only", "count"],
|
||||
"description": "For target='content': 'content' shows matches, 'files_only' shows file paths, 'count' shows match counts per file",
|
||||
"description": "Output format for content search: 'content' shows matching lines with line numbers, 'files_only' lists file paths, 'count' shows match counts per file",
|
||||
"default": "content"
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"description": "Lines of context around matches (only for target='content', output_mode='content')",
|
||||
"description": "Number of lines to show before and after each match (only for target='content', output_mode='content')",
|
||||
"default": 0
|
||||
}
|
||||
},
|
||||
@@ -909,6 +946,53 @@ def get_all_tool_names() -> List[str]:
|
||||
return tool_names
|
||||
|
||||
|
||||
# Master mapping of every tool name → its toolset.
|
||||
# This is the single source of truth for all valid tool names in the system.
|
||||
# Import TOOL_TO_TOOLSET_MAP from here whenever you need to check valid tools.
|
||||
TOOL_TO_TOOLSET_MAP = {
|
||||
"web_search": "web_tools",
|
||||
"web_extract": "web_tools",
|
||||
"terminal": "terminal_tools",
|
||||
"vision_analyze": "vision_tools",
|
||||
"mixture_of_agents": "moa_tools",
|
||||
"image_generate": "image_tools",
|
||||
# Skills tools
|
||||
"skills_categories": "skills_tools",
|
||||
"skills_list": "skills_tools",
|
||||
"skill_view": "skills_tools",
|
||||
# Browser automation tools
|
||||
"browser_navigate": "browser_tools",
|
||||
"browser_snapshot": "browser_tools",
|
||||
"browser_click": "browser_tools",
|
||||
"browser_type": "browser_tools",
|
||||
"browser_scroll": "browser_tools",
|
||||
"browser_back": "browser_tools",
|
||||
"browser_press": "browser_tools",
|
||||
"browser_close": "browser_tools",
|
||||
"browser_get_images": "browser_tools",
|
||||
"browser_vision": "browser_tools",
|
||||
# Cronjob management tools
|
||||
"schedule_cronjob": "cronjob_tools",
|
||||
"list_cronjobs": "cronjob_tools",
|
||||
"remove_cronjob": "cronjob_tools",
|
||||
# RL Training tools
|
||||
"rl_list_environments": "rl_tools",
|
||||
"rl_select_environment": "rl_tools",
|
||||
"rl_get_current_config": "rl_tools",
|
||||
"rl_edit_config": "rl_tools",
|
||||
"rl_start_training": "rl_tools",
|
||||
"rl_check_status": "rl_tools",
|
||||
"rl_stop_training": "rl_tools",
|
||||
"rl_get_results": "rl_tools",
|
||||
"rl_list_runs": "rl_tools",
|
||||
# File manipulation tools
|
||||
"read_file": "file_tools",
|
||||
"write_file": "file_tools",
|
||||
"patch": "file_tools",
|
||||
"search": "file_tools",
|
||||
}
|
||||
|
||||
|
||||
def get_toolset_for_tool(tool_name: str) -> str:
|
||||
"""
|
||||
Get the toolset that a tool belongs to.
|
||||
@@ -919,50 +1003,7 @@ def get_toolset_for_tool(tool_name: str) -> str:
|
||||
Returns:
|
||||
str: Name of the toolset, or "unknown" if not found
|
||||
"""
|
||||
toolset_mapping = {
|
||||
"web_search": "web_tools",
|
||||
"web_extract": "web_tools",
|
||||
"terminal": "terminal_tools",
|
||||
"vision_analyze": "vision_tools",
|
||||
"mixture_of_agents": "moa_tools",
|
||||
"image_generate": "image_tools",
|
||||
# Skills tools
|
||||
"skills_categories": "skills_tools",
|
||||
"skills_list": "skills_tools",
|
||||
"skill_view": "skills_tools",
|
||||
# Browser automation tools
|
||||
"browser_navigate": "browser_tools",
|
||||
"browser_snapshot": "browser_tools",
|
||||
"browser_click": "browser_tools",
|
||||
"browser_type": "browser_tools",
|
||||
"browser_scroll": "browser_tools",
|
||||
"browser_back": "browser_tools",
|
||||
"browser_press": "browser_tools",
|
||||
"browser_close": "browser_tools",
|
||||
"browser_get_images": "browser_tools",
|
||||
"browser_vision": "browser_tools",
|
||||
# Cronjob management tools
|
||||
"schedule_cronjob": "cronjob_tools",
|
||||
"list_cronjobs": "cronjob_tools",
|
||||
"remove_cronjob": "cronjob_tools",
|
||||
# RL Training tools
|
||||
"rl_list_environments": "rl_tools",
|
||||
"rl_select_environment": "rl_tools",
|
||||
"rl_get_current_config": "rl_tools",
|
||||
"rl_edit_config": "rl_tools",
|
||||
"rl_start_training": "rl_tools",
|
||||
"rl_check_status": "rl_tools",
|
||||
"rl_stop_training": "rl_tools",
|
||||
"rl_get_results": "rl_tools",
|
||||
"rl_list_runs": "rl_tools",
|
||||
# File manipulation tools
|
||||
"read_file": "file_tools",
|
||||
"write_file": "file_tools",
|
||||
"patch": "file_tools",
|
||||
"search": "file_tools",
|
||||
}
|
||||
|
||||
return toolset_mapping.get(tool_name, "unknown")
|
||||
return TOOL_TO_TOOLSET_MAP.get(tool_name, "unknown")
|
||||
|
||||
|
||||
def get_tool_definitions(
|
||||
@@ -1191,8 +1232,19 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any])
|
||||
urls = function_args.get("urls", [])
|
||||
# Limit URLs to prevent abuse
|
||||
urls = urls[:5] if isinstance(urls, list) else []
|
||||
# Run async function in event loop
|
||||
return asyncio.run(web_extract_tool(urls, "markdown"))
|
||||
# Run async function -- use existing loop if available (Atropos),
|
||||
# otherwise create one (normal CLI)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
# Already in an async context (Atropos) -- run in a thread
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(
|
||||
lambda: asyncio.run(web_extract_tool(urls, "markdown"))
|
||||
).result(timeout=120)
|
||||
except RuntimeError:
|
||||
# No running loop (normal CLI) -- use asyncio.run directly
|
||||
return asyncio.run(web_extract_tool(urls, "markdown"))
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
37
nomad-dev.hcl
Normal file
37
nomad-dev.hcl
Normal file
@@ -0,0 +1,37 @@
|
||||
# Nomad Development Configuration (Hermes-Agent)
|
||||
# Run with: nomad agent -dev -config=nomad-dev.hcl
|
||||
#
|
||||
# This is intended for local development only.
|
||||
|
||||
client {
|
||||
enabled = true
|
||||
|
||||
options {
|
||||
# Enable Docker volume mounts for persistent slot workspaces
|
||||
"docker.volumes.enabled" = "true"
|
||||
}
|
||||
}
|
||||
|
||||
# Docker driver plugin configuration
|
||||
plugin "docker" {
|
||||
config {
|
||||
# CRITICAL: Enable volume mounts
|
||||
volumes {
|
||||
enabled = true
|
||||
}
|
||||
|
||||
# Allow privileged containers if needed
|
||||
allow_privileged = false
|
||||
|
||||
# Garbage collection settings
|
||||
gc {
|
||||
image = true
|
||||
# NOTE: For local dev we often rely on locally built images like `atropos-sandbox:local`.
|
||||
# A short image GC delay can delete these between runs, causing confusing "Failed to pull"
|
||||
# crash loops. Keep this comfortably long; tighten it for CI/production if needed.
|
||||
image_delay = "24h"
|
||||
container = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
31
nomad-singularity.hcl
Normal file
31
nomad-singularity.hcl
Normal file
@@ -0,0 +1,31 @@
|
||||
# Nomad Configuration for Singularity/Apptainer Sandbox
|
||||
# Run with: nomad agent -dev -config=nomad-singularity.hcl
|
||||
#
|
||||
# This uses the raw_exec driver to run Apptainer containers.
|
||||
# Suitable for HPC environments where Docker cannot run without sudo.
|
||||
|
||||
client {
|
||||
enabled = true
|
||||
|
||||
options {
|
||||
# Enable raw_exec driver for Singularity/Apptainer
|
||||
"driver.raw_exec.enable" = "1"
|
||||
}
|
||||
}
|
||||
|
||||
# raw_exec driver plugin configuration
|
||||
plugin "raw_exec" {
|
||||
config {
|
||||
enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
# Optional: If you have the nomad-driver-singularity plugin installed,
|
||||
# uncomment the following instead of using raw_exec:
|
||||
# plugin "singularity" {
|
||||
# config {
|
||||
# enabled = true
|
||||
# # Allow bind mounts
|
||||
# bind_paths = ["/tmp", "/var/tmp"]
|
||||
# }
|
||||
# }
|
||||
@@ -19,9 +19,12 @@ dependencies = [
|
||||
"rich",
|
||||
"tenacity",
|
||||
"pyyaml",
|
||||
"prompt_toolkit",
|
||||
"requests",
|
||||
"jinja2",
|
||||
"pydantic>=2.0",
|
||||
# Interactive CLI (prompt_toolkit is used directly by cli.py)
|
||||
"prompt_toolkit",
|
||||
# Tools
|
||||
"firecrawl-py",
|
||||
"fal-client",
|
||||
@@ -32,19 +35,51 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
modal = ["modal", "boto3"]
|
||||
modal = ["swe-rex[modal]>=1.4.0"]
|
||||
dev = ["pytest", "pytest-asyncio"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0"]
|
||||
cron = ["croniter"]
|
||||
cli = ["simple-term-menu"]
|
||||
all = ["croniter", "python-telegram-bot>=20.0", "discord.py>=2.0", "simple-term-menu"]
|
||||
# Install Atropos + Tinker training integration from source.
|
||||
# Uses tool_call_support branch for ManagedServer tool calling (PR #366).
|
||||
atropos = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git@tool_call_support",
|
||||
"tinker @ git+https://github.com/thinking-machines-lab/tinker.git",
|
||||
# Atropos integration runtime deps (kept optional for Hermes-only users)
|
||||
"aiohttp",
|
||||
"fastapi",
|
||||
"uvicorn",
|
||||
"pyte",
|
||||
"torch",
|
||||
"wandb",
|
||||
"math-verify",
|
||||
]
|
||||
all = [
|
||||
"hermes-agent[modal]",
|
||||
"hermes-agent[messaging]",
|
||||
"hermes-agent[cron]",
|
||||
"hermes-agent[cli]",
|
||||
"hermes-agent[dev]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
hermes = "hermes_cli.main:main"
|
||||
hermes-agent = "run_agent:main"
|
||||
hermes-atropos-sandbox-smoke = "atropos.envs.sandbox_terminal_smoke_env:SandboxTerminalSmokeEnv.cli"
|
||||
hermes-atropos-toolserver-smoke = "atropos.envs.toolserver_smoke_env:ToolServerSmokeEnv.cli"
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions", "cli"]
|
||||
py-modules = [
|
||||
"run_agent",
|
||||
"model_tools",
|
||||
"toolsets",
|
||||
"batch_runner",
|
||||
"trajectory_compressor",
|
||||
"toolset_distributions",
|
||||
"atropos_compatible_agent",
|
||||
"local_server",
|
||||
"cli",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron"]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron", "atropos", "atropos.*"]
|
||||
|
||||
@@ -6,6 +6,10 @@ httpx
|
||||
rich
|
||||
tenacity
|
||||
prompt_toolkit
|
||||
pyyaml
|
||||
requests
|
||||
jinja2
|
||||
pydantic>=2.0
|
||||
|
||||
# Web tools
|
||||
firecrawl-py
|
||||
@@ -15,10 +19,6 @@ fal-client
|
||||
|
||||
# mini-swe-agent dependencies (for terminal tool)
|
||||
# Note: Install mini-swe-agent itself with: pip install -e ./mini-swe-agent
|
||||
pyyaml
|
||||
requests
|
||||
jinja2
|
||||
pydantic>=2.0
|
||||
litellm>=1.75.5
|
||||
typer
|
||||
platformdirs
|
||||
@@ -27,18 +27,17 @@ platformdirs
|
||||
# Requires Docker installed and user in 'docker' group
|
||||
|
||||
# Optional: For Modal backend (cloud execution)
|
||||
# modal
|
||||
# boto3
|
||||
# swe-rex[modal]>=1.4.0 # Includes modal + boto3 + swe-rex runtime
|
||||
|
||||
# Optional: For cron expression parsing (cronjob scheduling)
|
||||
croniter
|
||||
|
||||
# Optional: For messaging platform integrations (gateway)
|
||||
# Telegram: pip install python-telegram-bot
|
||||
# Telegram
|
||||
python-telegram-bot>=20.0
|
||||
|
||||
# Discord: pip install discord.py
|
||||
# Discord
|
||||
discord.py>=2.0
|
||||
|
||||
# WhatsApp: Requires Node.js bridge (see docs/messaging.md)
|
||||
# aiohttp # For WhatsApp bridge communication
|
||||
# WhatsApp bridge communication + general async HTTP (used by gateway)
|
||||
aiohttp>=3.9.0
|
||||
|
||||
368
run_agent.py
368
run_agent.py
@@ -30,7 +30,6 @@ import threading
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI
|
||||
import fire
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -66,6 +65,7 @@ _MODEL_CACHE_TTL = 3600 # 1 hour cache TTL
|
||||
DEFAULT_CONTEXT_LENGTHS = {
|
||||
"anthropic/claude-opus-4": 200000,
|
||||
"anthropic/claude-opus-4.5": 200000,
|
||||
"anthropic/claude-opus-4.6": 200000,
|
||||
"anthropic/claude-sonnet-4": 200000,
|
||||
"anthropic/claude-sonnet-4-20250514": 200000,
|
||||
"anthropic/claude-haiku-4.5": 200000,
|
||||
@@ -206,7 +206,7 @@ class ContextCompressor:
|
||||
self,
|
||||
model: str,
|
||||
threshold_percent: float = 0.85,
|
||||
summary_model: str = "google/gemini-2.0-flash-001",
|
||||
summary_model: str = "google/gemini-3-flash-preview",
|
||||
protect_first_n: int = 3,
|
||||
protect_last_n: int = 4,
|
||||
summary_target_tokens: int = 500,
|
||||
@@ -584,7 +584,7 @@ class AIAgent:
|
||||
self,
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514", # OpenRouter format
|
||||
model: str = "anthropic/claude-opus-4.6", # OpenRouter format
|
||||
max_iterations: int = 60, # Default tool-calling iterations
|
||||
tool_delay: float = 1.0,
|
||||
enabled_toolsets: List[str] = None,
|
||||
@@ -601,6 +601,9 @@ class AIAgent:
|
||||
provider_sort: str = None,
|
||||
session_id: str = None,
|
||||
tool_progress_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -625,6 +628,12 @@ class AIAgent:
|
||||
provider_sort (str): Sort providers by price/throughput/latency (optional)
|
||||
session_id (str): Pre-generated session ID for logging (optional, auto-generated if not provided)
|
||||
tool_progress_callback (callable): Callback function(tool_name, args_preview) for progress notifications
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_config (Dict): OpenRouter reasoning configuration override (e.g. {"effort": "none"} to disable thinking).
|
||||
If None, defaults to {"enabled": True, "effort": "xhigh"} for OpenRouter. Set to disable/customize reasoning.
|
||||
prefill_messages (List[Dict]): Messages to prepend to conversation history as prefilled context.
|
||||
Useful for injecting a few-shot example or priming the model's response style.
|
||||
Example: [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hello!"}]
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
@@ -653,6 +662,11 @@ class AIAgent:
|
||||
self.enabled_toolsets = enabled_toolsets
|
||||
self.disabled_toolsets = disabled_toolsets
|
||||
|
||||
# Model response configuration
|
||||
self.max_tokens = max_tokens # None = use model default
|
||||
self.reasoning_config = reasoning_config # None = use default (xhigh for OpenRouter)
|
||||
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
|
||||
|
||||
# Configure logging
|
||||
if self.verbose_logging:
|
||||
logging.basicConfig(
|
||||
@@ -781,7 +795,7 @@ class AIAgent:
|
||||
# Compresses conversation when approaching model's context limit
|
||||
# Configuration via environment variables (can be set in .env or cli-config.yaml)
|
||||
compression_threshold = float(os.getenv("CONTEXT_COMPRESSION_THRESHOLD", "0.85"))
|
||||
compression_model = os.getenv("CONTEXT_COMPRESSION_MODEL", "google/gemini-2.0-flash-001")
|
||||
compression_model = os.getenv("CONTEXT_COMPRESSION_MODEL", "google/gemini-3-flash-preview")
|
||||
compression_enabled = os.getenv("CONTEXT_COMPRESSION_ENABLED", "true").lower() in ("true", "1", "yes")
|
||||
|
||||
self.context_compressor = ContextCompressor(
|
||||
@@ -1086,6 +1100,43 @@ class AIAgent:
|
||||
|
||||
return json.dumps(formatted_tools, ensure_ascii=False)
|
||||
|
||||
@staticmethod
|
||||
def _convert_scratchpad_to_think(content: str) -> str:
|
||||
"""
|
||||
Convert <REASONING_SCRATCHPAD> tags to <think> tags in content.
|
||||
|
||||
When native thinking/reasoning is disabled and the model is prompted to
|
||||
reason inside <REASONING_SCRATCHPAD> XML tags instead, this converts those
|
||||
to the standard <think> format used in our trajectory storage.
|
||||
|
||||
Args:
|
||||
content: Assistant message content that may contain scratchpad tags
|
||||
|
||||
Returns:
|
||||
Content with scratchpad tags replaced by think tags
|
||||
"""
|
||||
if not content or "<REASONING_SCRATCHPAD>" not in content:
|
||||
return content
|
||||
return content.replace("<REASONING_SCRATCHPAD>", "<think>").replace("</REASONING_SCRATCHPAD>", "</think>")
|
||||
|
||||
@staticmethod
|
||||
def _has_incomplete_scratchpad(content: str) -> bool:
|
||||
"""
|
||||
Check if content has an opening <REASONING_SCRATCHPAD> without a closing tag.
|
||||
|
||||
This indicates the model ran out of output tokens mid-reasoning, producing
|
||||
a broken turn that shouldn't be saved. The caller should retry or discard.
|
||||
|
||||
Args:
|
||||
content: Assistant message content to check
|
||||
|
||||
Returns:
|
||||
True if there's an unclosed scratchpad tag
|
||||
"""
|
||||
if not content:
|
||||
return False
|
||||
return "<REASONING_SCRATCHPAD>" in content and "</REASONING_SCRATCHPAD>" not in content
|
||||
|
||||
def _convert_to_trajectory_format(self, messages: List[Dict[str, Any]], user_query: str, completed: bool) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert internal message format to trajectory format for saving.
|
||||
@@ -1120,14 +1171,19 @@ class AIAgent:
|
||||
"value": system_msg
|
||||
})
|
||||
|
||||
# Add the initial user message
|
||||
# Add the actual user prompt (from the dataset) as the first human message
|
||||
trajectory.append({
|
||||
"from": "human",
|
||||
"value": user_query
|
||||
})
|
||||
|
||||
# Process remaining messages
|
||||
i = 1 # Skip the first user message as we already added it
|
||||
# Calculate where agent responses start in the messages list.
|
||||
# Prefill messages are ephemeral (only used to prime model response style)
|
||||
# so we skip them entirely in the saved trajectory.
|
||||
# Layout: [*prefill_msgs, actual_user_msg, ...agent_responses...]
|
||||
num_prefill = len(self.prefill_messages) if self.prefill_messages else 0
|
||||
i = num_prefill + 1 # Skip prefill messages + the actual user message (already added above)
|
||||
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
@@ -1138,12 +1194,14 @@ class AIAgent:
|
||||
# Add <think> tags around reasoning for trajectory storage
|
||||
content = ""
|
||||
|
||||
# Prepend reasoning in <think> tags if available
|
||||
# Prepend reasoning in <think> tags if available (native thinking tokens)
|
||||
if msg.get("reasoning") and msg["reasoning"].strip():
|
||||
content = f"<think>\n{msg['reasoning']}\n</think>\n"
|
||||
|
||||
if msg.get("content") and msg["content"].strip():
|
||||
content += msg["content"] + "\n"
|
||||
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
|
||||
# (used when native thinking is disabled and model reasons via XML)
|
||||
content += self._convert_scratchpad_to_think(msg["content"]) + "\n"
|
||||
|
||||
# Add tool calls wrapped in XML tags
|
||||
for tool_call in msg["tool_calls"]:
|
||||
@@ -1163,6 +1221,11 @@ class AIAgent:
|
||||
}
|
||||
content += f"<tool_call>\n{json.dumps(tool_call_json, ensure_ascii=False)}\n</tool_call>\n"
|
||||
|
||||
# Ensure every gpt turn has a <think> block (empty if no reasoning)
|
||||
# so the format is consistent for training data
|
||||
if "<think>" not in content:
|
||||
content = "<think>\n</think>\n" + content
|
||||
|
||||
trajectory.append({
|
||||
"from": "gpt",
|
||||
"value": content.rstrip()
|
||||
@@ -1206,11 +1269,18 @@ class AIAgent:
|
||||
# Add <think> tags around reasoning for trajectory storage
|
||||
content = ""
|
||||
|
||||
# Prepend reasoning in <think> tags if available
|
||||
# Prepend reasoning in <think> tags if available (native thinking tokens)
|
||||
if msg.get("reasoning") and msg["reasoning"].strip():
|
||||
content = f"<think>\n{msg['reasoning']}\n</think>\n"
|
||||
|
||||
content += msg["content"] or ""
|
||||
# Convert any <REASONING_SCRATCHPAD> tags to <think> tags
|
||||
# (used when native thinking is disabled and model reasons via XML)
|
||||
raw_content = msg["content"] or ""
|
||||
content += self._convert_scratchpad_to_think(raw_content)
|
||||
|
||||
# Ensure every gpt turn has a <think> block (empty if no reasoning)
|
||||
if "<think>" not in content:
|
||||
content = "<think>\n</think>\n" + content
|
||||
|
||||
trajectory.append({
|
||||
"from": "gpt",
|
||||
@@ -1261,6 +1331,66 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
print(f"⚠️ Failed to save trajectory: {e}")
|
||||
|
||||
def _log_api_payload(self, turn_number: int, api_kwargs: Dict[str, Any], response=None):
|
||||
"""
|
||||
[TEMPORARY DEBUG] Log the full API payload and response token metrics
|
||||
for each agent turn to a per-session JSONL file for inspection.
|
||||
|
||||
Writes one JSON line per turn to logs/payload_<session_id>.jsonl.
|
||||
Tool schemas are summarized (just names) to keep logs readable.
|
||||
|
||||
Args:
|
||||
turn_number: Which API call this is (1-indexed)
|
||||
api_kwargs: The full kwargs dict being passed to chat.completions.create
|
||||
response: The API response object (optional, added after the call completes)
|
||||
"""
|
||||
try:
|
||||
payload_log_file = self.logs_dir / f"payload_{self.session_id}.jsonl"
|
||||
|
||||
# Build a serializable copy of the request payload
|
||||
payload = {
|
||||
"turn": turn_number,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": api_kwargs.get("model"),
|
||||
"max_tokens": api_kwargs.get("max_tokens"),
|
||||
"extra_body": api_kwargs.get("extra_body"),
|
||||
"num_tools": len(api_kwargs.get("tools") or []),
|
||||
"tool_names": [t["function"]["name"] for t in (api_kwargs.get("tools") or [])],
|
||||
"messages": api_kwargs.get("messages", []),
|
||||
}
|
||||
|
||||
# Add response token metrics if available
|
||||
if response is not None:
|
||||
try:
|
||||
usage_raw = response.usage.model_dump() if hasattr(response.usage, 'model_dump') else {}
|
||||
payload["response"] = {
|
||||
# Core token counts
|
||||
"prompt_tokens": usage_raw.get("prompt_tokens"),
|
||||
"completion_tokens": usage_raw.get("completion_tokens"),
|
||||
"total_tokens": usage_raw.get("total_tokens"),
|
||||
# Completion breakdown (reasoning tokens, etc.)
|
||||
"completion_tokens_details": usage_raw.get("completion_tokens_details"),
|
||||
# Prompt breakdown (cached tokens, etc.)
|
||||
"prompt_tokens_details": usage_raw.get("prompt_tokens_details"),
|
||||
# Cost tracking
|
||||
"cost": usage_raw.get("cost"),
|
||||
"is_byok": usage_raw.get("is_byok"),
|
||||
"cost_details": usage_raw.get("cost_details"),
|
||||
# Provider info (top-level field from OpenRouter)
|
||||
"provider": getattr(response, 'provider', None),
|
||||
"response_model": getattr(response, 'model', None),
|
||||
}
|
||||
except Exception:
|
||||
payload["response"] = {"error": "failed to extract usage"}
|
||||
|
||||
with open(payload_log_file, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(payload, ensure_ascii=False, default=str) + "\n")
|
||||
|
||||
except Exception as e:
|
||||
# Silent fail - don't interrupt the agent for debug logging
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to log API payload: {e}")
|
||||
|
||||
def _save_session_log(self, messages: List[Dict[str, Any]] = None):
|
||||
"""
|
||||
Save the current session trajectory to the logs directory.
|
||||
@@ -1276,10 +1406,12 @@ class AIAgent:
|
||||
return
|
||||
|
||||
try:
|
||||
# Extract the first user message for the trajectory format
|
||||
# The first message should be the user's initial query
|
||||
# Extract the actual user query for the trajectory format.
|
||||
# Skip prefill messages (they're ephemeral and shouldn't appear in trajectories)
|
||||
# so the first user message we find is the real task prompt.
|
||||
first_user_query = ""
|
||||
for msg in messages:
|
||||
start_idx = len(self.prefill_messages) if self.prefill_messages else 0
|
||||
for msg in messages[start_idx:]:
|
||||
if msg.get("role") == "user":
|
||||
first_user_query = msg.get("content", "")
|
||||
break
|
||||
@@ -1373,6 +1505,12 @@ class AIAgent:
|
||||
# Initialize conversation
|
||||
messages = conversation_history or []
|
||||
|
||||
# Inject prefill messages at the start of conversation (before user's actual prompt)
|
||||
# This is used for few-shot priming, e.g., a greeting exchange to set response style
|
||||
if self.prefill_messages and not conversation_history:
|
||||
for prefill_msg in self.prefill_messages:
|
||||
messages.append(prefill_msg.copy())
|
||||
|
||||
# Add user message
|
||||
messages.append({
|
||||
"role": "user",
|
||||
@@ -1442,6 +1580,16 @@ class AIAgent:
|
||||
if active_system_prompt:
|
||||
# Insert system message at the beginning
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
if os.getenv("HERMES_DEBUG_OPENAI_REQUEST") == "1":
|
||||
meta = {
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
"messages": api_messages,
|
||||
"tools": self.tools if self.tools else None,
|
||||
}
|
||||
print("\n=== HERMES_DEBUG_OPENAI_REQUEST ===", flush=True)
|
||||
print(json.dumps(meta, ensure_ascii=False, indent=2)[:200_000], flush=True)
|
||||
|
||||
# Calculate approximate request size for logging
|
||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||
@@ -1455,12 +1603,13 @@ class AIAgent:
|
||||
print(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)")
|
||||
print(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}")
|
||||
else:
|
||||
# Animated thinking spinner in quiet mode
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner.start()
|
||||
# Animated thinking spinner in quiet mode (disable for wrappers/non-TTY usage)
|
||||
if os.getenv("HERMES_DISABLE_SPINNER") != "1":
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner.start()
|
||||
|
||||
# Log request details if verbose
|
||||
if self.verbose_logging:
|
||||
@@ -1493,6 +1642,10 @@ class AIAgent:
|
||||
"timeout": 600.0 # 10 minute timeout for very long responses
|
||||
}
|
||||
|
||||
# Add max_tokens if configured (overrides model default)
|
||||
if self.max_tokens is not None:
|
||||
api_kwargs["max_tokens"] = self.max_tokens
|
||||
|
||||
# Add extra_body for OpenRouter (provider preferences + reasoning)
|
||||
extra_body = {}
|
||||
|
||||
@@ -1500,17 +1653,30 @@ class AIAgent:
|
||||
if provider_preferences:
|
||||
extra_body["provider"] = provider_preferences
|
||||
|
||||
# Enable reasoning with xhigh effort for OpenRouter
|
||||
# Configure reasoning for OpenRouter
|
||||
# If reasoning_config is explicitly provided, use it (allows disabling/customizing)
|
||||
# Otherwise, default to xhigh effort for OpenRouter models
|
||||
if "openrouter" in self.base_url.lower():
|
||||
extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
if self.reasoning_config is not None:
|
||||
extra_body["reasoning"] = self.reasoning_config
|
||||
else:
|
||||
extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
|
||||
if extra_body:
|
||||
api_kwargs["extra_body"] = extra_body
|
||||
|
||||
response = self.client.chat.completions.create(**api_kwargs)
|
||||
|
||||
if os.getenv("HERMES_DEBUG_OPENAI_RESPONSE") == "1":
|
||||
try:
|
||||
dumped = response.model_dump()
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
print("\n=== HERMES_DEBUG_OPENAI_RESPONSE: ChatCompletion (raw) ===", flush=True)
|
||||
print(json.dumps(dumped, ensure_ascii=False, indent=2), flush=True)
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
@@ -1527,6 +1693,9 @@ class AIAgent:
|
||||
# Log response with provider info if available
|
||||
resp_model = getattr(response, 'model', 'N/A') if response else 'N/A'
|
||||
logging.debug(f"API Response received - Model: {resp_model}, Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}")
|
||||
|
||||
# [DEBUG] Log the full API payload + response token metrics
|
||||
self._log_api_payload(api_call_count, api_kwargs, response=response)
|
||||
|
||||
# Validate response has valid choices before proceeding
|
||||
if response is None or not hasattr(response, 'choices') or response.choices is None or len(response.choices) == 0:
|
||||
@@ -1589,7 +1758,20 @@ class AIAgent:
|
||||
wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s
|
||||
print(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...")
|
||||
logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Sleep in small increments to stay responsive to interrupts
|
||||
sleep_end = time.time() + wait_time
|
||||
while time.time() < sleep_end:
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
}
|
||||
time.sleep(0.2)
|
||||
continue # Retry the API call
|
||||
|
||||
# Check finish_reason before proceeding
|
||||
@@ -1668,6 +1850,41 @@ class AIAgent:
|
||||
print(f"{self.log_prefix} 📝 Error: {str(api_error)[:200]}")
|
||||
print(f"{self.log_prefix} 📊 Request context: {len(api_messages)} messages, ~{approx_tokens:,} tokens, {len(self.tools) if self.tools else 0} tools")
|
||||
|
||||
# Check for interrupt before deciding to retry
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during error handling, aborting retries.")
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
}
|
||||
|
||||
# Check for non-retryable client errors (4xx HTTP status codes).
|
||||
# These indicate a problem with the request itself (bad model ID,
|
||||
# invalid API key, forbidden, etc.) and will never succeed on retry.
|
||||
is_client_error = any(phrase in error_msg for phrase in [
|
||||
'error code: 400', 'error code: 401', 'error code: 403',
|
||||
'error code: 404', 'error code: 422',
|
||||
'is not a valid model', 'invalid model', 'model not found',
|
||||
'invalid api key', 'invalid_api_key', 'authentication',
|
||||
'unauthorized', 'forbidden', 'not found',
|
||||
])
|
||||
|
||||
if is_client_error:
|
||||
print(f"{self.log_prefix}❌ Non-retryable client error detected. Aborting immediately.")
|
||||
print(f"{self.log_prefix} 💡 This type of error won't be fixed by retrying.")
|
||||
logging.error(f"{self.log_prefix}Non-retryable client error: {api_error}")
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"failed": True,
|
||||
"error": str(api_error),
|
||||
}
|
||||
|
||||
# Check for non-retryable errors (context length exceeded)
|
||||
is_context_length_error = any(phrase in error_msg for phrase in [
|
||||
'context length', 'maximum context', 'token limit',
|
||||
@@ -1708,7 +1925,21 @@ class AIAgent:
|
||||
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
||||
print(f"⏳ Retrying in {wait_time}s...")
|
||||
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
||||
time.sleep(wait_time)
|
||||
|
||||
# Sleep in small increments so we can respond to interrupts quickly
|
||||
# instead of blocking the entire wait_time in one sleep() call
|
||||
sleep_end = time.time() + wait_time
|
||||
while time.time() < sleep_end:
|
||||
if self._interrupt_requested:
|
||||
print(f"{self.log_prefix}⚡ Interrupt detected during retry wait, aborting.")
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"interrupted": True,
|
||||
}
|
||||
time.sleep(0.2) # Check interrupt every 200ms
|
||||
|
||||
try:
|
||||
assistant_message = response.choices[0].message
|
||||
@@ -1717,6 +1948,48 @@ class AIAgent:
|
||||
if assistant_message.content and not self.quiet_mode:
|
||||
print(f"{self.log_prefix}🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}")
|
||||
|
||||
# Check for incomplete <REASONING_SCRATCHPAD> (opened but never closed)
|
||||
# This means the model ran out of output tokens mid-reasoning — retry up to 2 times
|
||||
if self._has_incomplete_scratchpad(assistant_message.content or ""):
|
||||
if not hasattr(self, '_incomplete_scratchpad_retries'):
|
||||
self._incomplete_scratchpad_retries = 0
|
||||
self._incomplete_scratchpad_retries += 1
|
||||
|
||||
print(f"{self.log_prefix}⚠️ Incomplete <REASONING_SCRATCHPAD> detected (opened but never closed)")
|
||||
|
||||
if self._incomplete_scratchpad_retries <= 2:
|
||||
print(f"{self.log_prefix}🔄 Retrying API call ({self._incomplete_scratchpad_retries}/2)...")
|
||||
# Don't add the broken message, just retry
|
||||
continue
|
||||
else:
|
||||
# Max retries - discard this turn and save as partial
|
||||
print(f"{self.log_prefix}❌ Max retries (2) for incomplete scratchpad. Saving as partial.")
|
||||
self._incomplete_scratchpad_retries = 0
|
||||
|
||||
rolled_back_messages = self._get_messages_up_to_last_assistant(messages)
|
||||
|
||||
try:
|
||||
cleanup_vm(effective_task_id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
cleanup_browser(effective_task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"final_response": None,
|
||||
"messages": rolled_back_messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"partial": True,
|
||||
"error": "Incomplete REASONING_SCRATCHPAD after 2 retries"
|
||||
}
|
||||
|
||||
# Reset incomplete scratchpad counter on clean response
|
||||
if hasattr(self, '_incomplete_scratchpad_retries'):
|
||||
self._incomplete_scratchpad_retries = 0
|
||||
|
||||
# Check for tool calls
|
||||
if assistant_message.tool_calls:
|
||||
if not self.quiet_mode:
|
||||
@@ -1882,7 +2155,7 @@ class AIAgent:
|
||||
tool_start_time = time.time()
|
||||
|
||||
# Execute the tool - with animated spinner in quiet mode
|
||||
if self.quiet_mode:
|
||||
if self.quiet_mode and os.getenv("HERMES_DISABLE_SPINNER") != "1":
|
||||
# Tool-specific spinner animations
|
||||
tool_spinners = {
|
||||
'web_search': ('arrows', ['🔍', '🌐', '📡', '🔎']),
|
||||
@@ -1912,6 +2185,9 @@ class AIAgent:
|
||||
tool_duration = time.time() - tool_start_time
|
||||
cute_msg = self._get_cute_tool_message(function_name, function_args, tool_duration)
|
||||
spinner.stop(cute_msg)
|
||||
elif self.quiet_mode:
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
else:
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
@@ -2069,13 +2345,28 @@ class AIAgent:
|
||||
if self.ephemeral_system_prompt:
|
||||
api_messages = [{"role": "system", "content": self.ephemeral_system_prompt}] + api_messages
|
||||
|
||||
summary_response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
# Build extra_body for summary call (same reasoning config as main loop)
|
||||
summary_extra_body = {}
|
||||
if "openrouter" in self.base_url.lower():
|
||||
if self.reasoning_config is not None:
|
||||
summary_extra_body["reasoning"] = self.reasoning_config
|
||||
else:
|
||||
summary_extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
}
|
||||
|
||||
summary_kwargs = {
|
||||
"model": self.model,
|
||||
"messages": api_messages,
|
||||
# No tools parameter - forces text response
|
||||
extra_headers=self.extra_headers,
|
||||
extra_body=self.extra_body,
|
||||
)
|
||||
}
|
||||
if self.max_tokens is not None:
|
||||
summary_kwargs["max_tokens"] = self.max_tokens
|
||||
if summary_extra_body:
|
||||
summary_kwargs["extra_body"] = summary_extra_body
|
||||
|
||||
summary_response = self.client.chat.completions.create(**summary_kwargs)
|
||||
|
||||
if summary_response.choices and summary_response.choices[0].message.content:
|
||||
final_response = summary_response.choices[0].message.content
|
||||
@@ -2151,7 +2442,7 @@ class AIAgent:
|
||||
|
||||
def main(
|
||||
query: str = None,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
model: str = "anthropic/claude-opus-4.6",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_turns: int = 10,
|
||||
@@ -2365,4 +2656,11 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
import fire # type: ignore
|
||||
except ModuleNotFoundError as exc:
|
||||
raise SystemExit(
|
||||
"Missing optional dependency 'fire'. Install hermes-agent with its CLI extras or add `fire` "
|
||||
f"to your environment. Original error: {exc}"
|
||||
) from exc
|
||||
fire.Fire(main)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Hermes Agent Installer for Windows
|
||||
# ============================================================================
|
||||
# Installation script for Windows (PowerShell).
|
||||
# Uses uv for fast Python provisioning and package management.
|
||||
#
|
||||
# Usage:
|
||||
# irm https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.ps1 | iex
|
||||
@@ -27,6 +28,7 @@ $ErrorActionPreference = "Stop"
|
||||
|
||||
$RepoUrlSsh = "git@github.com:NousResearch/hermes-agent.git"
|
||||
$RepoUrlHttps = "https://github.com/NousResearch/hermes-agent.git"
|
||||
$PythonVersion = "3.11"
|
||||
|
||||
# ============================================================================
|
||||
# Helper functions
|
||||
@@ -52,12 +54,12 @@ function Write-Success {
|
||||
Write-Host "✓ $Message" -ForegroundColor Green
|
||||
}
|
||||
|
||||
function Write-Warning {
|
||||
function Write-Warn {
|
||||
param([string]$Message)
|
||||
Write-Host "⚠ $Message" -ForegroundColor Yellow
|
||||
}
|
||||
|
||||
function Write-Error {
|
||||
function Write-Err {
|
||||
param([string]$Message)
|
||||
Write-Host "✗ $Message" -ForegroundColor Red
|
||||
}
|
||||
@@ -66,33 +68,93 @@ function Write-Error {
|
||||
# Dependency checks
|
||||
# ============================================================================
|
||||
|
||||
function Test-Python {
|
||||
Write-Info "Checking Python..."
|
||||
function Install-Uv {
|
||||
Write-Info "Checking for uv package manager..."
|
||||
|
||||
# Try different python commands
|
||||
$pythonCmds = @("python3", "python", "py -3")
|
||||
# Check if uv is already available
|
||||
if (Get-Command uv -ErrorAction SilentlyContinue) {
|
||||
$version = uv --version
|
||||
$script:UvCmd = "uv"
|
||||
Write-Success "uv found ($version)"
|
||||
return $true
|
||||
}
|
||||
|
||||
foreach ($cmd in $pythonCmds) {
|
||||
try {
|
||||
$version = & $cmd.Split()[0] $cmd.Split()[1..99] -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')" 2>$null
|
||||
if ($version) {
|
||||
$major, $minor = $version.Split('.')
|
||||
if ([int]$major -ge 3 -and [int]$minor -ge 10) {
|
||||
$script:PythonCmd = $cmd
|
||||
Write-Success "Python $version found"
|
||||
return $true
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
# Try next command
|
||||
# Check common install locations
|
||||
$uvPaths = @(
|
||||
"$env:USERPROFILE\.local\bin\uv.exe",
|
||||
"$env:USERPROFILE\.cargo\bin\uv.exe"
|
||||
)
|
||||
foreach ($uvPath in $uvPaths) {
|
||||
if (Test-Path $uvPath) {
|
||||
$script:UvCmd = $uvPath
|
||||
$version = & $uvPath --version
|
||||
Write-Success "uv found at $uvPath ($version)"
|
||||
return $true
|
||||
}
|
||||
}
|
||||
|
||||
Write-Error "Python 3.10+ not found"
|
||||
Write-Info "Please install Python 3.10 or newer from:"
|
||||
Write-Info " https://www.python.org/downloads/"
|
||||
Write-Info ""
|
||||
Write-Info "Make sure to check 'Add Python to PATH' during installation"
|
||||
# Install uv
|
||||
Write-Info "Installing uv (fast Python package manager)..."
|
||||
try {
|
||||
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" 2>&1 | Out-Null
|
||||
|
||||
# Find the installed binary
|
||||
$uvExe = "$env:USERPROFILE\.local\bin\uv.exe"
|
||||
if (-not (Test-Path $uvExe)) {
|
||||
$uvExe = "$env:USERPROFILE\.cargo\bin\uv.exe"
|
||||
}
|
||||
if (-not (Test-Path $uvExe)) {
|
||||
# Refresh PATH and try again
|
||||
$env:Path = [Environment]::GetEnvironmentVariable("Path", "User") + ";" + [Environment]::GetEnvironmentVariable("Path", "Machine")
|
||||
if (Get-Command uv -ErrorAction SilentlyContinue) {
|
||||
$uvExe = (Get-Command uv).Source
|
||||
}
|
||||
}
|
||||
|
||||
if (Test-Path $uvExe) {
|
||||
$script:UvCmd = $uvExe
|
||||
$version = & $uvExe --version
|
||||
Write-Success "uv installed ($version)"
|
||||
return $true
|
||||
}
|
||||
|
||||
Write-Err "uv installed but not found on PATH"
|
||||
Write-Info "Try restarting your terminal and re-running"
|
||||
return $false
|
||||
} catch {
|
||||
Write-Err "Failed to install uv"
|
||||
Write-Info "Install manually: https://docs.astral.sh/uv/getting-started/installation/"
|
||||
return $false
|
||||
}
|
||||
}
|
||||
|
||||
function Test-Python {
|
||||
Write-Info "Checking Python $PythonVersion..."
|
||||
|
||||
# Let uv find or install Python
|
||||
try {
|
||||
$pythonPath = & $UvCmd python find $PythonVersion 2>$null
|
||||
if ($pythonPath) {
|
||||
$ver = & $pythonPath --version 2>$null
|
||||
Write-Success "Python found: $ver"
|
||||
return $true
|
||||
}
|
||||
} catch { }
|
||||
|
||||
# Python not found — use uv to install it (no admin needed!)
|
||||
Write-Info "Python $PythonVersion not found, installing via uv..."
|
||||
try {
|
||||
& $UvCmd python install $PythonVersion 2>&1 | Out-Null
|
||||
$pythonPath = & $UvCmd python find $PythonVersion 2>$null
|
||||
if ($pythonPath) {
|
||||
$ver = & $pythonPath --version 2>$null
|
||||
Write-Success "Python installed: $ver"
|
||||
return $true
|
||||
}
|
||||
} catch { }
|
||||
|
||||
Write-Err "Failed to install Python $PythonVersion"
|
||||
Write-Info "Install Python $PythonVersion manually, then re-run this script"
|
||||
return $false
|
||||
}
|
||||
|
||||
@@ -105,7 +167,7 @@ function Test-Git {
|
||||
return $true
|
||||
}
|
||||
|
||||
Write-Error "Git not found"
|
||||
Write-Err "Git not found"
|
||||
Write-Info "Please install Git from:"
|
||||
Write-Info " https://git-scm.com/download/win"
|
||||
return $false
|
||||
@@ -121,7 +183,7 @@ function Test-Node {
|
||||
return $true
|
||||
}
|
||||
|
||||
Write-Warning "Node.js not found (browser tools will be limited)"
|
||||
Write-Warn "Node.js not found (browser tools will be limited)"
|
||||
Write-Info "To install Node.js (optional):"
|
||||
Write-Info " https://nodejs.org/en/download/"
|
||||
$script:HasNode = $false
|
||||
@@ -138,7 +200,7 @@ function Test-Ripgrep {
|
||||
return $true
|
||||
}
|
||||
|
||||
Write-Warning "ripgrep not found (file search will use findstr fallback)"
|
||||
Write-Warn "ripgrep not found (file search will use findstr fallback)"
|
||||
|
||||
# Check what package managers are available
|
||||
$hasWinget = Get-Command winget -ErrorAction SilentlyContinue
|
||||
@@ -185,7 +247,7 @@ function Test-Ripgrep {
|
||||
} catch { }
|
||||
}
|
||||
|
||||
Write-Warning "Auto-install failed. You can install manually:"
|
||||
Write-Warn "Auto-install failed. You can install manually:"
|
||||
} else {
|
||||
Write-Info "Skipping ripgrep installation. To install manually:"
|
||||
}
|
||||
@@ -216,13 +278,12 @@ function Install-Repository {
|
||||
git pull origin $Branch
|
||||
Pop-Location
|
||||
} else {
|
||||
Write-Error "Directory exists but is not a git repository: $InstallDir"
|
||||
Write-Err "Directory exists but is not a git repository: $InstallDir"
|
||||
Write-Info "Remove it or choose a different directory with -InstallDir"
|
||||
exit 1
|
||||
}
|
||||
} else {
|
||||
# Try SSH first (for private repo access), fall back to HTTPS
|
||||
# Use --recurse-submodules to also clone mini-swe-agent and tinker-atropos
|
||||
Write-Info "Trying SSH clone..."
|
||||
$sshResult = git clone --branch $Branch --recurse-submodules $RepoUrlSsh $InstallDir 2>&1
|
||||
|
||||
@@ -235,7 +296,7 @@ function Install-Repository {
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "Cloned via HTTPS"
|
||||
} else {
|
||||
Write-Error "Failed to clone repository"
|
||||
Write-Err "Failed to clone repository"
|
||||
Write-Info "For private repo access, ensure your SSH key is added to GitHub:"
|
||||
Write-Info " ssh-add ~/.ssh/id_rsa"
|
||||
Write-Info " ssh -T git@github.com # Test connection"
|
||||
@@ -244,7 +305,7 @@ function Install-Repository {
|
||||
}
|
||||
}
|
||||
|
||||
# Ensure submodules are initialized and updated (for existing installs or if --recurse failed)
|
||||
# Ensure submodules are initialized and updated
|
||||
Write-Info "Initializing submodules (mini-swe-agent, tinker-atropos)..."
|
||||
Push-Location $InstallDir
|
||||
git submodule update --init --recursive
|
||||
@@ -260,23 +321,21 @@ function Install-Venv {
|
||||
return
|
||||
}
|
||||
|
||||
Write-Info "Creating virtual environment..."
|
||||
Write-Info "Creating virtual environment with Python $PythonVersion..."
|
||||
|
||||
Push-Location $InstallDir
|
||||
|
||||
if (-not (Test-Path "venv")) {
|
||||
& $PythonCmd -m venv venv
|
||||
if (Test-Path "venv") {
|
||||
Write-Info "Virtual environment already exists, recreating..."
|
||||
Remove-Item -Recurse -Force "venv"
|
||||
}
|
||||
|
||||
# Activate
|
||||
& .\venv\Scripts\Activate.ps1
|
||||
|
||||
# Upgrade pip
|
||||
pip install --upgrade pip wheel setuptools | Out-Null
|
||||
# uv creates the venv and pins the Python version in one step
|
||||
& $UvCmd venv venv --python $PythonVersion
|
||||
|
||||
Pop-Location
|
||||
|
||||
Write-Success "Virtual environment ready"
|
||||
Write-Success "Virtual environment ready (Python $PythonVersion)"
|
||||
}
|
||||
|
||||
function Install-Dependencies {
|
||||
@@ -285,14 +344,15 @@ function Install-Dependencies {
|
||||
Push-Location $InstallDir
|
||||
|
||||
if (-not $NoVenv) {
|
||||
& .\venv\Scripts\Activate.ps1
|
||||
# Tell uv to install into our venv (no activation needed)
|
||||
$env:VIRTUAL_ENV = "$InstallDir\venv"
|
||||
}
|
||||
|
||||
# Install main package
|
||||
# Install main package with all extras
|
||||
try {
|
||||
pip install -e ".[all]" 2>&1 | Out-Null
|
||||
& $UvCmd pip install -e ".[all]" 2>&1 | Out-Null
|
||||
} catch {
|
||||
pip install -e "." | Out-Null
|
||||
& $UvCmd pip install -e "." | Out-Null
|
||||
}
|
||||
|
||||
Write-Success "Main package installed"
|
||||
@@ -301,25 +361,25 @@ function Install-Dependencies {
|
||||
Write-Info "Installing mini-swe-agent (terminal tool backend)..."
|
||||
if (Test-Path "mini-swe-agent\pyproject.toml") {
|
||||
try {
|
||||
pip install -e ".\mini-swe-agent" 2>&1 | Out-Null
|
||||
& $UvCmd pip install -e ".\mini-swe-agent" 2>&1 | Out-Null
|
||||
Write-Success "mini-swe-agent installed"
|
||||
} catch {
|
||||
Write-Warning "mini-swe-agent install failed (terminal tools may not work)"
|
||||
Write-Warn "mini-swe-agent install failed (terminal tools may not work)"
|
||||
}
|
||||
} else {
|
||||
Write-Warning "mini-swe-agent not found (run: git submodule update --init)"
|
||||
Write-Warn "mini-swe-agent not found (run: git submodule update --init)"
|
||||
}
|
||||
|
||||
Write-Info "Installing tinker-atropos (RL training backend)..."
|
||||
if (Test-Path "tinker-atropos\pyproject.toml") {
|
||||
try {
|
||||
pip install -e ".\tinker-atropos" 2>&1 | Out-Null
|
||||
& $UvCmd pip install -e ".\tinker-atropos" 2>&1 | Out-Null
|
||||
Write-Success "tinker-atropos installed"
|
||||
} catch {
|
||||
Write-Warning "tinker-atropos install failed (RL tools may not work)"
|
||||
Write-Warn "tinker-atropos install failed (RL tools may not work)"
|
||||
}
|
||||
} else {
|
||||
Write-Warning "tinker-atropos not found (run: git submodule update --init)"
|
||||
Write-Warn "tinker-atropos not found (run: git submodule update --init)"
|
||||
}
|
||||
|
||||
Pop-Location
|
||||
@@ -328,41 +388,44 @@ function Install-Dependencies {
|
||||
}
|
||||
|
||||
function Set-PathVariable {
|
||||
Write-Info "Setting up PATH..."
|
||||
Write-Info "Setting up hermes command..."
|
||||
|
||||
if ($NoVenv) {
|
||||
$binDir = "$InstallDir"
|
||||
$hermesBin = "$InstallDir"
|
||||
} else {
|
||||
$binDir = "$InstallDir\venv\Scripts"
|
||||
$hermesBin = "$InstallDir\venv\Scripts"
|
||||
}
|
||||
|
||||
# Add to user PATH
|
||||
# Add the venv Scripts dir to user PATH so hermes is globally available
|
||||
# On Windows, the hermes.exe in venv\Scripts\ has the venv Python baked in
|
||||
$currentPath = [Environment]::GetEnvironmentVariable("Path", "User")
|
||||
|
||||
if ($currentPath -notlike "*$binDir*") {
|
||||
if ($currentPath -notlike "*$hermesBin*") {
|
||||
[Environment]::SetEnvironmentVariable(
|
||||
"Path",
|
||||
"$binDir;$currentPath",
|
||||
"$hermesBin;$currentPath",
|
||||
"User"
|
||||
)
|
||||
Write-Success "Added to user PATH"
|
||||
Write-Success "Added to user PATH: $hermesBin"
|
||||
} else {
|
||||
Write-Info "PATH already configured"
|
||||
}
|
||||
|
||||
# Update current session
|
||||
$env:Path = "$binDir;$env:Path"
|
||||
$env:Path = "$hermesBin;$env:Path"
|
||||
|
||||
Write-Success "hermes command ready"
|
||||
}
|
||||
|
||||
function Copy-ConfigTemplates {
|
||||
Write-Info "Setting up configuration files..."
|
||||
|
||||
# Create ~/.hermes directory structure (config at top level, code in subdir)
|
||||
# Create ~/.hermes directory structure
|
||||
New-Item -ItemType Directory -Force -Path "$HermesHome\cron" | Out-Null
|
||||
New-Item -ItemType Directory -Force -Path "$HermesHome\sessions" | Out-Null
|
||||
New-Item -ItemType Directory -Force -Path "$HermesHome\logs" | Out-Null
|
||||
|
||||
# Create .env at ~/.hermes/.env (top level, easy to find)
|
||||
# Create .env
|
||||
$envPath = "$HermesHome\.env"
|
||||
if (-not (Test-Path $envPath)) {
|
||||
$examplePath = "$InstallDir\.env.example"
|
||||
@@ -370,7 +433,6 @@ function Copy-ConfigTemplates {
|
||||
Copy-Item $examplePath $envPath
|
||||
Write-Success "Created ~/.hermes/.env from template"
|
||||
} else {
|
||||
# Create empty .env if no example exists
|
||||
New-Item -ItemType File -Force -Path $envPath | Out-Null
|
||||
Write-Success "Created ~/.hermes/.env"
|
||||
}
|
||||
@@ -378,7 +440,7 @@ function Copy-ConfigTemplates {
|
||||
Write-Info "~/.hermes/.env already exists, keeping it"
|
||||
}
|
||||
|
||||
# Create config.yaml at ~/.hermes/config.yaml (top level, easy to find)
|
||||
# Create config.yaml
|
||||
$configPath = "$HermesHome\config.yaml"
|
||||
if (-not (Test-Path $configPath)) {
|
||||
$examplePath = "$InstallDir\cli-config.yaml.example"
|
||||
@@ -407,7 +469,7 @@ function Install-NodeDeps {
|
||||
npm install --silent 2>&1 | Out-Null
|
||||
Write-Success "Node.js dependencies installed"
|
||||
} catch {
|
||||
Write-Warning "npm install failed (browser tools may not work)"
|
||||
Write-Warn "npm install failed (browser tools may not work)"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -426,12 +488,13 @@ function Invoke-SetupWizard {
|
||||
|
||||
Push-Location $InstallDir
|
||||
|
||||
# Run hermes setup using the venv Python directly (no activation needed)
|
||||
if (-not $NoVenv) {
|
||||
& .\venv\Scripts\Activate.ps1
|
||||
& ".\venv\Scripts\python.exe" -m hermes_cli.main setup
|
||||
} else {
|
||||
python -m hermes_cli.main setup
|
||||
}
|
||||
|
||||
python -m hermes_cli.main setup
|
||||
|
||||
Pop-Location
|
||||
}
|
||||
|
||||
@@ -478,7 +541,6 @@ function Write-Completion {
|
||||
Write-Host "⚡ Restart your terminal for PATH changes to take effect" -ForegroundColor Yellow
|
||||
Write-Host ""
|
||||
|
||||
# Show notes about optional tools
|
||||
if (-not $HasNode) {
|
||||
Write-Host "Note: Node.js was not found. Browser automation tools" -ForegroundColor Yellow
|
||||
Write-Host "will have limited functionality." -ForegroundColor Yellow
|
||||
@@ -500,6 +562,7 @@ function Write-Completion {
|
||||
function Main {
|
||||
Write-Banner
|
||||
|
||||
if (-not (Install-Uv)) { exit 1 }
|
||||
if (-not (Test-Python)) { exit 1 }
|
||||
if (-not (Test-Git)) { exit 1 }
|
||||
Test-Node # Optional, doesn't fail
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# Hermes Agent Installer
|
||||
# ============================================================================
|
||||
# Installation script for Linux and macOS.
|
||||
# Uses uv for fast Python provisioning and package management.
|
||||
#
|
||||
# Usage:
|
||||
# curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
@@ -29,7 +30,7 @@ REPO_URL_SSH="git@github.com:NousResearch/hermes-agent.git"
|
||||
REPO_URL_HTTPS="https://github.com/NousResearch/hermes-agent.git"
|
||||
HERMES_HOME="$HOME/.hermes"
|
||||
INSTALL_DIR="${HERMES_INSTALL_DIR:-$HERMES_HOME/hermes-agent}"
|
||||
PYTHON_MIN_VERSION="3.10"
|
||||
PYTHON_VERSION="3.11"
|
||||
|
||||
# Options
|
||||
USE_VENV=true
|
||||
@@ -64,7 +65,7 @@ while [[ $# -gt 0 ]]; do
|
||||
echo " --no-venv Don't create virtual environment"
|
||||
echo " --skip-setup Skip interactive setup wizard"
|
||||
echo " --branch NAME Git branch to install (default: main)"
|
||||
echo " --dir PATH Installation directory (default: ~/.hermes-agent)"
|
||||
echo " --dir PATH Installation directory (default: ~/.hermes/hermes-agent)"
|
||||
echo " -h, --help Show this help"
|
||||
exit 0
|
||||
;;
|
||||
@@ -146,50 +147,80 @@ detect_os() {
|
||||
# Dependency checks
|
||||
# ============================================================================
|
||||
|
||||
check_python() {
|
||||
log_info "Checking Python..."
|
||||
install_uv() {
|
||||
log_info "Checking for uv package manager..."
|
||||
|
||||
# Try different python commands
|
||||
for cmd in python3.12 python3.11 python3.10 python3 python; do
|
||||
if command -v $cmd &> /dev/null; then
|
||||
PYTHON_CMD=$cmd
|
||||
PYTHON_VERSION=$($cmd -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
||||
|
||||
# Check version
|
||||
if python3 -c "import sys; exit(0 if sys.version_info >= (3, 10) else 1)" 2>/dev/null; then
|
||||
log_success "Python $PYTHON_VERSION found"
|
||||
return 0
|
||||
fi
|
||||
# Check common locations for uv
|
||||
if command -v uv &> /dev/null; then
|
||||
UV_CMD="uv"
|
||||
UV_VERSION=$($UV_CMD --version 2>/dev/null)
|
||||
log_success "uv found ($UV_VERSION)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Check ~/.local/bin (default uv install location) even if not on PATH yet
|
||||
if [ -x "$HOME/.local/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.local/bin/uv"
|
||||
UV_VERSION=$($UV_CMD --version 2>/dev/null)
|
||||
log_success "uv found at ~/.local/bin ($UV_VERSION)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Check ~/.cargo/bin (alternative uv install location)
|
||||
if [ -x "$HOME/.cargo/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.cargo/bin/uv"
|
||||
UV_VERSION=$($UV_CMD --version 2>/dev/null)
|
||||
log_success "uv found at ~/.cargo/bin ($UV_VERSION)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
# Install uv
|
||||
log_info "Installing uv (fast Python package manager)..."
|
||||
if curl -LsSf https://astral.sh/uv/install.sh | sh 2>/dev/null; then
|
||||
# uv installs to ~/.local/bin by default
|
||||
if [ -x "$HOME/.local/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.local/bin/uv"
|
||||
elif [ -x "$HOME/.cargo/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.cargo/bin/uv"
|
||||
elif command -v uv &> /dev/null; then
|
||||
UV_CMD="uv"
|
||||
else
|
||||
log_error "uv installed but not found on PATH"
|
||||
log_info "Try adding ~/.local/bin to your PATH and re-running"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
UV_VERSION=$($UV_CMD --version 2>/dev/null)
|
||||
log_success "uv installed ($UV_VERSION)"
|
||||
else
|
||||
log_error "Failed to install uv"
|
||||
log_info "Install manually: https://docs.astral.sh/uv/getting-started/installation/"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_python() {
|
||||
log_info "Checking Python $PYTHON_VERSION..."
|
||||
|
||||
log_error "Python 3.10+ not found"
|
||||
log_info "Please install Python 3.10 or newer:"
|
||||
# Let uv handle Python — it can download and manage Python versions
|
||||
# First check if a suitable Python is already available
|
||||
if $UV_CMD python find "$PYTHON_VERSION" &> /dev/null; then
|
||||
PYTHON_PATH=$($UV_CMD python find "$PYTHON_VERSION")
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
log_success "Python found: $PYTHON_FOUND_VERSION"
|
||||
return 0
|
||||
fi
|
||||
|
||||
case "$OS" in
|
||||
linux)
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian)
|
||||
log_info " sudo apt update && sudo apt install python3.11 python3.11-venv"
|
||||
;;
|
||||
fedora)
|
||||
log_info " sudo dnf install python3.11"
|
||||
;;
|
||||
arch)
|
||||
log_info " sudo pacman -S python"
|
||||
;;
|
||||
*)
|
||||
log_info " Use your package manager to install Python 3.10+"
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
macos)
|
||||
log_info " brew install python@3.11"
|
||||
log_info " Or download from https://www.python.org/downloads/"
|
||||
;;
|
||||
esac
|
||||
|
||||
exit 1
|
||||
# Python not found — use uv to install it (no sudo needed!)
|
||||
log_info "Python $PYTHON_VERSION not found, installing via uv..."
|
||||
if $UV_CMD python install "$PYTHON_VERSION"; then
|
||||
PYTHON_PATH=$($UV_CMD python find "$PYTHON_VERSION")
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
log_success "Python installed: $PYTHON_FOUND_VERSION"
|
||||
else
|
||||
log_error "Failed to install Python $PYTHON_VERSION"
|
||||
log_info "Install Python $PYTHON_VERSION manually, then re-run this script"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
check_git() {
|
||||
@@ -294,7 +325,6 @@ check_ripgrep() {
|
||||
# Check if we can use sudo
|
||||
CAN_SUDO=false
|
||||
if command -v sudo &> /dev/null; then
|
||||
# Check if user has sudo access (without actually running sudo)
|
||||
if sudo -n true 2>/dev/null || sudo -v 2>/dev/null; then
|
||||
CAN_SUDO=true
|
||||
fi
|
||||
@@ -328,7 +358,6 @@ check_ripgrep() {
|
||||
esac
|
||||
else
|
||||
log_warn "sudo not available - cannot auto-install system packages"
|
||||
# Try cargo as fallback if available
|
||||
if command -v cargo &> /dev/null; then
|
||||
log_info "Trying cargo install (no sudo required)..."
|
||||
if cargo install ripgrep 2>/dev/null; then
|
||||
@@ -371,7 +400,6 @@ check_ripgrep() {
|
||||
log_info " https://github.com/BurntSushi/ripgrep#installation"
|
||||
;;
|
||||
esac
|
||||
# Show cargo alternative for users without sudo
|
||||
if command -v cargo &> /dev/null; then
|
||||
log_info " Or without sudo: cargo install ripgrep"
|
||||
fi
|
||||
@@ -440,39 +468,36 @@ setup_venv() {
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_info "Creating virtual environment..."
|
||||
log_info "Creating virtual environment with Python $PYTHON_VERSION..."
|
||||
|
||||
if [ -d "venv" ]; then
|
||||
log_info "Virtual environment already exists"
|
||||
else
|
||||
$PYTHON_CMD -m venv venv
|
||||
log_info "Virtual environment already exists, recreating..."
|
||||
rm -rf venv
|
||||
fi
|
||||
|
||||
# Activate
|
||||
source venv/bin/activate
|
||||
# uv creates the venv and pins the Python version in one step
|
||||
$UV_CMD venv venv --python "$PYTHON_VERSION"
|
||||
|
||||
# Upgrade pip
|
||||
pip install --upgrade pip wheel setuptools > /dev/null
|
||||
|
||||
log_success "Virtual environment ready"
|
||||
log_success "Virtual environment ready (Python $PYTHON_VERSION)"
|
||||
}
|
||||
|
||||
install_deps() {
|
||||
log_info "Installing dependencies..."
|
||||
|
||||
if [ "$USE_VENV" = true ]; then
|
||||
source venv/bin/activate
|
||||
# Tell uv to install into our venv (no need to activate)
|
||||
export VIRTUAL_ENV="$INSTALL_DIR/venv"
|
||||
fi
|
||||
|
||||
# Install the main package in editable mode with all extras
|
||||
pip install -e ".[all]" > /dev/null 2>&1 || pip install -e "." > /dev/null
|
||||
$UV_CMD pip install -e ".[all]" || $UV_CMD pip install -e "."
|
||||
|
||||
log_success "Main package installed"
|
||||
|
||||
# Install submodules
|
||||
log_info "Installing mini-swe-agent (terminal tool backend)..."
|
||||
if [ -d "mini-swe-agent" ] && [ -f "mini-swe-agent/pyproject.toml" ]; then
|
||||
pip install -e "./mini-swe-agent" > /dev/null 2>&1 || log_warn "mini-swe-agent install failed (terminal tools may not work)"
|
||||
$UV_CMD pip install -e "./mini-swe-agent" || log_warn "mini-swe-agent install failed (terminal tools may not work)"
|
||||
log_success "mini-swe-agent installed"
|
||||
else
|
||||
log_warn "mini-swe-agent not found (run: git submodule update --init)"
|
||||
@@ -480,7 +505,7 @@ install_deps() {
|
||||
|
||||
log_info "Installing tinker-atropos (RL training backend)..."
|
||||
if [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then
|
||||
pip install -e "./tinker-atropos" > /dev/null 2>&1 || log_warn "tinker-atropos install failed (RL tools may not work)"
|
||||
$UV_CMD pip install -e "./tinker-atropos" || log_warn "tinker-atropos install failed (RL tools may not work)"
|
||||
log_success "tinker-atropos installed"
|
||||
else
|
||||
log_warn "tinker-atropos not found (run: git submodule update --init)"
|
||||
@@ -490,53 +515,56 @@ install_deps() {
|
||||
}
|
||||
|
||||
setup_path() {
|
||||
log_info "Setting up PATH..."
|
||||
log_info "Setting up hermes command..."
|
||||
|
||||
# Determine the bin directory
|
||||
if [ "$USE_VENV" = true ]; then
|
||||
BIN_DIR="$INSTALL_DIR/venv/bin"
|
||||
HERMES_BIN="$INSTALL_DIR/venv/bin/hermes"
|
||||
else
|
||||
BIN_DIR="$HOME/.local/bin"
|
||||
mkdir -p "$BIN_DIR"
|
||||
HERMES_BIN="$(which hermes 2>/dev/null || echo "")"
|
||||
if [ -z "$HERMES_BIN" ]; then
|
||||
log_warn "hermes not found on PATH after install"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
|
||||
# Create symlink in ~/.local/bin (standard user binary location, usually on PATH)
|
||||
mkdir -p "$HOME/.local/bin"
|
||||
ln -sf "$HERMES_BIN" "$HOME/.local/bin/hermes"
|
||||
log_success "Symlinked hermes → ~/.local/bin/hermes"
|
||||
|
||||
# Check if ~/.local/bin is on PATH; if not, add it to shell config
|
||||
if ! echo "$PATH" | tr ':' '\n' | grep -q "^$HOME/.local/bin$"; then
|
||||
SHELL_CONFIG=""
|
||||
if [ -n "$BASH_VERSION" ]; then
|
||||
if [ -f "$HOME/.bashrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.bashrc"
|
||||
elif [ -f "$HOME/.bash_profile" ]; then
|
||||
SHELL_CONFIG="$HOME/.bash_profile"
|
||||
fi
|
||||
elif [ -n "$ZSH_VERSION" ] || [ -f "$HOME/.zshrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.zshrc"
|
||||
fi
|
||||
|
||||
# Create a wrapper script
|
||||
cat > "$BIN_DIR/hermes" << EOF
|
||||
#!/bin/bash
|
||||
cd "$INSTALL_DIR"
|
||||
exec python -m hermes_cli.main "\$@"
|
||||
EOF
|
||||
chmod +x "$BIN_DIR/hermes"
|
||||
fi
|
||||
|
||||
# Add to PATH in shell config
|
||||
SHELL_CONFIG=""
|
||||
if [ -n "$BASH_VERSION" ]; then
|
||||
if [ -f "$HOME/.bashrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.bashrc"
|
||||
elif [ -f "$HOME/.bash_profile" ]; then
|
||||
SHELL_CONFIG="$HOME/.bash_profile"
|
||||
PATH_LINE='export PATH="$HOME/.local/bin:$PATH"'
|
||||
|
||||
if [ -n "$SHELL_CONFIG" ]; then
|
||||
if ! grep -q '\.local/bin' "$SHELL_CONFIG" 2>/dev/null; then
|
||||
echo "" >> "$SHELL_CONFIG"
|
||||
echo "# Hermes Agent — ensure ~/.local/bin is on PATH" >> "$SHELL_CONFIG"
|
||||
echo "$PATH_LINE" >> "$SHELL_CONFIG"
|
||||
log_success "Added ~/.local/bin to PATH in $SHELL_CONFIG"
|
||||
else
|
||||
log_info "~/.local/bin already referenced in $SHELL_CONFIG"
|
||||
fi
|
||||
fi
|
||||
elif [ -n "$ZSH_VERSION" ] || [ -f "$HOME/.zshrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.zshrc"
|
||||
else
|
||||
log_info "~/.local/bin already on PATH"
|
||||
fi
|
||||
|
||||
PATH_LINE="export PATH=\"$BIN_DIR:\$PATH\""
|
||||
# Export for current session so hermes works immediately
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
|
||||
if [ -n "$SHELL_CONFIG" ]; then
|
||||
if ! grep -q "hermes-agent" "$SHELL_CONFIG" 2>/dev/null; then
|
||||
echo "" >> "$SHELL_CONFIG"
|
||||
echo "# Hermes Agent" >> "$SHELL_CONFIG"
|
||||
echo "$PATH_LINE" >> "$SHELL_CONFIG"
|
||||
log_success "Added to $SHELL_CONFIG"
|
||||
else
|
||||
log_info "PATH already configured in $SHELL_CONFIG"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Also export for current session
|
||||
export PATH="$BIN_DIR:$PATH"
|
||||
|
||||
log_success "PATH configured"
|
||||
log_success "hermes command ready"
|
||||
}
|
||||
|
||||
copy_config_templates() {
|
||||
@@ -553,7 +581,6 @@ copy_config_templates() {
|
||||
cp "$INSTALL_DIR/.env.example" "$HERMES_HOME/.env"
|
||||
log_success "Created ~/.hermes/.env from template"
|
||||
else
|
||||
# Create empty .env if no example exists
|
||||
touch "$HERMES_HOME/.env"
|
||||
log_success "Created ~/.hermes/.env"
|
||||
fi
|
||||
@@ -601,12 +628,14 @@ run_setup_wizard() {
|
||||
log_info "Starting setup wizard..."
|
||||
echo ""
|
||||
|
||||
if [ "$USE_VENV" = true ]; then
|
||||
source "$INSTALL_DIR/venv/bin/activate"
|
||||
fi
|
||||
|
||||
cd "$INSTALL_DIR"
|
||||
python -m hermes_cli.main setup
|
||||
|
||||
# Run hermes setup using the venv Python directly (no activation needed)
|
||||
if [ "$USE_VENV" = true ]; then
|
||||
"$INSTALL_DIR/venv/bin/python" -m hermes_cli.main setup
|
||||
else
|
||||
python -m hermes_cli.main setup
|
||||
fi
|
||||
}
|
||||
|
||||
print_success() {
|
||||
@@ -673,6 +702,7 @@ main() {
|
||||
print_banner
|
||||
|
||||
detect_os
|
||||
install_uv
|
||||
check_python
|
||||
check_git
|
||||
check_node
|
||||
|
||||
62
scripts/launch_llama_cpp_glm47_flash.sh
Executable file
62
scripts/launch_llama_cpp_glm47_flash.sh
Executable file
@@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Launch a local llama.cpp OpenAI-compatible server running GLM-4.7-Flash (GGUF).
|
||||
#
|
||||
# Requires:
|
||||
# - `llama-server` installed (e.g. `brew install llama.cpp`)
|
||||
#
|
||||
# Default settings are chosen to avoid clashing with Atropos sandbox_server
|
||||
# (which commonly uses port 8080 in local dev).
|
||||
#
|
||||
# Usage:
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh
|
||||
#
|
||||
# Override defaults:
|
||||
# LLAMA_CPP_HOST=127.0.0.1 LLAMA_CPP_PORT=8082 \
|
||||
# LLAMA_CPP_HF_REPO=ggml-org/GLM-4.7-Flash-GGUF \
|
||||
# LLAMA_CPP_HF_FILE=GLM-4.7-Flash-Q4_K.gguf \
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh
|
||||
|
||||
HOST="${LLAMA_CPP_HOST:-127.0.0.1}"
|
||||
PORT="${LLAMA_CPP_PORT:-8080}"
|
||||
HF_REPO="${LLAMA_CPP_HF_REPO:-ggml-org/GLM-4.7-Flash-GGUF}"
|
||||
HF_FILE="${LLAMA_CPP_HF_FILE:-GLM-4.7-Flash-Q4_K.gguf}"
|
||||
ALIAS="${LLAMA_CPP_ALIAS:-glm-4.7-flash}"
|
||||
|
||||
if ! command -v llama-server >/dev/null 2>&1; then
|
||||
echo "Error: llama-server not found in PATH."
|
||||
echo "Install via Homebrew: brew install llama.cpp"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Launching llama.cpp server..."
|
||||
echo " host: $HOST"
|
||||
echo " port: $PORT"
|
||||
echo " repo: $HF_REPO"
|
||||
echo " file: $HF_FILE"
|
||||
echo " alias: $ALIAS"
|
||||
echo
|
||||
echo "Suggested env vars for Hermes/Atropos integration:"
|
||||
echo " export ATROPOS_SERVER_BASE_URL=http://${HOST}:${PORT}"
|
||||
echo " export ATROPOS_SERVER_MODEL=${ALIAS}"
|
||||
echo " export ATROPOS_SERVER_API_KEY=local"
|
||||
echo
|
||||
|
||||
if command -v lsof >/dev/null 2>&1; then
|
||||
if lsof -nP -iTCP:"$PORT" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
echo "Error: port $PORT is already in use."
|
||||
echo "Pick a different port, e.g.:"
|
||||
echo " LLAMA_CPP_PORT=8082 Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
exec llama-server \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--hf-repo "$HF_REPO" \
|
||||
--hf-file "$HF_FILE" \
|
||||
--alias "$ALIAS" \
|
||||
-c 32768 \
|
||||
-n -1
|
||||
70
scripts/launch_llama_cpp_hermes_4_36b.sh
Executable file
70
scripts/launch_llama_cpp_hermes_4_36b.sh
Executable file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Launch a local llama.cpp OpenAI-compatible server running Hermes 4.3 36B (GGUF).
|
||||
#
|
||||
# Requires:
|
||||
# - `llama-server` installed (e.g. `brew install llama.cpp`)
|
||||
#
|
||||
# Note: Port choice can conflict with other local dev servers. If 8080 is already
|
||||
# in use, override via `LLAMA_CPP_PORT=...`.
|
||||
#
|
||||
# Usage:
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
#
|
||||
# Override defaults:
|
||||
# LLAMA_CPP_HOST=127.0.0.1 LLAMA_CPP_PORT=8082 \
|
||||
# LLAMA_CPP_HF_REPO=NousResearch/Hermes-4.3-36B-GGUF \
|
||||
# LLAMA_CPP_HF_FILE=hermes-4_3_36b-Q4_K_M.gguf \
|
||||
# LLAMA_CPP_ALIAS=hermes-4-36b \
|
||||
# LLAMA_CPP_PARALLEL=4 LLAMA_CPP_THREADS_HTTP=4 \
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
|
||||
HOST="${LLAMA_CPP_HOST:-127.0.0.1}"
|
||||
PORT="${LLAMA_CPP_PORT:-8080}"
|
||||
HF_REPO="${LLAMA_CPP_HF_REPO:-NousResearch/Hermes-4.3-36B-GGUF}"
|
||||
HF_FILE="${LLAMA_CPP_HF_FILE:-hermes-4_3_36b-Q4_K_M.gguf}"
|
||||
ALIAS="${LLAMA_CPP_ALIAS:-hermes-4-36b}"
|
||||
PARALLEL="${LLAMA_CPP_PARALLEL:-4}"
|
||||
THREADS_HTTP="${LLAMA_CPP_THREADS_HTTP:-4}"
|
||||
|
||||
if ! command -v llama-server >/dev/null 2>&1; then
|
||||
echo "Error: llama-server not found in PATH."
|
||||
echo "Install via Homebrew: brew install llama.cpp"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Launching llama.cpp server..."
|
||||
echo " host: $HOST"
|
||||
echo " port: $PORT"
|
||||
echo " repo: $HF_REPO"
|
||||
echo " file: $HF_FILE"
|
||||
echo " alias: $ALIAS"
|
||||
echo " slots: $PARALLEL"
|
||||
echo
|
||||
echo "Suggested env vars for Hermes/Atropos integration:"
|
||||
echo " export ATROPOS_SERVER_BASE_URL=http://${HOST}:${PORT}"
|
||||
echo " export ATROPOS_SERVER_MODEL=${ALIAS}"
|
||||
echo " export ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B"
|
||||
echo " export ATROPOS_SERVER_API_KEY=local"
|
||||
echo
|
||||
|
||||
if command -v lsof >/dev/null 2>&1; then
|
||||
if lsof -nP -iTCP:"$PORT" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
echo "Error: port $PORT is already in use."
|
||||
echo "Pick a different port, e.g.:"
|
||||
echo " LLAMA_CPP_PORT=8082 Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
exec llama-server \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--hf-repo "$HF_REPO" \
|
||||
--hf-file "$HF_FILE" \
|
||||
--alias "$ALIAS" \
|
||||
--parallel "$PARALLEL" \
|
||||
--threads-http "$THREADS_HTTP" \
|
||||
-c 32768 \
|
||||
-n -1
|
||||
152
setup-hermes.sh
152
setup-hermes.sh
@@ -3,16 +3,18 @@
|
||||
# Hermes Agent Setup Script
|
||||
# ============================================================================
|
||||
# Quick setup for developers who cloned the repo manually.
|
||||
# Uses uv for fast Python provisioning and package management.
|
||||
#
|
||||
# Usage:
|
||||
# ./setup-hermes.sh
|
||||
#
|
||||
# This script:
|
||||
# 1. Creates a virtual environment (if not exists)
|
||||
# 2. Installs dependencies
|
||||
# 3. Creates .env from template (if not exists)
|
||||
# 4. Installs the 'hermes' CLI command
|
||||
# 5. Runs the setup wizard (optional)
|
||||
# 1. Installs uv if not present
|
||||
# 2. Creates a virtual environment with Python 3.11 via uv
|
||||
# 3. Installs all dependencies (main package + submodules)
|
||||
# 4. Creates .env from template (if not exists)
|
||||
# 5. Symlinks the 'hermes' CLI command into ~/.local/bin
|
||||
# 6. Runs the setup wizard (optional)
|
||||
# ============================================================================
|
||||
|
||||
set -e
|
||||
@@ -21,38 +23,75 @@ set -e
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[0;33m'
|
||||
CYAN='\033[0;36m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m'
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
PYTHON_VERSION="3.11"
|
||||
|
||||
echo ""
|
||||
echo -e "${CYAN}🦋 Hermes Agent Setup${NC}"
|
||||
echo ""
|
||||
|
||||
# ============================================================================
|
||||
# Python check
|
||||
# Install / locate uv
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Checking Python..."
|
||||
echo -e "${CYAN}→${NC} Checking for uv..."
|
||||
|
||||
PYTHON_CMD=""
|
||||
for cmd in python3.12 python3.11 python3.10 python3 python; do
|
||||
if command -v $cmd &> /dev/null; then
|
||||
if $cmd -c "import sys; exit(0 if sys.version_info >= (3, 10) else 1)" 2>/dev/null; then
|
||||
PYTHON_CMD=$cmd
|
||||
break
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$PYTHON_CMD" ]; then
|
||||
echo -e "${YELLOW}✗${NC} Python 3.10+ required"
|
||||
exit 1
|
||||
UV_CMD=""
|
||||
if command -v uv &> /dev/null; then
|
||||
UV_CMD="uv"
|
||||
elif [ -x "$HOME/.local/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.local/bin/uv"
|
||||
elif [ -x "$HOME/.cargo/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.cargo/bin/uv"
|
||||
fi
|
||||
|
||||
PYTHON_VERSION=$($PYTHON_CMD -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
||||
echo -e "${GREEN}✓${NC} Python $PYTHON_VERSION found"
|
||||
if [ -n "$UV_CMD" ]; then
|
||||
UV_VERSION=$($UV_CMD --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} uv found ($UV_VERSION)"
|
||||
else
|
||||
echo -e "${CYAN}→${NC} Installing uv..."
|
||||
if curl -LsSf https://astral.sh/uv/install.sh | sh 2>/dev/null; then
|
||||
if [ -x "$HOME/.local/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.local/bin/uv"
|
||||
elif [ -x "$HOME/.cargo/bin/uv" ]; then
|
||||
UV_CMD="$HOME/.cargo/bin/uv"
|
||||
fi
|
||||
|
||||
if [ -n "$UV_CMD" ]; then
|
||||
UV_VERSION=$($UV_CMD --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} uv installed ($UV_VERSION)"
|
||||
else
|
||||
echo -e "${RED}✗${NC} uv installed but not found. Add ~/.local/bin to PATH and retry."
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo -e "${RED}✗${NC} Failed to install uv. Visit https://docs.astral.sh/uv/"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
# Python check (uv can provision it automatically)
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Checking Python $PYTHON_VERSION..."
|
||||
|
||||
if $UV_CMD python find "$PYTHON_VERSION" &> /dev/null; then
|
||||
PYTHON_PATH=$($UV_CMD python find "$PYTHON_VERSION")
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} $PYTHON_FOUND_VERSION found"
|
||||
else
|
||||
echo -e "${CYAN}→${NC} Python $PYTHON_VERSION not found, installing via uv..."
|
||||
$UV_CMD python install "$PYTHON_VERSION"
|
||||
PYTHON_PATH=$($UV_CMD python find "$PYTHON_VERSION")
|
||||
PYTHON_FOUND_VERSION=$($PYTHON_PATH --version 2>/dev/null)
|
||||
echo -e "${GREEN}✓${NC} $PYTHON_FOUND_VERSION installed"
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
# Virtual environment
|
||||
@@ -60,15 +99,16 @@ echo -e "${GREEN}✓${NC} Python $PYTHON_VERSION found"
|
||||
|
||||
echo -e "${CYAN}→${NC} Setting up virtual environment..."
|
||||
|
||||
if [ ! -d "venv" ]; then
|
||||
$PYTHON_CMD -m venv venv
|
||||
echo -e "${GREEN}✓${NC} Created venv"
|
||||
else
|
||||
echo -e "${GREEN}✓${NC} venv exists"
|
||||
if [ -d "venv" ]; then
|
||||
echo -e "${CYAN}→${NC} Removing old venv..."
|
||||
rm -rf venv
|
||||
fi
|
||||
|
||||
source venv/bin/activate
|
||||
pip install --upgrade pip wheel setuptools > /dev/null
|
||||
$UV_CMD venv venv --python "$PYTHON_VERSION"
|
||||
echo -e "${GREEN}✓${NC} venv created (Python $PYTHON_VERSION)"
|
||||
|
||||
# Tell uv to install into this venv (no activation needed for uv)
|
||||
export VIRTUAL_ENV="$SCRIPT_DIR/venv"
|
||||
|
||||
# ============================================================================
|
||||
# Dependencies
|
||||
@@ -76,10 +116,34 @@ pip install --upgrade pip wheel setuptools > /dev/null
|
||||
|
||||
echo -e "${CYAN}→${NC} Installing dependencies..."
|
||||
|
||||
pip install -e ".[all]" > /dev/null 2>&1 || pip install -e "." > /dev/null
|
||||
$UV_CMD pip install -e ".[all]" || $UV_CMD pip install -e "."
|
||||
|
||||
echo -e "${GREEN}✓${NC} Dependencies installed"
|
||||
|
||||
# ============================================================================
|
||||
# Submodules (terminal backend + RL training)
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Installing submodules..."
|
||||
|
||||
# mini-swe-agent (terminal tool backend)
|
||||
if [ -d "mini-swe-agent" ] && [ -f "mini-swe-agent/pyproject.toml" ]; then
|
||||
$UV_CMD pip install -e "./mini-swe-agent" && \
|
||||
echo -e "${GREEN}✓${NC} mini-swe-agent installed" || \
|
||||
echo -e "${YELLOW}⚠${NC} mini-swe-agent install failed (terminal tools may not work)"
|
||||
else
|
||||
echo -e "${YELLOW}⚠${NC} mini-swe-agent not found (run: git submodule update --init --recursive)"
|
||||
fi
|
||||
|
||||
# tinker-atropos (RL training backend)
|
||||
if [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then
|
||||
$UV_CMD pip install -e "./tinker-atropos" && \
|
||||
echo -e "${GREEN}✓${NC} tinker-atropos installed" || \
|
||||
echo -e "${YELLOW}⚠${NC} tinker-atropos install failed (RL tools may not work)"
|
||||
else
|
||||
echo -e "${YELLOW}⚠${NC} tinker-atropos not found (run: git submodule update --init --recursive)"
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
# Optional: ripgrep (for faster file search)
|
||||
# ============================================================================
|
||||
@@ -141,14 +205,17 @@ else
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
# PATH setup
|
||||
# PATH setup — symlink hermes into ~/.local/bin
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Setting up hermes command..."
|
||||
|
||||
BIN_DIR="$SCRIPT_DIR/venv/bin"
|
||||
HERMES_BIN="$SCRIPT_DIR/venv/bin/hermes"
|
||||
mkdir -p "$HOME/.local/bin"
|
||||
ln -sf "$HERMES_BIN" "$HOME/.local/bin/hermes"
|
||||
echo -e "${GREEN}✓${NC} Symlinked hermes → ~/.local/bin/hermes"
|
||||
|
||||
# Add to shell config if not already there
|
||||
# Ensure ~/.local/bin is on PATH in shell config
|
||||
SHELL_CONFIG=""
|
||||
if [ -f "$HOME/.zshrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.zshrc"
|
||||
@@ -159,13 +226,17 @@ elif [ -f "$HOME/.bash_profile" ]; then
|
||||
fi
|
||||
|
||||
if [ -n "$SHELL_CONFIG" ]; then
|
||||
if ! grep -q "hermes-agent" "$SHELL_CONFIG" 2>/dev/null; then
|
||||
echo "" >> "$SHELL_CONFIG"
|
||||
echo "# Hermes Agent" >> "$SHELL_CONFIG"
|
||||
echo "export PATH=\"$BIN_DIR:\$PATH\"" >> "$SHELL_CONFIG"
|
||||
echo -e "${GREEN}✓${NC} Added to $SHELL_CONFIG"
|
||||
if ! echo "$PATH" | tr ':' '\n' | grep -q "^$HOME/.local/bin$"; then
|
||||
if ! grep -q '\.local/bin' "$SHELL_CONFIG" 2>/dev/null; then
|
||||
echo "" >> "$SHELL_CONFIG"
|
||||
echo "# Hermes Agent — ensure ~/.local/bin is on PATH" >> "$SHELL_CONFIG"
|
||||
echo 'export PATH="$HOME/.local/bin:$PATH"' >> "$SHELL_CONFIG"
|
||||
echo -e "${GREEN}✓${NC} Added ~/.local/bin to PATH in $SHELL_CONFIG"
|
||||
else
|
||||
echo -e "${GREEN}✓${NC} ~/.local/bin already in $SHELL_CONFIG"
|
||||
fi
|
||||
else
|
||||
echo -e "${GREEN}✓${NC} PATH already in $SHELL_CONFIG"
|
||||
echo -e "${GREEN}✓${NC} ~/.local/bin already on PATH"
|
||||
fi
|
||||
fi
|
||||
|
||||
@@ -199,5 +270,6 @@ read -p "Would you like to run the setup wizard now? [Y/n] " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
echo ""
|
||||
python -m hermes_cli.main setup
|
||||
# Run directly with venv Python (no activation needed)
|
||||
"$SCRIPT_DIR/venv/bin/python" -m hermes_cli.main setup
|
||||
fi
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user