Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 24c13bc412 | |||
| 06e9422324 | |||
| 907616a692 | |||
| 33a00d9b8e | |||
| a2312076da | |||
| 499490d06a | |||
| 35b2250b36 | |||
| 395392e5de | |||
| 2041b354a9 | |||
| 3951eab399 | |||
| 62001e3bf5 | |||
| c8b30e9efa | |||
| f82c3081f2 | |||
| a69924631c | |||
| 4619d1c8ef | |||
| 98d945f6de | |||
| 507b77c4ac | |||
| b99c2a2644 | |||
| 975c849308 | |||
| 9dc27880cd | |||
| 3b9c53e6db | |||
| 05dd31131f | |||
| 36ea883d45 | |||
| 6be8cdeeca | |||
| 0bc914b00c | |||
| 411e7f8ff4 | |||
| eb2e6b73fe | |||
| 664acf7426 | |||
| fd1c3da305 | |||
| 4d619bcd21 | |||
| beac2ee06a | |||
| 487487406d | |||
| 87464821d8 | |||
| 661d8f4d6c | |||
| bf13a848ef | |||
| 88286f6da3 | |||
| 5b82190460 | |||
| ea7aa0b0d4 | |||
| 7130fa50cb | |||
| 5a9c98a771 | |||
| 6cb4fe948a | |||
| 30221d8c20 | |||
| b5b1fef20a | |||
| 16fb41f9cc | |||
| 4939130485 | |||
| 8dccd6569e | |||
| db348dc467 | |||
| 88722e230d | |||
| 68fb0efe0e | |||
| e38c274f8d |
+115
@@ -0,0 +1,115 @@
|
||||
# Cline's Memory Bank
|
||||
|
||||
I am Cline, an expert software engineer with a unique characteristic: my memory resets completely between sessions. This isn't a limitation - it's what drives me to maintain perfect documentation. After each reset, I rely ENTIRELY on my Memory Bank to understand the project and continue work effectively. I MUST read ALL memory bank files at the start of EVERY task - this is not optional.
|
||||
|
||||
## Memory Bank Structure
|
||||
|
||||
The Memory Bank consists of core files and optional context files, all in Markdown format. Files build upon each other in a clear hierarchy:
|
||||
|
||||
flowchart TD
|
||||
PB[projectbrief.md] --> PC[productContext.md]
|
||||
PB --> SP[systemPatterns.md]
|
||||
PB --> TC[techContext.md]
|
||||
|
||||
PC --> AC[activeContext.md]
|
||||
SP --> AC
|
||||
TC --> AC
|
||||
|
||||
AC --> P[progress.md]
|
||||
|
||||
### Core Files (Required)
|
||||
1. `projectbrief.md`
|
||||
- Foundation document that shapes all other files
|
||||
- Created at project start if it doesn't exist
|
||||
- Defines core requirements and goals
|
||||
- Source of truth for project scope
|
||||
|
||||
2. `productContext.md`
|
||||
- Why this project exists
|
||||
- Problems it solves
|
||||
- How it should work
|
||||
- User experience goals
|
||||
|
||||
3. `activeContext.md`
|
||||
- Current work focus
|
||||
- Recent changes
|
||||
- Next steps
|
||||
- Active decisions and considerations
|
||||
- Important patterns and preferences
|
||||
- Learnings and project insights
|
||||
|
||||
4. `systemPatterns.md`
|
||||
- System architecture
|
||||
- Key technical decisions
|
||||
- Design patterns in use
|
||||
- Component relationships
|
||||
- Critical implementation paths
|
||||
|
||||
5. `techContext.md`
|
||||
- Technologies used
|
||||
- Development setup
|
||||
- Technical constraints
|
||||
- Dependencies
|
||||
- Tool usage patterns
|
||||
|
||||
6. `progress.md`
|
||||
- What works
|
||||
- What's left to build
|
||||
- Current status
|
||||
- Known issues
|
||||
- Evolution of project decisions
|
||||
|
||||
### Additional Context
|
||||
Create additional files/folders within memory-bank/ when they help organize:
|
||||
- Complex feature documentation
|
||||
- Integration specifications
|
||||
- API documentation
|
||||
- Testing strategies
|
||||
- Deployment procedures
|
||||
|
||||
## Core Workflows
|
||||
|
||||
### Plan Mode
|
||||
flowchart TD
|
||||
Start[Start] --> ReadFiles[Read Memory Bank]
|
||||
ReadFiles --> CheckFiles{Files Complete?}
|
||||
|
||||
CheckFiles -->|No| Plan[Create Plan]
|
||||
Plan --> Document[Document in Chat]
|
||||
|
||||
CheckFiles -->|Yes| Verify[Verify Context]
|
||||
Verify --> Strategy[Develop Strategy]
|
||||
Strategy --> Present[Present Approach]
|
||||
|
||||
### Act Mode
|
||||
flowchart TD
|
||||
Start[Start] --> Context[Check Memory Bank]
|
||||
Context --> Update[Update Documentation]
|
||||
Update --> Execute[Execute Task]
|
||||
Execute --> Document[Document Changes]
|
||||
|
||||
## Documentation Updates
|
||||
|
||||
Memory Bank updates occur when:
|
||||
1. Discovering new project patterns
|
||||
2. After implementing significant changes
|
||||
3. When user requests with **update memory bank** (MUST review ALL files)
|
||||
4. When context needs clarification
|
||||
|
||||
flowchart TD
|
||||
Start[Update Process]
|
||||
|
||||
subgraph Process
|
||||
P1[Review ALL Files]
|
||||
P2[Document Current State]
|
||||
P3[Clarify Next Steps]
|
||||
P4[Document Insights & Patterns]
|
||||
|
||||
P1 --> P2 --> P3 --> P4
|
||||
end
|
||||
|
||||
Start --> Process
|
||||
|
||||
Note: When triggered by **update memory bank**, I MUST review every memory bank file, even if some don't require updates. Focus particularly on activeContext.md and progress.md as they track current state.
|
||||
|
||||
REMEMBER: After every memory reset, I begin completely fresh. The Memory Bank is my only link to previous work. It must be maintained with precision and clarity, as my effectiveness depends entirely on its accuracy.
|
||||
+139
-9
@@ -1,12 +1,68 @@
|
||||
# Hermes Agent Environment Configuration
|
||||
# Copy this file to .env and fill in your API keys
|
||||
|
||||
# =============================================================================
|
||||
# CORE SETTINGS
|
||||
# =============================================================================
|
||||
# Agent backend:
|
||||
# - openai : default Hermes-Agent loop (OpenAI function-calling via OpenAI SDK)
|
||||
# - atropos : Atroposlib ServerManager/ManagedServer-backed loop (training/env integration)
|
||||
HERMES_BACKEND=openai
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LOCAL / SELF-HOSTED OPENAI-COMPATIBLE ENDPOINTS (vLLM, SGLang, llama.cpp, etc.)
|
||||
# =============================================================================
|
||||
# For local development (matches the Atropos test env defaults):
|
||||
# ATROPOS_SERVER_BASE_URL=http://127.0.0.1:8080
|
||||
# ATROPOS_SERVER_MODEL=hermes-4-36b
|
||||
# For hosted inference (Nous Research inference API):
|
||||
ATROPOS_SERVER_BASE_URL=
|
||||
ATROPOS_SERVER_MODEL=
|
||||
ATROPOS_TOKENIZER_NAME=
|
||||
# Set this to your Nous API key (Bearer token).
|
||||
ATROPOS_SERVER_API_KEY=
|
||||
|
||||
# Debugging (prints to stdout; use with care)
|
||||
# HERMES_DEBUG_ATROPOS_REQUEST=1
|
||||
# HERMES_DEBUG_ATROPOS_RESPONSE=1
|
||||
# HERMES_DEBUG_OPENAI_REQUEST=1
|
||||
# HERMES_DEBUG_OPENAI_RESPONSE=1
|
||||
|
||||
# =============================================================================
|
||||
# LOCAL / SELF-HOSTED OPENAI-COMPATIBLE ENDPOINTS (vLLM, SGLang, llama.cpp, etc.)
|
||||
# =============================================================================
|
||||
# If you set ATROPOS_SERVER_BASE_URL or OPENAI_BASE_URL, Hermes will use it instead
|
||||
# of OpenRouter.
|
||||
#
|
||||
# Local server convenience (base URL without /v1):
|
||||
# llama.cpp example (see `Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh`):
|
||||
# ATROPOS_SERVER_BASE_URL=http://127.0.0.1:8080
|
||||
# ATROPOS_SERVER_MODEL=hermes-4-36b
|
||||
# ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B
|
||||
# ATROPOS_SERVER_API_KEY=local
|
||||
#
|
||||
# Hosted Nous inference API:
|
||||
# ATROPOS_SERVER_BASE_URL=https://inference-api.nousresearch.com
|
||||
# ATROPOS_SERVER_MODEL=Hermes-4.3-36B
|
||||
# ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B
|
||||
# ATROPOS_SERVER_API_KEY=sk-... (Bearer token)
|
||||
#
|
||||
# If you plan to run GRPO-style group sampling (e.g. `--env.group_size 4`) against
|
||||
# llama.cpp, start the server with at least that many slots, e.g.:
|
||||
# LLAMA_CPP_PARALLEL=4 Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
#
|
||||
# Generic OpenAI-compatible (base URL should include /v1):
|
||||
# OPENAI_BASE_URL=http://127.0.0.1:8080/v1
|
||||
# OPENAI_API_KEY=local
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (OpenRouter)
|
||||
# =============================================================================
|
||||
# OpenRouter provides access to many models through one API
|
||||
# All LLM calls go through OpenRouter - no direct provider keys needed
|
||||
# Get your key at: https://openrouter.ai/keys
|
||||
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
|
||||
OPENROUTER_API_KEY=
|
||||
|
||||
# Default model to use (OpenRouter format: provider/model)
|
||||
@@ -42,10 +98,9 @@ TERMINAL_ENV=local
|
||||
|
||||
|
||||
# Container images (for singularity/docker/modal backends)
|
||||
TERMINAL_DOCKER_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://nikolaik/python-nodejs:python3.11-nodejs20
|
||||
TERMINAL_MODAL_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
|
||||
TERMINAL_DOCKER_IMAGE=python:3.11
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://python:3.11
|
||||
TERMINAL_MODAL_IMAGE=python:3.11
|
||||
|
||||
# Working directory for terminal commands
|
||||
# For CLI: "." means current directory (resolved automatically from config.yaml)
|
||||
@@ -93,12 +148,87 @@ TERMINAL_LIFETIME_SECONDS=300
|
||||
# SUDO_PASSWORD=your_password_here
|
||||
|
||||
# =============================================================================
|
||||
# MODAL CLOUD BACKEND (Optional - for TERMINAL_ENV=modal)
|
||||
# MODAL CLOUD BACKEND (for TERMINAL_ENV=modal)
|
||||
# =============================================================================
|
||||
# Modal uses CLI authentication, not environment variables.
|
||||
# Run: pip install modal && modal setup
|
||||
# This will authenticate via browser and store credentials locally.
|
||||
# No API key needed in .env - Modal handles auth automatically.
|
||||
# Modal provides cloud sandboxes with per-second billing and auto-scaling.
|
||||
# This implementation uses a warm pool of sandboxes for cost efficiency.
|
||||
#
|
||||
# SETUP:
|
||||
# pip install modal && modal setup
|
||||
# (Authenticates via browser, stores credentials locally)
|
||||
#
|
||||
# FEATURES:
|
||||
# - Auto-scaling warm sandbox pool (no cold start after first use)
|
||||
# - Named sandbox recovery (reconnects after restart)
|
||||
# - Profile-based heterogeneous environments (CPU, GPU, different images)
|
||||
# - Server-side idle_timeout protection against orphaned sandboxes
|
||||
|
||||
# Modal app name (groups all sandboxes, used for recovery)
|
||||
TERMINAL_MODAL_APP_NAME=hermes-sandbox
|
||||
|
||||
# Default profile when none specified
|
||||
TERMINAL_MODAL_DEFAULT_PROFILE=default
|
||||
|
||||
# Profile config file (optional - YAML format, see modal_profiles.yaml)
|
||||
# TERMINAL_MODAL_PROFILES_FILE=modal_profiles.yaml
|
||||
|
||||
# --- Default Profile Settings (used if no YAML file) ---
|
||||
# These apply when no profile is specified or for the "default" profile
|
||||
TERMINAL_MODAL_IMAGE=python:3.11
|
||||
TERMINAL_MODAL_MIN_POOL=1
|
||||
TERMINAL_MODAL_MAX_POOL=5
|
||||
TERMINAL_MODAL_IDLE_TIMEOUT=120
|
||||
TERMINAL_MODAL_MAX_LIFETIME=3600
|
||||
TERMINAL_MODAL_SCALE_DOWN_IDLE=180
|
||||
|
||||
# --- Custom Profile Example: pytorch-gpu ---
|
||||
# Uncomment to enable a GPU profile for ML tasks
|
||||
# Usage: terminal_tool("python train.py", profile="pytorch-gpu")
|
||||
#
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_IMAGE=pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_GPU=T4
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MEMORY=16384
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MIN_POOL=0
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MAX_POOL=2
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_IDLE_TIMEOUT=60
|
||||
|
||||
# --- Custom Profile Example: node ---
|
||||
# Uncomment to enable a Node.js profile
|
||||
# Usage: terminal_tool("npm test", profile="node")
|
||||
#
|
||||
# TERMINAL_MODAL_PROFILE_node_IMAGE=node:18
|
||||
# TERMINAL_MODAL_PROFILE_node_MIN_POOL=0
|
||||
# TERMINAL_MODAL_PROFILE_node_MAX_POOL=3
|
||||
|
||||
# =============================================================================
|
||||
# MODAL SECRETS (Secure credential injection)
|
||||
# =============================================================================
|
||||
# Modal Secrets allow you to securely pass API keys, passwords, and other
|
||||
# sensitive data to your sandboxes without exposing them in code or logs.
|
||||
#
|
||||
# SETUP SECRETS:
|
||||
# 1. Via Dashboard: https://modal.com/secrets
|
||||
# 2. Via CLI: modal secret create my-secret KEY1=value1 KEY2=value2
|
||||
# 3. Via CLI with env: modal secret create my-secret API_KEY="$API_KEY"
|
||||
#
|
||||
# LIST SECRETS:
|
||||
# modal secret list
|
||||
#
|
||||
# DELETE SECRETS:
|
||||
# modal secret delete my-secret
|
||||
|
||||
# Global secrets applied to ALL profiles (comma-separated secret names)
|
||||
# These secrets must be created on Modal dashboard or via CLI first
|
||||
# TERMINAL_MODAL_SECRETS=my-api-keys,database-creds
|
||||
|
||||
# Per-profile secrets (comma-separated secret names)
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_SECRETS=huggingface-token,wandb-key
|
||||
|
||||
# Per-profile environment variables (semicolon-separated KEY=VALUE pairs)
|
||||
# TERMINAL_MODAL_PROFILE_default_ENV_VARS=DEBUG=1;LOG_LEVEL=info
|
||||
|
||||
# Load local .env file into sandbox (useful for development)
|
||||
# TERMINAL_MODAL_PROFILE_default_USE_DOTENV=true
|
||||
|
||||
# =============================================================================
|
||||
# BROWSER TOOL CONFIGURATION (agent-browser + Browserbase)
|
||||
|
||||
+24
@@ -46,3 +46,27 @@ testlogs
|
||||
|
||||
# CLI config (may contain sensitive SSH paths)
|
||||
cli-config.yaml
|
||||
|
||||
.DS_Store
|
||||
|
||||
# artifacts
|
||||
*.jsonl
|
||||
*.html
|
||||
*.json
|
||||
*.log
|
||||
*.csv
|
||||
|
||||
# Singularity/Apptainer images (large binary files)
|
||||
*.sif
|
||||
|
||||
# Test files
|
||||
test_singularity_*.py
|
||||
test_*.py
|
||||
!tests/test_*.py
|
||||
|
||||
# Nomad data
|
||||
/tmp/NomadClient*/
|
||||
|
||||
*.egg-info*
|
||||
wandb
|
||||
logs
|
||||
@@ -1,142 +0,0 @@
|
||||
# Project Notes
|
||||
|
||||
*Maintained by Hermes — last updated June 2025*
|
||||
|
||||
---
|
||||
|
||||
## 1. Kandinsky (Multimodal Transformer)
|
||||
- **Repo:** https://github.com/samherring99/kandinsky
|
||||
- **Local path:** `~/Desktop/Projects/kandinsky`
|
||||
- **Description:** An anything-to-anything transformer combining text, image, and audio modalities. Trains on Pokemon BLIP captions paired with Gen 1 Pokemon audio cries. Uses audio tokenization adapted from nanoGPT.
|
||||
- **Status:** Early POC. Training code exists (`model.py`) and dataset creation (`create_dataset.py`) works. Audio heads are producing the same sound — unclear if it's a training issue or data issue.
|
||||
- **TODO:**
|
||||
- Debug why audio heads produce identical output
|
||||
- Investigate if model needs more training time
|
||||
- Design a data pipeline for better/more training data
|
||||
- General repo cleanup (requirements.txt, proper CLI, etc.)
|
||||
|
||||
---
|
||||
|
||||
## 2. NightwingGameSim (LLM → GameBoy ROM Generator)
|
||||
- **Repo:** https://github.com/samherring99/NightwingGameSim
|
||||
- **Local path:** `~/Desktop/Projects/NightwingGameSim`
|
||||
- **Description:** AI-powered pipeline that turns natural language prompts into playable GameBoy ROM files. Generates C code, compiles with GBDK, outputs `.gb` files. Supports Claude API, local Llama, and RAG backends.
|
||||
- **Status:** Functional — generation pipeline works end-to-end with Claude 4 system prompt. Has tests, docs, examples, and retry logic.
|
||||
- **TODO:**
|
||||
- Harden the repo, clean up structure
|
||||
- Build a better testing pipeline
|
||||
- Come up with better prompt ideas / examples
|
||||
|
||||
---
|
||||
|
||||
## 3. ContentBasedMIR (Music Information Retrieval)
|
||||
- **Repo:** https://github.com/samherring99/ContentBasedMIR
|
||||
- **Local path:** `~/Desktop/Projects/ContentBasedMIR`
|
||||
- **Description:** Music similarity analysis using Spotify API track data. Extracts 54 audio features per song and visualizes similarity matrices for music recommendation.
|
||||
- **Status:** Early stage. Can download Spotify track analysis data and plot similarity matrices. Needs significant expansion.
|
||||
- **TODO:**
|
||||
- Expand analysis pipeline with more features
|
||||
- Integrate with text message data for personalized recommendations
|
||||
- Build out visualization and exploration tools
|
||||
- General modernization (dependencies, structure)
|
||||
|
||||
---
|
||||
|
||||
## 4. MessageRetrieval (iMessage RAG/SQL)
|
||||
- **Repo:** https://github.com/samherring99/MessageRetrieval
|
||||
- **Local path:** `~/Desktop/Projects/MessageRetrieval`
|
||||
- **Description:** Natural language querying over iMessage data using SQL generation (text2SQL) instead of vector embeddings. Uses LLM-as-Judge pattern for scoring and ranking retrieved messages.
|
||||
- **Status:** Has initial text2SQL pipeline and summarization tool. Recently worked on with Claude Code. Needs testing.
|
||||
- **TODO:**
|
||||
- Test out the recent Claude Code work
|
||||
- Build "iMessage Jarvis" — answer questions about texts
|
||||
- Improve SQL generation prompts and accuracy
|
||||
- Better error handling and UX
|
||||
|
||||
---
|
||||
|
||||
## 5. Grailed Embedding Search
|
||||
- **Repo:** https://github.com/samherring99/grailed-embedding-search
|
||||
- **Local path:** `~/Desktop/Projects/grailed-embedding-search`
|
||||
- **Description:** Semantic similarity search over Grailed fashion listings using CLIP embeddings and FAISS. Search by image URL or text description to find visually similar products.
|
||||
- **Status:** Functional core pipeline. CLIP ViT-B/32 embeds product cover photos into 512-dim vectors, indexed with FAISS cosine similarity. Has CLI, batch embedding, persistent index save/load, and logging.
|
||||
- **Recent work (June 2025):**
|
||||
- PR #1 — Initial cleanup: docstrings, type hints, `.gitignore`, `requirements.txt`, README rewrite
|
||||
- PR #2 — Feature improvements: persistent FAISS save/load, batch embedding, CLI (`cli.py`), proper logging throughout, lazy Grailed client, `fetch_details` toggle
|
||||
- **TODO:**
|
||||
- Embedding cache (avoid re-embedding known product URLs)
|
||||
- Async/threaded image downloads for faster batch indexing
|
||||
- Search result visualization (matplotlib grid of cover photos)
|
||||
- Filter by category, designer, price range before search
|
||||
- Web UI (Gradio or Streamlit)
|
||||
|
||||
---
|
||||
|
||||
## 6. NightwingNBA (Sports Analytics)
|
||||
- **Repo:** https://github.com/samherring99/NightwingNBA
|
||||
- **Local path:** `~/Desktop/Projects/NightwingNBA`
|
||||
- **Description:** NBA game prediction system. Builds a database of game data, trains a PyTorch model, and makes daily predictions. Has full pipeline: build DB → write data → train → predict.
|
||||
- **Status:** Functional pipeline exists. Has database building, training, prediction, and daily update scripts.
|
||||
- **TODO:**
|
||||
- Explore and potentially revive
|
||||
- Update data sources if stale
|
||||
- Improve model accuracy
|
||||
- Add visualization/reporting
|
||||
|
||||
---
|
||||
|
||||
## 7. Stable Audio Sample Explorer
|
||||
- **Repo:** https://github.com/samherring99/stable-audio-sample-explorer
|
||||
- **Local path:** `~/Desktop/Projects/stable-audio-sample-explorer`
|
||||
- **Description:** Tool for exploring audio samples generated by Stable Audio.
|
||||
- **Status:** 🪦 **Dead** — no active work needed per Sam.
|
||||
|
||||
---
|
||||
|
||||
## 8. NightwingArt (Art Tools)
|
||||
- **Repo:** https://github.com/samherring99/NightwingArt
|
||||
- **Local path:** `~/Desktop/Projects/NightwingArt`
|
||||
- **Description:** Collection of art tooling scripts — video editing, clip splicing with beat matching, damage effects, and general image manipulation.
|
||||
- **Status:** Maintenance mode. Tools exist for various effects. Work happens as-needed.
|
||||
- **TODO:**
|
||||
- Add tools as needed for new art projects
|
||||
|
||||
---
|
||||
|
||||
## 9. Claude-based VST Building ⚠️ *Needs new repo*
|
||||
- **Description:** Generate VST audio plugins for DAWs from English language prompts. LLM-powered audio plugin creation.
|
||||
- **Status:** Concept only — no repo exists yet.
|
||||
- **TODO:**
|
||||
- Create repo
|
||||
- Research VST SDK / JUCE framework
|
||||
- Design prompt → code → compile pipeline
|
||||
|
||||
---
|
||||
|
||||
## 10. Government Auction Site Scraper ⚠️ *Needs new repo*
|
||||
- **Description:** Tool that monitors and scrapes government auction sites in San Francisco for deals.
|
||||
- **Status:** Concept only — no repo exists yet.
|
||||
- **TODO:**
|
||||
- Create repo
|
||||
- Research SF government auction sites and their structure
|
||||
- Build scraper + notification system
|
||||
|
||||
---
|
||||
|
||||
## Priority Assessment
|
||||
|
||||
| Project | Activity Level | Suggested Priority |
|
||||
|---------|---------------|-------------------|
|
||||
| NightwingGameSim | Active | 🔴 High |
|
||||
| MessageRetrieval | Active | 🔴 High |
|
||||
| Kandinsky | Active | 🟡 Medium |
|
||||
| ContentBasedMIR | Exploratory | 🟡 Medium |
|
||||
| Grailed Embedding Search | Early | 🟡 Medium |
|
||||
| NightwingNBA | Dormant | 🟢 Low |
|
||||
| NightwingArt | As-needed | 🟢 Low |
|
||||
| VST Builder | Concept | 🔵 Future |
|
||||
| Gov Auction Scraper | Concept | 🔵 Future |
|
||||
| Stable Audio Explorer | Dead | ⚫ None |
|
||||
|
||||
|
||||
|
||||
@@ -37,9 +37,8 @@ All your settings are stored in `~/.hermes/` for easy access:
|
||||
|
||||
```
|
||||
~/.hermes/
|
||||
├── config.yaml # Settings (model, terminal, TTS, compression, etc.)
|
||||
├── config.yaml # Settings (model, terminal, compression, etc.)
|
||||
├── .env # API keys and secrets
|
||||
├── SOUL.md # Optional: global persona (agent embodies this personality)
|
||||
├── cron/ # Scheduled jobs
|
||||
├── sessions/ # Gateway sessions
|
||||
└── logs/ # Logs
|
||||
@@ -77,8 +76,6 @@ You need at least one LLM provider:
|
||||
| Web scraping | [Firecrawl](https://firecrawl.dev/) | `FIRECRAWL_API_KEY` |
|
||||
| Browser automation | [Browserbase](https://browserbase.com/) | `BROWSERBASE_API_KEY`, `BROWSERBASE_PROJECT_ID` |
|
||||
| Image generation | [FAL](https://fal.ai/) | `FAL_KEY` |
|
||||
| Premium TTS voices | [ElevenLabs](https://elevenlabs.io/) | `ELEVENLABS_API_KEY` |
|
||||
| OpenAI TTS voices | [OpenAI](https://platform.openai.com/api-keys) | `OPENAI_API_KEY` |
|
||||
| RL Training | [Tinker](https://tinker-console.thinkingmachines.ai/) + [WandB](https://wandb.ai/) | `TINKER_API_KEY`, `WANDB_API_KEY` |
|
||||
| Messaging | Telegram, Discord | `TELEGRAM_BOT_TOKEN`, `DISCORD_BOT_TOKEN` |
|
||||
|
||||
@@ -131,58 +128,7 @@ hermes --toolsets "web,terminal"
|
||||
hermes --list-tools
|
||||
```
|
||||
|
||||
**Available toolsets:** `web`, `terminal`, `browser`, `vision`, `creative`, `reasoning`, `skills`, `tts`, `cronjob`, and more.
|
||||
|
||||
### 🔊 Text-to-Speech
|
||||
|
||||
Convert text to speech with three providers:
|
||||
|
||||
| Provider | Quality | Cost | API Key |
|
||||
|----------|---------|------|---------|
|
||||
| **Edge TTS** (default) | Good | Free | None needed |
|
||||
| **ElevenLabs** | Excellent | Paid | `ELEVENLABS_API_KEY` |
|
||||
| **OpenAI TTS** | Good | Paid | `OPENAI_API_KEY` |
|
||||
|
||||
On Telegram, audio plays as native voice bubbles. On Discord/WhatsApp, sent as audio files. In CLI mode, saved to `~/voice-memos/`.
|
||||
|
||||
**Configure in `~/.hermes/config.yaml`:**
|
||||
```yaml
|
||||
tts:
|
||||
provider: "edge" # "edge" | "elevenlabs" | "openai"
|
||||
edge:
|
||||
voice: "en-US-AriaNeural" # 322 voices, 74 languages
|
||||
elevenlabs:
|
||||
voice_id: "pNInz6obpgDQGcFmaJgB" # Adam
|
||||
model_id: "eleven_multilingual_v2"
|
||||
openai:
|
||||
model: "gpt-4o-mini-tts"
|
||||
voice: "alloy" # alloy, echo, fable, onyx, nova, shimmer
|
||||
```
|
||||
|
||||
> **Note:** Telegram voice bubbles require `ffmpeg` for Opus conversion (Edge TTS only outputs MP3). Install with `apt install ffmpeg` or `brew install ffmpeg`. Without ffmpeg, audio is sent as a file instead of a voice bubble.
|
||||
|
||||
### 📄 Context Files (SOUL.md, AGENTS.md, .cursorrules)
|
||||
|
||||
Drop these files in your project directory and the agent automatically picks them up:
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `AGENTS.md` | Project-specific instructions, coding conventions, tool usage guidelines |
|
||||
| `SOUL.md` | Persona definition -- the agent embodies this personality and tone |
|
||||
| `.cursorrules` | Cursor IDE rules (also detected) |
|
||||
| `.cursor/rules/*.mdc` | Cursor rule files (also detected) |
|
||||
|
||||
- **AGENTS.md** is hierarchical: if subdirectories also have `AGENTS.md`, all are combined (like Codex/Cline).
|
||||
- **SOUL.md** checks cwd first, then `~/.hermes/SOUL.md` as a global fallback.
|
||||
- All context files are capped at 20,000 characters with smart truncation.
|
||||
|
||||
### 🛡️ Exec Approval (Messaging Platforms)
|
||||
|
||||
When the agent tries to run a potentially dangerous command (rm -rf, chmod 777, etc.) on Telegram/Discord/WhatsApp, instead of blocking it silently, it asks the user for approval:
|
||||
|
||||
> ⚠️ This command is potentially dangerous (recursive delete). Reply "yes" to approve.
|
||||
|
||||
Reply "yes"/"y" to approve or "no"/"n" to deny. In CLI mode, the existing interactive approval prompt (once/session/always/deny) is preserved.
|
||||
**Available toolsets:** `web`, `terminal`, `browser`, `vision`, `creative`, `reasoning`, `skills`, `cronjob`, and more.
|
||||
|
||||
### 🖥️ Terminal Backend
|
||||
|
||||
@@ -1049,6 +995,137 @@ All variables go in `~/.hermes/.env`. Run `hermes config set VAR value` to set t
|
||||
|
||||
---
|
||||
|
||||
## RL Training with Tinker
|
||||
|
||||
Hermes-Agent includes an RL training integration with [Tinker](https://thinkingmachines.ai/tinker/) (Thinking Machines) and [Atropos](https://github.com/NousResearch/atropos) for training language models with reinforcement learning from agent trajectories.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. **Install with Atropos extras** (includes Tinker SDK, atroposlib, torch, wandb):
|
||||
```bash
|
||||
pip install -e ".[atropos]"
|
||||
```
|
||||
|
||||
2. **Initialize the tinker-atropos submodule**:
|
||||
```bash
|
||||
git submodule update --init
|
||||
pip install -e ./tinker-atropos
|
||||
```
|
||||
|
||||
3. **Get API keys**:
|
||||
- `TINKER_API_KEY` from [Tinker Console](https://tinker-console.thinkingmachines.ai/keys) (requires billing setup)
|
||||
- `WANDB_API_KEY` from [Weights & Biases](https://wandb.ai/settings) (for metrics tracking)
|
||||
|
||||
4. **Add keys to your `.env` file**:
|
||||
```bash
|
||||
# Add to .env or ~/.hermes/.env
|
||||
TINKER_API_KEY=your_tinker_key
|
||||
WANDB_API_KEY=your_wandb_key
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
The RL training pipeline uses three processes that communicate over HTTP:
|
||||
|
||||
```
|
||||
┌──────────────────────┐ ┌─────────────────────┐ ┌────────────────────────┐
|
||||
│ Atropos Rollout API │ │ Tinker Trainer │ │ Environment │
|
||||
│ (port 8000) │◄──│ (port 8001) │◄──│ (worker) │
|
||||
│ │ │ │ │ │
|
||||
│ • Collects batches │ │ • LoRA training │ │ • Generates prompts │
|
||||
│ • Coordinates env │ │ • Inference server │ │ • Calls inference API │
|
||||
│ and trainer │ │ • Weight updates │ │ • Scores responses │
|
||||
│ │ │ • WandB logging │ │ • Sends scored batches │
|
||||
└──────────────────────┘ └─────────────────────┘ └────────────────────────┘
|
||||
```
|
||||
|
||||
### Quick Start: GSM8k Agent Training
|
||||
|
||||
This example trains a model on math problems using a Python REPL tool — the model learns to write and execute Python code to solve math:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start Atropos Rollout API
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
run-api
|
||||
|
||||
# Terminal 2: Start Tinker Trainer + Inference Server
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
python launch_training.py --config configs/gsm8k_agent.yaml
|
||||
|
||||
# Terminal 3: Start GSM8k Agent Environment
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
python tinker_atropos/environments/gsm8k_agent.py serve --config configs/gsm8k_agent.yaml
|
||||
```
|
||||
|
||||
### Available Environments
|
||||
|
||||
| Environment | File | Description |
|
||||
|------------|------|-------------|
|
||||
| `gsm8k` | `gsm8k_tinker.py` | Standard GSM8k math (no tools) |
|
||||
| `gsm8k_agent` | `gsm8k_agent.py` | GSM8k with Python REPL tool calling |
|
||||
|
||||
### Configuration
|
||||
|
||||
Configs are YAML files in `tinker-atropos/configs/` with three sections:
|
||||
|
||||
```yaml
|
||||
env: # Atropos environment settings
|
||||
group_size: 4 # Parallel rollouts per problem
|
||||
batch_size: 16 # Training batch size
|
||||
tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
max_token_length: 2048 # Max generation length
|
||||
total_steps: 20 # Training steps
|
||||
|
||||
openai: # Inference server (served by Tinker trainer)
|
||||
- model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
base_url: "http://localhost:8001/v1"
|
||||
|
||||
tinker: # Tinker training parameters
|
||||
lora_rank: 16 # LoRA rank (lower = faster, less capacity)
|
||||
learning_rate: 0.00005 # Learning rate
|
||||
max_token_trainer_length: 4096 # Max tokens for training
|
||||
wandb_project: "hermes-agent-rl"
|
||||
```
|
||||
|
||||
### RL CLI (Agent-Driven Training)
|
||||
|
||||
For interactive training management via the Hermes agent:
|
||||
|
||||
```bash
|
||||
# Interactive mode - let the agent manage training
|
||||
python rl_cli.py --interactive
|
||||
|
||||
# List available environments
|
||||
python rl_cli.py --list-environments
|
||||
|
||||
# Direct task
|
||||
python rl_cli.py "Train a model on GSM8k with tool use"
|
||||
```
|
||||
|
||||
### Sandbox Backends for Agent Environments
|
||||
|
||||
For agent environments that need isolated tool execution (e.g., SWE tasks), Hermes-Agent supports multiple sandbox backends:
|
||||
|
||||
| Backend | Use Case | Command |
|
||||
|---------|----------|---------|
|
||||
| **Nomad + Docker** | Default, local development | `--env.tool_pool_mode nomad` |
|
||||
| **Nomad + Singularity** | HPC clusters without Docker | `--env.tool_pool_mode nomad --env.driver singularity` |
|
||||
| **Modal** | Cloud-based, auto-scaling | `--env.tool_pool_mode modal` |
|
||||
|
||||
See [docs/MODAL_BACKEND.md](docs/MODAL_BACKEND.md) for Modal backend details.
|
||||
|
||||
### Cost
|
||||
|
||||
Check the [Tinker Rate Card](https://tinker-console.thinkingmachines.ai/rate-card) for available models and pricing.
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
```bash
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# Dockerfile for atropos-agent sandbox server
|
||||
# Runs inside Nomad containers to handle tool execution
|
||||
# Includes bubblewrap for namespace-based slot isolation
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# Bubblewrap for namespace isolation
|
||||
bubblewrap \
|
||||
# `script` for PTY allocation (used for stable tmux+asciinema startup)
|
||||
util-linux \
|
||||
# Git for SWE-style tasks (cloning repos)
|
||||
git \
|
||||
# tmux for stateful terminal sessions (Phase 4.7+)
|
||||
tmux \
|
||||
# Common tools agents might need
|
||||
curl \
|
||||
wget \
|
||||
jq \
|
||||
# Cleanup
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies (sandbox server + optional terminal recording)
|
||||
RUN pip install --no-cache-dir aiohttp asciinema
|
||||
|
||||
# Copy the sandbox server
|
||||
COPY sandbox_server.py /app/sandbox_server.py
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Create data directory for slot workspaces
|
||||
RUN mkdir -p /data
|
||||
|
||||
# Verify bubblewrap is installed and working
|
||||
RUN bwrap --version
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
# Default command - can be overridden by Nomad job spec
|
||||
CMD ["python", "sandbox_server.py", "--port", "8080", "--slots", "10", "--data-dir", "/data"]
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
Atropos integration for Hermes-Agent.
|
||||
|
||||
This package is intentionally optional: Hermes-Agent should work without Atropos.
|
||||
If you import anything from `atropos.*` without having `atroposlib` installed,
|
||||
we raise a clear error with install instructions.
|
||||
|
||||
Install (recommended, from repo checkout):
|
||||
uv sync --extra atropos
|
||||
|
||||
Or (pip / editable):
|
||||
pip install -e '.[atropos]'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _require_atroposlib() -> None:
|
||||
try:
|
||||
import atroposlib # noqa: F401
|
||||
except ModuleNotFoundError as exc: # pragma: no cover
|
||||
raise ModuleNotFoundError(
|
||||
"Hermes-Agent Atropos integration requires `atroposlib`, but it is not installed.\n"
|
||||
"Install it with:\n"
|
||||
" uv sync --extra atropos\n"
|
||||
"or:\n"
|
||||
" pip install -e '.[atropos]'\n"
|
||||
) from exc
|
||||
|
||||
|
||||
_require_atroposlib()
|
||||
|
||||
# Re-export the most commonly used pieces for convenience.
|
||||
# Agent imports are eager (always available).
|
||||
from .agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData # noqa: E402
|
||||
|
||||
# Env imports are lazy to avoid pulling in deleted atropos.tools dependencies.
|
||||
# Use: from atropos.envs import AgentEnv, AgentEnvConfig (if needed)
|
||||
|
||||
__all__ = [
|
||||
"AtroposAgent",
|
||||
"AgentConfig",
|
||||
"AgentResult",
|
||||
"AgentStep",
|
||||
"SequenceData",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Agent abstractions for atropos-agent.
|
||||
|
||||
Provides the core AtroposAgent class for running ReACT-style agent loops.
|
||||
"""
|
||||
|
||||
from .atropos_agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData
|
||||
|
||||
__all__ = [
|
||||
"AtroposAgent",
|
||||
"AgentConfig",
|
||||
"AgentResult",
|
||||
"AgentStep",
|
||||
"SequenceData",
|
||||
]
|
||||
@@ -0,0 +1,850 @@
|
||||
"""
|
||||
ReACT-style agent implementation for atropos-agent.
|
||||
|
||||
This module provides the core AtroposAgent class that implements a basic
|
||||
Reason-Act-Observe loop with tool calling capabilities.
|
||||
|
||||
Uses ManagedServer from atroposlib for automatic token/logprob tracking,
|
||||
making trajectories ready for RL training.
|
||||
|
||||
The agent uses Hermes-style XML tags for tool calls:
|
||||
- <think>...</think> for reasoning
|
||||
- <tool_call>{"name": "...", "arguments": {...}}</tool_call> for actions
|
||||
- <tool_response>...</tool_response> for observations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import uuid4
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
|
||||
from ..tools import ToolCall, ToolRegistry, ToolResult
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Default system prompt with tool calling instructions.
|
||||
AGENT_SYSTEM_PROMPT = """You are a deep thinking AI. You MUST enclose your internal reasoning inside <think>...</think> tags.
|
||||
|
||||
You are a function calling AI model.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags.
|
||||
You must call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.
|
||||
You can ONLY respond without a tool call if you are totally certain you have the final answer to the user's question or task
|
||||
After calling & executing a function, you will be provided with function results within <tool_response></tool_response> XML tags.
|
||||
|
||||
Here are the available tools:
|
||||
<tools>
|
||||
{tools_json}
|
||||
</tools>
|
||||
|
||||
Use the following JSON schema for each tool call you will make:
|
||||
{"title": "FunctionCall", "type": "object", "properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"]}
|
||||
|
||||
## REQUIRED TOOL FORMAT
|
||||
|
||||
When you decide to call a tool, your assistant message MUST be:
|
||||
1) exactly one <think>...</think> block, followed by
|
||||
2) one or more <tool_call>...</tool_call> blocks,
|
||||
and NOTHING else in that message.
|
||||
|
||||
If you need to explain anything, put it inside <think>. Do NOT write natural language outside <think> or <tool_call>.
|
||||
|
||||
For each function call return a JSON object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": "<function-name>", "arguments": {"arg1": "value1"}}
|
||||
</tool_call>
|
||||
|
||||
Each <tool_call> must be on its own and contain ONLY the JSON object (no extra text).
|
||||
The JSON inside <tool_call> MUST be valid JSON with double quotes.
|
||||
|
||||
Do NOT output <tool_response> in an assistant message.
|
||||
|
||||
After you receive tool results, you may either call more tools (same required format) or provide the final answer.
|
||||
When providing the final answer, do NOT include any <tool_call> blocks.
|
||||
|
||||
## TERMINAL TOOL NOTES
|
||||
|
||||
- Commands execute under POSIX `/bin/sh` (not bash).
|
||||
- Each tool call runs in a fresh shell: environment changes (like `cd` or venv activation) do not persist across tool calls.
|
||||
- Avoid bash-only features like `source`, `[[ ... ]]`, or process substitution.
|
||||
- Prefer explicit venv usage:
|
||||
- `python -m venv .venv && . .venv/bin/activate && python -m pip install -e .` (POSIX `.` activation), or
|
||||
- `.venv/bin/python -m pip install -e .` (no activation required).
|
||||
|
||||
## ICL (examples)
|
||||
|
||||
User: Show the current directory.
|
||||
Assistant:
|
||||
<think>I should run pwd.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "pwd"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "/tmp\\n"}</tool_response>
|
||||
Assistant: /tmp
|
||||
|
||||
User: List files, then count them.
|
||||
Assistant:
|
||||
<think>I should count files.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "ls -1 | wc -l"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "3\\n"}</tool_response>
|
||||
Assistant: 3
|
||||
|
||||
User: Run pwd, then print ok (two tool calls).
|
||||
Assistant:
|
||||
<think>I should run two commands.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "pwd"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "echo ok"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "/tmp\\n"}</tool_response>
|
||||
User: <tool_response>{"success": true, "output": "ok\\n"}</tool_response>
|
||||
Assistant: ok
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Configuration for the AtroposAgent."""
|
||||
|
||||
# Generation parameters
|
||||
temperature: Optional[float] = 0.7
|
||||
# Default to "let the backend decide" (important for tool-tag completions that may be longer).
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
# Agent behavior
|
||||
max_steps: int = 50
|
||||
system_prompt: Optional[str] = None
|
||||
tool_delay_s: float = 0.0
|
||||
|
||||
# Working directory for tools
|
||||
working_dir: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceData:
|
||||
"""Token/logprob data from a single completion."""
|
||||
|
||||
full_text: str
|
||||
tokens: List[int]
|
||||
masked_tokens: List[int] # -100 for prompt, actual IDs for completion
|
||||
logprobs: List[float] # 1.0 for prompt, actual values for completion
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_sequence_node(cls, node) -> "SequenceData":
|
||||
"""Create from a ManagedServer SequenceNode."""
|
||||
return cls(
|
||||
full_text=node.full_text,
|
||||
tokens=node.tokens,
|
||||
masked_tokens=node.masked_tokens,
|
||||
logprobs=node.logprobs,
|
||||
metadata=getattr(node, "metadata", None),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStep:
|
||||
"""A single step in the agent's trajectory."""
|
||||
|
||||
step_number: int
|
||||
assistant_message: str
|
||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||||
tool_results: List[ToolResult] = field(default_factory=list)
|
||||
sequence_data: Optional[SequenceData] = None # Token data from this step
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""Result of running an agent trajectory."""
|
||||
|
||||
success: bool
|
||||
final_response: str
|
||||
steps: List[AgentStep] = field(default_factory=list)
|
||||
total_tokens: int = 0
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Full trajectory token data for RL training
|
||||
trajectory_data: Optional[SequenceData] = None
|
||||
|
||||
@property
|
||||
def num_steps(self) -> int:
|
||||
return len(self.steps)
|
||||
|
||||
@property
|
||||
def total_tool_calls(self) -> int:
|
||||
return sum(len(step.tool_calls) for step in self.steps)
|
||||
|
||||
def to_messages(self) -> List[Dict[str, str]]:
|
||||
"""Convert trajectory to messages format for logging."""
|
||||
messages = []
|
||||
for step in self.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
# Combine all tool responses
|
||||
responses = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": responses})
|
||||
return messages
|
||||
|
||||
def to_scored_data(self, score: float) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Convert to format suitable for ScoredDataGroup.
|
||||
|
||||
Args:
|
||||
score: The score for this trajectory
|
||||
|
||||
Returns:
|
||||
Dict with tokens, masks, scores suitable for training, or None if no data
|
||||
"""
|
||||
if self.trajectory_data is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"tokens": self.trajectory_data.tokens,
|
||||
"masks": self.trajectory_data.masked_tokens,
|
||||
"scores": score,
|
||||
"logprobs": self.trajectory_data.logprobs,
|
||||
}
|
||||
|
||||
|
||||
class AtroposAgent:
|
||||
"""
|
||||
A ReACT-style agent that uses LLMs with tool calling.
|
||||
|
||||
This implementation wraps ManagedServer for automatic token/logprob tracking,
|
||||
making trajectories ready for RL training.
|
||||
|
||||
Example:
|
||||
# `server` may be an Atropos `ServerManager` (recommended) or a single `APIServer`.
|
||||
# In practice, environments usually construct this via `BaseEnv`.
|
||||
server = ...
|
||||
tools = ToolRegistry()
|
||||
tools.register(BashTool())
|
||||
|
||||
agent = AtroposAgent(server=server, tools=tools)
|
||||
result = await agent.run("List the files in the current directory")
|
||||
|
||||
# Access token data for training
|
||||
if result.trajectory_data:
|
||||
print(f"Tokens: {result.trajectory_data.tokens}")
|
||||
print(f"Masked: {result.trajectory_data.masked_tokens}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server, # ServerManager or APIServer
|
||||
tools: Optional[ToolRegistry] = None,
|
||||
config: Optional[AgentConfig] = None,
|
||||
tokenizer: Optional[Any] = None,
|
||||
execute_tool: Optional[Callable[[ToolCall], Awaitable[ToolResult]]] = None,
|
||||
):
|
||||
self.server = server
|
||||
self.tools = tools or ToolRegistry()
|
||||
self.config = config or AgentConfig()
|
||||
self.tokenizer = tokenizer or getattr(server, "tokenizer", None)
|
||||
self.execute_tool = execute_tool or self.tools.execute
|
||||
|
||||
@asynccontextmanager
|
||||
async def _managed(self) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
Yield a ManagedServer-like object.
|
||||
|
||||
- If `self.server` is a ServerManager, use its `managed_server()` context manager.
|
||||
- If `self.server` is a single APIServer, wrap it in `ManagedServer` directly.
|
||||
"""
|
||||
if os.getenv("ATROPOS_BYPASS_MANAGED_SERVER") == "1":
|
||||
yield _DirectChatCompletionClient(server=self.server)
|
||||
return
|
||||
if hasattr(self.server, "managed_server"):
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
yield managed
|
||||
else:
|
||||
managed = ManagedServer(server=self.server, tokenizer=self.tokenizer)
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
managed.reset()
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build the system prompt with tool descriptions."""
|
||||
if self.config.system_prompt:
|
||||
return self.config.system_prompt
|
||||
|
||||
tools_json = self.tools.get_prompt_tool_definitions_json()
|
||||
# Avoid `str.format()` here because the prompt contains many literal `{}` braces
|
||||
# in JSON examples; we only want to substitute the single `{tools_json}` token.
|
||||
return AGENT_SYSTEM_PROMPT.replace("{tools_json}", tools_json)
|
||||
|
||||
def _infer_server_model_for_debug(self) -> Optional[str]:
|
||||
"""
|
||||
Best-effort inference of the configured model name for debug payload saving.
|
||||
|
||||
ManagedServer/server_manager typically injects `model` internally, so `chat_kwargs`
|
||||
may not contain it. For replaying saved payloads via curl, it's useful to persist it.
|
||||
"""
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
model = getattr(cfg, "model_name", None) or getattr(s0, "model_name", None)
|
||||
if isinstance(model, str) and model:
|
||||
return model
|
||||
model = getattr(self.server, "model_name", None) or getattr(self.server, "model", None)
|
||||
if isinstance(model, str) and model:
|
||||
return model
|
||||
return None
|
||||
|
||||
def _infer_server_base_url_for_debug(self) -> Optional[str]:
|
||||
"""
|
||||
Best-effort inference of the configured base_url for debug logging.
|
||||
|
||||
This is helpful when diagnosing hangs / retries at the transport layer.
|
||||
"""
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
|
||||
if isinstance(base_url, str) and base_url:
|
||||
return base_url
|
||||
base_url = getattr(self.server, "base_url", None)
|
||||
if isinstance(base_url, str) and base_url:
|
||||
return base_url
|
||||
return None
|
||||
|
||||
def _extract_response_metadata(self, response: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract lightweight, JSON-serializable metadata from an OpenAI-style response.
|
||||
|
||||
This is useful for debugging training runs, especially when ManagedServer state
|
||||
tracking is unavailable (e.g. OpenAI-compatible chat endpoints).
|
||||
"""
|
||||
meta: Dict[str, Any] = {}
|
||||
try:
|
||||
rid = getattr(response, "id", None)
|
||||
if isinstance(rid, str) and rid:
|
||||
meta["id"] = rid
|
||||
model = getattr(response, "model", None)
|
||||
if isinstance(model, str) and model:
|
||||
meta["model"] = model
|
||||
created = getattr(response, "created", None)
|
||||
if isinstance(created, int):
|
||||
meta["created"] = created
|
||||
system_fingerprint = getattr(response, "system_fingerprint", None)
|
||||
if isinstance(system_fingerprint, str) and system_fingerprint:
|
||||
meta["system_fingerprint"] = system_fingerprint
|
||||
|
||||
choices = getattr(response, "choices", None)
|
||||
if isinstance(choices, list) and choices:
|
||||
fr = getattr(choices[0], "finish_reason", None)
|
||||
if isinstance(fr, str) and fr:
|
||||
meta["finish_reason"] = fr
|
||||
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage is not None:
|
||||
if hasattr(usage, "model_dump"):
|
||||
meta["usage"] = usage.model_dump()
|
||||
elif isinstance(usage, dict):
|
||||
meta["usage"] = usage
|
||||
except Exception:
|
||||
pass
|
||||
return meta
|
||||
|
||||
def _debug_dump_request(self, *, step_num: int, chat_kwargs: Dict[str, Any]) -> None:
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST") != "1":
|
||||
return
|
||||
try:
|
||||
# Avoid dumping megabytes by default; messages can be huge.
|
||||
meta = {
|
||||
"step": step_num,
|
||||
"base_url": self._infer_server_base_url_for_debug(),
|
||||
"model": chat_kwargs.get("model") or self._infer_server_model_for_debug(),
|
||||
"chat_kwargs_keys": sorted(list(chat_kwargs.keys())),
|
||||
"n": chat_kwargs.get("n"),
|
||||
"max_tokens": chat_kwargs.get("max_tokens"),
|
||||
"temperature": chat_kwargs.get("temperature"),
|
||||
"num_messages": len(chat_kwargs.get("messages") or []),
|
||||
}
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_REQUEST ===", flush=True)
|
||||
print(meta, flush=True)
|
||||
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST_FULL") == "1":
|
||||
payload = dict(chat_kwargs)
|
||||
# Make the payload more legible and less huge.
|
||||
try:
|
||||
dumped = json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
dumped = repr(payload)
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_REQUEST_FULL ===", flush=True)
|
||||
print(dumped[:200_000], flush=True)
|
||||
|
||||
# Optional: save the FULL request payload to disk (no truncation).
|
||||
save_dir = os.getenv("ATROPOS_DEBUG_AGENT_REQUEST_SAVE_DIR")
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
payload: Dict[str, Any] = dict(chat_kwargs)
|
||||
if "model" not in payload:
|
||||
model = self._infer_server_model_for_debug()
|
||||
if model:
|
||||
payload["model"] = model
|
||||
# Use a unique filename so parallel trajectories don't clobber each other.
|
||||
fname = os.path.join(
|
||||
save_dir,
|
||||
f"atropos_agent_request_step{step_num}_{int(time.time()*1000)}_{os.getpid()}_{uuid4().hex}.json",
|
||||
)
|
||||
with open(fname, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
print(f"[AtroposAgent] saved request payload: {fname}", flush=True)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _debug_dump_response(self, *, step_num: int, response: Any) -> None:
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_RESPONSE") != "1":
|
||||
return
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_RESPONSE ===", flush=True)
|
||||
print({"step": step_num, "type": type(response).__name__}, flush=True)
|
||||
try:
|
||||
dumped = response.model_dump() # openai pydantic model
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
# Keep the dump bounded; we only need enough to see the assistant message content.
|
||||
text = str(dumped)
|
||||
print(text[:200_000], flush=True)
|
||||
|
||||
async def _chat_completion_with_debug(
|
||||
self, *, managed: Any, step_num: int, chat_kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Call `managed.chat_completion()` with optional timeout + richer failure logging.
|
||||
|
||||
Debug env vars:
|
||||
- `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`.
|
||||
- `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting.
|
||||
"""
|
||||
# Hard guardrail: never allow a single chat completion to block for too long.
|
||||
# This is essential for RL data-gen stability; long hangs should be treated as failures (score=0).
|
||||
timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S")
|
||||
timeout_s_default = 240.0
|
||||
timeout_s = float(timeout_s_raw) if timeout_s_raw else timeout_s_default
|
||||
timeout_s = min(timeout_s, 240.0)
|
||||
|
||||
wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S")
|
||||
wait_every_s = float(wait_every_raw) if wait_every_raw else None
|
||||
|
||||
async def _await_call() -> Any:
|
||||
if not wait_every_s or wait_every_s <= 0:
|
||||
return await managed.chat_completion(**chat_kwargs)
|
||||
|
||||
# Heartbeat mode: wait in chunks without cancelling the underlying request.
|
||||
# NOTE: do NOT use `asyncio.wait_for(task, timeout=...)` here, because a timeout
|
||||
# will cancel the task and surface as `CancelledError` on the next loop.
|
||||
task = asyncio.create_task(managed.chat_completion(**chat_kwargs))
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
while True:
|
||||
done, _pending = await asyncio.wait({task}, timeout=wait_every_s)
|
||||
if task in done:
|
||||
return task.result()
|
||||
|
||||
waited = time.perf_counter() - t0
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num} still waiting for chat_completion... ({waited:.1f}s)",
|
||||
flush=True,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
task.cancel()
|
||||
raise
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(_await_call(), timeout=timeout_s)
|
||||
except asyncio.TimeoutError as e:
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_TIMEOUT ===", flush=True)
|
||||
print({"step": step_num, "timeout_s": timeout_s}, flush=True)
|
||||
raise RuntimeError(f"chat_completion timed out after {timeout_s:.1f}s") from e
|
||||
except asyncio.CancelledError:
|
||||
# Treat cancellation as a hard failure rather than crashing the whole env run.
|
||||
# (Atropos/BaseEnv may cancel tasks during shutdown or retries.)
|
||||
raise RuntimeError("chat_completion cancelled") from None
|
||||
except Exception as e:
|
||||
detail: Dict[str, Any] = {
|
||||
"step": step_num,
|
||||
"exc_type": type(e).__name__,
|
||||
"exc_str": str(e),
|
||||
}
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
try:
|
||||
detail["status_code"] = e.response.status_code
|
||||
detail["response_text"] = e.response.text[:20_000]
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(e, httpx.RequestError):
|
||||
detail["request"] = repr(getattr(e, "request", None))
|
||||
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_FAILURE ===", flush=True)
|
||||
print(detail, flush=True)
|
||||
raise
|
||||
|
||||
async def run(
|
||||
self,
|
||||
task: str,
|
||||
initial_messages: Optional[List[Dict[str, str]]] = None,
|
||||
) -> AgentResult:
|
||||
"""
|
||||
Run the agent on a task using ManagedServer for token tracking.
|
||||
|
||||
Args:
|
||||
task: The task/prompt for the agent
|
||||
initial_messages: Optional additional context messages
|
||||
|
||||
Returns:
|
||||
AgentResult with the trajectory, final response, and token data
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": self._build_system_prompt()},
|
||||
]
|
||||
|
||||
if initial_messages:
|
||||
messages.extend(initial_messages)
|
||||
|
||||
messages.append({"role": "user", "content": task})
|
||||
|
||||
steps = []
|
||||
final_response = ""
|
||||
final_node = None
|
||||
final_prompt_messages: Optional[List[Dict[str, str]]] = None
|
||||
last_node = None
|
||||
last_prompt_messages: Optional[List[Dict[str, str]]] = None
|
||||
last_response_text: str = ""
|
||||
|
||||
# Use ManagedServer for automatic token tracking
|
||||
async with self._managed() as managed:
|
||||
for step_num in range(self.config.max_steps):
|
||||
# ReACT loop iteration here, just call -> tools -> observe until done (no tools called)
|
||||
try:
|
||||
# Keep a copy of the prompt messages used for this completion.
|
||||
# Useful for reconstructing tokens/masks when state tracking is unavailable.
|
||||
prompt_messages = list(messages)
|
||||
chat_kwargs: Dict[str, Any] = {"messages": messages, "n": 1}
|
||||
if self.config.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.config.max_tokens
|
||||
if self.config.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.config.temperature
|
||||
|
||||
t_req = time.perf_counter()
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num+1} chat_completion start "
|
||||
f"(messages={len(messages)}, max_tokens={self.config.max_tokens}, temp={self.config.temperature})",
|
||||
flush=True,
|
||||
)
|
||||
self._debug_dump_request(step_num=step_num + 1, chat_kwargs=chat_kwargs)
|
||||
response = await self._chat_completion_with_debug(
|
||||
managed=managed, step_num=step_num + 1, chat_kwargs=chat_kwargs
|
||||
)
|
||||
self._debug_dump_response(step_num=step_num + 1, response=response)
|
||||
response_meta = self._extract_response_metadata(response)
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
current_node = None
|
||||
if hasattr(managed, "get_state"):
|
||||
state = managed.get_state()
|
||||
nodes = state.get("nodes", [])
|
||||
current_node = nodes[-1] if nodes else None
|
||||
|
||||
except Exception as e:
|
||||
return AgentResult(
|
||||
success=False,
|
||||
final_response="",
|
||||
steps=steps,
|
||||
error=f"Generation error: {str(e)}",
|
||||
)
|
||||
|
||||
msg = response.choices[0].message
|
||||
# Some OpenAI-compatible servers populate `message.reasoning` and leave `content=""`.
|
||||
response_text = (msg.content or "") or (getattr(msg, "reasoning", None) or "")
|
||||
tool_calls = ToolCall.parse_from_text(response_text)
|
||||
last_node = current_node
|
||||
last_prompt_messages = prompt_messages
|
||||
last_response_text = response_text
|
||||
|
||||
step_sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
|
||||
if step_sequence_data is None:
|
||||
if response_meta:
|
||||
# We still want metadata for debugging even if token/logprob state tracking is unavailable.
|
||||
step_sequence_data = SequenceData(
|
||||
full_text=response_text,
|
||||
tokens=[],
|
||||
masked_tokens=[],
|
||||
logprobs=[],
|
||||
metadata=response_meta,
|
||||
)
|
||||
else:
|
||||
merged = dict(response_meta)
|
||||
node_meta = step_sequence_data.metadata
|
||||
if isinstance(node_meta, dict):
|
||||
merged.update(node_meta)
|
||||
step_sequence_data.metadata = merged or step_sequence_data.metadata
|
||||
|
||||
step = AgentStep(
|
||||
step_number=step_num + 1,
|
||||
assistant_message=response_text,
|
||||
tool_calls=tool_calls,
|
||||
sequence_data=step_sequence_data,
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
steps.append(step)
|
||||
final_response = response_text
|
||||
final_node = current_node
|
||||
final_prompt_messages = prompt_messages
|
||||
break
|
||||
|
||||
messages.append({"role": "assistant", "content": response_text})
|
||||
|
||||
tool_responses = []
|
||||
for call in tool_calls:
|
||||
result = await self.execute_tool(call)
|
||||
step.tool_results.append(result)
|
||||
tool_responses.append(result.to_xml())
|
||||
if self.config.tool_delay_s > 0:
|
||||
await asyncio.sleep(self.config.tool_delay_s)
|
||||
|
||||
steps.append(step)
|
||||
|
||||
responses_text = "\n".join(tool_responses)
|
||||
# Tool observations are represented as user content with Hermes-style tags.
|
||||
# This is compatible with most OpenAI-compatible chat APIs and ensures
|
||||
# tokenizers/chat templates include tool outputs during training.
|
||||
messages.append({"role": "user", "content": responses_text})
|
||||
|
||||
else:
|
||||
# Reached max steps without completing
|
||||
# Return a failure result but include the last observed completion so callers can
|
||||
# record the trajectory (score=0) without triggering retries.
|
||||
final_response = last_response_text or final_response
|
||||
final_node = last_node
|
||||
final_prompt_messages = last_prompt_messages
|
||||
trajectory_data = None
|
||||
if final_node:
|
||||
trajectory_data = SequenceData.from_sequence_node(final_node)
|
||||
elif final_prompt_messages is not None and self.tokenizer is not None:
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
prompt_text = self.tokenizer.apply_chat_template(
|
||||
final_prompt_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
else:
|
||||
prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in final_prompt_messages])
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True)
|
||||
output_tokens = self.tokenizer.encode(final_response, add_special_tokens=False)
|
||||
tokens = prompt_tokens + output_tokens
|
||||
masked_tokens = ([-100] * len(prompt_tokens)) + output_tokens
|
||||
logprobs = ([1.0] * len(prompt_tokens)) + ([0.0] * len(output_tokens))
|
||||
trajectory_data = SequenceData(
|
||||
full_text=f"{prompt_text}{final_response}",
|
||||
tokens=tokens,
|
||||
masked_tokens=masked_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# Preserve response metadata (if any) even on failure trajectories.
|
||||
try:
|
||||
if trajectory_data is not None and steps:
|
||||
last_step = steps[-1]
|
||||
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
|
||||
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
|
||||
except Exception:
|
||||
pass
|
||||
return AgentResult(
|
||||
success=False,
|
||||
final_response=final_response,
|
||||
steps=steps,
|
||||
error=f"Reached maximum steps ({self.config.max_steps})",
|
||||
trajectory_data=trajectory_data,
|
||||
)
|
||||
|
||||
# Build result with trajectory data
|
||||
trajectory_data = None
|
||||
if final_node:
|
||||
trajectory_data = SequenceData.from_sequence_node(final_node)
|
||||
elif final_prompt_messages is not None and self.tokenizer is not None:
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
prompt_text = self.tokenizer.apply_chat_template(
|
||||
final_prompt_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
else:
|
||||
prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in final_prompt_messages])
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True)
|
||||
output_tokens = self.tokenizer.encode(final_response, add_special_tokens=False)
|
||||
tokens = prompt_tokens + output_tokens
|
||||
masked_tokens = ([-100] * len(prompt_tokens)) + output_tokens
|
||||
logprobs = ([1.0] * len(prompt_tokens)) + ([0.0] * len(output_tokens))
|
||||
trajectory_data = SequenceData(
|
||||
full_text=f"{prompt_text}{final_response}",
|
||||
tokens=tokens,
|
||||
masked_tokens=masked_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
# Ensure trajectory_data carries the most recent metadata we observed (if any).
|
||||
try:
|
||||
if trajectory_data is not None and steps:
|
||||
last_step = steps[-1]
|
||||
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
|
||||
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
final_response=final_response,
|
||||
steps=steps,
|
||||
trajectory_data=trajectory_data,
|
||||
)
|
||||
|
||||
async def run_single_turn(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
execute_tools: bool = True,
|
||||
) -> tuple[str, List[ToolResult], Optional[SequenceData]]:
|
||||
"""
|
||||
Run a single turn of the agent (one LLM call + tool execution).
|
||||
|
||||
This is useful for integration with BaseEnv where you want more
|
||||
control over the loop.
|
||||
|
||||
Args:
|
||||
messages: The conversation history
|
||||
execute_tools: Whether to execute parsed tool calls
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, tool_results, sequence_data)
|
||||
"""
|
||||
async with self._managed() as managed:
|
||||
chat_kwargs: Dict[str, Any] = {"messages": messages, "n": 1}
|
||||
if self.config.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.config.max_tokens
|
||||
if self.config.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.config.temperature
|
||||
|
||||
self._debug_dump_request(step_num=1, chat_kwargs=chat_kwargs)
|
||||
response = await self._chat_completion_with_debug(managed=managed, step_num=1, chat_kwargs=chat_kwargs)
|
||||
self._debug_dump_response(step_num=1, response=response)
|
||||
|
||||
current_node = None
|
||||
if hasattr(managed, "get_state"):
|
||||
state = managed.get_state()
|
||||
nodes = state.get("nodes", [])
|
||||
current_node = nodes[-1] if nodes else None
|
||||
|
||||
msg = response.choices[0].message
|
||||
response_text = (msg.content or "") or (getattr(msg, "reasoning", None) or "")
|
||||
tool_results = []
|
||||
|
||||
if execute_tools:
|
||||
tool_calls = ToolCall.parse_from_text(response_text)
|
||||
for call in tool_calls:
|
||||
result = await self.execute_tool(call)
|
||||
tool_results.append(result)
|
||||
|
||||
sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
|
||||
|
||||
return response_text, tool_results, sequence_data
|
||||
|
||||
|
||||
class _DirectChatCompletionClient:
|
||||
"""
|
||||
Minimal stand-in for ManagedServer that calls the OpenAI-compatible endpoint directly.
|
||||
|
||||
This is for isolating issues where `ManagedServer.chat_completion()` hangs or misbehaves.
|
||||
It intentionally does NOT do token/logprob tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, server: Any):
|
||||
self._server = server
|
||||
|
||||
def _server_config(self) -> tuple[str, str, str]:
|
||||
# ServerManager case: first configured server.
|
||||
servers = getattr(self._server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
|
||||
api_key = getattr(cfg, "api_key", None) or getattr(s0, "api_key", None)
|
||||
model = getattr(cfg, "model_name", None) or getattr(s0, "model_name", None)
|
||||
if isinstance(base_url, str) and isinstance(api_key, str) and isinstance(model, str):
|
||||
return base_url.rstrip("/"), api_key, model
|
||||
|
||||
# APIServer-like fallback.
|
||||
base_url = getattr(self._server, "base_url", None)
|
||||
api_key = getattr(self._server, "api_key", None)
|
||||
model = getattr(self._server, "model_name", None) or getattr(self._server, "model", None)
|
||||
if isinstance(base_url, str) and isinstance(api_key, str) and isinstance(model, str):
|
||||
return base_url.rstrip("/"), api_key, model
|
||||
|
||||
raise RuntimeError("Unable to resolve server base_url/api_key/model for direct chat completion")
|
||||
|
||||
async def chat_completion(self, *, messages: List[Dict[str, str]], n: int = 1, **kwargs: Any) -> Any:
|
||||
base_url, api_key, model = self._server_config()
|
||||
url = f"{base_url}/chat/completions"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"n": n,
|
||||
}
|
||||
# Pass through common generation kwargs.
|
||||
for k in ("max_tokens", "temperature", "top_p", "presence_penalty", "frequency_penalty", "stop"):
|
||||
if k in kwargs and kwargs[k] is not None:
|
||||
payload[k] = kwargs[k]
|
||||
|
||||
timeout_s = float(os.getenv("ATROPOS_DIRECT_REQUEST_TIMEOUT_S") or "120")
|
||||
print(f"[AtroposAgent] DIRECT chat_completion POST {url} (timeout={timeout_s}s)", flush=True)
|
||||
async with httpx.AsyncClient(timeout=timeout_s) as client:
|
||||
resp = await client.post(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Return a very small object compatible with the code paths that read
|
||||
# `response.choices[0].message.content`.
|
||||
class _Msg:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self.content = d.get("content")
|
||||
self.reasoning = d.get("reasoning")
|
||||
|
||||
class _Choice:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self.message = _Msg(d.get("message") or {})
|
||||
|
||||
class _Resp:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self._d = d
|
||||
self.choices = [_Choice(c) for c in (d.get("choices") or [])]
|
||||
|
||||
def model_dump(self) -> Dict[str, Any]:
|
||||
return self._d
|
||||
|
||||
return _Resp(data)
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
FastAPI services for atropos-agent.
|
||||
|
||||
- tool_executor_server: queued/batched sandbox tool execution (Phase 4)
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Tool Executor API (Phase 4)
|
||||
|
||||
This service provides a queued, batched execution layer on top of a ToolBackend.
|
||||
It mirrors the stateful FastAPI + app.state pattern used in:
|
||||
atropos/atroposlib/api/server.py
|
||||
|
||||
Run (dev):
|
||||
uv run uvicorn atropos_agent.api.tool_executor_server:app --host 0.0.0.0 --port 9001
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..backends.nomad_backend import NomadBackendConfig, NomadToolBackend
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.base import (
|
||||
ArtifactArchiveRequestPayload,
|
||||
ArtifactArchiveResponsePayload,
|
||||
ArtifactListRequestPayload,
|
||||
ArtifactListResponsePayload,
|
||||
ArtifactReadRequestPayload,
|
||||
ArtifactReadResponsePayload,
|
||||
ToolExecutorExecuteRequest,
|
||||
ToolExecutorReleaseRequest,
|
||||
ToolResultPayload,
|
||||
)
|
||||
from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig
|
||||
|
||||
|
||||
class ToolExecutorServerConfig(BaseModel):
|
||||
nomad_address: str = Field(default="http://localhost:4646")
|
||||
job_id: str = Field(default="atropos-sandbox-tool-executor")
|
||||
image: str = Field(default="atropos-sandbox:local")
|
||||
slots_per_container: int = Field(default=10)
|
||||
min_containers: int = Field(default=1)
|
||||
max_containers: int = Field(default=10)
|
||||
privileged: bool = Field(default=False)
|
||||
acquire_timeout_s: float = Field(default=30.0)
|
||||
|
||||
batch_window_ms: int = Field(default=20)
|
||||
max_batch_size: int = Field(default=200)
|
||||
allow_network: bool = Field(default=True)
|
||||
|
||||
tool_server_url: Optional[str] = Field(default=None)
|
||||
tool_server_token: Optional[str] = Field(default=None)
|
||||
|
||||
token: Optional[str] = Field(default=None, description="Bearer token required for requests (optional in dev).")
|
||||
|
||||
purge_job_on_shutdown: bool = Field(default=True)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "ToolExecutorServerConfig":
|
||||
# In dev, prefer loading secrets/config from the repo-local `.env` (not committed).
|
||||
try:
|
||||
from dotenv import load_dotenv # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
load_dotenv = None # type: ignore[assignment]
|
||||
if load_dotenv is not None:
|
||||
env_path = Path(__file__).resolve().parents[2] / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
def _get_bool(name: str, default: bool) -> bool:
|
||||
raw = os.getenv(name)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in {"1", "true", "yes", "y", "on"}
|
||||
|
||||
return cls(
|
||||
nomad_address=os.getenv("TOOL_EXECUTOR_NOMAD_ADDRESS", "http://localhost:4646"),
|
||||
job_id=os.getenv("TOOL_EXECUTOR_JOB_ID", "atropos-sandbox-tool-executor"),
|
||||
image=os.getenv("TOOL_EXECUTOR_IMAGE", "atropos-sandbox:local"),
|
||||
slots_per_container=int(os.getenv("TOOL_EXECUTOR_SLOTS", "10")),
|
||||
min_containers=int(os.getenv("TOOL_EXECUTOR_MIN_CONTAINERS", "1")),
|
||||
max_containers=int(os.getenv("TOOL_EXECUTOR_MAX_CONTAINERS", "10")),
|
||||
privileged=_get_bool("TOOL_EXECUTOR_PRIVILEGED", False),
|
||||
acquire_timeout_s=float(os.getenv("TOOL_EXECUTOR_ACQUIRE_TIMEOUT_S", "30.0")),
|
||||
batch_window_ms=int(os.getenv("TOOL_EXECUTOR_BATCH_WINDOW_MS", "20")),
|
||||
max_batch_size=int(os.getenv("TOOL_EXECUTOR_MAX_BATCH_SIZE", "200")),
|
||||
allow_network=_get_bool("TOOL_EXECUTOR_ALLOW_NETWORK", True),
|
||||
tool_server_url=os.getenv("TOOL_EXECUTOR_TOOL_SERVER_URL") or None,
|
||||
tool_server_token=os.getenv("TOOL_EXECUTOR_TOOL_SERVER_TOKEN") or None,
|
||||
token=os.getenv("TOOL_EXECUTOR_TOKEN") or None,
|
||||
purge_job_on_shutdown=_get_bool("TOOL_EXECUTOR_PURGE_JOB_ON_SHUTDOWN", True),
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI(title="Atropos-Agent Tool Executor")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, str]:
|
||||
return {"message": "Atropos-Agent Tool Executor"}
|
||||
|
||||
|
||||
def _check_auth(cfg: ToolExecutorServerConfig, authorization: Optional[str]) -> None:
|
||||
if not cfg.token:
|
||||
return
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header")
|
||||
if not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header")
|
||||
token = authorization.split(" ", 1)[1].strip()
|
||||
if token != cfg.token:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup() -> None:
|
||||
cfg = ToolExecutorServerConfig.from_env()
|
||||
|
||||
# Default to Atropos "full" tool surface: sandbox + external (if tool_server_url provided).
|
||||
tools: ToolRegistry = build_tool_registry(
|
||||
enabled_toolsets=["full"],
|
||||
disabled_toolsets=None,
|
||||
tool_server_url=cfg.tool_server_url,
|
||||
)
|
||||
|
||||
backend = NomadToolBackend(
|
||||
NomadBackendConfig(
|
||||
nomad_address=cfg.nomad_address,
|
||||
sandbox_job_id=cfg.job_id,
|
||||
sandbox_image=cfg.image,
|
||||
slots_per_container=cfg.slots_per_container,
|
||||
min_containers=cfg.min_containers,
|
||||
max_containers=cfg.max_containers,
|
||||
privileged=cfg.privileged,
|
||||
acquire_timeout_s=cfg.acquire_timeout_s,
|
||||
purge_job_on_start=False,
|
||||
)
|
||||
)
|
||||
await backend.start()
|
||||
|
||||
executor = ToolExecutor(
|
||||
backend=backend,
|
||||
tools=tools,
|
||||
config=ToolExecutorConfig(
|
||||
batch_window_ms=cfg.batch_window_ms,
|
||||
max_batch_size=cfg.max_batch_size,
|
||||
allow_network=cfg.allow_network,
|
||||
tool_server_url=cfg.tool_server_url,
|
||||
tool_server_token=cfg.tool_server_token,
|
||||
),
|
||||
)
|
||||
await executor.start()
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.backend = backend
|
||||
app.state.executor = executor
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def _shutdown() -> None:
|
||||
executor: Optional[ToolExecutor] = getattr(app.state, "executor", None)
|
||||
backend: Optional[NomadToolBackend] = getattr(app.state, "backend", None)
|
||||
cfg: Optional[ToolExecutorServerConfig] = getattr(app.state, "cfg", None)
|
||||
|
||||
if executor is not None:
|
||||
await executor.close()
|
||||
|
||||
if backend is not None:
|
||||
await backend.stop(purge=bool(cfg.purge_job_on_shutdown) if cfg else False)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/status")
|
||||
async def status_endpoint() -> Dict[str, Any]:
|
||||
executor: ToolExecutor = app.state.executor
|
||||
backend: NomadToolBackend = app.state.backend
|
||||
|
||||
return {
|
||||
"queue_size": executor.queue_size(),
|
||||
"total_requests": executor.total_requests,
|
||||
"total_errors": executor.total_errors,
|
||||
"pool": backend.get_stats(),
|
||||
}
|
||||
|
||||
|
||||
@app.post("/execute", response_model=ToolResultPayload)
|
||||
async def execute_tool(
|
||||
req: ToolExecutorExecuteRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
status_code: int = status.HTTP_200_OK, # noqa: B008
|
||||
) -> ToolResultPayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
result = await executor.execute(
|
||||
trajectory_id=req.trajectory_id,
|
||||
call=req.tool.to_tool_call(),
|
||||
timeout_s=req.timeout_s,
|
||||
)
|
||||
return ToolResultPayload.from_tool_result(result)
|
||||
|
||||
|
||||
@app.post("/release")
|
||||
async def release_trajectory(
|
||||
req: ToolExecutorReleaseRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> Dict[str, Any]:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
await executor.release_trajectory(req.trajectory_id, reset_workspace=req.reset_workspace)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/artifacts/read", response_model=ArtifactReadResponsePayload)
|
||||
async def artifacts_read(
|
||||
req: ArtifactReadRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactReadResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.read_artifact(req)
|
||||
|
||||
|
||||
@app.post("/artifacts/list", response_model=ArtifactListResponsePayload)
|
||||
async def artifacts_list(
|
||||
req: ArtifactListRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactListResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.list_artifacts(req)
|
||||
|
||||
|
||||
@app.post("/artifacts/archive", response_model=ArtifactArchiveResponsePayload)
|
||||
async def artifacts_archive(
|
||||
req: ArtifactArchiveRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactArchiveResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.archive_artifacts(req)
|
||||
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
External ToolServer (Phase 4.5+).
|
||||
|
||||
This server executes tools that must NOT run inside the sandbox, typically
|
||||
because they require credentials or access to external services.
|
||||
|
||||
Run (dev):
|
||||
uv run uvicorn atropos_agent.api.tool_server:app --host 0.0.0.0 --port 9002
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.base import ToolResultPayload, ToolServerExecuteRequest
|
||||
|
||||
|
||||
class ToolServerConfig(BaseModel):
|
||||
token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Bearer token required for requests (optional in dev).",
|
||||
)
|
||||
max_concurrency: int = Field(default=16, ge=1, description="Max concurrent tool executions.")
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "ToolServerConfig":
|
||||
# In dev, prefer loading secrets from the repo-local `.env` (not committed).
|
||||
try:
|
||||
from dotenv import load_dotenv # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
load_dotenv = None # type: ignore[assignment]
|
||||
if load_dotenv is not None:
|
||||
env_path = Path(__file__).resolve().parents[2] / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
token = os.getenv("TOOL_SERVER_TOKEN") or None
|
||||
max_concurrency = int(os.getenv("TOOL_SERVER_MAX_CONCURRENCY", "16"))
|
||||
return cls(token=token, max_concurrency=max_concurrency)
|
||||
|
||||
|
||||
app = FastAPI(title="Atropos-Agent Tool Server")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, str]:
|
||||
return {"message": "Atropos-Agent Tool Server"}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup() -> None:
|
||||
cfg = ToolServerConfig.from_env()
|
||||
|
||||
# External-only registry. It will only include tools that are enabled by toolsets and
|
||||
# whose Hermes requirements/keys are satisfied in this process.
|
||||
tools: ToolRegistry = build_tool_registry(
|
||||
enabled_toolsets=["all"],
|
||||
disabled_toolsets=["terminal", "sandbox", "filesystem", "terminal_stateful", "default"],
|
||||
tool_server_url="enabled",
|
||||
)
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.tools = tools
|
||||
app.state.semaphore = asyncio.Semaphore(cfg.max_concurrency)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/tools")
|
||||
async def list_tools() -> Dict[str, Any]:
|
||||
tools: ToolRegistry = app.state.tools
|
||||
return {"tools": [s.to_dict() for s in tools.get_schemas()]}
|
||||
|
||||
|
||||
def _check_auth(cfg: ToolServerConfig, authorization: Optional[str]) -> None:
|
||||
if not cfg.token:
|
||||
return
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header")
|
||||
if not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header")
|
||||
token = authorization.split(" ", 1)[1].strip()
|
||||
if token != cfg.token:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
|
||||
@app.post("/execute", response_model=ToolResultPayload)
|
||||
async def execute_tool(
|
||||
req: ToolServerExecuteRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ToolResultPayload:
|
||||
cfg: ToolServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
tools: ToolRegistry = app.state.tools
|
||||
sem: asyncio.Semaphore = app.state.semaphore
|
||||
|
||||
tool = tools.get(req.tool.name)
|
||||
if tool is None:
|
||||
return ToolResultPayload(
|
||||
success=False,
|
||||
error=f"Unknown tool: {req.tool.name}",
|
||||
uniq_id=req.tool.uniq_id,
|
||||
)
|
||||
|
||||
async with sem:
|
||||
try:
|
||||
kwargs = dict(req.tool.arguments)
|
||||
sig = inspect.signature(tool.execute).parameters
|
||||
# Some tools can benefit from extra context.
|
||||
if req.trajectory_id and "trajectory_id" in sig:
|
||||
kwargs["trajectory_id"] = req.trajectory_id
|
||||
if req.slot_id and "slot_id" in sig:
|
||||
kwargs["slot_id"] = req.slot_id
|
||||
if req.container_addr and "container_addr" in sig:
|
||||
kwargs["container_addr"] = req.container_addr
|
||||
if "task_id" in sig:
|
||||
kwargs["task_id"] = req.trajectory_id
|
||||
result = await tool.execute(**kwargs)
|
||||
except Exception as e:
|
||||
return ToolResultPayload(
|
||||
success=False,
|
||||
error=f"Tool execution error: {e}",
|
||||
uniq_id=req.tool.uniq_id,
|
||||
)
|
||||
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = req.tool.uniq_id
|
||||
return ToolResultPayload.from_tool_result(result)
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import ToolBackend
|
||||
from .modal_backend import ModalSandboxConfig, ModalToolBackend
|
||||
from .nomad_backend import NomadBackendConfig, NomadToolBackend
|
||||
|
||||
|
||||
def create_tool_backend(cfg: Any) -> ToolBackend:
|
||||
mode = str(getattr(cfg, "tool_pool_mode", "nomad")).strip().lower()
|
||||
if mode == "nomad":
|
||||
return NomadToolBackend(NomadBackendConfig.from_agent_env_config(cfg))
|
||||
if mode == "modal":
|
||||
return ModalToolBackend(ModalSandboxConfig.from_agent_env_config(cfg))
|
||||
raise ValueError(f"Unknown tool_pool_mode: {mode}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolBackend",
|
||||
"create_tool_backend",
|
||||
"NomadBackendConfig",
|
||||
"NomadToolBackend",
|
||||
"ModalSandboxConfig",
|
||||
"ModalToolBackend",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Backend interfaces for AgentEnv tool execution.
|
||||
|
||||
The goal of this module is to decouple ToolExecutor / AgentEnv from any single
|
||||
execution backend (Nomad/Docker today; Modal later).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
from ..slots.executor import ExecutionResult
|
||||
from ..slots.slot import Slot
|
||||
|
||||
|
||||
class ToolBackend(Protocol):
|
||||
"""
|
||||
Minimal interface required by ToolExecutor.
|
||||
|
||||
Backends provide:
|
||||
- lifecycle (start/stop)
|
||||
- slot acquisition/release (workspace affinity)
|
||||
- batched tool execution across slots
|
||||
- optional artifact helpers (for env verification / demos)
|
||||
"""
|
||||
|
||||
@property
|
||||
def default_timeout_s(self) -> Optional[float]:
|
||||
"""Default sandbox execution timeout in seconds (if any)."""
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the backend (provision workers/containers, health checks, etc)."""
|
||||
|
||||
async def stop(self, *, purge: bool = False) -> None:
|
||||
"""Stop the backend and optionally purge remote resources."""
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
"""Acquire a slot for a trajectory (workspace affinity)."""
|
||||
|
||||
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None:
|
||||
"""Release a slot back to the pool."""
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
*,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""Execute a batch of sandbox tool calls and return results in order."""
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Optional artifact helpers (supported by the Nomad sandbox-server today)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Nomad/Docker tool backend.
|
||||
|
||||
This backend is the current default for AgentEnv: it provisions a Nomad job
|
||||
running `sandbox_server.py` and multiplexes stateless slots inside each container.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..slots import Slot, SlotPool, SlotPoolConfig
|
||||
from ..slots.executor import ExecutionResult
|
||||
from .base import ToolBackend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NomadBackendConfig:
|
||||
nomad_address: str
|
||||
sandbox_job_id: str
|
||||
sandbox_image: str
|
||||
slots_per_container: int
|
||||
min_containers: int
|
||||
max_containers: int
|
||||
privileged: bool
|
||||
acquire_timeout_s: float
|
||||
purge_job_on_start: bool
|
||||
# Driver selection: "docker" or "singularity"
|
||||
driver: str = "docker"
|
||||
# Path to .sif file for singularity driver (required if driver="singularity")
|
||||
singularity_image: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_agent_env_config(cls, cfg: Any) -> "NomadBackendConfig":
|
||||
return cls(
|
||||
nomad_address=str(getattr(cfg, "nomad_address")),
|
||||
sandbox_job_id=str(getattr(cfg, "sandbox_job_id")),
|
||||
sandbox_image=str(getattr(cfg, "sandbox_image")),
|
||||
slots_per_container=int(getattr(cfg, "slots_per_container")),
|
||||
min_containers=int(getattr(cfg, "min_containers")),
|
||||
max_containers=int(getattr(cfg, "max_containers")),
|
||||
privileged=bool(getattr(cfg, "privileged")),
|
||||
acquire_timeout_s=float(getattr(cfg, "acquire_timeout_s")),
|
||||
purge_job_on_start=bool(getattr(cfg, "purge_job_on_start", False)),
|
||||
driver=str(getattr(cfg, "driver", "docker")),
|
||||
singularity_image=getattr(cfg, "singularity_image", None),
|
||||
)
|
||||
|
||||
|
||||
class NomadToolBackend(ToolBackend):
|
||||
def __init__(self, config: NomadBackendConfig):
|
||||
self.config = config
|
||||
self.pool = SlotPool(
|
||||
SlotPoolConfig(
|
||||
nomad_address=config.nomad_address,
|
||||
job_id=config.sandbox_job_id,
|
||||
image=config.sandbox_image,
|
||||
slots_per_container=config.slots_per_container,
|
||||
min_containers=config.min_containers,
|
||||
max_containers=config.max_containers,
|
||||
privileged=config.privileged,
|
||||
acquire_timeout=config.acquire_timeout_s,
|
||||
purge_job_on_start=bool(config.purge_job_on_start),
|
||||
driver=config.driver,
|
||||
singularity_image=config.singularity_image,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def default_timeout_s(self) -> Optional[float]:
|
||||
t = getattr(self.pool.executor, "timeout", None)
|
||||
total = getattr(t, "total", None)
|
||||
try:
|
||||
return float(total) if total is not None else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.pool.start()
|
||||
|
||||
async def stop(self, *, purge: bool = False) -> None:
|
||||
await self.pool.stop(purge_job=purge)
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
return await self.pool.acquire(trajectory_id)
|
||||
|
||||
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None:
|
||||
await self.pool.release(slot, reset_workspace=reset_workspace)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
*,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
return await self.pool.execute_batch(requests, timeout=timeout_s)
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.read_artifact(
|
||||
slot,
|
||||
path,
|
||||
encoding=encoding,
|
||||
max_bytes=max_bytes,
|
||||
include_sha256=include_sha256,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.list_artifacts(
|
||||
slot,
|
||||
path,
|
||||
recursive=recursive,
|
||||
max_entries=max_entries,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.archive_artifacts(
|
||||
slot,
|
||||
path,
|
||||
archive_format=archive_format,
|
||||
max_bytes=max_bytes,
|
||||
max_entries=max_entries,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self.pool.get_stats()
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
Environment implementations for atropos-agent.
|
||||
|
||||
NOTE: AgentEnv is the OLD environment system, replaced by
|
||||
environments/hermes_base_env.py (HermesAgentBaseEnv).
|
||||
Import is lazy to avoid pulling in deleted dependencies.
|
||||
"""
|
||||
|
||||
|
||||
def __getattr__(name):
|
||||
"""Lazy import to avoid breaking when old dependencies are removed."""
|
||||
if name in ("AgentEnv", "AgentEnvConfig"):
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
return {"AgentEnv": AgentEnv, "AgentEnvConfig": AgentEnvConfig}[name]
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
__all__ = ["AgentEnv", "AgentEnvConfig"]
|
||||
@@ -0,0 +1,537 @@
|
||||
"""
|
||||
AgentEnv - Atropos BaseEnv extension for agent/tool-call workloads.
|
||||
|
||||
AgentEnv is responsible for starting the sandbox tool execution backend and
|
||||
providing helpers for running agent trajectories with queued/batched tool calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Tuple, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, Item, ScoredDataGroup, ScoredDataItem
|
||||
from atroposlib.envs.server_handling.server_baseline import AsyncSemWithAdaptiveWeight
|
||||
|
||||
from ..agent import AgentConfig, AgentResult, AtroposAgent
|
||||
from ..backends import ToolBackend, create_tool_backend
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig
|
||||
|
||||
# Main BaseEnv child classes. Child class THESE to get agent+tooling functionality easily.
|
||||
|
||||
class AgentEnvConfig(BaseEnvConfig):
|
||||
tool_pool_mode: str = Field(default="nomad", description="Tool execution backend ('nomad' or 'modal')")
|
||||
|
||||
allow_network: bool = Field(
|
||||
default=True,
|
||||
description="Whether sandbox bash commands may access the network (env policy).",
|
||||
)
|
||||
require_sandbox: bool = Field(
|
||||
default=False,
|
||||
description="Fail closed if bubblewrap sandboxing is unavailable/unusable for stateless sandbox tools.",
|
||||
)
|
||||
require_stateful_sandbox: bool = Field(
|
||||
default=False,
|
||||
description="Fail closed if bubblewrap/PID isolation is unavailable for stateful terminal tools (tmux).",
|
||||
)
|
||||
tool_batch_window_ms: int = Field(default=20, description="ToolExecutor batching window (ms)")
|
||||
tool_max_batch_size: int = Field(default=200, description="ToolExecutor maximum batch size")
|
||||
|
||||
# nomad mode settings. TODO: Add Modal support, split this into own config
|
||||
nomad_address: str = Field(default="http://localhost:4646", description="Nomad API address")
|
||||
sandbox_job_id: str = Field(default="atropos-sandbox-agent-env", description="Nomad job id for sandbox containers")
|
||||
sandbox_image: str = Field(default="atropos-sandbox:local", description="Docker image for sandbox containers")
|
||||
slots_per_container: int = Field(default=10, description="Nomad mode: slots per container")
|
||||
min_containers: int = Field(default=1, description="Nomad mode: minimum containers")
|
||||
max_containers: int = Field(default=10, description="Nomad mode: maximum containers")
|
||||
privileged: bool = Field(default=False, description="Nomad mode: run container privileged")
|
||||
acquire_timeout_s: float = Field(default=30.0, description="Slot acquisition timeout (seconds)")
|
||||
purge_job_on_start: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Nomad mode: stop/purge the sandbox job on startup. This is helpful in local dev and training runs "
|
||||
"to recover from previous crashes that leave the job in a restart backoff state."
|
||||
),
|
||||
)
|
||||
purge_job_on_shutdown: bool = Field(default=True, description="Nomad mode: stop/purge job on shutdown")
|
||||
|
||||
# Nomad driver selection (docker or singularity)
|
||||
driver: str = Field(
|
||||
default="docker",
|
||||
description="Nomad task driver: 'docker' (default) or 'singularity' (for HPC without sudo Docker)",
|
||||
)
|
||||
singularity_image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to .sif file for Singularity driver (required if driver='singularity')",
|
||||
)
|
||||
|
||||
# Modal mode settings
|
||||
modal_app_name: str = Field(default="atropos-sandbox", description="Modal app name prefix")
|
||||
modal_image: str = Field(default="python:3.11", description="Modal: container image")
|
||||
modal_gpu: Optional[str] = Field(default=None, description="Modal: GPU type (None, 'T4', 'A10G', 'A100', 'H100')")
|
||||
modal_cpu: float = Field(default=1.0, description="Modal: CPU cores")
|
||||
modal_memory: int = Field(default=2048, description="Modal: memory in MB")
|
||||
modal_slots_per_sandbox: int = Field(default=10, description="Modal: slots per sandbox")
|
||||
modal_min_sandboxes: int = Field(default=1, description="Modal: minimum sandboxes")
|
||||
modal_max_sandboxes: int = Field(default=5, description="Modal: maximum sandboxes")
|
||||
modal_idle_timeout: int = Field(default=120, description="Modal: server-side idle timeout (seconds)")
|
||||
modal_max_lifetime: int = Field(default=3600, description="Modal: max sandbox lifetime (seconds)")
|
||||
modal_acquire_timeout: float = Field(default=60.0, description="Modal: slot acquisition timeout (seconds)")
|
||||
modal_execution_timeout: float = Field(default=30.0, description="Modal: default command execution timeout (seconds)")
|
||||
modal_secrets: str = Field(default="", description="Modal: comma-separated list of Modal Secret names")
|
||||
modal_env_vars: str = Field(default="", description="Modal: semicolon-separated KEY=VALUE pairs for env vars")
|
||||
modal_workspace_base: str = Field(default="/data", description="Modal: workspace base directory in sandbox")
|
||||
|
||||
# basic agent defaults
|
||||
agent_max_steps: int = Field(default=50, description="Max ReACT steps per trajectory")
|
||||
agent_temperature: float = Field(default=0.7, description="Sampling temperature")
|
||||
agent_max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Max tokens per model response (default: let backend decide)",
|
||||
)
|
||||
agent_tool_delay_s: float = Field(default=0.0, description="Delay between tool calls (seconds)")
|
||||
|
||||
# tool selection
|
||||
enabled_toolsets: List[str] = Field(
|
||||
default_factory=lambda: ["default"],
|
||||
description="Toolsets to enable (Hermes-style grouping).",
|
||||
)
|
||||
disabled_toolsets: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Toolsets to disable (applied after enabled_toolsets).",
|
||||
)
|
||||
|
||||
# external ToolServer routing (Phase 4.5+)
|
||||
tool_server_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Base URL for external ToolServer (enables external tools).",
|
||||
)
|
||||
tool_server_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Bearer token for ToolServer auth (optional in dev).",
|
||||
)
|
||||
|
||||
AgentEnvConfigT = TypeVar("AgentEnvConfigT", bound="AgentEnvConfig")
|
||||
|
||||
|
||||
class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||
env_config_cls = AgentEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AgentEnvConfigT,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: AgentEnvConfigT = config
|
||||
|
||||
self.tools: ToolRegistry = self.build_tools()
|
||||
|
||||
self._backend: Optional[ToolBackend] = None
|
||||
self._tool_executor: Optional[ToolExecutor] = None
|
||||
self._tool_server_inprocess: bool = False
|
||||
self._trajectory_workspace_meta: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def build_tools(self) -> ToolRegistry:
|
||||
"""Wraps original Hermes-Agent ToolRegistry for atropos AgentEnv use.
|
||||
See Hermes-Agent docs for toolsets and available tools etc.
|
||||
"""
|
||||
return build_tool_registry(
|
||||
enabled_toolsets=self.config.enabled_toolsets or ["default"],
|
||||
disabled_toolsets=self.config.disabled_toolsets or None,
|
||||
tool_server_url=self.config.tool_server_url,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def build_task(self, item: Item) -> str:
|
||||
"""Return the user-facing task string for the agent."""
|
||||
|
||||
@abstractmethod
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
"""Return a scalar score for this trajectory."""
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self,
|
||||
item: Item,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optional hook: prepare the sandbox workspace before the agent starts.
|
||||
|
||||
Examples:
|
||||
- clone a repo and checkout a commit
|
||||
- write fixture files (e.g. images) for external-tool demos
|
||||
- pre-install dependencies
|
||||
|
||||
Default: no-op.
|
||||
"""
|
||||
_ = (item, trajectory_id, exec_tool)
|
||||
return {}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]],
|
||||
agent_result: Optional[AgentResult] = None,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
Optional hook: run in-sandbox verification before scoring.
|
||||
|
||||
Many agent envs need to execute verification inside the same trajectory
|
||||
workspace (e.g. pytest) before releasing/resetting the slot.
|
||||
|
||||
Default: calls `score_trajectory()` and returns empty metadata.
|
||||
"""
|
||||
_ = (trajectory_id, exec_tool, agent_result, workspace_meta) # default ignores in-workspace verification
|
||||
score = await self.score_trajectory(item, final_response)
|
||||
return score, {}
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
tool_delay_s=self.config.agent_tool_delay_s,
|
||||
)
|
||||
|
||||
async def setup(self) -> None:
|
||||
print(f"[AgentEnv] setup(): starting tool backend ({self.config.tool_pool_mode})", flush=True)
|
||||
await self._start_tool_backend()
|
||||
print("[AgentEnv] setup(): configuring server concurrency", flush=True)
|
||||
self._configure_server_concurrency()
|
||||
print("[AgentEnv] setup(): running env-specific setup_agent_env()", flush=True)
|
||||
await self.setup_agent_env()
|
||||
print("[AgentEnv] setup(): done", flush=True)
|
||||
|
||||
def _configure_server_concurrency(self) -> None:
|
||||
"""
|
||||
Ensure the LLM server concurrency isn't accidentally capped below `group_size`.
|
||||
|
||||
In `BaseEnv process` mode, groups are collected concurrently and if the underlying
|
||||
ServerManager/OpenAIServer semaphore is left at 1, we serialize inference even
|
||||
when `--env.group_size` is > 1.
|
||||
"""
|
||||
desired = int(getattr(self.config, "group_size", 1) or 1)
|
||||
if desired <= 1:
|
||||
return
|
||||
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if not isinstance(servers, list) or not servers:
|
||||
return
|
||||
|
||||
for s in servers:
|
||||
sem = getattr(s, "sem", None)
|
||||
eval_sem = getattr(s, "eval_sem", None)
|
||||
# Only increase; never shrink.
|
||||
if sem is not None and getattr(sem, "max_val", 0) < desired:
|
||||
s.sem = AsyncSemWithAdaptiveWeight(desired)
|
||||
if hasattr(s, "config") and hasattr(s.config, "num_max_requests_at_once"):
|
||||
s.config.num_max_requests_at_once = desired
|
||||
if eval_sem is not None and getattr(eval_sem, "max_val", 0) < desired:
|
||||
s.eval_sem = AsyncSemWithAdaptiveWeight(desired)
|
||||
if hasattr(s, "config") and hasattr(s.config, "num_requests_for_eval"):
|
||||
s.config.num_requests_for_eval = desired
|
||||
|
||||
@abstractmethod
|
||||
async def setup_agent_env(self) -> None:
|
||||
"""Subclass hook for env-specific setup."""
|
||||
|
||||
async def evaluate(self, *args, **kwargs): # noqa: ARG002
|
||||
"""
|
||||
Default eval hook (no-op).
|
||||
|
||||
Atropos BaseEnv requires an `evaluate()` implementation. Many agent envs
|
||||
won't have a meaningful evaluation path during early PoC work; they can
|
||||
override this when needed.
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def env_manager(self):
|
||||
try:
|
||||
return await super().env_manager()
|
||||
finally:
|
||||
await self.shutdown_tool_backend()
|
||||
|
||||
async def process_manager(self):
|
||||
try:
|
||||
return await super().process_manager()
|
||||
finally:
|
||||
await self.shutdown_tool_backend()
|
||||
|
||||
async def _start_tool_backend(self) -> None:
|
||||
if self._tool_executor is not None:
|
||||
return
|
||||
|
||||
tool_server_url = self.config.tool_server_url
|
||||
tool_server_client = None
|
||||
if tool_server_url == "inprocess":
|
||||
import httpx
|
||||
from ..api.tool_server import app as tool_server_app
|
||||
|
||||
await tool_server_app.router.startup()
|
||||
tool_server_client = httpx.AsyncClient(
|
||||
transport=httpx.ASGITransport(app=tool_server_app),
|
||||
base_url="http://toolserver",
|
||||
)
|
||||
tool_server_url = "http://toolserver"
|
||||
self._tool_server_inprocess = True
|
||||
|
||||
backend = create_tool_backend(self.config)
|
||||
await backend.start()
|
||||
|
||||
executor = ToolExecutor(
|
||||
backend=backend,
|
||||
tools=self.tools,
|
||||
config=ToolExecutorConfig(
|
||||
batch_window_ms=self.config.tool_batch_window_ms,
|
||||
max_batch_size=self.config.tool_max_batch_size,
|
||||
allow_network=self.config.allow_network,
|
||||
require_sandbox=self.config.require_sandbox,
|
||||
require_stateful_sandbox=self.config.require_stateful_sandbox,
|
||||
tool_server_url=tool_server_url,
|
||||
tool_server_token=self.config.tool_server_token,
|
||||
),
|
||||
)
|
||||
await executor.start()
|
||||
if tool_server_client is not None:
|
||||
executor._tool_server_client = tool_server_client # type: ignore[attr-defined]
|
||||
|
||||
self._backend = backend
|
||||
self._tool_executor = executor
|
||||
|
||||
async def shutdown_tool_backend(self) -> None:
|
||||
executor = self._tool_executor
|
||||
backend = self._backend
|
||||
inprocess_tool_server = self._tool_server_inprocess
|
||||
self._tool_executor = None
|
||||
self._backend = None
|
||||
self._tool_server_inprocess = False
|
||||
|
||||
if executor is not None:
|
||||
await executor.close()
|
||||
if backend is not None:
|
||||
await backend.stop(purge=bool(self.config.purge_job_on_shutdown))
|
||||
if inprocess_tool_server:
|
||||
from ..api.tool_server import app as tool_server_app
|
||||
|
||||
await tool_server_app.router.shutdown()
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
|
||||
if self._tool_executor is None:
|
||||
raise RuntimeError("Tool backend not started")
|
||||
|
||||
trajectory_id = str(uuid.uuid4())
|
||||
t0 = time.perf_counter()
|
||||
print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} start", flush=True)
|
||||
task = self.build_task(item)
|
||||
agent_config = self.build_agent_config(item)
|
||||
if os.getenv("ATROPOS_DEBUG_PRINT_TASK") == "1":
|
||||
print(f"Starting trajectory {trajectory_id} with task: {task}", flush=True)
|
||||
else:
|
||||
# Avoid printing the full task prompt by default (can be huge/noisy).
|
||||
one_line = " ".join(str(task).splitlines()).strip()
|
||||
preview = one_line[:240] + ("…" if len(one_line) > 240 else "")
|
||||
print(f"Starting trajectory {trajectory_id} (task preview): {preview}", flush=True)
|
||||
|
||||
async def _exec(call):
|
||||
return await self._tool_executor.execute(trajectory_id, call)
|
||||
|
||||
agent = AtroposAgent(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
tools=self.tools,
|
||||
config=agent_config,
|
||||
execute_tool=_exec,
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"[AgentEnv] tid={trajectory_id} setup_trajectory_workspace() start", flush=True)
|
||||
workspace_meta = await self.setup_trajectory_workspace(item, trajectory_id=trajectory_id, exec_tool=_exec)
|
||||
if not isinstance(workspace_meta, dict):
|
||||
workspace_meta = {}
|
||||
self._trajectory_workspace_meta[trajectory_id] = workspace_meta
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} setup_trajectory_workspace() done in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print(f"[AgentEnv] tid={trajectory_id} agent.run() start", flush=True)
|
||||
result = await agent.run(task)
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} agent.run() done in {time.perf_counter() - t0:.2f}s "
|
||||
f"success={result.success} tool_calls={result.total_tool_calls}",
|
||||
flush=True,
|
||||
)
|
||||
if not result.success or result.trajectory_data is None:
|
||||
# Do not trigger BaseEnv retries for agent failures.
|
||||
# Record the trajectory with score 0.0 so training/eval can see the failure mode.
|
||||
messages = [{"role": "system", "content": agent._build_system_prompt()}] # noqa: SLF001
|
||||
messages.append({"role": "user", "content": task})
|
||||
for step in result.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
tool_text = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": tool_text})
|
||||
|
||||
scored: ScoredDataItem = {
|
||||
"tokens": (result.trajectory_data.tokens if result.trajectory_data else []),
|
||||
"masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []),
|
||||
"scores": 0.0,
|
||||
}
|
||||
if result.trajectory_data is not None:
|
||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||
if getattr(result.trajectory_data, "metadata", None):
|
||||
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
|
||||
if self.config.include_messages:
|
||||
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
|
||||
import json
|
||||
|
||||
err = result.error or "agent_failed"
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>{json.dumps({'success': False, 'error': err})}</tool_response>",
|
||||
}
|
||||
)
|
||||
scored["messages"] = messages
|
||||
return scored, []
|
||||
|
||||
print(f"[AgentEnv] tid={trajectory_id} verify_and_score_trajectory() start", flush=True)
|
||||
score, score_metadata = await self.verify_and_score_trajectory(
|
||||
item,
|
||||
result.final_response,
|
||||
trajectory_id=trajectory_id,
|
||||
exec_tool=_exec,
|
||||
agent_result=result,
|
||||
workspace_meta=workspace_meta,
|
||||
)
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} verify_and_score_trajectory() done in {time.perf_counter() - t0:.2f}s "
|
||||
f"score={score}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": agent._build_system_prompt()}] # noqa: SLF001
|
||||
messages.append({"role": "user", "content": task})
|
||||
for step in result.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
tool_text = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": tool_text})
|
||||
|
||||
# Optional: allow env verification to attach additional messages (e.g. install logs).
|
||||
if self.config.include_messages and isinstance(score_metadata, dict):
|
||||
extra = score_metadata.get("verification_messages")
|
||||
if isinstance(extra, list):
|
||||
for m in extra:
|
||||
if isinstance(m, dict) and isinstance(m.get("role"), str) and isinstance(m.get("content"), str):
|
||||
messages.append({"role": m["role"], "content": m["content"]})
|
||||
|
||||
scored: ScoredDataItem = {
|
||||
"tokens": result.trajectory_data.tokens,
|
||||
"masks": result.trajectory_data.masked_tokens,
|
||||
"scores": score,
|
||||
}
|
||||
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
|
||||
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
|
||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||
if getattr(result.trajectory_data, "metadata", None):
|
||||
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
|
||||
if self.config.include_messages:
|
||||
scored["messages"] = messages
|
||||
|
||||
return scored, []
|
||||
finally:
|
||||
self._trajectory_workspace_meta.pop(trajectory_id, None)
|
||||
print(f"[AgentEnv] tid={trajectory_id} release_trajectory(reset_workspace=True)", flush=True)
|
||||
await self._tool_executor.release_trajectory(trajectory_id, reset_workspace=True)
|
||||
print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} done in {time.perf_counter() - t0:.2f}s", flush=True)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
tasks = [self.collect_trajectory(item) for _ in range(self.config.group_size)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
backlog: List[Item] = []
|
||||
items: List[ScoredDataItem] = []
|
||||
for scored, b in results:
|
||||
backlog.extend(b)
|
||||
if scored is not None:
|
||||
items.append(scored)
|
||||
|
||||
if len(items) != self.config.group_size:
|
||||
return None, backlog
|
||||
|
||||
group: ScoredDataGroup = ScoredDataGroup(
|
||||
tokens=[],
|
||||
masks=[],
|
||||
scores=[],
|
||||
advantages=[],
|
||||
ref_logprobs=[],
|
||||
messages=[] if self.config.include_messages else None,
|
||||
inference_logprobs=[],
|
||||
group_overrides={},
|
||||
overrides=[],
|
||||
images=[],
|
||||
generation_params=None,
|
||||
)
|
||||
|
||||
for it in items:
|
||||
group["tokens"].append(it["tokens"])
|
||||
group["masks"].append(it["masks"])
|
||||
group["scores"].append(it["scores"])
|
||||
# policy logprobs (for PPO/GRPO training) if present
|
||||
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
|
||||
if lp is not None:
|
||||
group["inference_logprobs"].append(lp)
|
||||
group["overrides"].append(it.get("overrides") or {}) # type: ignore[typeddict-item]
|
||||
if group.get("messages") is not None and it.get("messages") is not None:
|
||||
group["messages"].append(it["messages"])
|
||||
|
||||
return group, backlog
|
||||
|
||||
async def run_agent(self, task: str, *, trajectory_id: Optional[str] = None) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Run the AtroposAgent on a single task and return (final_response, debug).
|
||||
|
||||
This is a helper intended for simple environments and tests.
|
||||
"""
|
||||
if self._tool_executor is None:
|
||||
raise RuntimeError("Tool backend not started")
|
||||
|
||||
tid = trajectory_id or str(uuid.uuid4())
|
||||
|
||||
async def _exec(call):
|
||||
return await self._tool_executor.execute(tid, call)
|
||||
|
||||
agent = AtroposAgent(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
tools=self.tools,
|
||||
config=AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
),
|
||||
execute_tool=_exec,
|
||||
)
|
||||
result = await agent.run(task)
|
||||
await self._tool_executor.release_trajectory(tid, reset_workspace=True)
|
||||
return result.final_response, {"success": result.success, "error": result.error, "tool_calls": result.total_tool_calls}
|
||||
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Hermes-Agent + Atropos (Nomad sandbox) compatibility smoke environment.
|
||||
|
||||
This environment is intended to validate, end-to-end:
|
||||
BaseEnv.process -> AgentEnv -> ToolExecutor (batched) -> Nomad SlotPool -> sandbox_server
|
||||
|
||||
It forces the model to use a sandbox tool by asking it to run a command that
|
||||
generates a high-entropy token inside the sandbox, then repeat it exactly.
|
||||
|
||||
Run (process mode):
|
||||
uv run python -m atropos.envs.hermes_compat_test_env process --env.use_wandb false --env.total_steps 2 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _forced_tool_item() -> Item:
|
||||
# Use double quotes in the shell command and show JSON escaping explicitly.
|
||||
# This avoids invalid JSON escapes like `\\'` (not valid JSON) that some models produce.
|
||||
cmd = 'python -c "import secrets; print(secrets.token_hex(16))"'
|
||||
return {
|
||||
"command": cmd,
|
||||
"prompt": (
|
||||
"You are acting as an agent inside a sandboxed environment.\n"
|
||||
"You MUST use the terminal tool to execute commands.\n"
|
||||
"Run this exact command:\n"
|
||||
f"{cmd}\n"
|
||||
"When you call the tool, use valid JSON inside <tool_call>. Example:\n"
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": '
|
||||
'"python -c \\\\"import secrets; print(secrets.token_hex(16))\\\\""}}'
|
||||
"</tool_call>\n"
|
||||
"Then respond with EXACTLY what it printed (the hex token) and nothing else.\n"
|
||||
"Do not guess. Do not explain."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class HermesCompatTestEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class HermesCompatTestEnv(AgentEnv[HermesCompatTestEnvConfig]):
|
||||
name = "hermes_compat_test_env"
|
||||
env_config_cls = HermesCompatTestEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HermesCompatTestEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[HermesCompatTestEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = HermesCompatTestEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=2,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
# Tooling: sandbox-only terminal.
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
# Default to Nomad sandboxing; users can override via --env.* args.
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
# In local dev it's common for a previous crash to leave the job in backoff.
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return _forced_tool_item()
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# Avoid imposing max_tokens by default; tool-tag responses can be long for some models.
|
||||
return AgentConfig(
|
||||
max_steps=min(8, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Scoring happens in verify_and_score_trajectory so we can inspect tool results.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
observed: str = ""
|
||||
tool_ok = False
|
||||
for step in agent_result.steps:
|
||||
for res in step.tool_results:
|
||||
if not res.success:
|
||||
return 0.0, {"error": res.error, "output": res.output}
|
||||
out = (res.output or "").strip()
|
||||
if out:
|
||||
observed = out.splitlines()[-1].strip()
|
||||
tool_ok = True
|
||||
|
||||
final = (final_response or "").strip()
|
||||
score = 1.0 if tool_ok and agent_result.total_tool_calls > 0 and observed and final == observed else 0.0
|
||||
return score, {"observed": observed, "tool_calls": agent_result.total_tool_calls, "command": item.get("command")}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HermesCompatTestEnv.cli()
|
||||
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Nomad sandbox terminal smoke environment (training-oriented).
|
||||
|
||||
Validates, end-to-end:
|
||||
BaseEnv.process -> AgentEnv -> ToolExecutor (batched) -> Nomad SlotPool -> sandbox_server
|
||||
|
||||
It forces the model to use a sandbox tool by asking it to run a command that
|
||||
generates a high-entropy token inside the sandbox, then repeat it exactly.
|
||||
|
||||
Run (process mode):
|
||||
uv run python -m atropos.envs.sandbox_terminal_smoke_env process --env.use_wandb false --env.total_steps 2 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
STRICT_TOOLCALL_SYSTEM_PROMPT = None
|
||||
|
||||
|
||||
def _forced_tool_item() -> Item:
|
||||
# Use double quotes in the shell command and show JSON escaping explicitly.
|
||||
# This avoids invalid JSON escapes like `\\'` (not valid JSON) that some models produce.
|
||||
cmd = 'python -c "import secrets; print(secrets.token_hex(16))"'
|
||||
return {
|
||||
"command": cmd,
|
||||
"prompt": (
|
||||
"You MUST use the terminal tool.\n"
|
||||
"Run this exact command:\n"
|
||||
f"{cmd}\n"
|
||||
"When you call the tool, use valid JSON inside <tool_call>. Example:\n"
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": '
|
||||
'"python -c \\\\"import secrets; print(secrets.token_hex(16))\\\\""}}'
|
||||
"</tool_call>\n"
|
||||
"Then respond with EXACTLY what it printed (the hex token) and nothing else.\n"
|
||||
"Do not guess. Do not explain."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class SandboxTerminalSmokeEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SandboxTerminalSmokeEnv(AgentEnv[SandboxTerminalSmokeEnvConfig]):
|
||||
name = "sandbox_terminal_smoke_env"
|
||||
env_config_cls = SandboxTerminalSmokeEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SandboxTerminalSmokeEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SandboxTerminalSmokeEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SandboxTerminalSmokeEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=2,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
# Tooling: sandbox-only terminal.
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
# Default to Nomad sandboxing; users can override via --env.* args.
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return _forced_tool_item()
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# Avoid imposing max_tokens by default; tool-tag responses can be long for some models.
|
||||
return AgentConfig(
|
||||
max_steps=min(8, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
system_prompt=STRICT_TOOLCALL_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Scoring happens in verify_and_score_trajectory so we can inspect tool results.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
observed: str = ""
|
||||
tool_ok = False
|
||||
for step in agent_result.steps:
|
||||
for res in step.tool_results:
|
||||
if not res.success:
|
||||
return 0.0, {"error": res.error, "output": res.output}
|
||||
out = (res.output or "").strip()
|
||||
if out:
|
||||
observed = out.splitlines()[-1].strip()
|
||||
tool_ok = True
|
||||
|
||||
final = (final_response or "").strip()
|
||||
score = 1.0 if tool_ok and agent_result.total_tool_calls > 0 and observed and final == observed else 0.0
|
||||
return score, {"observed": observed, "tool_calls": agent_result.total_tool_calls, "command": item.get("command")}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SandboxTerminalSmokeEnv.cli()
|
||||
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
SWE-smith-oracle environment.
|
||||
|
||||
This environment is intentionally minimal:
|
||||
- prepares a sandbox workspace by cloning a public GitHub repo at `base_commit`
|
||||
- runs an AtroposAgent tool loop to apply a fix
|
||||
- verifies by running pytest nodeids from the dataset (reward = pass/fail)
|
||||
- Python only (no multi-language support currently, need to properly bauild & add to dropbox)
|
||||
- TODO: Get the other nonpython sandboxes up and running, then add a config knob to switch between them per row
|
||||
- oh and add to dockerhub
|
||||
|
||||
Dataset: NousResearch/SWE-smith-oracle (train; does NOT use SWE-bench eval set).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
|
||||
class SweSmithOracleEnvConfig(AgentEnvConfig):
|
||||
dataset_name: str = Field(default="NousResearch/SWE-smith-oracle")
|
||||
dataset_split: str = Field(default="train")
|
||||
max_items: int = Field(default=0, description="0 = no limit")
|
||||
shuffle: bool = Field(default=True)
|
||||
seed: int = Field(default=0)
|
||||
|
||||
python_only: bool = Field(default=True, description="Filter to Python-evaluable rows")
|
||||
score_include_fail_to_pass: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (default), score tests on PASS_TO_PASS ∪ FAIL_TO_PASS. "
|
||||
"Disable to only run PASS_TO_PASS (faster but weaker signal)."
|
||||
),
|
||||
)
|
||||
|
||||
prompt_mode: str = Field(
|
||||
default="problem_statement",
|
||||
description="Task prompt content: 'problem_statement' (fast) or 'problem_statement+text' (slower, includes dataset 'text').",
|
||||
)
|
||||
|
||||
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
||||
install_timeout_s: float = Field(default=600.0)
|
||||
test_timeout_s: float = Field(default=600.0)
|
||||
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||
"""
|
||||
SWE-smith-oracle AgentEnv.
|
||||
|
||||
This is designed for benchmarking multiplexed slot execution vs naive container-per-trajectory.
|
||||
"""
|
||||
|
||||
name = "swe_smith_oracle_env"
|
||||
env_config_cls = SweSmithOracleEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SweSmithOracleEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._dataset = None
|
||||
self._indices: List[int] = []
|
||||
self._cursor = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
||||
# Defaults for running the env via CLI in offline `process` mode.
|
||||
# Override via env vars or `--env.*` flags as needed.
|
||||
base_url_raw = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
base_url = base_url_raw.rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url}/v1"
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SweSmithOracleEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=1,
|
||||
max_token_length=8192,
|
||||
inference_weight=1.0,
|
||||
wandb_name="swe_smith_oracle",
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"),
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
from datasets import load_dataset
|
||||
|
||||
t0 = time.perf_counter()
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})",
|
||||
flush=True,
|
||||
)
|
||||
ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self._dataset = ds
|
||||
|
||||
indices: List[int] = []
|
||||
for idx in range(len(ds)):
|
||||
row = ds[idx]
|
||||
if self.config.python_only and not self._is_python_row(row):
|
||||
continue
|
||||
indices.append(idx)
|
||||
|
||||
if self.config.shuffle:
|
||||
rnd = random.Random(self.config.seed)
|
||||
rnd.shuffle(indices)
|
||||
|
||||
if self.config.max_items and self.config.max_items > 0:
|
||||
indices = indices[: self.config.max_items]
|
||||
|
||||
self._indices = indices
|
||||
self._cursor = 0
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loaded {len(self._indices)} items from {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _is_python_row(self, row: Dict[str, Any]) -> bool:
|
||||
nodeids = row.get("PASS_TO_PASS")
|
||||
if not isinstance(nodeids, list) or not nodeids:
|
||||
return False
|
||||
for nid in nodeids:
|
||||
if not isinstance(nid, str) or ".py::" not in nid:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
print(f"[SweSmithOracleEnv] get_next_item() cursor={self._cursor}/{len(self._indices)}", flush=True)
|
||||
if not self._dataset or not self._indices:
|
||||
raise RuntimeError("Dataset not initialized (did setup() run?)")
|
||||
if self._cursor >= len(self._indices):
|
||||
self._cursor = 0
|
||||
idx = self._indices[self._cursor]
|
||||
self._cursor += 1
|
||||
return dict(self._dataset[idx])
|
||||
|
||||
def _repo_name(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
if isinstance(repo, str) and "/" in repo:
|
||||
return repo.split("/")[-1]
|
||||
return "repo"
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
base_commit = item.get("base_commit") or ""
|
||||
problem = str(item.get("problem_statement") or "")
|
||||
context = str(item.get("text") or "")
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
tests_list = "\n".join(f"- {t}" for t in nodeids)
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
tests_block = (
|
||||
"Run these tests to verify:\n"
|
||||
f"{tests_list}\n\n"
|
||||
"When done, briefly describe what you changed and confirm tests pass."
|
||||
)
|
||||
|
||||
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
|
||||
if prompt_mode not in {"problem_statement", "problem_statement+text"}:
|
||||
raise ValueError(
|
||||
f"Invalid prompt_mode={self.config.prompt_mode!r}. "
|
||||
"Expected 'problem_statement' or 'problem_statement+text'."
|
||||
)
|
||||
|
||||
context_block = ""
|
||||
if prompt_mode == "problem_statement+text" and context:
|
||||
# Note: We intentionally do NOT truncate/cap here. This mode is for debugging / richer prompts and can be slow.
|
||||
context_block = f"\nAdditional context:\n{context}\n"
|
||||
|
||||
return (
|
||||
"You are a senior software engineer. Fix the repository so the specified tests pass.\n\n"
|
||||
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
||||
f"Workspace path: ./{repo_dir}\n\n"
|
||||
"Constraints:\n"
|
||||
"- You MUST use the terminal tool to inspect, edit, and verify the repository. Do not respond with a patch file.\n"
|
||||
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
|
||||
"- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n"
|
||||
"- Use non-interactive commands only.\n\n"
|
||||
"- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n"
|
||||
" Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n"
|
||||
"Problem statement:\n"
|
||||
f"{problem}\n\n"
|
||||
f"{context_block}\n"
|
||||
f"{tests_block}"
|
||||
)
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# SWE tasks are longer than the simple test env.
|
||||
return AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
tool_delay_s=self.config.agent_tool_delay_s,
|
||||
)
|
||||
|
||||
async def setup_trajectory_workspace(self, item: Item, *, trajectory_id: str, exec_tool) -> Dict[str, Any]:
|
||||
t0 = time.perf_counter()
|
||||
repo = item.get("repo")
|
||||
base_commit = item.get("base_commit")
|
||||
instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id")
|
||||
if not isinstance(repo, str) or not isinstance(base_commit, str):
|
||||
raise RuntimeError("Invalid dataset row: missing repo/base_commit")
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git"
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Repo setup strategy:
|
||||
# - Maintain a shared, per-container bare repo cache under /data/repo_cache
|
||||
# - For each trajectory, create an isolated git worktree under the slot workspace
|
||||
# This avoids cloning/fetching full repos per trajectory and is crucial for multiplexing.
|
||||
|
||||
def _repo_cache_slug(repo_name: str) -> str:
|
||||
return repo_name.replace("/", "__")
|
||||
|
||||
repo_slug = _repo_cache_slug(repo)
|
||||
cache_root = "/data/repo_cache"
|
||||
bare_repo = f"{cache_root}/{repo_slug}.git"
|
||||
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
|
||||
|
||||
# Use flock to serialize operations that mutate the shared bare repo (fetch/worktree).
|
||||
# util-linux (flock) is included in the sandbox image.
|
||||
worktree_cmd = (
|
||||
"set -e; "
|
||||
f"rm -rf {repo_dir}; "
|
||||
f"mkdir -p {cache_root}/.locks; "
|
||||
f": > {lock_file}; "
|
||||
f"flock -x {lock_file} sh -lc '"
|
||||
f"set -e; "
|
||||
"export GIT_TERMINAL_PROMPT=0; "
|
||||
"export GIT_LFS_SKIP_SMUDGE=1; "
|
||||
f"if [ ! -d \"{bare_repo}\" ]; then "
|
||||
f" git init --bare \"{bare_repo}\"; "
|
||||
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
|
||||
"fi; "
|
||||
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
|
||||
f"git -C \"{bare_repo}\" worktree prune || true; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
|
||||
"fi; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --prune origin; "
|
||||
"fi; "
|
||||
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
|
||||
"'"
|
||||
)
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
|
||||
res = await exec_tool(
|
||||
ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": worktree_cmd, "timeout": self.config.install_timeout_s},
|
||||
)
|
||||
)
|
||||
if not res.success:
|
||||
raise RuntimeError(
|
||||
"git worktree setup failed "
|
||||
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): worktree ready in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
||||
|
||||
def _tests_for_item(self, item: Item) -> List[str]:
|
||||
tests: List[str] = []
|
||||
if self.config.score_include_fail_to_pass:
|
||||
for key in ("PASS_TO_PASS", "FAIL_TO_PASS"):
|
||||
nodeids = item.get(key)
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
else:
|
||||
nodeids = item.get("PASS_TO_PASS")
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
# Stable order for reproducibility.
|
||||
return sorted(dict.fromkeys(tests))
|
||||
|
||||
def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]:
|
||||
chunks: List[List[str]] = []
|
||||
for i in range(0, len(nodeids), max_per_chunk):
|
||||
chunks.append(nodeids[i : i + max_per_chunk])
|
||||
return chunks
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str, # noqa: ARG002
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
agent_result=None,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
_ = trajectory_id
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
# Training correctness: do not reward trajectories that never actually used tools.
|
||||
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): no tool calls; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"error": "No tool calls were made by the agent",
|
||||
}
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
return 0.0, {"error": "No tests provided"}
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): ensuring venv + deps", flush=True)
|
||||
setup_cmd = (
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest"
|
||||
)
|
||||
setup_res = await exec_tool(
|
||||
ToolCall(name="terminal", arguments={"command": setup_cmd, "timeout": self.config.install_timeout_s})
|
||||
)
|
||||
verification_messages = [{"role": "user", "content": setup_res.to_xml()}]
|
||||
if not setup_res.success:
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"phase": "install",
|
||||
"error": setup_res.error,
|
||||
"output": setup_res.output,
|
||||
"verification_messages": verification_messages,
|
||||
}
|
||||
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}"
|
||||
res = await exec_tool(
|
||||
ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": cmd, "timeout": self.config.test_timeout_s},
|
||||
)
|
||||
)
|
||||
verification_messages.append({"role": "user", "content": res.to_xml()})
|
||||
if not res.success:
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"phase": "pytest",
|
||||
"failed_chunk": chunk_idx,
|
||||
"error": res.error,
|
||||
"output": res.output,
|
||||
"verification_messages": verification_messages,
|
||||
}
|
||||
|
||||
return 1.0, {"verification_mode": "dataset_tests", "passed": True, "verification_messages": verification_messages}
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Not used; scoring happens in verify_and_score_trajectory.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SweSmithOracleEnv.cli()
|
||||
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Simple test environment for validating the atropos-agent setup.
|
||||
|
||||
This environment uses a local OpenAI-compatible server for LLM testing to verify:
|
||||
- BaseEnv extension works correctly
|
||||
- API communication via OpenAI-compatible endpoint
|
||||
- Basic trajectory collection
|
||||
|
||||
This is a minimal environment for testing, not production use.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
Item,
|
||||
)
|
||||
|
||||
from ..agent import AgentConfig
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Simple test prompts for validation
|
||||
TEST_PROMPTS = [
|
||||
{
|
||||
"prompt": "What is 2 + 2? Answer with just the number.",
|
||||
"expected": "4",
|
||||
},
|
||||
{
|
||||
"prompt": "What is the capital of France? Answer with just the city name.",
|
||||
"expected": "Paris",
|
||||
},
|
||||
{
|
||||
"prompt": "What color is the sky on a clear day? Answer with just the color.",
|
||||
"expected": "Blue",
|
||||
},
|
||||
{
|
||||
"prompt": "How many days are in a week? Answer with just the number.",
|
||||
"expected": "7",
|
||||
},
|
||||
{
|
||||
"prompt": "What is 10 * 5? Answer with just the number.",
|
||||
"expected": "50",
|
||||
},
|
||||
]
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a helpful assistant. Answer questions concisely and directly. "
|
||||
"When asked for a simple answer, provide just that answer without explanation."
|
||||
)
|
||||
|
||||
|
||||
class SimpleTestEnvConfig(AgentEnvConfig):
|
||||
"""Configuration for the simple test environment."""
|
||||
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible server (without /v1)",
|
||||
)
|
||||
server_model: str = Field(
|
||||
default="hermes-4-36b",
|
||||
description="Model name",
|
||||
)
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SimpleTestEnv(AgentEnv[SimpleTestEnvConfig]):
|
||||
"""
|
||||
A simple test environment to validate the atropos-agent setup.
|
||||
|
||||
Uses a local OpenAI-compatible LLM endpoint with basic question-answering tasks.
|
||||
Scoring is based on whether the response contains the expected answer.
|
||||
"""
|
||||
|
||||
name = "simple_test_env"
|
||||
env_config_cls = SimpleTestEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SimpleTestEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.iter = 0
|
||||
self.test_prompts = TEST_PROMPTS
|
||||
self.percent_correct_buffer: List[float] = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SimpleTestEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Initialize configuration with local server settings from environment variables.
|
||||
"""
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SimpleTestEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=4,
|
||||
use_wandb=False, # Disable wandb for simple testing
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=10,
|
||||
batch_size=16,
|
||||
steps_per_eval=5,
|
||||
max_token_length=2048,
|
||||
inference_weight=1.0,
|
||||
wandb_name="simple_test",
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
)
|
||||
|
||||
# OpenAI-compatible servers typically expose chat completions at /v1.
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=4,
|
||||
num_requests_for_eval=8,
|
||||
timeout=120, # Local models may be slower
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self):
|
||||
"""Setup the environment - load test data."""
|
||||
print(f"SimpleTestEnv setup complete. {len(self.test_prompts)} test prompts loaded.")
|
||||
print(f"Using server at: {self.config.server_base_url}")
|
||||
print(f"Model: {self.config.server_model}")
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""Get the next test prompt."""
|
||||
item = self.test_prompts[self.iter % len(self.test_prompts)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return item["prompt"]
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=5,
|
||||
temperature=0.7,
|
||||
max_tokens=256,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
expected = item["expected"].lower()
|
||||
response_lower = (final_response or "").lower()
|
||||
score = 1.0 if expected in response_lower else 0.0
|
||||
self.percent_correct_buffer.append(score)
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Simple evaluation - run through all test prompts once.
|
||||
"""
|
||||
correct = 0
|
||||
total = len(self.test_prompts)
|
||||
|
||||
for item in self.test_prompts:
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": item["prompt"]},
|
||||
]
|
||||
|
||||
response = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=256,
|
||||
temperature=0.0, # Greedy for eval
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_text = response.choices[0].message.content or ""
|
||||
expected = item["expected"].lower()
|
||||
|
||||
if expected in response_text.lower():
|
||||
correct += 1
|
||||
|
||||
accuracy = correct / total
|
||||
print(f"Evaluation: {correct}/{total} = {accuracy:.2%} accuracy")
|
||||
return {"eval_accuracy": accuracy}
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log metrics (simplified for testing)."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.percent_correct_buffer:
|
||||
avg_correct = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
|
||||
wandb_metrics["train/percent_correct"] = avg_correct
|
||||
print(f"Train accuracy: {avg_correct:.2%}")
|
||||
self.percent_correct_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Allow running as CLI
|
||||
SimpleTestEnv.cli()
|
||||
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
ToolServer routing smoke environment.
|
||||
|
||||
Validates that:
|
||||
- sandbox tools run through Nomad SlotPool (terminal -> bash in sandbox)
|
||||
- external tools run through ToolServer (skills_list)
|
||||
|
||||
This env uses ToolServer in-process by default (`tool_server_url="inprocess"`),
|
||||
so it is self-contained for local testing.
|
||||
|
||||
Run:
|
||||
uv run python -m atropos.envs.toolserver_smoke_env process --env.use_wandb false --env.total_steps 1 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class ToolServerSmokeEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class ToolServerSmokeEnv(AgentEnv[ToolServerSmokeEnvConfig]):
|
||||
name = "toolserver_smoke_env"
|
||||
env_config_cls = ToolServerSmokeEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ToolServerSmokeEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[ToolServerSmokeEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = ToolServerSmokeEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
enabled_toolsets=["terminal", "skills"],
|
||||
disabled_toolsets=[],
|
||||
# Self-contained ToolServer for local smoke.
|
||||
tool_server_url="inprocess",
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return {
|
||||
"prompt": (
|
||||
"You MUST call exactly one tool per assistant message.\n"
|
||||
"\n"
|
||||
"Step 1) Call the skills_list tool (no arguments), then stop.\n"
|
||||
"Step 2) After you receive the tool response, call the terminal tool to run:\n"
|
||||
"python -c \"print('ok')\"\n"
|
||||
"Step 3) After you receive the terminal tool response, answer with just: ok\n"
|
||||
"\n"
|
||||
"Tool call format requirements:\n"
|
||||
"- Every tool call MUST be a complete XML block with a closing tag.\n"
|
||||
"- Do NOT emit a second <tool_call> in the same assistant message.\n"
|
||||
"\n"
|
||||
"Example:\n"
|
||||
"<tool_call>{\"name\": \"skills_list\", \"arguments\": {}}</tool_call>\n"
|
||||
"Do not include anything else in your final answer."
|
||||
)
|
||||
}
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=min(10, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
called = {c.name for s in agent_result.steps for c in s.tool_calls}
|
||||
need = {"skills_list", "terminal"}
|
||||
if not need.issubset(called):
|
||||
return 0.0, {"error": f"Missing tool calls: {sorted(need - called)}", "called": sorted(called)}
|
||||
|
||||
terminal_ok = False
|
||||
for step in agent_result.steps:
|
||||
for call, res in zip(step.tool_calls, step.tool_results):
|
||||
if call.name != "terminal":
|
||||
continue
|
||||
if res.success and (res.output or "").strip().splitlines()[-1].strip() == "ok":
|
||||
terminal_ok = True
|
||||
|
||||
score = 1.0 if terminal_ok and (final_response or "").strip() == "ok" else 0.0
|
||||
return score, {"called": sorted(called), "final": (final_response or "").strip()}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ToolServerSmokeEnv.cli()
|
||||
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Nomad integration for atropos-agent.
|
||||
|
||||
Provides:
|
||||
- NomadClient: Client for Nomad HTTP API
|
||||
- Job templates for sandbox containers
|
||||
"""
|
||||
|
||||
from .client import NomadClient
|
||||
|
||||
__all__ = ["NomadClient"]
|
||||
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
Nomad API Client for atropos-agent.
|
||||
|
||||
Provides a simple async client for interacting with the Nomad HTTP API:
|
||||
- Submit/stop jobs
|
||||
- Query allocations
|
||||
- Get allocation addresses
|
||||
- Scale jobs up/down
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class AllocationStatus(Enum):
|
||||
"""Nomad allocation status."""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETE = "complete"
|
||||
FAILED = "failed"
|
||||
LOST = "lost"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Allocation:
|
||||
"""Information about a Nomad allocation."""
|
||||
id: str
|
||||
job_id: str
|
||||
task_group: str
|
||||
node_id: str
|
||||
status: AllocationStatus
|
||||
# Network info for reaching the allocation
|
||||
address: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
|
||||
@property
|
||||
def http_address(self) -> Optional[str]:
|
||||
"""Get full HTTP address for the allocation."""
|
||||
if self.address and self.port:
|
||||
return f"http://{self.address}:{self.port}"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobStatus:
|
||||
"""Status of a Nomad job."""
|
||||
id: str
|
||||
name: str
|
||||
status: str
|
||||
allocations: List[Allocation] = field(default_factory=list)
|
||||
count: int = 0 # Number of task groups
|
||||
|
||||
|
||||
class NomadClient:
|
||||
"""
|
||||
Async client for Nomad HTTP API.
|
||||
|
||||
Usage:
|
||||
client = NomadClient(address="http://localhost:4646")
|
||||
|
||||
# Submit a job
|
||||
await client.submit_job(job_spec)
|
||||
|
||||
# Get allocations
|
||||
allocs = await client.get_job_allocations("sandbox-python")
|
||||
|
||||
# Scale job
|
||||
await client.scale_job("sandbox-python", count=5)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: str = "http://localhost:4646",
|
||||
token: Optional[str] = None,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
self.address = address.rstrip("/")
|
||||
self.token = token or os.environ.get("NOMAD_TOKEN")
|
||||
self.timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create HTTP session."""
|
||||
if self._session is None or self._session.closed:
|
||||
headers = {}
|
||||
if self.token:
|
||||
headers["X-Nomad-Token"] = self.token
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an HTTP request to Nomad API."""
|
||||
session = await self._get_session()
|
||||
url = f"{self.address}{path}"
|
||||
|
||||
try:
|
||||
async with session.request(method, url, json=data) as response:
|
||||
if response.status == 404:
|
||||
return {"error": "not_found", "status": 404}
|
||||
|
||||
text = await response.text()
|
||||
if not text:
|
||||
return {"status": response.status}
|
||||
|
||||
try:
|
||||
result = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return {"text": text, "status": response.status}
|
||||
|
||||
if response.status >= 400:
|
||||
return {"error": result, "status": response.status}
|
||||
|
||||
return result if isinstance(result, dict) else {"data": result, "status": response.status}
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return {"error": str(e), "status": 0}
|
||||
|
||||
# Job Operations
|
||||
|
||||
async def submit_job(self, job_spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Submit a job to Nomad.
|
||||
|
||||
Args:
|
||||
job_spec: Job specification dict (HCL converted to JSON)
|
||||
|
||||
Returns:
|
||||
Response with EvalID if successful
|
||||
"""
|
||||
return await self._request("POST", "/v1/jobs", {"Job": job_spec})
|
||||
|
||||
async def stop_job(self, job_id: str, purge: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Stop (and optionally purge) a job.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier
|
||||
purge: If True, completely remove the job
|
||||
"""
|
||||
path = f"/v1/job/{job_id}"
|
||||
if purge:
|
||||
path += "?purge=true"
|
||||
return await self._request("DELETE", path)
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get job details."""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}")
|
||||
if "error" in result and result.get("status") == 404:
|
||||
return None
|
||||
return result
|
||||
|
||||
async def get_job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
"""Get job status with allocations."""
|
||||
job = await self.get_job(job_id)
|
||||
if not job:
|
||||
return None
|
||||
|
||||
allocs = await self.get_job_allocations(job_id)
|
||||
|
||||
# Get count from task groups
|
||||
count = 0
|
||||
task_groups = job.get("TaskGroups", [])
|
||||
for tg in task_groups:
|
||||
count += tg.get("Count", 1)
|
||||
|
||||
return JobStatus(
|
||||
id=job_id,
|
||||
name=job.get("Name", job_id),
|
||||
status=job.get("Status", "unknown"),
|
||||
allocations=allocs,
|
||||
count=count,
|
||||
)
|
||||
|
||||
# Allocation Operations
|
||||
|
||||
async def get_job_allocations(self, job_id: str) -> List[Allocation]:
|
||||
"""Get all allocations for a job."""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}/allocations")
|
||||
|
||||
if "error" in result:
|
||||
return []
|
||||
|
||||
allocs_data = result.get("data", result) if isinstance(result, dict) else result
|
||||
if not isinstance(allocs_data, list):
|
||||
return []
|
||||
|
||||
allocations = []
|
||||
for alloc_data in allocs_data:
|
||||
# Parse allocation info
|
||||
alloc_id = alloc_data.get("ID", "")
|
||||
status_str = alloc_data.get("ClientStatus", "unknown")
|
||||
|
||||
try:
|
||||
status = AllocationStatus(status_str)
|
||||
except ValueError:
|
||||
status = AllocationStatus.PENDING
|
||||
|
||||
# Get network info - need to fetch detailed allocation for this
|
||||
address = None
|
||||
port = None
|
||||
|
||||
# First try the summary data
|
||||
resources = alloc_data.get("AllocatedResources") or {}
|
||||
shared = resources.get("Shared") or {}
|
||||
networks = shared.get("Networks") or []
|
||||
|
||||
# If no networks in summary, fetch detailed allocation
|
||||
if not networks and alloc_id:
|
||||
detailed = await self.get_allocation(alloc_id)
|
||||
if detailed:
|
||||
resources = detailed.get("AllocatedResources") or {}
|
||||
shared = resources.get("Shared") or {}
|
||||
networks = shared.get("Networks") or []
|
||||
|
||||
if networks:
|
||||
network = networks[0]
|
||||
address = network.get("IP")
|
||||
# Look for dynamic ports OR reserved ports (Singularity/raw_exec uses reserved)
|
||||
dyn_ports = network.get("DynamicPorts") or []
|
||||
reserved_ports = network.get("ReservedPorts") or []
|
||||
for dp in dyn_ports + reserved_ports:
|
||||
if dp.get("Label") == "http":
|
||||
port = dp.get("Value")
|
||||
break
|
||||
|
||||
allocations.append(Allocation(
|
||||
id=alloc_id,
|
||||
job_id=job_id,
|
||||
task_group=alloc_data.get("TaskGroup", ""),
|
||||
node_id=alloc_data.get("NodeID", ""),
|
||||
status=status,
|
||||
address=address,
|
||||
port=port,
|
||||
))
|
||||
|
||||
return allocations
|
||||
|
||||
async def get_allocation(self, alloc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get detailed allocation info."""
|
||||
result = await self._request("GET", f"/v1/allocation/{alloc_id}")
|
||||
if "error" in result and result.get("status") == 404:
|
||||
return None
|
||||
return result
|
||||
|
||||
# Scaling Operations
|
||||
|
||||
async def scale_job(self, job_id: str, count: int, task_group: str = "sandbox") -> Dict[str, Any]:
|
||||
"""
|
||||
Scale a job's task group to specified count.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier
|
||||
count: Desired number of allocations
|
||||
task_group: Name of task group to scale
|
||||
"""
|
||||
payload = {
|
||||
"Count": count,
|
||||
"Target": {
|
||||
"Group": task_group,
|
||||
},
|
||||
}
|
||||
return await self._request("POST", f"/v1/job/{job_id}/scale", payload)
|
||||
|
||||
async def get_job_scale_status(self, job_id: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get current scale status for a job.
|
||||
|
||||
Returns:
|
||||
Dict mapping task group name to count
|
||||
"""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}/scale")
|
||||
|
||||
if "error" in result:
|
||||
return {}
|
||||
|
||||
task_groups = result.get("TaskGroups", {})
|
||||
return {
|
||||
name: info.get("Running", 0)
|
||||
for name, info in task_groups.items()
|
||||
}
|
||||
|
||||
# Health Check
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
"""Check if Nomad is reachable and healthy."""
|
||||
try:
|
||||
result = await self._request("GET", "/v1/status/leader")
|
||||
return "error" not in result
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_leader(self) -> Optional[str]:
|
||||
"""Get current Nomad leader address."""
|
||||
result = await self._request("GET", "/v1/status/leader")
|
||||
if isinstance(result, dict) and "data" in result:
|
||||
return result["data"]
|
||||
return None
|
||||
|
||||
|
||||
def load_job_template(
|
||||
template_name: str = "sandbox",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load and configure a job template.
|
||||
|
||||
Args:
|
||||
template_name: Name of template (e.g., "sandbox")
|
||||
**kwargs: Template variables to substitute
|
||||
|
||||
Returns:
|
||||
Job specification dict ready for Nomad API
|
||||
"""
|
||||
# Default job template for sandbox container
|
||||
if template_name == "sandbox":
|
||||
return create_sandbox_job(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown template: {template_name}")
|
||||
|
||||
|
||||
def create_sandbox_job(
|
||||
job_id: str = "atropos-sandbox",
|
||||
image: str = "atropos-sandbox:local", # Use :local tag to avoid registry pull
|
||||
count: int = 1,
|
||||
slots_per_container: int = 10,
|
||||
privileged: bool = False,
|
||||
cpu: int = 500,
|
||||
memory: int = 512,
|
||||
port: int = 8080,
|
||||
datacenter: str = "dc1",
|
||||
driver: str = "docker", # "docker" or "singularity"
|
||||
singularity_image: str = None, # Path to .sif file for singularity driver
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a sandbox job specification.
|
||||
|
||||
This job runs the sandbox_server.py inside a container,
|
||||
with the specified number of slots for agent workspaces.
|
||||
|
||||
Args:
|
||||
job_id: Unique job identifier
|
||||
image: Docker image to use (for docker driver)
|
||||
count: Number of container instances
|
||||
slots_per_container: Number of slots per container
|
||||
privileged: Run container in privileged mode (recommended for bubblewrap)
|
||||
cpu: CPU allocation in MHz
|
||||
memory: Memory allocation in MB
|
||||
port: HTTP port for sandbox server
|
||||
datacenter: Nomad datacenter
|
||||
driver: Container driver - "docker" or "singularity"
|
||||
singularity_image: Path to .sif file (required if driver="singularity")
|
||||
|
||||
Returns:
|
||||
Job specification dict
|
||||
"""
|
||||
# Build task config based on driver
|
||||
if driver == "singularity":
|
||||
if not singularity_image:
|
||||
raise ValueError("singularity_image path required when driver='singularity'")
|
||||
|
||||
# Use raw_exec driver to run apptainer via shell for variable expansion
|
||||
# The container binds the allocation directory for workspace persistence
|
||||
# For raw_exec, we use static port since Nomad's dynamic port mapping doesn't
|
||||
# work the same as Docker - the process runs directly on the host.
|
||||
shell_cmd = (
|
||||
f'apptainer run '
|
||||
f'--bind "$NOMAD_ALLOC_DIR/data:/data" '
|
||||
f'--pwd /app '
|
||||
f'--env PYTHONUNBUFFERED=1 '
|
||||
f'{singularity_image} '
|
||||
f'python sandbox_server.py '
|
||||
f'--port {port} '
|
||||
f'--slots {slots_per_container} '
|
||||
f'--data-dir /data'
|
||||
)
|
||||
task_config = {
|
||||
"command": "/bin/sh",
|
||||
"args": ["-c", shell_cmd],
|
||||
}
|
||||
task_driver = "raw_exec"
|
||||
else:
|
||||
# Docker driver (default)
|
||||
task_config = {
|
||||
"image": image,
|
||||
"force_pull": False, # Use local image, don't try to pull
|
||||
"ports": ["http"],
|
||||
"privileged": privileged,
|
||||
"command": "python",
|
||||
"args": [
|
||||
"sandbox_server.py",
|
||||
"--port", str(port),
|
||||
"--slots", str(slots_per_container),
|
||||
"--data-dir", "/data",
|
||||
],
|
||||
# Note: On Linux, you can mount persistent storage:
|
||||
# "volumes": ["${NOMAD_ALLOC_DIR}/data:/data"],
|
||||
# On macOS/Docker Desktop, skip volumes for PoC
|
||||
# (container /data is ephemeral but works for testing)
|
||||
}
|
||||
task_driver = "docker"
|
||||
|
||||
# For Singularity/raw_exec, use static ports since the process runs directly on host.
|
||||
# For Docker, use dynamic ports with port mapping.
|
||||
if driver == "singularity":
|
||||
network_config = {
|
||||
"Mode": "host",
|
||||
"ReservedPorts": [
|
||||
{
|
||||
"Label": "http",
|
||||
"Value": port,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
network_config = {
|
||||
"Mode": "host",
|
||||
"DynamicPorts": [
|
||||
{
|
||||
"Label": "http",
|
||||
"To": port,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
return {
|
||||
"ID": job_id,
|
||||
"Name": job_id,
|
||||
"Type": "service",
|
||||
"Datacenters": [datacenter],
|
||||
"TaskGroups": [
|
||||
{
|
||||
"Name": "sandbox",
|
||||
"Count": count,
|
||||
# Speed up deployments and avoid Consul checks. Without this, Nomad may
|
||||
# keep an "active deployment" around for the default MinHealthyTime,
|
||||
# which blocks immediate scaling under load.
|
||||
"Update": {
|
||||
"HealthCheck": "task_states",
|
||||
"MinHealthyTime": 0,
|
||||
},
|
||||
"Networks": [network_config],
|
||||
"Tasks": [
|
||||
{
|
||||
"Name": "sandbox-server",
|
||||
"Driver": task_driver,
|
||||
"Config": task_config,
|
||||
"Env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"NOMAD_ALLOC_DIR": "${NOMAD_ALLOC_DIR}",
|
||||
},
|
||||
"Resources": {
|
||||
"CPU": cpu,
|
||||
"MemoryMB": memory,
|
||||
},
|
||||
# Note: Services with Checks require Consul, which we skip for the PoC
|
||||
}
|
||||
],
|
||||
"RestartPolicy": {
|
||||
"Attempts": 3,
|
||||
"Interval": 300_000_000_000, # 5 minutes
|
||||
"Delay": 10_000_000_000, # 10 seconds
|
||||
"Mode": "delay",
|
||||
},
|
||||
"ReschedulePolicy": {
|
||||
"Attempts": 5,
|
||||
"Interval": 3600_000_000_000, # 1 hour
|
||||
"Delay": 30_000_000_000, # 30 seconds
|
||||
"DelayFunction": "exponential",
|
||||
"MaxDelay": 300_000_000_000, # 5 minutes
|
||||
"Unlimited": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Slot-based multiplexing for atropos-agent.
|
||||
|
||||
Provides:
|
||||
- Slot: Isolated workspace for a single trajectory
|
||||
- SlotPool: Manages slots across Nomad allocations
|
||||
- SandboxExecutor: Executes tools in sandbox containers
|
||||
"""
|
||||
|
||||
from .executor import SandboxExecutor
|
||||
from .pool import SlotPool, SlotPoolConfig
|
||||
from .slot import Slot, SlotState
|
||||
|
||||
__all__ = [
|
||||
"Slot",
|
||||
"SlotState",
|
||||
"SlotPool",
|
||||
"SlotPoolConfig",
|
||||
"SandboxExecutor",
|
||||
]
|
||||
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
SandboxExecutor - HTTP client for sandbox container communication.
|
||||
|
||||
Sends tool execution requests to sandbox_server.py running inside Nomad containers.
|
||||
Supports single and batch execution for efficiency.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .slot import Slot, SlotState
|
||||
from ..tools.base import ToolCall, ToolResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionRequest:
|
||||
"""Request to execute a tool in a slot."""
|
||||
slot: Slot
|
||||
tool_name: str
|
||||
args: Dict[str, Any]
|
||||
execution_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
timeout: float = 30.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result from sandbox execution."""
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
execution_id: str = ""
|
||||
slot_id: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_tool_result(self) -> ToolResult:
|
||||
"""Convert to ToolResult for agent consumption."""
|
||||
return ToolResult(
|
||||
success=self.success,
|
||||
output=self.output,
|
||||
error=self.error,
|
||||
metadata=self.metadata,
|
||||
uniq_id=self.execution_id,
|
||||
)
|
||||
|
||||
|
||||
class SandboxExecutor:
|
||||
"""
|
||||
HTTP client for executing tools in sandbox containers.
|
||||
|
||||
Communicates with sandbox_server.py running inside Nomad allocations.
|
||||
Supports both single execution and batched parallel execution.
|
||||
|
||||
Usage:
|
||||
executor = SandboxExecutor()
|
||||
|
||||
# Single execution
|
||||
result = await executor.execute(slot, "bash", {"command": "ls"})
|
||||
|
||||
# Batch execution
|
||||
results = await executor.execute_batch([
|
||||
(slot1, "bash", {"command": "ls"}),
|
||||
(slot2, "write_file", {"path": "test.txt", "content": "hello"}),
|
||||
])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
):
|
||||
self.timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create HTTP session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession(timeout=self.timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
slot: Slot,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute a tool in a slot's workspace.
|
||||
|
||||
Args:
|
||||
slot: Slot to execute in
|
||||
tool_name: Name of tool (bash, read_file, write_file)
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
ExecutionResult with output or error
|
||||
"""
|
||||
execution_id = str(uuid.uuid4())
|
||||
exec_timeout = timeout or self.timeout.total or 30.0
|
||||
|
||||
# Mark slot as executing
|
||||
original_state = slot.state
|
||||
try:
|
||||
if slot.state == SlotState.ACQUIRED:
|
||||
slot.start_execution(execution_id)
|
||||
|
||||
result = await self._send_execute_request(
|
||||
container_addr=slot.container_addr,
|
||||
slot_id=slot.slot_id,
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
execution_id=execution_id,
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
result.slot_id = slot.slot_id
|
||||
return result
|
||||
|
||||
finally:
|
||||
# Restore slot state
|
||||
if slot.state == SlotState.EXECUTING:
|
||||
slot.end_execution()
|
||||
|
||||
async def _send_execute_request(
|
||||
self,
|
||||
container_addr: str,
|
||||
slot_id: str,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
execution_id: str,
|
||||
timeout: float,
|
||||
) -> ExecutionResult:
|
||||
"""Send execution request to sandbox server with retry logic."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/execute"
|
||||
|
||||
payload = {
|
||||
"slot_id": slot_id,
|
||||
"tool": tool_name,
|
||||
"args": args,
|
||||
"execution_id": execution_id,
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
data = await response.json()
|
||||
|
||||
return ExecutionResult(
|
||||
success=data.get("success", False),
|
||||
output=data.get("output", ""),
|
||||
error=data.get("error", ""),
|
||||
execution_id=data.get("execution_id", execution_id),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
last_error = str(e)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1))
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"Request timed out after {timeout}s"
|
||||
break
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
break
|
||||
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"Failed after {self.max_retries} attempts: {last_error}",
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
timeout: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""
|
||||
Execute multiple tools in parallel across slots.
|
||||
|
||||
This is the key optimization - we batch tool calls to maximize
|
||||
container utilization while agents are waiting for LLM responses.
|
||||
|
||||
Args:
|
||||
requests: List of (slot, tool_name, args) tuples
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
List of ExecutionResults in same order as requests
|
||||
"""
|
||||
if not requests:
|
||||
return []
|
||||
|
||||
# Group requests by container address for batch API
|
||||
by_container: Dict[str, List[Tuple[int, Slot, str, Dict[str, Any], str]]] = {}
|
||||
|
||||
for idx, (slot, tool_name, args) in enumerate(requests):
|
||||
execution_id = str(uuid.uuid4())
|
||||
container = slot.container_addr
|
||||
|
||||
if container not in by_container:
|
||||
by_container[container] = []
|
||||
by_container[container].append((idx, slot, tool_name, args, execution_id))
|
||||
|
||||
# Mark slots as executing
|
||||
if slot.state == SlotState.ACQUIRED:
|
||||
slot.start_execution(execution_id)
|
||||
|
||||
# Execute batches in parallel
|
||||
exec_timeout = timeout or self.timeout.total or 30.0
|
||||
batch_tasks = []
|
||||
|
||||
for container_addr, batch_requests in by_container.items():
|
||||
task = self._send_batch_request(
|
||||
container_addr=container_addr,
|
||||
batch_requests=batch_requests,
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
batch_tasks.append(task)
|
||||
|
||||
# Gather all batch results
|
||||
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
# Collect results in original order
|
||||
results: List[Optional[ExecutionResult]] = [None] * len(requests)
|
||||
|
||||
for batch_result in batch_results:
|
||||
if isinstance(batch_result, Exception):
|
||||
# Mark all in this batch as failed
|
||||
continue
|
||||
|
||||
for idx, result in batch_result:
|
||||
results[idx] = result
|
||||
|
||||
# Fill in any missing results
|
||||
for idx, result in enumerate(results):
|
||||
if result is None:
|
||||
slot, tool_name, args = requests[idx]
|
||||
results[idx] = ExecutionResult(
|
||||
success=False,
|
||||
error="Batch execution failed",
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
|
||||
# End execution on all slots
|
||||
for slot, _, _ in requests:
|
||||
if slot.state == SlotState.EXECUTING:
|
||||
slot.end_execution()
|
||||
|
||||
return results # type: ignore
|
||||
|
||||
async def _send_batch_request(
|
||||
self,
|
||||
container_addr: str,
|
||||
batch_requests: List[Tuple[int, Slot, str, Dict[str, Any], str]],
|
||||
timeout: float,
|
||||
) -> List[Tuple[int, ExecutionResult]]:
|
||||
"""Send batch execution request to a single container."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/batch"
|
||||
|
||||
# Build batch payload
|
||||
payload = [
|
||||
{
|
||||
"slot_id": slot.slot_id,
|
||||
"tool": tool_name,
|
||||
"args": args,
|
||||
"execution_id": execution_id,
|
||||
"timeout": timeout,
|
||||
}
|
||||
for _, slot, tool_name, args, execution_id in batch_requests
|
||||
]
|
||||
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
data = await response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValueError(f"Expected list response, got {type(data)}")
|
||||
|
||||
results = []
|
||||
for i, (idx, slot, _, _, execution_id) in enumerate(batch_requests):
|
||||
if i < len(data):
|
||||
item = data[i]
|
||||
result = ExecutionResult(
|
||||
success=item.get("success", False),
|
||||
output=item.get("output", ""),
|
||||
error=item.get("error", ""),
|
||||
execution_id=item.get("execution_id", execution_id),
|
||||
slot_id=slot.slot_id,
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
else:
|
||||
result = ExecutionResult(
|
||||
success=False,
|
||||
error="Missing result in batch response",
|
||||
execution_id=execution_id,
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
results.append((idx, result))
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
# Return error for all requests in batch
|
||||
return [
|
||||
(idx, ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
execution_id=execution_id,
|
||||
slot_id=slot.slot_id,
|
||||
))
|
||||
for idx, slot, _, _, execution_id in batch_requests
|
||||
]
|
||||
|
||||
async def reset_slot(self, slot: Slot) -> ExecutionResult:
|
||||
"""
|
||||
Reset a slot's workspace (delete all files).
|
||||
|
||||
Useful when reusing a slot for a new trajectory.
|
||||
"""
|
||||
session = await self._get_session()
|
||||
url = f"{slot.container_addr}/reset"
|
||||
|
||||
try:
|
||||
async with session.post(url, json={"slot_id": slot.slot_id}) as response:
|
||||
data = await response.json()
|
||||
return ExecutionResult(
|
||||
success=data.get("success", False),
|
||||
output=data.get("output", ""),
|
||||
error=data.get("error", ""),
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
|
||||
async def health_check(self, container_addr: str) -> bool:
|
||||
"""Check if a sandbox container is healthy."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/health"
|
||||
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
data = await response.json()
|
||||
return data.get("status") == "ok"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_container_status(
|
||||
self,
|
||||
container_addr: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get status info from a sandbox container."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/health"
|
||||
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
return await response.json()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Artifact helpers (optional)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _post_json(
|
||||
self,
|
||||
url: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
session = await self._get_session()
|
||||
try:
|
||||
async with session.post(url, json=payload, timeout=timeout) as response:
|
||||
data = await response.json()
|
||||
if isinstance(data, dict):
|
||||
data.setdefault("http_status", response.status)
|
||||
return data
|
||||
return {"success": False, "error": f"Unexpected response type: {type(data)}", "http_status": response.status}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/read"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "encoding": encoding, "include_sha256": include_sha256}
|
||||
if max_bytes is not None:
|
||||
payload["max_bytes"] = max_bytes
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/list"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "recursive": recursive}
|
||||
if max_entries is not None:
|
||||
payload["max_entries"] = max_entries
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/archive"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "format": archive_format}
|
||||
if max_bytes is not None:
|
||||
payload["max_bytes"] = max_bytes
|
||||
if max_entries is not None:
|
||||
payload["max_entries"] = max_entries
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
@@ -0,0 +1,659 @@
|
||||
"""
|
||||
SlotPool - Manages slots across Nomad allocations.
|
||||
|
||||
The SlotPool is the core abstraction for slot-based multiplexing:
|
||||
- Tracks available/acquired slots across containers
|
||||
- Handles slot acquisition and release
|
||||
- Auto-scales Nomad job count based on demand
|
||||
- Provides batched tool execution
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..nomad.client import (
|
||||
Allocation,
|
||||
AllocationStatus,
|
||||
NomadClient,
|
||||
create_sandbox_job,
|
||||
)
|
||||
from .executor import ExecutionResult, SandboxExecutor
|
||||
from .slot import Slot, SlotState, create_slots_for_allocation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlotPoolConfig:
|
||||
"""Configuration for SlotPool."""
|
||||
|
||||
# Nomad settings
|
||||
nomad_address: str = "http://localhost:4646"
|
||||
job_id: str = "atropos-sandbox"
|
||||
datacenter: str = "dc1"
|
||||
|
||||
# Container settings
|
||||
image: str = "atropos-sandbox:local" # Use :local tag to avoid registry pull
|
||||
slots_per_container: int = 10
|
||||
privileged: bool = False
|
||||
cpu: int = 500 # MHz
|
||||
memory: int = 512 # MB
|
||||
|
||||
# Driver selection: "docker" or "singularity"
|
||||
driver: str = "docker"
|
||||
# Path to .sif file for singularity driver (required if driver="singularity")
|
||||
singularity_image: Optional[str] = None
|
||||
|
||||
# Scaling settings
|
||||
min_containers: int = 1
|
||||
max_containers: int = 10
|
||||
|
||||
# Timeouts
|
||||
acquire_timeout: float = 30.0 # Seconds between acquire polls (also triggers scale-up attempts)
|
||||
health_check_interval: float = 30.0 # Seconds between health checks
|
||||
scale_cooldown: float = 60.0 # Seconds between scale operations
|
||||
|
||||
# Job lifecycle
|
||||
purge_job_on_start: bool = False # Purge any pre-existing job before starting (local dev/training friendly)
|
||||
|
||||
# Local Docker image convenience (macOS/Nomad dev mode)
|
||||
auto_build_local_image: bool = True # If image endswith :local and is missing, build it from the bundled Dockerfile.
|
||||
dockerfile_path: Optional[str] = None # Override Dockerfile path (default: Hermes-Agent/atropos/Dockerfile).
|
||||
docker_build_context: Optional[str] = None # Override build context (default: Hermes-Agent/atropos).
|
||||
|
||||
|
||||
class SlotPool:
|
||||
"""
|
||||
Manages a pool of slots across Nomad allocations.
|
||||
|
||||
The SlotPool:
|
||||
- Deploys sandbox containers to Nomad
|
||||
- Tracks slots across all running containers
|
||||
- Handles slot acquisition/release
|
||||
- Auto-scales based on demand
|
||||
- Provides batched execution via SandboxExecutor
|
||||
|
||||
Usage:
|
||||
config = SlotPoolConfig(
|
||||
nomad_address="http://localhost:4646",
|
||||
job_id="my-sandbox",
|
||||
slots_per_container=10,
|
||||
)
|
||||
|
||||
pool = SlotPool(config)
|
||||
await pool.start()
|
||||
|
||||
# Acquire a slot
|
||||
slot = await pool.acquire()
|
||||
|
||||
# Execute tool
|
||||
result = await pool.execute(slot, "bash", {"command": "ls"})
|
||||
|
||||
# Release slot
|
||||
await pool.release(slot)
|
||||
|
||||
# Shutdown
|
||||
await pool.stop()
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SlotPoolConfig] = None):
|
||||
self.config = config or SlotPoolConfig()
|
||||
|
||||
# Nomad client
|
||||
self.nomad = NomadClient(address=self.config.nomad_address)
|
||||
|
||||
# Sandbox executor for tool execution
|
||||
self.executor = SandboxExecutor()
|
||||
|
||||
# Slot tracking
|
||||
self._slots: Dict[str, Slot] = {} # slot_key -> Slot
|
||||
self._available_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._lock = asyncio.Lock()
|
||||
self._scale_lock = asyncio.Lock()
|
||||
|
||||
# State
|
||||
self._started = False
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
self._scale_task: Optional[asyncio.Task] = None
|
||||
self._last_scale_time = 0.0
|
||||
|
||||
def _default_dockerfile_path(self) -> Path:
|
||||
# Hermes-Agent/atropos/Dockerfile lives next to this module in source checkouts.
|
||||
return Path(__file__).resolve().parents[1] / "Dockerfile"
|
||||
|
||||
def _default_build_context(self) -> Path:
|
||||
return Path(__file__).resolve().parents[1]
|
||||
|
||||
def _docker_image_exists(self, image: str) -> bool:
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
["docker", "image", "inspect", image],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
check=False,
|
||||
env={**os.environ, "DOCKER_CLI_HINTS": "false"},
|
||||
)
|
||||
return proc.returncode == 0
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
def _try_build_local_image(self, image: str) -> None:
|
||||
dockerfile = Path(self.config.dockerfile_path) if self.config.dockerfile_path else self._default_dockerfile_path()
|
||||
context = Path(self.config.docker_build_context) if self.config.docker_build_context else self._default_build_context()
|
||||
|
||||
if not dockerfile.exists():
|
||||
raise RuntimeError(
|
||||
f"Sandbox Dockerfile not found at {dockerfile}. "
|
||||
"Build the sandbox image manually or set --env.purge_job_on_start false and provide a non-local image."
|
||||
)
|
||||
if not context.exists():
|
||||
raise RuntimeError(f"Docker build context not found at {context}")
|
||||
|
||||
# Prefer buildx+--load to ensure the image ends up in the local daemon (required by Nomad's docker driver).
|
||||
buildx_cmd = [
|
||||
"docker",
|
||||
"buildx",
|
||||
"build",
|
||||
"--load",
|
||||
"-t",
|
||||
image,
|
||||
"-f",
|
||||
str(dockerfile),
|
||||
str(context),
|
||||
]
|
||||
proc = subprocess.run(buildx_cmd, check=False, env={**os.environ, "DOCKER_CLI_HINTS": "false"})
|
||||
if proc.returncode == 0:
|
||||
return
|
||||
|
||||
# Fallback to classic docker build if buildx isn't available.
|
||||
build_cmd = ["docker", "build", "-t", image, "-f", str(dockerfile), str(context)]
|
||||
proc2 = subprocess.run(build_cmd, check=False, env={**os.environ, "DOCKER_CLI_HINTS": "false"})
|
||||
if proc2.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Failed to build local sandbox image {image}. "
|
||||
f"Tried: {' '.join(buildx_cmd)} and {' '.join(build_cmd)}"
|
||||
)
|
||||
|
||||
def _ensure_local_image(self) -> None:
|
||||
image = (self.config.image or "").strip()
|
||||
if not image.endswith(":local"):
|
||||
return
|
||||
if not self.config.auto_build_local_image:
|
||||
return
|
||||
|
||||
if self._docker_image_exists(image):
|
||||
return
|
||||
|
||||
logger.info(f"Local sandbox image {image} not found; building it now...")
|
||||
self._try_build_local_image(image)
|
||||
|
||||
def _slot_key(self, alloc_id: str, slot_id: str) -> str:
|
||||
"""Generate unique key for a slot."""
|
||||
return f"{alloc_id}:{slot_id}"
|
||||
|
||||
@property
|
||||
def total_slots(self) -> int:
|
||||
"""Total number of slots in pool."""
|
||||
return len(self._slots)
|
||||
|
||||
@property
|
||||
def available_slots(self) -> int:
|
||||
"""Number of available slots."""
|
||||
return sum(1 for s in self._slots.values() if s.is_available)
|
||||
|
||||
@property
|
||||
def acquired_slots(self) -> int:
|
||||
"""Number of acquired slots."""
|
||||
return sum(1 for s in self._slots.values() if s.is_acquired)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the slot pool.
|
||||
|
||||
- Checks if Nomad is healthy
|
||||
- Deploys sandbox job if not running
|
||||
- Discovers existing allocations
|
||||
- Starts health check background task
|
||||
"""
|
||||
if self._started:
|
||||
return
|
||||
|
||||
logger.info(f"Starting SlotPool (job_id={self.config.job_id})")
|
||||
|
||||
try:
|
||||
# Make sure local sandbox images exist before Nomad tries to pull them.
|
||||
# This is a common footgun in macOS dev mode with :local tags.
|
||||
self._ensure_local_image()
|
||||
|
||||
# Check Nomad health
|
||||
if not await self.nomad.is_healthy():
|
||||
raise RuntimeError(f"Nomad is not reachable at {self.config.nomad_address}")
|
||||
|
||||
if self.config.purge_job_on_start:
|
||||
logger.info(f"Purging any existing Nomad job: {self.config.job_id}")
|
||||
await self.nomad.stop_job(self.config.job_id, purge=True)
|
||||
|
||||
# Check if job exists (after optional purge)
|
||||
job = await self.nomad.get_job(self.config.job_id)
|
||||
|
||||
if job is None:
|
||||
# Deploy new job
|
||||
logger.info(f"Deploying sandbox job: {self.config.job_id} (driver={self.config.driver})")
|
||||
job_spec = create_sandbox_job(
|
||||
job_id=self.config.job_id,
|
||||
image=self.config.image,
|
||||
count=self.config.min_containers,
|
||||
slots_per_container=self.config.slots_per_container,
|
||||
privileged=self.config.privileged,
|
||||
cpu=self.config.cpu,
|
||||
memory=self.config.memory,
|
||||
datacenter=self.config.datacenter,
|
||||
driver=self.config.driver,
|
||||
singularity_image=self.config.singularity_image,
|
||||
)
|
||||
result = await self.nomad.submit_job(job_spec)
|
||||
if "error" in result:
|
||||
raise RuntimeError(f"Failed to submit job: {result}")
|
||||
|
||||
# Wait for allocations to be running (even if the job already existed).
|
||||
await self._wait_for_healthy_allocations(self.config.min_containers)
|
||||
|
||||
# Discover existing allocations and slots
|
||||
await self._refresh_slots()
|
||||
|
||||
# Start health check task
|
||||
self._health_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
self._started = True
|
||||
logger.info(f"SlotPool started: {self.total_slots} slots available")
|
||||
except Exception:
|
||||
# Ensure aiohttp sessions are not leaked if we fail to start.
|
||||
await self.stop(purge_job=False)
|
||||
raise
|
||||
|
||||
async def stop(self, purge_job: bool = False) -> None:
|
||||
"""
|
||||
Stop the slot pool.
|
||||
|
||||
Args:
|
||||
purge_job: If True, also stop the Nomad job
|
||||
"""
|
||||
logger.info("Stopping SlotPool")
|
||||
|
||||
# Cancel health check task
|
||||
if self._health_task:
|
||||
self._health_task.cancel()
|
||||
try:
|
||||
await self._health_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._health_task = None
|
||||
|
||||
if self._scale_task:
|
||||
self._scale_task.cancel()
|
||||
try:
|
||||
await self._scale_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._scale_task = None
|
||||
|
||||
# Optionally stop the job (do this even if start() never completed).
|
||||
if purge_job:
|
||||
logger.info(f"Stopping Nomad job: {self.config.job_id}")
|
||||
await self.nomad.stop_job(self.config.job_id, purge=True)
|
||||
|
||||
# Close connections
|
||||
await self.executor.close()
|
||||
await self.nomad.close()
|
||||
|
||||
self._started = False
|
||||
self._slots.clear()
|
||||
|
||||
# Clear the queue
|
||||
while not self._available_queue.empty():
|
||||
try:
|
||||
self._available_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
"""
|
||||
Acquire an available slot.
|
||||
|
||||
If no slots are available, waits up to acquire_timeout seconds.
|
||||
If still no slots, attempts to scale up.
|
||||
|
||||
Args:
|
||||
trajectory_id: Optional ID of trajectory acquiring the slot
|
||||
|
||||
Returns:
|
||||
Acquired Slot
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If no slot becomes available
|
||||
"""
|
||||
if not self._started:
|
||||
raise RuntimeError("SlotPool not started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Try to get an available slot
|
||||
slot_key = await asyncio.wait_for(
|
||||
self._available_queue.get(),
|
||||
timeout=self.config.acquire_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Try to scale up, but keep waiting even if scaling isn't possible.
|
||||
# In practice, slots may become available shortly (e.g. contention),
|
||||
# and scaling may be temporarily blocked by Nomad deployments.
|
||||
await self._try_scale_up()
|
||||
continue
|
||||
|
||||
slot = self._slots.get(slot_key)
|
||||
if slot is None:
|
||||
# Slot was removed; discard stale queue entry and retry.
|
||||
continue
|
||||
|
||||
try:
|
||||
slot.acquire(trajectory_id)
|
||||
except RuntimeError:
|
||||
# Slot isn't actually available (e.g. duplicate queue entry); retry.
|
||||
continue
|
||||
|
||||
logger.debug(f"Acquired slot {slot.slot_id} (alloc={slot.alloc_id[:8]})")
|
||||
return slot
|
||||
|
||||
async def release(self, slot: Slot, reset_workspace: bool = False) -> None:
|
||||
"""
|
||||
Release a slot back to the pool.
|
||||
|
||||
Args:
|
||||
slot: Slot to release
|
||||
reset_workspace: If True, clear the workspace files
|
||||
"""
|
||||
slot_key = self._slot_key(slot.alloc_id, slot.slot_id)
|
||||
|
||||
if slot_key not in self._slots:
|
||||
logger.warning(f"Releasing unknown slot: {slot_key}")
|
||||
return
|
||||
|
||||
# Optionally reset workspace
|
||||
if reset_workspace:
|
||||
await self.executor.reset_slot(slot)
|
||||
|
||||
slot.release()
|
||||
await self._available_queue.put(slot_key)
|
||||
|
||||
logger.debug(f"Released slot {slot.slot_id}")
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
slot: Slot,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute a tool in a slot's workspace.
|
||||
|
||||
Args:
|
||||
slot: Slot to execute in
|
||||
tool_name: Name of tool (bash, read_file, write_file)
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
ExecutionResult
|
||||
"""
|
||||
return await self.executor.execute(slot, tool_name, args, timeout)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
timeout: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""
|
||||
Execute multiple tools in parallel.
|
||||
|
||||
This is the key optimization - batch execution across multiple slots
|
||||
maximizes container utilization.
|
||||
|
||||
Args:
|
||||
requests: List of (slot, tool_name, args) tuples
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
List of ExecutionResults in same order
|
||||
"""
|
||||
return await self.executor.execute_batch(requests, timeout)
|
||||
|
||||
async def _refresh_slots(self) -> None:
|
||||
"""Refresh slot inventory from Nomad allocations."""
|
||||
async with self._lock:
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
|
||||
# Track which slots we've seen
|
||||
seen_keys = set()
|
||||
|
||||
for alloc in allocs:
|
||||
if alloc.status != AllocationStatus.RUNNING:
|
||||
continue
|
||||
|
||||
if not alloc.http_address:
|
||||
continue
|
||||
|
||||
# Check container health
|
||||
healthy = await self.executor.health_check(alloc.http_address)
|
||||
if not healthy:
|
||||
continue
|
||||
|
||||
# Create slots for this allocation
|
||||
for i in range(self.config.slots_per_container):
|
||||
slot_id = f"slot_{i}"
|
||||
slot_key = self._slot_key(alloc.id, slot_id)
|
||||
seen_keys.add(slot_key)
|
||||
|
||||
if slot_key not in self._slots:
|
||||
# New slot
|
||||
slot = Slot(
|
||||
slot_id=slot_id,
|
||||
alloc_id=alloc.id,
|
||||
container_addr=alloc.http_address,
|
||||
)
|
||||
self._slots[slot_key] = slot
|
||||
await self._available_queue.put(slot_key)
|
||||
logger.debug(f"Added slot: {slot_key}")
|
||||
|
||||
# Remove slots from dead allocations
|
||||
for slot_key in list(self._slots.keys()):
|
||||
if slot_key not in seen_keys:
|
||||
slot = self._slots.pop(slot_key)
|
||||
logger.debug(f"Removed slot: {slot_key}")
|
||||
|
||||
async def _wait_for_healthy_allocations(
|
||||
self,
|
||||
min_count: int,
|
||||
timeout: float = 120.0
|
||||
) -> None:
|
||||
"""Wait for allocations to become healthy."""
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
def _summarize_alloc_detail(detail: Dict[str, Any]) -> str:
|
||||
task_states = detail.get("TaskStates") or {}
|
||||
parts: List[str] = []
|
||||
if isinstance(task_states, dict):
|
||||
for task_name, st in task_states.items():
|
||||
events = (st or {}).get("Events") or []
|
||||
if isinstance(events, list) and events:
|
||||
# Include a few recent events; the latest can be a generic restart message
|
||||
# while the true root cause is slightly earlier (e.g. image pull failure).
|
||||
recent = events[-3:]
|
||||
msgs: List[str] = []
|
||||
for ev in recent:
|
||||
desc = ev.get("DisplayMessage") or ev.get("Message") or ev.get("Type") or ""
|
||||
if desc:
|
||||
msgs.append(desc)
|
||||
if msgs:
|
||||
parts.append(f"{task_name}: " + " | ".join(msgs))
|
||||
return "; ".join(parts)
|
||||
|
||||
def _alloc_events_lower(detail: Dict[str, Any]) -> str:
|
||||
task_states = detail.get("TaskStates") or {}
|
||||
texts: List[str] = []
|
||||
if isinstance(task_states, dict):
|
||||
for _task_name, st in task_states.items():
|
||||
events = (st or {}).get("Events") or []
|
||||
if isinstance(events, list):
|
||||
for ev in events[-10:]:
|
||||
desc = ev.get("DisplayMessage") or ev.get("Message") or ev.get("Type") or ""
|
||||
if desc:
|
||||
texts.append(desc)
|
||||
return " ".join(texts).lower()
|
||||
|
||||
while time.time() - start < timeout:
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
|
||||
healthy_count = 0
|
||||
for alloc in allocs:
|
||||
if alloc.status == AllocationStatus.RUNNING and alloc.http_address:
|
||||
if await self.executor.health_check(alloc.http_address):
|
||||
healthy_count += 1
|
||||
|
||||
# Fast-fail on obvious driver/image errors to avoid waiting out the full timeout.
|
||||
if alloc.id:
|
||||
detail = await self.nomad.get_allocation(alloc.id)
|
||||
if isinstance(detail, dict):
|
||||
summary = _summarize_alloc_detail(detail)
|
||||
lowered = _alloc_events_lower(detail) or summary.lower()
|
||||
if "failed to pull" in lowered or "pull access denied" in lowered:
|
||||
raise RuntimeError(
|
||||
"Nomad allocation failed to start due to a Docker image pull error. "
|
||||
f"Allocation {alloc.id[:8]}: {summary}\n"
|
||||
"If you're using a local image tag (e.g. `atropos-sandbox:local`) on macOS, "
|
||||
"make sure the image is loaded into Docker, e.g.:\n"
|
||||
" docker buildx build --load -t atropos-sandbox:local -f Hermes-Agent/atropos/Dockerfile Hermes-Agent/atropos"
|
||||
)
|
||||
if "exceeded allowed attempts" in lowered:
|
||||
raise RuntimeError(
|
||||
"Nomad allocation is crash-looping and has entered restart backoff. "
|
||||
f"Allocation {alloc.id[:8]}: {summary}\n"
|
||||
"Inspect logs with:\n"
|
||||
f" nomad alloc logs -stderr -task sandbox-server {alloc.id}\n"
|
||||
"Common causes include: missing local Docker image tag, container entrypoint error, "
|
||||
"or sandbox-server startup failure."
|
||||
)
|
||||
|
||||
if healthy_count >= min_count:
|
||||
return
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# Timed out: include allocation status detail to help debugging.
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
alloc_lines: List[str] = []
|
||||
for alloc in allocs[:10]:
|
||||
addr = alloc.http_address or "-"
|
||||
line = f"{alloc.id[:8]} status={alloc.status.value} http={addr}"
|
||||
detail = await self.nomad.get_allocation(alloc.id)
|
||||
if isinstance(detail, dict):
|
||||
summary = _summarize_alloc_detail(detail)
|
||||
if summary:
|
||||
line += f" detail={summary}"
|
||||
alloc_lines.append(line)
|
||||
|
||||
hint = (
|
||||
"Timed out waiting for healthy sandbox allocations.\n"
|
||||
f"Job: {self.config.job_id}, desired_healthy: {min_count}\n"
|
||||
"Allocations:\n - " + "\n - ".join(alloc_lines)
|
||||
)
|
||||
raise RuntimeError(hint)
|
||||
|
||||
async def _try_scale_up(self) -> bool:
|
||||
"""Attempt to scale up the job."""
|
||||
import time
|
||||
|
||||
async with self._scale_lock:
|
||||
# Check cooldown
|
||||
if time.time() - self._last_scale_time < self.config.scale_cooldown:
|
||||
return False
|
||||
|
||||
# Check max containers
|
||||
status = await self.nomad.get_job_status(self.config.job_id)
|
||||
if status is None:
|
||||
return False
|
||||
|
||||
current_count = status.count
|
||||
if current_count >= self.config.max_containers:
|
||||
logger.warning(f"Cannot scale up: already at max ({self.config.max_containers})")
|
||||
return False
|
||||
|
||||
# Scale up
|
||||
new_count = min(current_count + 1, self.config.max_containers)
|
||||
logger.info(f"Scaling up from {current_count} to {new_count} containers")
|
||||
|
||||
scale_resp = await self.nomad.scale_job(
|
||||
self.config.job_id,
|
||||
count=new_count,
|
||||
task_group="sandbox",
|
||||
)
|
||||
|
||||
# Nomad may return non-JSON errors (e.g. plain text) with a status field.
|
||||
if isinstance(scale_resp, dict) and scale_resp.get("status", 200) >= 400:
|
||||
logger.warning(f"Scale request rejected: {scale_resp}")
|
||||
self._last_scale_time = time.time()
|
||||
return False
|
||||
|
||||
self._last_scale_time = time.time()
|
||||
|
||||
# Wait for new allocation in the background so contended acquires can still
|
||||
# make progress (e.g. by grabbing slots released by other trajectories).
|
||||
if self._scale_task is None or self._scale_task.done():
|
||||
self._scale_task = asyncio.create_task(self._wait_for_scale(new_count))
|
||||
|
||||
return True
|
||||
|
||||
async def _wait_for_scale(self, desired_count: int) -> None:
|
||||
try:
|
||||
await self._wait_for_healthy_allocations(desired_count, timeout=60.0)
|
||||
await self._refresh_slots()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to scale up: {e}")
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""Background task to monitor container health."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.config.health_check_interval)
|
||||
await self._refresh_slots()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get pool statistics."""
|
||||
slots_by_state = {}
|
||||
for slot in self._slots.values():
|
||||
state = slot.state.value
|
||||
slots_by_state[state] = slots_by_state.get(state, 0) + 1
|
||||
|
||||
container_count = len({s.alloc_id for s in self._slots.values()}) if self._slots else 0
|
||||
|
||||
return {
|
||||
"total_slots": self.total_slots,
|
||||
"available_slots": self.available_slots,
|
||||
"acquired_slots": self.acquired_slots,
|
||||
"containers": container_count,
|
||||
"slots_by_state": slots_by_state,
|
||||
"started": self._started,
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Slot abstraction for atropos-agent.
|
||||
|
||||
A Slot represents an isolated workspace for a single agent trajectory.
|
||||
Slots are hosted on Nomad allocations and provide workspace isolation
|
||||
via filesystem directories.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
class SlotState(Enum):
|
||||
"""State of a slot in the pool."""
|
||||
AVAILABLE = "available" # Ready to be acquired
|
||||
ACQUIRED = "acquired" # Assigned to a trajectory
|
||||
EXECUTING = "executing" # Currently executing a tool
|
||||
RELEASING = "releasing" # Being released back to pool
|
||||
ERROR = "error" # In error state
|
||||
|
||||
|
||||
@dataclass
|
||||
class Slot:
|
||||
"""
|
||||
An isolated workspace for a single agent trajectory.
|
||||
|
||||
Slots are the unit of scheduling - each trajectory runs in its own slot,
|
||||
with an isolated workspace directory. Multiple slots share a container.
|
||||
|
||||
Attributes:
|
||||
slot_id: Unique identifier for this slot (e.g., "slot_0")
|
||||
alloc_id: Nomad allocation ID hosting this slot
|
||||
container_addr: HTTP address of the sandbox server (e.g., "http://10.0.0.1:8080")
|
||||
workspace_dir: Path to workspace in container (e.g., "/data/slot_0")
|
||||
state: Current state of the slot
|
||||
trajectory_id: ID of trajectory currently using this slot (if acquired)
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
slot_id: str
|
||||
alloc_id: str
|
||||
container_addr: str
|
||||
workspace_dir: str = ""
|
||||
state: SlotState = SlotState.AVAILABLE
|
||||
trajectory_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set default workspace_dir if not provided."""
|
||||
if not self.workspace_dir:
|
||||
self.workspace_dir = f"/data/{self.slot_id}"
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if slot is available for acquisition."""
|
||||
return self.state == SlotState.AVAILABLE
|
||||
|
||||
@property
|
||||
def is_acquired(self) -> bool:
|
||||
"""Check if slot is currently acquired."""
|
||||
return self.state in (SlotState.ACQUIRED, SlotState.EXECUTING)
|
||||
|
||||
def acquire(self, trajectory_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Mark slot as acquired by a trajectory.
|
||||
|
||||
Args:
|
||||
trajectory_id: Optional ID of acquiring trajectory
|
||||
"""
|
||||
if not self.is_available:
|
||||
raise RuntimeError(f"Cannot acquire slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.ACQUIRED
|
||||
self.trajectory_id = trajectory_id or str(uuid.uuid4())
|
||||
|
||||
def start_execution(self, execution_id: Optional[str] = None) -> None:
|
||||
"""Mark slot as executing."""
|
||||
if self.state != SlotState.ACQUIRED:
|
||||
raise RuntimeError(f"Cannot start execution on slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.EXECUTING
|
||||
if execution_id:
|
||||
self.metadata["current_execution_id"] = execution_id
|
||||
|
||||
def end_execution(self) -> None:
|
||||
"""Mark execution as complete, return to acquired state."""
|
||||
if self.state != SlotState.EXECUTING:
|
||||
raise RuntimeError(f"Cannot end execution on slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.ACQUIRED
|
||||
self.metadata.pop("current_execution_id", None)
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release slot back to available state."""
|
||||
self.state = SlotState.AVAILABLE
|
||||
self.trajectory_id = None
|
||||
self.metadata.pop("current_execution_id", None)
|
||||
|
||||
def mark_error(self, error: str) -> None:
|
||||
"""Mark slot as in error state."""
|
||||
self.state = SlotState.ERROR
|
||||
self.metadata["error"] = error
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"slot_id": self.slot_id,
|
||||
"alloc_id": self.alloc_id,
|
||||
"container_addr": self.container_addr,
|
||||
"workspace_dir": self.workspace_dir,
|
||||
"state": self.state.value,
|
||||
"trajectory_id": self.trajectory_id,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Slot":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
slot_id=data["slot_id"],
|
||||
alloc_id=data["alloc_id"],
|
||||
container_addr=data["container_addr"],
|
||||
workspace_dir=data.get("workspace_dir", ""),
|
||||
state=SlotState(data.get("state", "available")),
|
||||
trajectory_id=data.get("trajectory_id"),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Slot({self.slot_id}, state={self.state.value}, alloc={self.alloc_id[:8]}...)"
|
||||
|
||||
|
||||
def create_slots_for_allocation(
|
||||
alloc_id: str,
|
||||
container_addr: str,
|
||||
num_slots: int = 10,
|
||||
) -> list["Slot"]:
|
||||
"""
|
||||
Create slots for a Nomad allocation.
|
||||
|
||||
Args:
|
||||
alloc_id: Nomad allocation ID
|
||||
container_addr: HTTP address of sandbox server
|
||||
num_slots: Number of slots to create
|
||||
|
||||
Returns:
|
||||
List of Slot objects
|
||||
"""
|
||||
slots = []
|
||||
for i in range(num_slots):
|
||||
slot_id = f"slot_{i}"
|
||||
slots.append(Slot(
|
||||
slot_id=slot_id,
|
||||
alloc_id=alloc_id,
|
||||
container_addr=container_addr,
|
||||
workspace_dir=f"/data/{slot_id}",
|
||||
))
|
||||
return slots
|
||||
@@ -0,0 +1,2 @@
|
||||
"""Terminal helpers for stateful sandbox interactions."""
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pyte
|
||||
|
||||
|
||||
class AsciinemaStreamDecoder:
|
||||
def __init__(self, *, default_width: int = 80, default_height: int = 24) -> None:
|
||||
self._default_width = max(1, int(default_width))
|
||||
self._default_height = max(1, int(default_height))
|
||||
self._buffer = ""
|
||||
self._has_header = False
|
||||
self.width = self._default_width
|
||||
self.height = self._default_height
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._buffer = ""
|
||||
self._has_header = False
|
||||
self.width = self._default_width
|
||||
self.height = self._default_height
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
|
||||
def feed(self, chunk: str | bytes) -> None:
|
||||
if not chunk:
|
||||
return
|
||||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode("utf-8", errors="replace")
|
||||
self._buffer += chunk
|
||||
while True:
|
||||
line, sep, rest = self._buffer.partition("\n")
|
||||
if not sep:
|
||||
break
|
||||
self._buffer = rest
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parsed = self._parse_json_line(line)
|
||||
if parsed is None:
|
||||
continue
|
||||
if not self._has_header:
|
||||
if isinstance(parsed, dict):
|
||||
self._init_from_header(parsed)
|
||||
continue
|
||||
if isinstance(parsed, list):
|
||||
self._has_header = True
|
||||
self._apply_event(parsed)
|
||||
continue
|
||||
continue
|
||||
if isinstance(parsed, list):
|
||||
self._apply_event(parsed)
|
||||
|
||||
def render(self) -> str:
|
||||
return "\n".join(self._screen.display)
|
||||
|
||||
def _parse_json_line(self, line: str) -> Any | None:
|
||||
try:
|
||||
return json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _init_from_header(self, header: dict[str, Any]) -> None:
|
||||
width = _coerce_int(
|
||||
header.get("width") or header.get("columns") or header.get("cols"),
|
||||
self._default_width,
|
||||
)
|
||||
height = _coerce_int(
|
||||
header.get("height") or header.get("rows") or header.get("lines"),
|
||||
self._default_height,
|
||||
)
|
||||
self.width = max(1, width)
|
||||
self.height = max(1, height)
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
self._has_header = True
|
||||
|
||||
def _apply_event(self, event: list[Any]) -> None:
|
||||
if len(event) < 2:
|
||||
return
|
||||
event_type = event[1]
|
||||
payload = event[2] if len(event) > 2 else ""
|
||||
if event_type == "o":
|
||||
if isinstance(payload, str):
|
||||
self._stream.feed(payload)
|
||||
elif event_type == "r":
|
||||
width, height = _parse_resize(payload)
|
||||
if width and height:
|
||||
self.width = width
|
||||
self.height = height
|
||||
self._screen.resize(width, height)
|
||||
|
||||
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return int(default)
|
||||
|
||||
|
||||
def _parse_resize(payload: Any) -> tuple[int, int]:
|
||||
if isinstance(payload, str) and "x" in payload:
|
||||
left, right = payload.lower().split("x", 1)
|
||||
return _coerce_int(left, 0), _coerce_int(right, 0)
|
||||
if isinstance(payload, dict):
|
||||
width = _coerce_int(payload.get("width") or payload.get("columns") or payload.get("cols"), 0)
|
||||
height = _coerce_int(payload.get("height") or payload.get("rows") or payload.get("lines"), 0)
|
||||
return width, height
|
||||
if isinstance(payload, list) and len(payload) >= 2:
|
||||
return _coerce_int(payload[0], 0), _coerce_int(payload[1], 0)
|
||||
return 0, 0
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Tool abstractions for atropos-agent.
|
||||
|
||||
Provides base Tool class, ToolCall/ToolResult types, and specialized tools.
|
||||
|
||||
Kept modules:
|
||||
- base.py: ToolSchema, ToolCall, ToolResult, Tool ABC, ToolRegistry
|
||||
- tool_executor.py: Batched execution queue with slot routing
|
||||
- terminal_stateful_tool.py: Persistent terminal sessions
|
||||
- tmux_tool.py: Tmux-based streaming terminal
|
||||
|
||||
Removed (replaced by hermes-agent equivalents):
|
||||
- build_registry.py → model_tools.py + toolsets.py
|
||||
- sandbox_stubs.py → atropos/backends/ execute() methods
|
||||
- hermes_external_tools.py → environments/agent_loop.py handle_function_call()
|
||||
- toolset_resolver.py → toolsets.py
|
||||
"""
|
||||
|
||||
from .base import Tool, ToolCall, ToolRegistry, ToolResult, ToolSchema
|
||||
from .terminal_stateful_tool import TerminalStatefulTool
|
||||
from .tmux_tool import TmuxTool
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"ToolCall",
|
||||
"ToolRegistry",
|
||||
"ToolResult",
|
||||
"ToolSchema",
|
||||
"TerminalStatefulTool",
|
||||
"TmuxTool",
|
||||
]
|
||||
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Base Tool abstraction for atropos-agent.
|
||||
|
||||
Tools follow a simple pattern:
|
||||
1. Define schema (name, description, parameters)
|
||||
2. Implement execute() method
|
||||
3. Return ToolResult with output/error
|
||||
|
||||
Tool calls use Hermes-style XML tags:
|
||||
<tool_call>{"name": "bash", "arguments": {"command": "ls"}}</tool_call>
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSchema:
|
||||
"""JSON Schema for a tool's parameters."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any] = field(default_factory=dict)
|
||||
required: List[str] = field(default_factory=list)
|
||||
external: bool = False # Whether the tool must be executed via an external ToolServer (secret proxy) and not inside the sandbox.
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to OpenAI-compatible function schema."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": self.parameters,
|
||||
"required": self.required,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def to_prompt_description(self) -> str:
|
||||
"""Convert to human-readable description for system prompt."""
|
||||
params_desc = []
|
||||
for name, spec in self.parameters.items():
|
||||
req = "(required)" if name in self.required else "(optional)"
|
||||
desc = spec.get("description", "")
|
||||
param_type = spec.get("type", "string")
|
||||
params_desc.append(f" - {name} ({param_type}) {req}: {desc}")
|
||||
|
||||
params_str = "\n".join(params_desc) if params_desc else " (no parameters)"
|
||||
return f"**{self.name}**: {self.description}\nParameters:\n{params_str}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""A parsed tool call from model output."""
|
||||
|
||||
name: str
|
||||
arguments: Dict[str, Any]
|
||||
raw_text: str = "" # Original XML/JSON text
|
||||
uniq_id: str = field(default_factory=lambda: str(uuid.uuid4())) # Unique tool-call id for traceability/reconstruction.
|
||||
|
||||
@classmethod
|
||||
def parse_from_text(cls, text: str) -> List["ToolCall"]:
|
||||
"""
|
||||
Extract tool calls from text using Hermes-style XML tags.
|
||||
|
||||
Supported formats (STRICT: requires well-formed closing tags):
|
||||
- Hermes JSON wrapper:
|
||||
<tool_call>{"name": "...", "arguments": {...}}</tool_call>
|
||||
- GLM/llama.cpp style:
|
||||
<tool_call>terminal{"command":"ls -la"}</tool_call>
|
||||
"""
|
||||
calls: List["ToolCall"] = []
|
||||
|
||||
if not text:
|
||||
return calls
|
||||
|
||||
def _append_from_payload(*, name: str, arguments: Dict[str, Any], raw: str, uniq_id: Optional[str] = None) -> None:
|
||||
if not isinstance(name, str) or not name:
|
||||
return
|
||||
if not isinstance(arguments, dict):
|
||||
return
|
||||
calls.append(
|
||||
cls(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
raw_text=raw,
|
||||
uniq_id=uniq_id or str(uuid.uuid4()),
|
||||
)
|
||||
)
|
||||
|
||||
# STRICT parsing: only accept well-formed <tool_call>...</tool_call> blocks.
|
||||
pattern = r"<tool_call>\s*(.*?)\s*</tool_call>"
|
||||
for inner in re.findall(pattern, text, re.DOTALL):
|
||||
cleaned = (inner or "").strip()
|
||||
if not cleaned:
|
||||
continue
|
||||
|
||||
# Hermes JSON wrapper.
|
||||
if cleaned.startswith("{"):
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
uniq_id = data.get("uniq_id") or data.get("id") or None
|
||||
_append_from_payload(
|
||||
name=data.get("name", ""),
|
||||
arguments=data.get("arguments", {}),
|
||||
raw=inner,
|
||||
uniq_id=uniq_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# GLM/llama.cpp style: terminal{...}
|
||||
m = re.match(r"^\s*([A-Za-z0-9_.:\\-]+)\s*(\{.*\})\s*$", cleaned, re.DOTALL)
|
||||
if not m:
|
||||
continue
|
||||
name = m.group(1)
|
||||
args_text = m.group(2)
|
||||
try:
|
||||
args = json.loads(args_text)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
_append_from_payload(name=name, arguments=args, raw=inner)
|
||||
|
||||
return calls
|
||||
|
||||
@classmethod
|
||||
def has_tool_call(cls, text: str) -> bool:
|
||||
"""Check if text contains any tool calls."""
|
||||
return bool(re.search(r"<tool_call>", text))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result from executing a tool."""
|
||||
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
uniq_id: Optional[str] = None # Should match ToolCall.uniq_id for async execution tracking.
|
||||
|
||||
def to_xml(self) -> str:
|
||||
"""Format as XML for including in conversation."""
|
||||
data = {
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
}
|
||||
if self.uniq_id:
|
||||
data["uniq_id"] = self.uniq_id
|
||||
if self.error:
|
||||
data["error"] = self.error
|
||||
if self.metadata:
|
||||
data["metadata"] = self.metadata
|
||||
return f"<tool_response>{json.dumps(data)}</tool_response>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
"error": self.error,
|
||||
"metadata": self.metadata,
|
||||
"uniq_id": self.uniq_id,
|
||||
}
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
Abstract base class for tools.
|
||||
|
||||
Subclasses must implement:
|
||||
- schema: ToolSchema describing the tool
|
||||
- execute(): async method that performs the tool action
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def schema(self) -> ToolSchema:
|
||||
"""Return the tool's schema."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Tool name (from schema)."""
|
||||
return self.schema.name
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""
|
||||
Execute the tool with given arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: Tool-specific arguments
|
||||
|
||||
Returns:
|
||||
ToolResult with success/failure and output
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Return whether this tool should be exposed/executable in the current process.
|
||||
|
||||
Tools that depend on optional binaries/services/env vars can override this
|
||||
to avoid advertising a tool that will fail at runtime.
|
||||
"""
|
||||
return True, None
|
||||
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Allow calling tool instance directly."""
|
||||
return await self.execute(**kwargs)
|
||||
|
||||
# Note: This is only wrapping declarations for the external ToolServer (for execution on external process tools), and tools preinstalled in envs
|
||||
class ToolRegistry:
|
||||
"""Registry of available tools."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, Tool] = {}
|
||||
|
||||
def register(self, tool: Tool) -> None:
|
||||
"""Register a tool."""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def get(self, name: str) -> Optional[Tool]:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(self) -> List[Tool]:
|
||||
"""List all registered tools."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def get_schemas(self) -> List[ToolSchema]:
|
||||
"""Get schemas for all registered tools."""
|
||||
return [tool.schema for tool in self._tools.values()]
|
||||
|
||||
def get_prompt_description(self) -> str:
|
||||
"""Generate tool descriptions for system prompt."""
|
||||
descriptions = [tool.schema.to_prompt_description() for tool in self._tools.values()]
|
||||
return "\n\n".join(descriptions)
|
||||
|
||||
def get_prompt_tool_definitions_json(self) -> str:
|
||||
"""
|
||||
Return a Hermes-style JSON list of tool definitions for use inside a `<tools>...</tools>` block.
|
||||
|
||||
Hermes trajectories historically use a simplified schema list:
|
||||
[{"name": ..., "description": ..., "parameters": {...}, "required": null}, ...]
|
||||
"""
|
||||
formatted: List[Dict[str, Any]] = []
|
||||
for tool in self._tools.values():
|
||||
fn = tool.schema.to_dict().get("function", {})
|
||||
formatted.append(
|
||||
{
|
||||
"name": fn.get("name", tool.name),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
# Keep parity with Hermes saved trajectories (required is typically null there).
|
||||
"required": None,
|
||||
}
|
||||
)
|
||||
return json.dumps(formatted, ensure_ascii=False)
|
||||
|
||||
async def execute(self, call: ToolCall) -> ToolResult:
|
||||
"""Execute a tool call."""
|
||||
tool = self.get(call.name)
|
||||
if tool is None:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Unknown tool: {call.name}",
|
||||
uniq_id=call.uniq_id,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await tool.execute(**call.arguments)
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = call.uniq_id
|
||||
return result
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Tool execution error: {str(e)}",
|
||||
uniq_id=call.uniq_id,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FastAPI / transport models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ToolCallPayload(BaseModel):
|
||||
name: str
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict)
|
||||
uniq_id: str
|
||||
|
||||
@classmethod
|
||||
def from_tool_call(cls, call: ToolCall) -> "ToolCallPayload":
|
||||
return cls(name=call.name, arguments=call.arguments, uniq_id=call.uniq_id)
|
||||
|
||||
def to_tool_call(self) -> ToolCall:
|
||||
return ToolCall(name=self.name, arguments=self.arguments, uniq_id=self.uniq_id)
|
||||
|
||||
|
||||
class ToolResultPayload(BaseModel):
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
uniq_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_tool_result(cls, result: ToolResult) -> "ToolResultPayload":
|
||||
return cls(
|
||||
success=result.success,
|
||||
output=result.output,
|
||||
error=result.error,
|
||||
metadata=result.metadata,
|
||||
uniq_id=result.uniq_id,
|
||||
)
|
||||
|
||||
def to_tool_result(self) -> ToolResult:
|
||||
return ToolResult(
|
||||
success=self.success,
|
||||
output=self.output,
|
||||
error=self.error,
|
||||
metadata=self.metadata,
|
||||
uniq_id=self.uniq_id,
|
||||
)
|
||||
|
||||
|
||||
class ToolExecutorExecuteRequest(BaseModel):
|
||||
trajectory_id: str
|
||||
tool: ToolCallPayload
|
||||
timeout_s: Optional[float] = None
|
||||
|
||||
|
||||
class ToolExecutorReleaseRequest(BaseModel):
|
||||
trajectory_id: str
|
||||
reset_workspace: bool = False
|
||||
|
||||
|
||||
class ToolServerExecuteRequest(BaseModel):
|
||||
trajectory_id: Optional[str] = None
|
||||
tool: ToolCallPayload
|
||||
timeout_s: Optional[float] = None
|
||||
# Optional sandbox context for tools that need workspace artifacts.
|
||||
# This is set by ToolExecutor and is NOT model-controlled.
|
||||
slot_id: Optional[str] = None
|
||||
container_addr: Optional[str] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Artifact transport models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ArtifactReadRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str
|
||||
encoding: Literal["text", "base64"] = "text"
|
||||
max_bytes: Optional[int] = None
|
||||
include_sha256: bool = False
|
||||
|
||||
|
||||
class ArtifactReadResponsePayload(BaseModel):
|
||||
success: bool
|
||||
content: str = ""
|
||||
error: str = ""
|
||||
encoding: str = "text"
|
||||
truncated: bool = False
|
||||
bytes: int = 0
|
||||
file_size: Optional[int] = None
|
||||
path: str = ""
|
||||
mime: Optional[str] = None
|
||||
sha256: Optional[str] = None
|
||||
|
||||
|
||||
class ArtifactListRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str = "."
|
||||
recursive: bool = False
|
||||
max_entries: Optional[int] = None
|
||||
|
||||
|
||||
class ArtifactListEntryPayload(BaseModel):
|
||||
path: str
|
||||
is_dir: bool
|
||||
size: int
|
||||
mtime: float
|
||||
|
||||
|
||||
class ArtifactListResponsePayload(BaseModel):
|
||||
success: bool
|
||||
entries: List[ArtifactListEntryPayload] = Field(default_factory=list)
|
||||
truncated: bool = False
|
||||
error: str = ""
|
||||
|
||||
|
||||
class ArtifactArchiveRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str = "."
|
||||
format: Literal["tar.gz", "tgz"] = "tar.gz"
|
||||
max_bytes: Optional[int] = None
|
||||
max_entries: Optional[int] = None
|
||||
|
||||
|
||||
class ArtifactArchiveResponsePayload(BaseModel):
|
||||
success: bool
|
||||
content: str = ""
|
||||
error: str = ""
|
||||
encoding: str = "base64"
|
||||
format: str = "tar.gz"
|
||||
bytes: int = 0
|
||||
entry_count: int = 0
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Stateful terminal tool schema.
|
||||
|
||||
This is a sandbox tool that routes to the sandbox server as `bash_stateful`
|
||||
via ToolExecutor mapping. It exists to expose an explicit, opt-in terminal
|
||||
primitive suitable for stateful workflows (e.g. tmux sessions / TUIs).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
class TerminalStatefulTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="terminal_stateful",
|
||||
description=(
|
||||
"Execute a command in the sandbox, allowing stateful/background processes to persist "
|
||||
"across tool calls within the same trajectory slot (e.g. tmux sessions). "
|
||||
"Use sparingly; output is still non-interactive."
|
||||
),
|
||||
parameters={
|
||||
"command": {"type": "string", "description": "The command to execute"},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (optional).",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
required=["command"],
|
||||
)
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
async def execute(self, command: str, timeout: Optional[int] = None) -> ToolResult:
|
||||
_ = (command, timeout)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="terminal_stateful must be executed via ToolExecutor inside the sandbox",
|
||||
)
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
tmux tool schema (sandbox).
|
||||
|
||||
This is a sandbox tool that provides basic tmux session control suitable for
|
||||
TUI-style terminal interactions:
|
||||
- send keys (arrow keys, enter, etc.)
|
||||
- capture the current screen buffer
|
||||
|
||||
Execution is routed by ToolExecutor to the sandbox server's `tmux` backend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
class TmuxTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="tmux",
|
||||
description=(
|
||||
"Control a per-trajectory tmux session inside the sandbox (stateful terminal). "
|
||||
"Use this for TUI-style interactions: send keys and capture the current screen."
|
||||
),
|
||||
parameters={
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform: start | send_keys | stream | stop.",
|
||||
"enum": ["start", "send_keys", "stream", "stop", "capture"],
|
||||
},
|
||||
"keys": {
|
||||
"description": "Keys to send (string or list of strings) when action=send_keys.",
|
||||
},
|
||||
"block": {
|
||||
"type": "boolean",
|
||||
"description": "If true, wait for shell command completion (only valid at a shell prompt).",
|
||||
"default": False,
|
||||
},
|
||||
"min_wait_s": {
|
||||
"type": "number",
|
||||
"description": "For non-blocking send_keys, sleep this long after sending keys (seconds).",
|
||||
"default": 0.0,
|
||||
},
|
||||
"max_wait_s": {
|
||||
"type": "number",
|
||||
"description": "For blocking send_keys, max time to wait for completion (seconds).",
|
||||
},
|
||||
"capture_entire": {
|
||||
"type": "boolean",
|
||||
"description": "Deprecated. Streaming is preferred.",
|
||||
"default": False,
|
||||
},
|
||||
"max_bytes": {
|
||||
"type": "integer",
|
||||
"description": "Max bytes to return per stream call.",
|
||||
},
|
||||
"reset": {
|
||||
"type": "boolean",
|
||||
"description": "If true, reset stream offset to the beginning of the asciinema recording.",
|
||||
"default": False,
|
||||
},
|
||||
"pane_width": {
|
||||
"type": "integer",
|
||||
"description": "Pane width for action=start (columns).",
|
||||
"minimum": 20,
|
||||
},
|
||||
"pane_height": {
|
||||
"type": "integer",
|
||||
"description": "Pane height for action=start (rows).",
|
||||
"minimum": 10,
|
||||
},
|
||||
},
|
||||
required=["action"],
|
||||
)
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
async def execute(self, **kwargs: Dict[str, Any]) -> ToolResult:
|
||||
# This tool is intended to be executed via ToolExecutor -> sandbox server.
|
||||
# We keep a safe fallback for non-sandbox contexts.
|
||||
action = str(kwargs.get("action") or "").strip()
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"tmux tool must be executed in the sandbox (got action={action!r})",
|
||||
)
|
||||
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
ToolExecutor - queued, batched tool dispatch for multiplexed agent trajectories.
|
||||
|
||||
This component is responsible for:
|
||||
- Maintaining trajectory -> Slot affinity (workspace continuity)
|
||||
- Batching sandbox tool calls across trajectories to maximize container utilization
|
||||
- Routing external tools (ToolSchema.external=True) to a ToolServer (Phase 4.5)
|
||||
|
||||
For now, only sandbox tools are executed:
|
||||
- bash
|
||||
- read_file
|
||||
- write_file
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import (
|
||||
ArtifactArchiveRequestPayload,
|
||||
ArtifactArchiveResponsePayload,
|
||||
ArtifactListRequestPayload,
|
||||
ArtifactListResponsePayload,
|
||||
ArtifactReadRequestPayload,
|
||||
ArtifactReadResponsePayload,
|
||||
ToolCall,
|
||||
ToolCallPayload,
|
||||
ToolRegistry,
|
||||
ToolResult,
|
||||
ToolResultPayload,
|
||||
ToolServerExecuteRequest,
|
||||
)
|
||||
from ..backends.base import ToolBackend
|
||||
from ..slots import Slot
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecutorConfig:
|
||||
batch_window_ms: int = 20
|
||||
max_batch_size: int = 200
|
||||
allow_network: bool = True
|
||||
require_sandbox: bool = False
|
||||
require_stateful_sandbox: bool = False
|
||||
tool_server_url: Optional[str] = None
|
||||
tool_server_token: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _QueuedToolRequest:
|
||||
trajectory_id: str
|
||||
call: ToolCall
|
||||
timeout_s: Optional[float]
|
||||
future: asyncio.Future
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
backend: ToolBackend,
|
||||
tools: ToolRegistry,
|
||||
config: Optional[ToolExecutorConfig] = None,
|
||||
) -> None:
|
||||
self.backend = backend
|
||||
self.tools = tools
|
||||
self.config = config or ToolExecutorConfig()
|
||||
|
||||
self._queue: asyncio.Queue[Optional[_QueuedToolRequest]] = asyncio.Queue()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._stopping = asyncio.Event()
|
||||
|
||||
self._slots_lock = asyncio.Lock()
|
||||
self._slot_by_trajectory: Dict[str, Slot] = {}
|
||||
|
||||
self._tool_server_client: Optional[httpx.AsyncClient] = None
|
||||
self._tool_server_lock = asyncio.Lock()
|
||||
|
||||
# lightweight stats for status endpoints
|
||||
self.total_requests: int = 0
|
||||
self.total_errors: int = 0
|
||||
self.latencies_s: List[float] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
|
||||
def queue_size(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def close(self) -> None:
|
||||
self._stopping.set()
|
||||
await self._queue.put(None)
|
||||
if self._task:
|
||||
await self._task
|
||||
self._task = None
|
||||
|
||||
client = self._tool_server_client
|
||||
self._tool_server_client = None
|
||||
if client is not None:
|
||||
await client.aclose()
|
||||
|
||||
# Best-effort release any remaining slots.
|
||||
async with self._slots_lock:
|
||||
slots = list(self._slot_by_trajectory.items())
|
||||
self._slot_by_trajectory.clear()
|
||||
|
||||
for _, slot in slots:
|
||||
try:
|
||||
await self.backend.release(slot, reset_workspace=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
trajectory_id: str,
|
||||
call: ToolCall,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> ToolResult:
|
||||
if self._task is None:
|
||||
raise RuntimeError("ToolExecutor not started (call start() first)")
|
||||
|
||||
# Allow tool args to suggest a timeout (Hermes-compatible terminal tool),
|
||||
# but never let the model choose "infinite" timeouts.
|
||||
if timeout_s is None:
|
||||
raw_timeout = call.arguments.get("timeout")
|
||||
if isinstance(raw_timeout, (int, float)):
|
||||
timeout_s = float(raw_timeout)
|
||||
if timeout_s is not None:
|
||||
timeout_s = max(1.0, min(float(timeout_s), 600.0))
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
fut: asyncio.Future = loop.create_future()
|
||||
started = time.perf_counter()
|
||||
await self._queue.put(_QueuedToolRequest(trajectory_id=trajectory_id, call=call, timeout_s=timeout_s, future=fut))
|
||||
try:
|
||||
result: ToolResult = await fut
|
||||
return result
|
||||
finally:
|
||||
self.latencies_s.append(time.perf_counter() - started)
|
||||
|
||||
async def release_trajectory(self, trajectory_id: str, reset_workspace: bool = False) -> None:
|
||||
async with self._slots_lock:
|
||||
slot = self._slot_by_trajectory.pop(trajectory_id, None)
|
||||
|
||||
if slot is not None:
|
||||
await self.backend.release(slot, reset_workspace=reset_workspace)
|
||||
|
||||
async def _get_slot_if_present(self, trajectory_id: str) -> Optional[Slot]:
|
||||
async with self._slots_lock:
|
||||
return self._slot_by_trajectory.get(trajectory_id)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Artifact helpers (optional)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def read_artifact(self, req: ArtifactReadRequestPayload) -> ArtifactReadResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactReadResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.read_artifact(
|
||||
slot,
|
||||
req.path,
|
||||
encoding=req.encoding,
|
||||
max_bytes=req.max_bytes,
|
||||
include_sha256=req.include_sha256,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactReadResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactReadResponsePayload(success=False, error=f"Invalid artifact read response: {e}")
|
||||
|
||||
async def list_artifacts(self, req: ArtifactListRequestPayload) -> ArtifactListResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactListResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.list_artifacts(
|
||||
slot,
|
||||
req.path,
|
||||
recursive=req.recursive,
|
||||
max_entries=req.max_entries,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactListResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactListResponsePayload(success=False, error=f"Invalid artifact list response: {e}")
|
||||
|
||||
async def archive_artifacts(self, req: ArtifactArchiveRequestPayload) -> ArtifactArchiveResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactArchiveResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.archive_artifacts(
|
||||
slot,
|
||||
req.path,
|
||||
archive_format=req.format,
|
||||
max_bytes=req.max_bytes,
|
||||
max_entries=req.max_entries,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactArchiveResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactArchiveResponsePayload(success=False, error=f"Invalid artifact archive response: {e}")
|
||||
|
||||
async def _get_or_acquire_slot(self, trajectory_id: str) -> Slot:
|
||||
async with self._slots_lock:
|
||||
existing = self._slot_by_trajectory.get(trajectory_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
slot = await self.backend.acquire(trajectory_id)
|
||||
|
||||
async with self._slots_lock:
|
||||
existing = self._slot_by_trajectory.get(trajectory_id)
|
||||
if existing is not None:
|
||||
# Another coroutine won the race; return its slot.
|
||||
await self.backend.release(slot, reset_workspace=False)
|
||||
return existing
|
||||
self._slot_by_trajectory[trajectory_id] = slot
|
||||
return slot
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
pending: List[_QueuedToolRequest] = []
|
||||
deadline: Optional[float] = None
|
||||
|
||||
batch_window_s = max(0.0, self.config.batch_window_ms / 1000.0)
|
||||
max_batch = max(1, self.config.max_batch_size)
|
||||
|
||||
while True:
|
||||
if self._stopping.is_set() and self._queue.empty() and not pending:
|
||||
break
|
||||
|
||||
timeout = None
|
||||
if pending and deadline is not None:
|
||||
timeout = max(0.0, deadline - time.perf_counter())
|
||||
|
||||
try:
|
||||
item = await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
||||
if item is None:
|
||||
continue
|
||||
pending.append(item)
|
||||
if len(pending) == 1:
|
||||
deadline = time.perf_counter() + batch_window_s
|
||||
if len(pending) < max_batch:
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
# batch window elapsed
|
||||
pass
|
||||
|
||||
if not pending:
|
||||
deadline = None
|
||||
continue
|
||||
|
||||
batch = pending
|
||||
pending = []
|
||||
deadline = None
|
||||
|
||||
await self._execute_batch(batch)
|
||||
|
||||
async def _get_tool_server_client(self) -> httpx.AsyncClient:
|
||||
url = self.config.tool_server_url
|
||||
if not url:
|
||||
raise RuntimeError("ToolServer not configured")
|
||||
|
||||
if self._tool_server_client is not None:
|
||||
return self._tool_server_client
|
||||
|
||||
async with self._tool_server_lock:
|
||||
if self._tool_server_client is None:
|
||||
self._tool_server_client = httpx.AsyncClient(base_url=url.rstrip("/"))
|
||||
return self._tool_server_client
|
||||
|
||||
def _tool_server_headers(self) -> Dict[str, str]:
|
||||
token = self.config.tool_server_token
|
||||
if not token:
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async def _execute_external(self, req: _QueuedToolRequest) -> ToolResult:
|
||||
client = await self._get_tool_server_client()
|
||||
slot_id: Optional[str] = None
|
||||
container_addr: Optional[str] = None
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is not None:
|
||||
slot_id = slot.slot_id
|
||||
container_addr = slot.container_addr
|
||||
|
||||
payload = ToolServerExecuteRequest(
|
||||
trajectory_id=req.trajectory_id,
|
||||
tool=ToolCallPayload.from_tool_call(req.call),
|
||||
timeout_s=req.timeout_s,
|
||||
slot_id=slot_id,
|
||||
container_addr=container_addr,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/execute",
|
||||
json=payload.model_dump(),
|
||||
headers=self._tool_server_headers(),
|
||||
timeout=req.timeout_s,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
parsed = ToolResultPayload(**data)
|
||||
result = parsed.to_tool_result()
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = req.call.uniq_id
|
||||
return result
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"External tool failed: {e}",
|
||||
uniq_id=req.call.uniq_id,
|
||||
)
|
||||
|
||||
async def _execute_batch(self, batch: List[_QueuedToolRequest]) -> None:
|
||||
# Resolve tool schemas once per request and separate sandbox/external/unknown.
|
||||
sandbox_items: List[_QueuedToolRequest] = []
|
||||
external_items: List[_QueuedToolRequest] = []
|
||||
unknown_items: List[_QueuedToolRequest] = []
|
||||
|
||||
for it in batch:
|
||||
tool = self.tools.get(it.call.name)
|
||||
if tool is None:
|
||||
unknown_items.append(it)
|
||||
continue
|
||||
|
||||
schema = tool.schema
|
||||
if not schema.external:
|
||||
sandbox_items.append(it)
|
||||
else:
|
||||
external_items.append(it)
|
||||
|
||||
for it in unknown_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Unknown tool: {it.call.name}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
|
||||
if external_items:
|
||||
if not self.config.tool_server_url:
|
||||
for it in external_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"External tool not available (ToolServer not configured): {it.call.name}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results = await asyncio.gather(*[self._execute_external(it) for it in external_items])
|
||||
for it, res in zip(external_items, results):
|
||||
self.total_requests += 1
|
||||
if not getattr(res, "success", False):
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(res)
|
||||
|
||||
if not sandbox_items:
|
||||
return
|
||||
|
||||
# Acquire slots for the distinct trajectories in this batch.
|
||||
try:
|
||||
traj_ids = list({it.trajectory_id for it in sandbox_items})
|
||||
slots = await asyncio.gather(*[self._get_or_acquire_slot(tid) for tid in traj_ids])
|
||||
slot_by_traj = dict(zip(traj_ids, slots))
|
||||
except Exception as e:
|
||||
for it in sandbox_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Failed to acquire slot: {e}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Group by timeout so we don't accidentally make short timeouts wait on long ones.
|
||||
by_timeout: Dict[float, List[_QueuedToolRequest]] = {}
|
||||
default_timeout = self.backend.default_timeout_s
|
||||
|
||||
for it in sandbox_items:
|
||||
t = it.timeout_s
|
||||
if t is None:
|
||||
t = default_timeout
|
||||
if t is None:
|
||||
t = 30.0
|
||||
by_timeout.setdefault(float(t), []).append(it)
|
||||
|
||||
for timeout_s, items in by_timeout.items():
|
||||
requests = []
|
||||
dispatched: List[_QueuedToolRequest] = []
|
||||
for it in items:
|
||||
slot = slot_by_traj[it.trajectory_id]
|
||||
tool_name = it.call.name
|
||||
args = dict(it.call.arguments)
|
||||
|
||||
# Hermes compatibility: treat `terminal` as an alias of sandbox `bash`.
|
||||
if tool_name == "terminal":
|
||||
if args.get("background"):
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error="terminal background execution is not supported in sandbox",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
tool_name = "bash"
|
||||
# `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args.
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "terminal_stateful":
|
||||
tool_name = "bash_stateful"
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "tmux":
|
||||
# `tmux` is a sandbox tool backed by the stateful session manager.
|
||||
# Network policy is env-controlled.
|
||||
args.pop("allow_network", None)
|
||||
|
||||
if tool_name == "bash":
|
||||
# Network policy is set by the environment/executor, not by the model.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_sandbox"] = bool(self.config.require_sandbox)
|
||||
# `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args.
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "bash_stateful":
|
||||
# Network policy is set by the environment/executor, not by the model.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args.pop("require_stateful_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox)
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "tmux":
|
||||
# Network policy applies to the underlying stateful session.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args.pop("require_stateful_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox)
|
||||
|
||||
requests.append((slot, tool_name, args))
|
||||
dispatched.append(it)
|
||||
|
||||
results = None
|
||||
try:
|
||||
if not dispatched:
|
||||
continue
|
||||
results = await self.backend.execute_batch(requests, timeout_s=timeout_s)
|
||||
except Exception as e:
|
||||
for it in items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Batch execution failed: {e}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
for it, res in zip(dispatched, results):
|
||||
self.total_requests += 1
|
||||
if not getattr(res, "success", False):
|
||||
self.total_errors += 1
|
||||
tool_result = res.to_tool_result()
|
||||
tool_result.uniq_id = it.call.uniq_id
|
||||
if not it.future.done():
|
||||
it.future.set_result(tool_result)
|
||||
@@ -0,0 +1,415 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Atropos-compatible Hermes agent runner.
|
||||
|
||||
This is a minimal subclass of Hermes-Agent's `AIAgent` that swaps the OpenAI
|
||||
function-calling backend for Atroposlib's `ManagedServer`/`ServerManager` backend
|
||||
and uses Hermes-style XML tool tags:
|
||||
|
||||
- <tool_call>{"name": "...", "arguments": {...}}</tool_call>
|
||||
- <tool_response>{...}</tool_response>
|
||||
|
||||
Tool observations are appended as `role="user"` messages containing one or more
|
||||
`<tool_response>` blocks so they survive common chat templates during tokenization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from model_tools import cleanup_vm, handle_function_call
|
||||
from run_agent import AIAgent
|
||||
|
||||
_TOOL_CALL_RE = re.compile(r"<tool_call>\\s*(.*?)\\s*</tool_call>", re.DOTALL)
|
||||
|
||||
|
||||
ATROPOS_TOOL_SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools.
|
||||
|
||||
## Available Tools
|
||||
<tools>
|
||||
{tool_descriptions}
|
||||
</tools>
|
||||
|
||||
## How to Use Tools
|
||||
To call a tool, output:
|
||||
<tool_call>{{"name": "tool_name", "arguments": {{"arg1": "value1"}}}}</tool_call>
|
||||
|
||||
You may include optional reasoning in <think>...</think> before tool calls.
|
||||
|
||||
After each tool call, you will receive tool results as:
|
||||
<tool_response>{{...}}</tool_response>
|
||||
|
||||
Continue until finished, then provide a final response with no <tool_call> blocks.
|
||||
"""
|
||||
|
||||
|
||||
class AtroposAIAgent(AIAgent):
|
||||
"""
|
||||
Hermes `AIAgent` variant that uses Atroposlib ServerManager/ManagedServer.
|
||||
|
||||
Notes:
|
||||
- The default Hermes `AIAgent` remains unchanged; this class is opt-in.
|
||||
- The underlying server must expose `managed_server(tokenizer=...)` OR be a single
|
||||
APIServer-compatible object usable by Atroposlib's `ManagedServer`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server: Any,
|
||||
tokenizer: Any = None,
|
||||
model: str = "local",
|
||||
max_iterations: int = 10,
|
||||
tool_delay: float = 0.0,
|
||||
enabled_toolsets: Optional[List[str]] = None,
|
||||
disabled_toolsets: Optional[List[str]] = None,
|
||||
save_trajectories: bool = False,
|
||||
verbose_logging: bool = False,
|
||||
quiet_mode: bool = False,
|
||||
ephemeral_system_prompt: Optional[str] = None,
|
||||
log_prefix_chars: int = 100,
|
||||
log_prefix: str = "",
|
||||
session_id: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
# Call parent init mainly to reuse tool selection + trajectory saving utilities.
|
||||
super().__init__(
|
||||
base_url="http://unused",
|
||||
api_key="dummy-key",
|
||||
model=model,
|
||||
max_iterations=max_iterations,
|
||||
tool_delay=tool_delay,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
disabled_toolsets=disabled_toolsets,
|
||||
save_trajectories=save_trajectories,
|
||||
verbose_logging=verbose_logging,
|
||||
quiet_mode=quiet_mode,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||
log_prefix_chars=log_prefix_chars,
|
||||
log_prefix=log_prefix,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@asynccontextmanager
|
||||
async def _managed(self) -> AsyncGenerator[Any, None]:
|
||||
if hasattr(self.server, "managed_server"):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"Using OpenAIServer with managed_server does not allow for state tracking",
|
||||
category=UserWarning,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
yield managed
|
||||
return
|
||||
|
||||
# Fall back to directly wrapping a single server object.
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
managed = ManagedServer(server=self.server, tokenizer=self.tokenizer)
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
managed.reset()
|
||||
|
||||
def _tool_descriptions_text(self) -> str:
|
||||
if not self.tools:
|
||||
return "(no tools available)"
|
||||
|
||||
parts: List[str] = []
|
||||
for tool in self.tools:
|
||||
fn = (tool or {}).get("function", {})
|
||||
name = fn.get("name", "")
|
||||
desc = (fn.get("description") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
if desc:
|
||||
parts.append(f"- {name}: {desc}")
|
||||
else:
|
||||
parts.append(f"- {name}")
|
||||
return "\n".join(parts) if parts else "(no tools available)"
|
||||
|
||||
def _build_system_prompt(self, system_message: Optional[str]) -> Optional[str]:
|
||||
tool_prompt = ATROPOS_TOOL_SYSTEM_PROMPT.format(
|
||||
tool_descriptions=self._tool_descriptions_text()
|
||||
)
|
||||
|
||||
parts: List[str] = []
|
||||
if system_message:
|
||||
parts.append(system_message)
|
||||
if self.ephemeral_system_prompt:
|
||||
parts.append(self.ephemeral_system_prompt)
|
||||
parts.append(tool_prompt)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _parse_tool_calls(self, content: str) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[str]]:
|
||||
"""
|
||||
Returns:
|
||||
(calls, errors)
|
||||
"""
|
||||
calls: List[Tuple[str, Dict[str, Any]]] = []
|
||||
errors: List[str] = []
|
||||
|
||||
for raw in _TOOL_CALL_RE.findall(content or ""):
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
errors.append(f"Invalid JSON inside <tool_call>: {exc}")
|
||||
continue
|
||||
|
||||
name = payload.get("name")
|
||||
args = payload.get("arguments", {})
|
||||
if not isinstance(name, str) or not name:
|
||||
errors.append("Tool call missing 'name' string")
|
||||
continue
|
||||
if not isinstance(args, dict):
|
||||
errors.append("Tool call 'arguments' must be an object")
|
||||
continue
|
||||
|
||||
calls.append((name, args))
|
||||
|
||||
return calls, errors
|
||||
|
||||
async def run_conversation_async(
|
||||
self,
|
||||
user_message: str,
|
||||
system_message: Optional[str] = None,
|
||||
conversation_history: Optional[List[Dict[str, Any]]] = None,
|
||||
task_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
import uuid
|
||||
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
messages: List[Dict[str, Any]] = conversation_history.copy() if conversation_history else []
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
active_system_prompt = self._build_system_prompt(system_message)
|
||||
|
||||
api_call_count = 0
|
||||
final_response: Optional[str] = None
|
||||
managed_state: Optional[Dict[str, Any]] = None
|
||||
completed = False
|
||||
|
||||
try:
|
||||
async with self._managed() as managed:
|
||||
while api_call_count < self.max_iterations:
|
||||
api_call_count += 1
|
||||
|
||||
api_messages = messages.copy()
|
||||
if active_system_prompt:
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
chat_kwargs: Dict[str, Any] = {"messages": api_messages, "n": 1}
|
||||
if self.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.max_tokens
|
||||
if self.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.temperature
|
||||
|
||||
# Prefer OpenAI tool calling when supported by the backend:
|
||||
# - Many providers normalize Hermes-style <tool_call> tags into tool_calls when `tools` is provided.
|
||||
# - ManagedServer (atroposlib) does prompt->completion conversion and does not support `tools`.
|
||||
# Only pass `tools` when we're calling an OpenAI-compatible chat endpoint directly.
|
||||
tool_schemas = self.tools if self.tools else None
|
||||
managed_cls = type(managed).__name__
|
||||
if tool_schemas and managed_cls != "ManagedServer":
|
||||
chat_kwargs["tools"] = tool_schemas
|
||||
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_REQUEST") == "1":
|
||||
meta = {
|
||||
"managed_type": managed_cls,
|
||||
"model": getattr(getattr(managed, "config", None), "model_name", self.model),
|
||||
"base_url": getattr(getattr(managed, "config", None), "base_url", None),
|
||||
"kwargs": chat_kwargs,
|
||||
}
|
||||
# Avoid dumping megabytes of data accidentally.
|
||||
# (Messages can be large; this is still "full" but bounded.)
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_REQUEST ===", flush=True)
|
||||
print(json.dumps(meta, ensure_ascii=False, indent=2)[:200_000], flush=True)
|
||||
|
||||
response = await managed.chat_completion(**chat_kwargs)
|
||||
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_RESPONSE") == "1":
|
||||
try:
|
||||
dumped = response.model_dump() # openai pydantic model
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_RESPONSE: ChatCompletion (raw) ===", flush=True)
|
||||
print(json.dumps(dumped, ensure_ascii=False, indent=2), flush=True)
|
||||
|
||||
if hasattr(managed, "get_state"):
|
||||
managed_state = managed.get_state()
|
||||
|
||||
msg = response.choices[0].message
|
||||
assistant_content = (msg.content or "")
|
||||
msg_reasoning = getattr(msg, "reasoning", None)
|
||||
|
||||
# Use tool_calls if the backend provides them (preferred).
|
||||
structured_tool_calls = getattr(msg, "tool_calls", None)
|
||||
|
||||
# If the backend emits content="" but includes useful text in reasoning,
|
||||
# use it for parsing *only if needed* (e.g. tool tags).
|
||||
if assistant_content == "" and isinstance(msg_reasoning, str) and msg_reasoning:
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_RESPONSE") == "1":
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_RESPONSE: message.reasoning present (content empty) ===", flush=True)
|
||||
print(msg_reasoning, flush=True)
|
||||
|
||||
assistant_msg: Dict[str, Any] = {"role": "assistant", "content": assistant_content}
|
||||
if structured_tool_calls:
|
||||
# Preserve tool_calls so the next request is consistent with OpenAI protocol.
|
||||
try:
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
||||
}
|
||||
for tc in structured_tool_calls
|
||||
]
|
||||
except Exception:
|
||||
# Best-effort; keep conversation moving.
|
||||
pass
|
||||
messages.append(assistant_msg)
|
||||
|
||||
# Mode A: OpenAI tool calling (preferred when supported)
|
||||
if structured_tool_calls:
|
||||
for tc in structured_tool_calls:
|
||||
tool_start = time.time()
|
||||
try:
|
||||
tool_args = json.loads(tc.function.arguments or "{}")
|
||||
except Exception:
|
||||
tool_args = {}
|
||||
tool_result = handle_function_call(tc.function.name, tool_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start
|
||||
|
||||
# Keep the raw tool result as tool content (OpenAI protocol expects role=tool).
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
if self.tool_delay and self.tool_delay > 0:
|
||||
await asyncio.sleep(self.tool_delay)
|
||||
|
||||
# Continue loop after tool execution.
|
||||
continue
|
||||
|
||||
# Mode B: Hermes XML tool tags in assistant text (fallback).
|
||||
parse_source = assistant_content or (msg_reasoning or "")
|
||||
tool_calls, parse_errors = self._parse_tool_calls(parse_source)
|
||||
|
||||
if parse_errors and not tool_calls:
|
||||
# Ask the model to retry with valid tool JSON.
|
||||
err_text = "; ".join(parse_errors[:3])
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"<tool_response>{json.dumps({'error': err_text}, ensure_ascii=False)}</tool_response>\n"
|
||||
"The previous <tool_call> blocks were invalid. Please output valid JSON inside <tool_call>."
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls: treat as final answer.
|
||||
final_response = (assistant_content or "").strip()
|
||||
completed = True
|
||||
break
|
||||
|
||||
tool_responses: List[str] = []
|
||||
for tool_name, tool_args in tool_calls:
|
||||
tool_start = time.time()
|
||||
tool_result = handle_function_call(tool_name, tool_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start
|
||||
|
||||
try:
|
||||
parsed = json.loads(tool_result)
|
||||
payload: Any = parsed
|
||||
except Exception:
|
||||
payload = tool_result
|
||||
|
||||
tool_payload = {
|
||||
"name": tool_name,
|
||||
"duration_s": round(tool_duration, 3),
|
||||
"result": payload,
|
||||
}
|
||||
tool_responses.append(
|
||||
f"<tool_response>{json.dumps(tool_payload, ensure_ascii=False)}</tool_response>"
|
||||
)
|
||||
|
||||
if self.tool_delay and self.tool_delay > 0:
|
||||
await asyncio.sleep(self.tool_delay)
|
||||
|
||||
messages.append({"role": "user", "content": "\n".join(tool_responses)})
|
||||
|
||||
if final_response is None:
|
||||
final_response = "I've reached the maximum number of iterations."
|
||||
|
||||
finally:
|
||||
try:
|
||||
cleanup_vm(effective_task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save trajectory using Hermes formatting (optional).
|
||||
self._save_trajectory(messages, user_message, completed=completed)
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": completed,
|
||||
"managed_state": managed_state,
|
||||
"system_prompt": active_system_prompt,
|
||||
"task_id": effective_task_id,
|
||||
}
|
||||
|
||||
def run_conversation(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Sync wrapper for convenience.
|
||||
|
||||
If called from within a running event loop (e.g. prompt_toolkit), this
|
||||
runs the async conversation in a dedicated thread to avoid nested loops.
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(self.run_conversation_async(*args, **kwargs))
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
out: "queue.Queue[object]" = queue.Queue(maxsize=1)
|
||||
|
||||
def runner() -> None:
|
||||
try:
|
||||
out.put(asyncio.run(self.run_conversation_async(*args, **kwargs)))
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
out.put(exc)
|
||||
|
||||
thread = threading.Thread(target=runner, daemon=True)
|
||||
thread.start()
|
||||
|
||||
result = out.get()
|
||||
if isinstance(result, BaseException):
|
||||
raise result
|
||||
return result # type: ignore[return-value]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -28,13 +28,18 @@ os.environ["HERMES_QUIET"] = "1" # Our own modules
|
||||
import yaml
|
||||
|
||||
# prompt_toolkit for fixed input area TUI
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.history import FileHistory
|
||||
from prompt_toolkit.styles import Style as PTStyle
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
from prompt_toolkit.patch_stdout import patch_stdout
|
||||
from prompt_toolkit.application import Application
|
||||
from prompt_toolkit.application import Application, get_app
|
||||
from prompt_toolkit.buffer import Buffer
|
||||
from prompt_toolkit.layout import Layout, HSplit, Window, FormattedTextControl
|
||||
from prompt_toolkit.layout.processors import BeforeInput
|
||||
from prompt_toolkit.widgets import TextArea
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
import asyncio
|
||||
import threading
|
||||
import queue
|
||||
|
||||
@@ -493,8 +498,6 @@ COMMANDS = {
|
||||
"/clear": "Clear screen and reset conversation (fresh start)",
|
||||
"/history": "Show conversation history",
|
||||
"/reset": "Reset conversation only (keep screen)",
|
||||
"/retry": "Retry the last message (resend to agent)",
|
||||
"/undo": "Remove the last user/assistant exchange",
|
||||
"/save": "Save the current conversation",
|
||||
"/config": "Show current configuration",
|
||||
"/cron": "Manage scheduled tasks (list, add, remove)",
|
||||
@@ -505,11 +508,7 @@ COMMANDS = {
|
||||
|
||||
def save_config_value(key_path: str, value: any) -> bool:
|
||||
"""
|
||||
Save a value to the active config file at the specified key path.
|
||||
|
||||
Respects the same lookup order as load_cli_config():
|
||||
1. ~/.hermes/config.yaml (user config - preferred, used if it exists)
|
||||
2. ./cli-config.yaml (project config - fallback)
|
||||
Save a value to cli-config.yaml at the specified key path.
|
||||
|
||||
Args:
|
||||
key_path: Dot-separated path like "agent.system_prompt"
|
||||
@@ -518,15 +517,9 @@ def save_config_value(key_path: str, value: any) -> bool:
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
# Use the same precedence as load_cli_config: user config first, then project config
|
||||
user_config_path = Path.home() / '.hermes' / 'config.yaml'
|
||||
project_config_path = Path(__file__).parent / 'cli-config.yaml'
|
||||
config_path = user_config_path if user_config_path.exists() else project_config_path
|
||||
config_path = Path(__file__).parent / 'cli-config.yaml'
|
||||
|
||||
try:
|
||||
# Ensure parent directory exists (for ~/.hermes/config.yaml on first use)
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load existing config
|
||||
if config_path.exists():
|
||||
with open(config_path, 'r') as f:
|
||||
@@ -638,8 +631,26 @@ class HermesCLI:
|
||||
short_uuid = uuid.uuid4().hex[:6]
|
||||
self.session_id = f"{timestamp_str}_{short_uuid}"
|
||||
|
||||
# History file for persistent input recall across sessions
|
||||
self._history_file = Path.home() / ".hermes_history"
|
||||
# Setup prompt_toolkit session with history
|
||||
self._setup_prompt_session()
|
||||
|
||||
def _setup_prompt_session(self):
|
||||
"""Setup prompt_toolkit session with history and styling."""
|
||||
history_file = Path.home() / ".hermes_history"
|
||||
|
||||
# Custom style for the prompt
|
||||
self.prompt_style = PTStyle.from_dict({
|
||||
'prompt': '#FFD700 bold',
|
||||
'input': '#FFF8DC',
|
||||
})
|
||||
|
||||
# Create prompt session with file history
|
||||
# Note: multiline disabled - Enter submits, use \ at end of line for continuation
|
||||
self.prompt_session = PromptSession(
|
||||
history=FileHistory(str(history_file)),
|
||||
style=self.prompt_style,
|
||||
enable_history_search=True,
|
||||
)
|
||||
|
||||
def _init_agent(self) -> bool:
|
||||
"""
|
||||
@@ -662,7 +673,6 @@ class HermesCLI:
|
||||
quiet_mode=True, # Suppress verbose output for clean CLI
|
||||
ephemeral_system_prompt=self.system_prompt if self.system_prompt else None,
|
||||
session_id=self.session_id, # Pass CLI's session ID to agent
|
||||
platform="cli", # CLI interface — agent uses terminal-friendly formatting
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -921,67 +931,6 @@ class HermesCLI:
|
||||
except Exception as e:
|
||||
print(f"(x_x) Failed to save: {e}")
|
||||
|
||||
def retry_last(self):
|
||||
"""Retry the last user message by removing the last exchange and re-sending.
|
||||
|
||||
Removes the last assistant response (and any tool-call messages) and
|
||||
the last user message, then re-sends that user message to the agent.
|
||||
Returns the message to re-send, or None if there's nothing to retry.
|
||||
"""
|
||||
if not self.conversation_history:
|
||||
print("(._.) No messages to retry.")
|
||||
return None
|
||||
|
||||
# Walk backwards to find the last user message
|
||||
last_user_idx = None
|
||||
for i in range(len(self.conversation_history) - 1, -1, -1):
|
||||
if self.conversation_history[i].get("role") == "user":
|
||||
last_user_idx = i
|
||||
break
|
||||
|
||||
if last_user_idx is None:
|
||||
print("(._.) No user message found to retry.")
|
||||
return None
|
||||
|
||||
# Extract the message text and remove everything from that point forward
|
||||
last_message = self.conversation_history[last_user_idx].get("content", "")
|
||||
self.conversation_history = self.conversation_history[:last_user_idx]
|
||||
|
||||
print(f"(^_^)b Retrying: \"{last_message[:60]}{'...' if len(last_message) > 60 else ''}\"")
|
||||
return last_message
|
||||
|
||||
def undo_last(self):
|
||||
"""Remove the last user/assistant exchange from conversation history.
|
||||
|
||||
Walks backwards and removes all messages from the last user message
|
||||
onward (including assistant responses, tool calls, etc.).
|
||||
"""
|
||||
if not self.conversation_history:
|
||||
print("(._.) No messages to undo.")
|
||||
return
|
||||
|
||||
# Walk backwards to find the last user message
|
||||
last_user_idx = None
|
||||
for i in range(len(self.conversation_history) - 1, -1, -1):
|
||||
if self.conversation_history[i].get("role") == "user":
|
||||
last_user_idx = i
|
||||
break
|
||||
|
||||
if last_user_idx is None:
|
||||
print("(._.) No user message found to undo.")
|
||||
return
|
||||
|
||||
# Count how many messages we're removing
|
||||
removed_count = len(self.conversation_history) - last_user_idx
|
||||
removed_msg = self.conversation_history[last_user_idx].get("content", "")
|
||||
|
||||
# Truncate history to before the last user message
|
||||
self.conversation_history = self.conversation_history[:last_user_idx]
|
||||
|
||||
print(f"(^_^)b Undid {removed_count} message(s). Removed: \"{removed_msg[:60]}{'...' if len(removed_msg) > 60 else ''}\"")
|
||||
remaining = len(self.conversation_history)
|
||||
print(f" {remaining} message(s) remaining in history.")
|
||||
|
||||
def _handle_prompt_command(self, cmd: str):
|
||||
"""Handle the /prompt command to view or set system prompt."""
|
||||
parts = cmd.split(maxsplit=1)
|
||||
@@ -1319,13 +1268,6 @@ class HermesCLI:
|
||||
elif cmd_lower.startswith("/personality"):
|
||||
# Use original case (handler lowercases the personality name itself)
|
||||
self._handle_personality_command(cmd_original)
|
||||
elif cmd_lower == "/retry":
|
||||
retry_msg = self.retry_last()
|
||||
if retry_msg and hasattr(self, '_pending_input'):
|
||||
# Re-queue the message so process_loop sends it to the agent
|
||||
self._pending_input.put(retry_msg)
|
||||
elif cmd_lower == "/undo":
|
||||
self.undo_last()
|
||||
elif cmd_lower == "/save":
|
||||
self.save_conversation()
|
||||
elif cmd_lower.startswith("/cron"):
|
||||
@@ -1360,9 +1302,8 @@ class HermesCLI:
|
||||
# Add user message to history
|
||||
self.conversation_history.append({"role": "user", "content": message})
|
||||
|
||||
# Visual separator after user input (adapt to terminal width, capped for readability)
|
||||
term_width = min(self.console.width, 120)
|
||||
print("─" * term_width, flush=True)
|
||||
# Visual separator after user input
|
||||
print("─" * 60, flush=True)
|
||||
|
||||
try:
|
||||
# Run the conversation with interrupt monitoring
|
||||
@@ -1420,20 +1361,14 @@ class HermesCLI:
|
||||
|
||||
if response:
|
||||
# Use simple print for compatibility with prompt_toolkit's patch_stdout
|
||||
# Adapt box width to terminal (cap at 120 for readability)
|
||||
box_width = min(self.console.width, 120)
|
||||
inner = box_width - 2 # account for border chars ╭/╰ and ╮/╯
|
||||
label = "⚕ Hermes"
|
||||
padding = inner - len(label) - 1 # -1 for the leading space
|
||||
|
||||
print()
|
||||
print("╭" + "─" * inner + "╮")
|
||||
print("│ " + label + " " * max(padding, 0) + "│")
|
||||
print("╰" + "─" * inner + "╯")
|
||||
print("╭" + "─" * 58 + "╮")
|
||||
print("│ ⚕ Hermes" + " " * 49 + "│")
|
||||
print("╰" + "─" * 58 + "╯")
|
||||
print()
|
||||
print(response)
|
||||
print()
|
||||
print("─" * box_width)
|
||||
print("─" * 60)
|
||||
|
||||
# If we have a pending message from interrupt, re-queue it for process_loop
|
||||
# instead of recursing (avoids unbounded recursion from rapid interrupts)
|
||||
@@ -1447,6 +1382,37 @@ class HermesCLI:
|
||||
print(f"Error: {e}")
|
||||
return None
|
||||
|
||||
def get_input(self) -> Optional[str]:
|
||||
"""
|
||||
Get user input using prompt_toolkit.
|
||||
|
||||
Enter submits. For multiline, end line with \\ to continue.
|
||||
|
||||
Returns:
|
||||
The user's input, or None if EOF/interrupt
|
||||
"""
|
||||
try:
|
||||
# Get first line
|
||||
line = self.prompt_session.prompt(
|
||||
HTML('<prompt>❯ </prompt>'),
|
||||
style=self.prompt_style,
|
||||
)
|
||||
|
||||
# Handle multi-line input (lines ending with \)
|
||||
lines = [line]
|
||||
while line.endswith("\\"):
|
||||
lines[-1] = line[:-1] # Remove trailing backslash
|
||||
line = self.prompt_session.prompt(
|
||||
HTML('<prompt> </prompt>'), # Continuation prompt
|
||||
style=self.prompt_style,
|
||||
)
|
||||
lines.append(line)
|
||||
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
"""Run the interactive CLI loop with persistent input at bottom."""
|
||||
self.show_banner()
|
||||
@@ -1460,6 +1426,9 @@ class HermesCLI:
|
||||
self._should_exit = False
|
||||
self._last_ctrl_c_time = 0 # Track double Ctrl+C for force exit
|
||||
|
||||
# Create a persistent input area using prompt_toolkit Application
|
||||
input_buffer = Buffer()
|
||||
|
||||
# Key bindings for the input area
|
||||
kb = KeyBindings()
|
||||
|
||||
@@ -1517,14 +1486,13 @@ class HermesCLI:
|
||||
self._should_exit = True
|
||||
event.app.exit()
|
||||
|
||||
# Create the input area widget with persistent history across sessions
|
||||
# Create the input area widget
|
||||
input_area = TextArea(
|
||||
height=1,
|
||||
prompt='❯ ',
|
||||
style='class:input-area',
|
||||
multiline=False,
|
||||
wrap_lines=False,
|
||||
history=FileHistory(str(self._history_file)),
|
||||
)
|
||||
|
||||
# Create a status line that shows when agent is working
|
||||
@@ -1577,7 +1545,6 @@ class HermesCLI:
|
||||
|
||||
# Check for commands
|
||||
if user_input.startswith("/"):
|
||||
print(f"\n⚙️ {user_input}")
|
||||
if not self.process_command(user_input):
|
||||
self._should_exit = True
|
||||
# Schedule app exit
|
||||
@@ -1589,9 +1556,6 @@ class HermesCLI:
|
||||
self._agent_running = True
|
||||
app.invalidate() # Refresh status line
|
||||
|
||||
# Echo the user's input so it stays visible in scrollback
|
||||
print(f"\n💬 You: {user_input}")
|
||||
|
||||
try:
|
||||
self.chat(user_input)
|
||||
finally:
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
# Modal Backend
|
||||
|
||||
Hermes Agent uses [Modal](https://modal.com) for scalable, isolated cloud execution environments. There are two Modal integrations:
|
||||
|
||||
1. **Terminal Tool** (`tools/terminal_tool.py`) - For CLI/agent command execution
|
||||
2. **Atropos Backend** (`atropos/backends/modal_backend.py`) - For batch RL training workloads
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Terminal Tool (CLI/Agent)
|
||||
|
||||
The terminal tool provides a simple interface for executing commands in Modal sandboxes.
|
||||
|
||||
### Configuration
|
||||
|
||||
Set environment variables:
|
||||
|
||||
```bash
|
||||
export TERMINAL_ENV=modal
|
||||
export TERMINAL_MODAL_IMAGE=python:3.11
|
||||
export TERMINAL_MODAL_APP_NAME=hermes-sandbox
|
||||
```
|
||||
|
||||
Or use a YAML config file (`modal_profiles.yaml`):
|
||||
|
||||
```yaml
|
||||
profiles:
|
||||
default:
|
||||
image: python:3.11
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 1
|
||||
max_pool: 5
|
||||
idle_timeout: 120
|
||||
|
||||
gpu:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: T4
|
||||
memory: 16384
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
```
|
||||
|
||||
### Features
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| **Sandbox Pool** | Pre-warmed sandboxes for low latency |
|
||||
| **Auto-scaling** | Grows/shrinks pool based on demand |
|
||||
| **Idle Timeout** | Sandboxes auto-terminate when unused |
|
||||
| **Profile Selection** | Different configs for different workloads |
|
||||
| **Credential Injection** | `modal.Secret` integration |
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from tools.terminal_tool import terminal_tool
|
||||
|
||||
# Simple command
|
||||
output = terminal_tool("echo hello", task_id="my-task")
|
||||
|
||||
# With profile selection
|
||||
output = terminal_tool("python train.py", task_id="training", profile="gpu")
|
||||
|
||||
# Cleanup when done
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
cleanup_vm("my-task")
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
_ModalPoolManager (singleton)
|
||||
├── "default" pool → [sandbox-0, sandbox-1, ...]
|
||||
└── "gpu" pool → [sandbox-0, ...]
|
||||
|
||||
Each pool:
|
||||
- Maintains min_pool warm sandboxes
|
||||
- Scales up to max_pool on demand
|
||||
- Background thread scales down idle sandboxes
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Atropos Backend (RL Training)
|
||||
|
||||
The Atropos backend is designed for high-throughput batch execution during reinforcement learning training.
|
||||
|
||||
### Key Concept: Slot-based Multiplexing
|
||||
|
||||
Instead of one sandbox per trajectory, multiple trajectories share sandboxes via **slots**:
|
||||
|
||||
```
|
||||
Sandbox (1 container)
|
||||
├── Slot 0 → Trajectory A (workspace: /data/slot_0)
|
||||
├── Slot 1 → Trajectory B (workspace: /data/slot_1)
|
||||
└── Slot 2 → Trajectory C (workspace: /data/slot_2)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Fewer containers = lower cost
|
||||
- Shared warm-up time
|
||||
- Better GPU utilization
|
||||
|
||||
### Configuration
|
||||
|
||||
```python
|
||||
from atropos.backends.modal_backend import ModalSandboxConfig, ModalToolBackend
|
||||
|
||||
config = ModalSandboxConfig(
|
||||
name="default",
|
||||
image="python:3.11",
|
||||
cpu=1.0,
|
||||
memory=2048,
|
||||
slots_per_sandbox=10, # 10 trajectories per container
|
||||
min_sandboxes=1,
|
||||
max_sandboxes=5,
|
||||
)
|
||||
|
||||
backend = ModalToolBackend(config.with_app_name("my-training"))
|
||||
```
|
||||
|
||||
### Multi-Profile Support
|
||||
|
||||
Different trajectory types can request different resources:
|
||||
|
||||
```python
|
||||
backend = ModalToolBackend.with_profiles(
|
||||
app_name="rl-training",
|
||||
profiles={
|
||||
"default": ModalSandboxConfig(
|
||||
name="default",
|
||||
cpu=1.0,
|
||||
memory=2048,
|
||||
),
|
||||
"pytorch-gpu": ModalSandboxConfig(
|
||||
name="pytorch-gpu",
|
||||
image="pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime",
|
||||
gpu="T4",
|
||||
memory=16384,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# CPU task
|
||||
slot1 = await backend.acquire("traj-1", profile="default")
|
||||
|
||||
# GPU task
|
||||
slot2 = await backend.acquire("traj-2", profile="pytorch-gpu")
|
||||
```
|
||||
|
||||
### Batched Execution
|
||||
|
||||
The key optimization - execute many commands in parallel:
|
||||
|
||||
```python
|
||||
# Acquire slots for multiple trajectories
|
||||
slots = [await backend.acquire(f"traj-{i}") for i in range(50)]
|
||||
|
||||
# Execute batch across all slots in parallel
|
||||
results = await backend.execute_batch([
|
||||
(slot, "bash", {"command": "python step.py"})
|
||||
for slot in slots
|
||||
])
|
||||
|
||||
# Release slots
|
||||
for slot in slots:
|
||||
await backend.release(slot)
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
ModalToolBackend
|
||||
└── _ModalMultiProfileManager
|
||||
├── "default" → _ModalSandboxPool
|
||||
│ ├── Sandbox 0 (slots 0-9)
|
||||
│ └── Sandbox 1 (slots 0-9)
|
||||
│
|
||||
└── "pytorch-gpu" → _ModalSandboxPool
|
||||
└── Sandbox 0 (slots 0-9)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Credentials
|
||||
|
||||
Inject secrets securely using Modal's secret management:
|
||||
|
||||
```bash
|
||||
# Create secret in Modal dashboard or CLI
|
||||
modal secret create my-api-key API_KEY=sk-xxx
|
||||
```
|
||||
|
||||
```python
|
||||
# Reference in config
|
||||
config = ModalSandboxConfig(
|
||||
secrets=["my-api-key"], # Modal secret names
|
||||
env_vars={"DEBUG": "1"}, # Additional env vars
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Modal package not installed"
|
||||
```bash
|
||||
pip install modal
|
||||
modal token new # Authenticate
|
||||
```
|
||||
|
||||
### "Sandbox creation failed"
|
||||
- Check Modal dashboard for quota limits
|
||||
- Verify image exists and is accessible
|
||||
- Check secret names are correct
|
||||
|
||||
### Shutdown errors
|
||||
These are harmless warnings during Python interpreter shutdown:
|
||||
```
|
||||
[Modal] Error terminating ...: cannot schedule new futures after interpreter shutdown
|
||||
```
|
||||
|
||||
The sandboxes will auto-terminate via Modal's idle_timeout anyway.
|
||||
@@ -1,330 +0,0 @@
|
||||
# Hermes-Agent Atropos Environments
|
||||
|
||||
This directory contains the integration layer between **hermes-agent's** tool-calling capabilities and the **Atropos** RL training framework. It provides everything needed to run agentic LLMs through multi-turn tool-calling loops, score their output with arbitrary reward functions, and feed results into Atropos for training or evaluation.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Atropos Framework
|
||||
┌───────────────────────┐
|
||||
│ BaseEnv │ (atroposlib)
|
||||
│ - Server management │
|
||||
│ - Worker scheduling │
|
||||
│ - Wandb logging │
|
||||
│ - CLI (serve/process/ │
|
||||
│ evaluate) │
|
||||
└───────────┬───────────┘
|
||||
│ inherits
|
||||
┌───────────┴───────────┐
|
||||
│ HermesAgentBaseEnv │ hermes_base_env.py
|
||||
│ - Terminal backend │
|
||||
│ - Tool resolution │
|
||||
│ - Agent loop │
|
||||
│ - ToolContext │
|
||||
│ - Async patches │
|
||||
└───────────┬───────────┘
|
||||
│ inherits
|
||||
┌─────────────────┼─────────────────┐
|
||||
│ │ │
|
||||
TerminalTestEnv HermesSweEnv TerminalBench2EvalEnv
|
||||
(stack testing) (SWE training) (TB2 benchmark eval)
|
||||
```
|
||||
|
||||
### Inheritance Chain
|
||||
|
||||
**BaseEnv** (from `atroposlib`) is the Atropos base class. It provides:
|
||||
- Server management (OpenAI-compatible API servers, VLLM, SGLang)
|
||||
- Worker scheduling for parallel rollouts
|
||||
- Wandb integration for metrics and rollout logging
|
||||
- CLI interface with three subcommands: `serve`, `process`, `evaluate`
|
||||
- `evaluate_log()` for saving eval results to JSON + samples.jsonl
|
||||
|
||||
**HermesAgentBaseEnv** (`hermes_base_env.py`) extends BaseEnv with hermes-agent specifics:
|
||||
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, modal, ssh, singularity)
|
||||
- Resolves hermes-agent toolsets via `_resolve_tools_for_group()` (calls `get_tool_definitions()` from `model_tools.py`)
|
||||
- Implements `collect_trajectory()` which runs the full agent loop and computes rewards
|
||||
- Supports two-phase operation (Phase 1: OpenAI server, Phase 2: VLLM ManagedServer)
|
||||
- Applies monkey patches for async-safe tool operation at import time
|
||||
|
||||
Concrete environments inherit from `HermesAgentBaseEnv` and implement:
|
||||
- `setup()` -- Load dataset, initialize state
|
||||
- `get_next_item()` -- Return the next item for rollout
|
||||
- `format_prompt()` -- Convert a dataset item into the user message
|
||||
- `compute_reward()` -- Score the rollout using ToolContext
|
||||
- `evaluate()` -- Periodic evaluation logic
|
||||
|
||||
## Core Components
|
||||
|
||||
### Agent Loop (`agent_loop.py`)
|
||||
|
||||
`HermesAgentLoop` is the reusable multi-turn agent engine. It runs the same pattern as hermes-agent's `run_agent.py`:
|
||||
|
||||
1. Send messages + tools to the API via `server.chat_completion()`
|
||||
2. If the response contains `tool_calls`, execute each one via `handle_function_call()` from `model_tools.py`
|
||||
3. Append tool results to the conversation and go back to step 1
|
||||
4. If the response has no tool_calls, the agent is done
|
||||
|
||||
Tool calls are executed in a thread pool (`run_in_executor`) so backends that use `asyncio.run()` internally (Modal, Docker) don't deadlock inside Atropos's event loop.
|
||||
|
||||
Returns an `AgentResult` containing the full conversation history, turn count, reasoning content per turn, tool errors, and optional ManagedServer state (for Phase 2).
|
||||
|
||||
### Tool Context (`tool_context.py`)
|
||||
|
||||
`ToolContext` is a per-rollout handle that gives reward/verification functions direct access to **all** hermes-agent tools, scoped to the rollout's `task_id`. The same `task_id` means the terminal/browser session is the SAME one the model used during its rollout -- all state (files, processes, browser tabs) is preserved.
|
||||
|
||||
```python
|
||||
async def compute_reward(self, item, result, ctx: ToolContext):
|
||||
# Run tests in the model's terminal sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
if test["exit_code"] == 0:
|
||||
return 1.0
|
||||
|
||||
# Check if a file was created
|
||||
content = ctx.read_file("/workspace/solution.py")
|
||||
if content.get("content"):
|
||||
return 0.5
|
||||
|
||||
# Download files locally for verification (binary-safe)
|
||||
ctx.download_file("/remote/output.bin", "/local/output.bin")
|
||||
|
||||
return 0.0
|
||||
```
|
||||
|
||||
Available methods:
|
||||
- **Terminal**: `terminal(command, timeout)` -- run shell commands
|
||||
- **Files**: `read_file(path)`, `write_file(path, content)`, `search(query, path)`
|
||||
- **Transfers**: `upload_file()`, `upload_dir()`, `download_file()`, `download_dir()` -- binary-safe file transfers between host and sandbox
|
||||
- **Web**: `web_search(query)`, `web_extract(urls)`
|
||||
- **Browser**: `browser_navigate(url)`, `browser_snapshot()`
|
||||
- **Generic**: `call_tool(name, args)` -- call any hermes-agent tool by name
|
||||
- **Cleanup**: `cleanup()` -- release all resources (called automatically after `compute_reward`)
|
||||
|
||||
### Patches (`patches.py`)
|
||||
|
||||
**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., mini-swe-agent's Modal backend via SWE-ReX). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested.
|
||||
|
||||
**Solution**: `patches.py` monkey-patches `SwerexModalEnvironment` to use a dedicated background thread (`_AsyncWorker`) with its own event loop. The calling code sees the same sync interface, but internally the async work happens on a separate thread that doesn't conflict with Atropos's loop.
|
||||
|
||||
What gets patched:
|
||||
- `SwerexModalEnvironment.__init__` -- creates Modal deployment on a background thread
|
||||
- `SwerexModalEnvironment.execute` -- runs commands on the same background thread
|
||||
- `SwerexModalEnvironment.stop` -- stops deployment on the background thread
|
||||
|
||||
The patches are:
|
||||
- **Idempotent** -- calling `apply_patches()` multiple times is safe
|
||||
- **Transparent** -- same interface and behavior, only the internal async execution changes
|
||||
- **Universal** -- works identically in normal CLI use (no running event loop)
|
||||
|
||||
Applied automatically at import time by `hermes_base_env.py`.
|
||||
|
||||
### Tool Call Parsers (`tool_call_parsers/`)
|
||||
|
||||
Client-side parsers that extract structured `tool_calls` from raw model output text. Used in **Phase 2** (VLLM server type) where ManagedServer's `/generate` endpoint returns raw text without tool call parsing.
|
||||
|
||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's `extract_tool_calls()` logic. No VLLM dependency -- only standard library (`re`, `json`, `uuid`) and `openai` types.
|
||||
|
||||
Available parsers:
|
||||
- `hermes` -- Hermes/ChatML `<tool_call>` XML format
|
||||
- `mistral` -- Mistral `[TOOL_CALLS]` format
|
||||
- `llama3_json` -- Llama 3 JSON tool calling
|
||||
- `qwen` -- Qwen tool calling format
|
||||
- `qwen3_coder` -- Qwen3 Coder format
|
||||
- `deepseek_v3` -- DeepSeek V3 format
|
||||
- `deepseek_v3_1` -- DeepSeek V3.1 format
|
||||
- `kimi_k2` -- Kimi K2 format
|
||||
- `longcat` -- Longcat format
|
||||
- `glm45` / `glm47` -- GLM model formats
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse(raw_model_output)
|
||||
```
|
||||
|
||||
In Phase 1 (OpenAI server type), these parsers are not needed -- the server handles tool call parsing natively.
|
||||
|
||||
## Two-Phase Operation
|
||||
|
||||
### Phase 1: OpenAI Server (Evaluation / SFT Data Generation)
|
||||
|
||||
Uses `server.chat_completion()` with `tools=` parameter. The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing natively. Returns `ChatCompletion` objects with structured `tool_calls`.
|
||||
|
||||
- Good for: evaluation, SFT data generation, testing
|
||||
- Run with: `serve` (with `run-api`), `process`, or `evaluate` subcommands
|
||||
- Placeholder tokens are created for the Atropos pipeline
|
||||
|
||||
### Phase 2: VLLM ManagedServer (Full RL Training)
|
||||
|
||||
Uses ManagedServer for exact token IDs + logprobs via `/generate`. Client-side tool call parser (from `tool_call_parsers/`) reconstructs structured `tool_calls` from raw output.
|
||||
|
||||
- Good for: full RL training with GRPO/PPO
|
||||
- Run with: `serve` subcommand
|
||||
- Real tokens, masks, and logprobs flow through the pipeline
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
environments/
|
||||
├── README.md # This file
|
||||
├── __init__.py # Package exports
|
||||
├── hermes_base_env.py # Abstract base (HermesAgentBaseEnv)
|
||||
├── agent_loop.py # Multi-turn agent engine (HermesAgentLoop)
|
||||
├── tool_context.py # Per-rollout tool access for reward functions
|
||||
├── patches.py # Async-safety patches for Modal backend
|
||||
│
|
||||
├── tool_call_parsers/ # Phase 2 client-side parsers
|
||||
│ ├── __init__.py # Registry + base class
|
||||
│ ├── hermes_parser.py
|
||||
│ ├── mistral_parser.py
|
||||
│ ├── llama_parser.py
|
||||
│ ├── qwen_parser.py
|
||||
│ ├── qwen3_coder_parser.py
|
||||
│ ├── deepseek_v3_parser.py
|
||||
│ ├── deepseek_v3_1_parser.py
|
||||
│ ├── kimi_k2_parser.py
|
||||
│ ├── longcat_parser.py
|
||||
│ ├── glm45_parser.py
|
||||
│ └── glm47_parser.py
|
||||
│
|
||||
├── terminal_test_env/ # Stack validation environment
|
||||
│ └── terminal_test_env.py
|
||||
│
|
||||
├── hermes_swe_env/ # SWE-bench style training environment
|
||||
│ └── hermes_swe_env.py
|
||||
│
|
||||
└── benchmarks/ # Evaluation benchmarks
|
||||
└── terminalbench_2/
|
||||
└── terminalbench2_env.py
|
||||
```
|
||||
|
||||
## Concrete Environments
|
||||
|
||||
### TerminalTestEnv (`terminal_test_env/`)
|
||||
|
||||
A self-contained environment with inline tasks (no external dataset needed) for validating the full stack end-to-end. Each task asks the model to create a file at a known path, and the verifier checks the content matches.
|
||||
|
||||
```bash
|
||||
# Serve mode (needs run-api)
|
||||
run-api
|
||||
python environments/terminal_test_env/terminal_test_env.py serve
|
||||
|
||||
# Process mode (no run-api, saves to JSONL)
|
||||
python environments/terminal_test_env/terminal_test_env.py process \
|
||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
||||
```
|
||||
|
||||
### HermesSweEnv (`hermes_swe_env/`)
|
||||
|
||||
SWE-bench style training environment. The model gets a coding task, uses terminal + file + web tools to solve it, and the reward function runs tests in the same Modal sandbox.
|
||||
|
||||
```bash
|
||||
python environments/hermes_swe_env/hermes_swe_env.py serve \
|
||||
--openai.model_name YourModel \
|
||||
--env.dataset_name bigcode/humanevalpack \
|
||||
--env.terminal_backend modal
|
||||
```
|
||||
|
||||
### TerminalBench2EvalEnv (`benchmarks/terminalbench_2/`)
|
||||
|
||||
**Eval-only** environment for the Terminal-Bench 2.0 benchmark (89 tasks). Each task gets a pre-built Docker Hub image, a natural language instruction, and a test suite. The agent uses terminal + file tools to solve the task, then the test suite verifies correctness.
|
||||
|
||||
Follows the standard Atropos eval pattern (like GPQA, MMLU, etc.):
|
||||
- Run via `evaluate` subcommand (no `run-api` needed)
|
||||
- `setup()` loads the dataset, `evaluate()` runs all tasks
|
||||
- `rollout_and_score_eval()` handles per-task agent loop + test verification
|
||||
- Downloads verifier output locally for reliable reward checking (Harbor pattern)
|
||||
|
||||
```bash
|
||||
# Run full benchmark
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6
|
||||
|
||||
# Run subset of tasks
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6 \
|
||||
--env.task_filter fix-git,git-multibranch
|
||||
|
||||
# Skip specific tasks
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6 \
|
||||
--env.skip_tasks heavy-task,slow-task
|
||||
```
|
||||
|
||||
## Creating a New Environment
|
||||
|
||||
### Training Environment
|
||||
|
||||
1. Create a new directory under `environments/`
|
||||
2. Create your env file inheriting from `HermesAgentBaseEnv`
|
||||
3. Implement the four abstract methods + `evaluate()`
|
||||
|
||||
```python
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
class MyEnvConfig(HermesAgentEnvConfig):
|
||||
pass # Add custom fields as needed
|
||||
|
||||
class MyEnv(HermesAgentBaseEnv):
|
||||
name = "my-env"
|
||||
env_config_cls = MyEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls):
|
||||
env_config = MyEnvConfig(
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
terminal_backend="modal",
|
||||
# ... other config
|
||||
)
|
||||
server_configs = [APIServerConfig(...)]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
self.dataset = load_dataset(...)
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item):
|
||||
return item["instruction"]
|
||||
|
||||
async def compute_reward(self, item, result, ctx):
|
||||
# ctx gives you full tool access to the rollout's sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
return 1.0 if test["exit_code"] == 0 else 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# Periodic evaluation logic
|
||||
...
|
||||
|
||||
if __name__ == "__main__":
|
||||
MyEnv.cli()
|
||||
```
|
||||
|
||||
### Eval-Only Environment (Benchmark)
|
||||
|
||||
For eval benchmarks, follow the pattern in `terminalbench2_env.py`:
|
||||
1. Create under `environments/benchmarks/your-benchmark/`
|
||||
2. Inherit from `HermesAgentBaseEnv`
|
||||
3. Set eval-only config: `eval_handling=STOP_TRAIN`, `steps_per_eval=1`, `total_steps=1`
|
||||
4. Stub the training methods (`collect_trajectories`, `score`)
|
||||
5. Implement `rollout_and_score_eval()` and `evaluate()`
|
||||
6. Run with `evaluate` subcommand
|
||||
|
||||
## Key Config Fields
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| `enabled_toolsets` | Which hermes toolsets to enable | `None` (all) |
|
||||
| `disabled_toolsets` | Toolsets to disable | `None` |
|
||||
| `distribution` | Probabilistic toolset distribution name | `None` |
|
||||
| `max_agent_turns` | Max LLM calls per rollout | `30` |
|
||||
| `agent_temperature` | Sampling temperature | `1.0` |
|
||||
| `terminal_backend` | `local`, `docker`, `modal`, `ssh`, `singularity` | `local` |
|
||||
| `system_prompt` | System message for the agent | `None` |
|
||||
| `tool_call_parser` | Parser name for Phase 2 | `hermes` |
|
||||
| `eval_handling` | `STOP_TRAIN`, `LIMIT_TRAIN`, `NONE` | `STOP_TRAIN` |
|
||||
@@ -4,18 +4,15 @@ Hermes-Agent Atropos Environments
|
||||
Provides a layered integration between hermes-agent's tool-calling capabilities
|
||||
and the Atropos RL training framework.
|
||||
|
||||
Core layers:
|
||||
Layers:
|
||||
- agent_loop: Reusable multi-turn agent loop with standard OpenAI-spec tool calling
|
||||
- tool_context: Per-rollout tool access handle for reward/verification functions
|
||||
- hermes_base_env: Abstract base environment (BaseEnv subclass) for Atropos
|
||||
- tool_call_parsers: Client-side tool call parser registry for Phase 2 (VLLM /generate)
|
||||
|
||||
Concrete environments:
|
||||
- terminal_test_env/: Simple file-creation tasks for testing the stack
|
||||
- hermes_swe_env/: SWE-bench style tasks with Modal sandboxes
|
||||
|
||||
Benchmarks (eval-only):
|
||||
- benchmarks/terminalbench_2/: Terminal-Bench 2.0 evaluation
|
||||
- terminal_test_env: Simple file-creation tasks for testing the stack
|
||||
- hermes_swe_env: SWE-bench style tasks with Modal sandboxes
|
||||
"""
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
|
||||
+81
-95
@@ -15,7 +15,6 @@ import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
@@ -25,22 +24,7 @@ from model_tools import handle_function_call
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
# (e.g., mini-swe-agent's modal/docker backends). Running them in a separate
|
||||
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
|
||||
# Size must be large enough for concurrent eval tasks (e.g., 89 TB2 tasks all
|
||||
# making tool calls). Too small = thread pool starvation, tasks queue for minutes.
|
||||
# Resized at runtime by HermesAgentBaseEnv.__init__ via resize_tool_pool().
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=128)
|
||||
|
||||
|
||||
def resize_tool_pool(max_workers: int):
|
||||
"""
|
||||
Replace the global tool executor with a new one of the given size.
|
||||
|
||||
Called by HermesAgentBaseEnv.__init__ based on config.tool_pool_size.
|
||||
Safe to call before any tasks are submitted.
|
||||
"""
|
||||
global _tool_executor
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
logger.info("Tool thread pool resized to %d workers", max_workers)
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -73,11 +57,11 @@ class AgentResult:
|
||||
# Tool errors encountered during the loop
|
||||
tool_errors: List[ToolError] = field(default_factory=list)
|
||||
|
||||
# Tool-call metrics (debugging / optional reward shaping)
|
||||
tool_calls_attempted: int = 0
|
||||
tool_calls_schema_valid: int = 0
|
||||
tool_calls_executed_ok: int = 0
|
||||
tool_calls_exec_error: int = 0
|
||||
# Tool-call metrics (for reward shaping + debugging)
|
||||
tool_calls_attempted: int = 0 # Valid tool name + attempted dispatch
|
||||
tool_calls_schema_valid: int = 0 # Arguments matched schema (no coercion)
|
||||
tool_calls_executed_ok: int = 0 # Tool ran and returned no error
|
||||
tool_calls_exec_error: int = 0 # Unknown tool / exception / tool returned error
|
||||
|
||||
|
||||
def _extract_reasoning_from_message(message) -> Optional[str]:
|
||||
@@ -141,7 +125,6 @@ class HermesAgentLoop:
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
tool_handler=None,
|
||||
max_context_tokens: Optional[int] = None,
|
||||
):
|
||||
@@ -157,9 +140,6 @@ class HermesAgentLoop:
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
||||
Used for OpenRouter provider preferences, transforms, etc.
|
||||
e.g. {"provider": {"ignore": ["DeepInfra"]}}
|
||||
tool_handler: Optional async callable(tool_name, args, task_id) -> str.
|
||||
When provided, used INSTEAD of handle_function_call() for
|
||||
tool dispatch. This allows sandbox backends (Modal, Nomad)
|
||||
@@ -175,12 +155,13 @@ class HermesAgentLoop:
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.extra_body = extra_body
|
||||
self.tool_handler = tool_handler
|
||||
self.max_context_tokens = max_context_tokens
|
||||
|
||||
|
||||
def _truncate_context(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Truncate conversation history to fit within max_context_tokens.
|
||||
"""
|
||||
Truncate conversation history to fit within max_context_tokens.
|
||||
|
||||
Strategy:
|
||||
- Keep system message (index 0) and initial user message (index 1) always
|
||||
@@ -189,9 +170,6 @@ class HermesAgentLoop:
|
||||
- If still too long, drop oldest middle messages entirely
|
||||
|
||||
Uses rough char/4 token estimate (fast, no tokenizer needed).
|
||||
|
||||
NOTE: This function mutates the provided list (it may pop/replace entries).
|
||||
Call it on a copy when you want to preserve the full trajectory.
|
||||
"""
|
||||
if self.max_context_tokens is None:
|
||||
return messages
|
||||
@@ -205,38 +183,42 @@ class HermesAgentLoop:
|
||||
total += 50 * len(m["tool_calls"]) # tool call overhead
|
||||
return total
|
||||
|
||||
if estimate_tokens(messages) <= self.max_context_tokens:
|
||||
est = estimate_tokens(messages)
|
||||
if est <= self.max_context_tokens:
|
||||
return messages
|
||||
|
||||
# Phase 1: Truncate tool result content in middle messages
|
||||
# Keep first 2 and last 6 messages untouched
|
||||
protect_head = 2
|
||||
protect_tail = max(0, min(6, len(messages) - protect_head))
|
||||
middle_start = protect_head
|
||||
middle_end = len(messages) - protect_tail
|
||||
|
||||
# Phase 1: truncate tool outputs in the middle
|
||||
if middle_start < middle_end:
|
||||
# Truncate tool results from oldest first
|
||||
for i in range(middle_start, middle_end):
|
||||
if messages[i].get("role") == "tool":
|
||||
content = messages[i].get("content", "") or ""
|
||||
if len(content) > 200:
|
||||
messages[i] = dict(messages[i])
|
||||
messages[i] = dict(messages[i]) # copy
|
||||
messages[i]["content"] = content[:100] + "\n...[truncated]...\n" + content[-50:]
|
||||
|
||||
if estimate_tokens(messages) <= self.max_context_tokens:
|
||||
est = estimate_tokens(messages)
|
||||
if est <= self.max_context_tokens:
|
||||
logger.debug("Context truncated (phase 1: tool results): %d tokens", est)
|
||||
return messages
|
||||
|
||||
# Phase 2: drop oldest middle messages (try to keep assistant+tool pairs)
|
||||
# Phase 2: Drop oldest middle messages entirely
|
||||
while middle_start < middle_end and estimate_tokens(messages) > self.max_context_tokens:
|
||||
# Remove the oldest middle message
|
||||
# But keep assistant+tool pairs together
|
||||
msg = messages[middle_start]
|
||||
messages.pop(middle_start)
|
||||
middle_end -= 1
|
||||
|
||||
# If we removed an assistant with tool_calls, also remove matching tool responses
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_ids = {
|
||||
tc.get("id") or tc.get("tool_call_id", "")
|
||||
for tc in msg.get("tool_calls", [])
|
||||
if isinstance(tc, dict)
|
||||
}
|
||||
tool_ids = {tc.get("id") or tc.get("tool_call_id", "") for tc in msg.get("tool_calls", []) if isinstance(tc, dict)}
|
||||
# Remove tool responses for those IDs
|
||||
i = middle_start
|
||||
while i < middle_end:
|
||||
if messages[i].get("role") == "tool" and messages[i].get("tool_call_id", "") in tool_ids:
|
||||
@@ -245,37 +227,45 @@ class HermesAgentLoop:
|
||||
else:
|
||||
i += 1
|
||||
|
||||
est = estimate_tokens(messages)
|
||||
logger.info("Context truncated (phase 2: dropped messages): %d estimated tokens, %d messages remaining", est, len(messages))
|
||||
return messages
|
||||
|
||||
def _normalize_tool_args(self, tool_name: str, tool_args_raw: str) -> (Dict[str, Any], bool):
|
||||
"""Normalize tool arguments into a dict.
|
||||
|
||||
Returns: (args_dict, schema_valid)
|
||||
Returns:
|
||||
(args_dict, schema_valid)
|
||||
|
||||
schema_valid is True only when arguments decode directly into a dict
|
||||
(no double-decoding and no coercion/wrapping required).
|
||||
schema_valid is True only when the arguments decode directly into a dict
|
||||
(i.e. no double-decoding and no coercion/wrapping was needed).
|
||||
|
||||
Goal: keep environments robust (never crash on args format drift) while
|
||||
still allowing reward functions to penalize malformed formats if desired.
|
||||
This lets us keep the environment robust (never crash due to args format)
|
||||
while still scoring down malformed tool-call argument formats.
|
||||
"""
|
||||
try:
|
||||
decoded = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
# Not JSON at all — treat as a plain string
|
||||
# Not valid JSON at all. Be robust: treat it as a plain string.
|
||||
# (Some parsers/providers may pass through non-JSON strings.)
|
||||
if tool_name == "terminal":
|
||||
return {"command": tool_args_raw}, False
|
||||
return {"input": tool_args_raw}, False
|
||||
|
||||
# Canonical case: decoded is already a dict
|
||||
if isinstance(decoded, dict):
|
||||
# For terminal tool, require a command key
|
||||
if tool_name == "terminal":
|
||||
cmd = decoded.get("command")
|
||||
if isinstance(cmd, str) and cmd.strip():
|
||||
return decoded, True
|
||||
# Common alternate key
|
||||
if isinstance(decoded.get("input"), str):
|
||||
return {"command": decoded.get("input")}, False
|
||||
return decoded, False
|
||||
return decoded, True
|
||||
|
||||
# Common drift case: decoded is a JSON string of an object
|
||||
if isinstance(decoded, str):
|
||||
s = decoded.strip()
|
||||
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
|
||||
@@ -284,12 +274,17 @@ class HermesAgentLoop:
|
||||
except json.JSONDecodeError:
|
||||
decoded2 = None
|
||||
if isinstance(decoded2, dict):
|
||||
# Terminal tool: ensure command
|
||||
if tool_name == "terminal" and isinstance(decoded2.get("command"), str):
|
||||
return decoded2, False
|
||||
return decoded2, False
|
||||
|
||||
# Plain string (not JSON) — coerce to expected shape
|
||||
if tool_name == "terminal":
|
||||
return {"command": decoded}, False
|
||||
return {"input": decoded}, False
|
||||
|
||||
# Other JSON types (list/number/etc.) — wrap
|
||||
if tool_name == "terminal":
|
||||
return {"command": str(decoded)}, False
|
||||
return {"input": decoded}, False
|
||||
@@ -300,7 +295,12 @@ class HermesAgentLoop:
|
||||
|
||||
Args:
|
||||
messages: Initial conversation messages (system + user).
|
||||
Modified in-place as the conversation progresses.
|
||||
This list is treated as the FULL trajectory and is
|
||||
appended to as the conversation progresses.
|
||||
|
||||
Prompt truncation (to avoid context overflow) is applied
|
||||
on a copy of this list per turn, so we do not lose
|
||||
earlier messages for reward computation/debugging.
|
||||
|
||||
Returns:
|
||||
AgentResult with full conversation history, managed state, and metadata
|
||||
@@ -308,17 +308,16 @@ class HermesAgentLoop:
|
||||
reasoning_per_turn = []
|
||||
tool_errors: List[ToolError] = []
|
||||
|
||||
# Metrics to separate "attempted tool use" from "schema-valid tool use"
|
||||
tool_calls_attempted = 0
|
||||
tool_calls_schema_valid = 0
|
||||
tool_calls_executed_ok = 0
|
||||
tool_calls_exec_error = 0
|
||||
|
||||
import time as _time
|
||||
|
||||
for turn in range(self.max_turns):
|
||||
turn_start = _time.monotonic()
|
||||
|
||||
# Truncate prompt view on a copy (preserve full trajectory in `messages`)
|
||||
# Truncate context if approaching limit.
|
||||
# IMPORTANT: do this on a copy so we keep the full trajectory in `messages`
|
||||
# for reward computation + debugging, while only trimming the prompt view.
|
||||
prompt_messages = self._truncate_context(list(messages))
|
||||
|
||||
# Build the chat_completion kwargs
|
||||
@@ -336,18 +335,11 @@ class HermesAgentLoop:
|
||||
if self.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.max_tokens
|
||||
|
||||
# Inject extra_body for provider-specific params (e.g., OpenRouter
|
||||
# provider preferences like banned/preferred providers, transforms)
|
||||
if self.extra_body:
|
||||
chat_kwargs["extra_body"] = self.extra_body
|
||||
|
||||
# Make the API call -- standard OpenAI spec
|
||||
api_start = _time.monotonic()
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
logger.error("API call failed on turn %d: %s", turn + 1, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
@@ -361,10 +353,8 @@ class HermesAgentLoop:
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
|
||||
if not response or not response.choices:
|
||||
logger.warning("Empty response on turn %d (api=%.1fs)", turn + 1, api_elapsed)
|
||||
logger.warning("Empty response on turn %d", turn + 1)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
@@ -418,7 +408,6 @@ class HermesAgentLoop:
|
||||
|
||||
# Validate tool name
|
||||
if tool_name not in self.valid_tool_names:
|
||||
tool_calls_exec_error += 1
|
||||
tool_result = json.dumps(
|
||||
{
|
||||
"error": f"Unknown tool '{tool_name}'. "
|
||||
@@ -435,28 +424,36 @@ class HermesAgentLoop:
|
||||
"Model called unknown tool '%s' on turn %d",
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
tool_calls_exec_error += 1
|
||||
else:
|
||||
tool_calls_attempted += 1
|
||||
|
||||
# Normalize args into a dict so we never crash due to formatting.
|
||||
# Track schema_valid separately so reward shaping can penalize
|
||||
# non-canonical formats (e.g. stringified JSON).
|
||||
args, schema_valid = self._normalize_tool_args(tool_name, tool_args_raw)
|
||||
if schema_valid:
|
||||
tool_calls_schema_valid += 1
|
||||
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
import os
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
if self.tool_handler:
|
||||
backend = "sandbox"
|
||||
cmd_preview = str(args.get("command", ""))[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
|
||||
tool_submit_time = _time.monotonic()
|
||||
print(f" 🖥️ [{backend}] $ {cmd_preview}")
|
||||
|
||||
if self.tool_handler:
|
||||
tool_result = await self.tool_handler(tool_name, args, self.task_id)
|
||||
# Use custom tool handler (sandbox backend routing)
|
||||
tool_result = await self.tool_handler(
|
||||
tool_name, args, self.task_id
|
||||
)
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that use
|
||||
# asyncio.run() internally (modal, docker) get a clean
|
||||
# event loop instead of deadlocking inside Atropos's loop.
|
||||
# Default: run via hermes-agent's handle_function_call
|
||||
# in a thread pool so backends that use asyncio.run()
|
||||
# internally (modal, docker) get a clean event loop
|
||||
# instead of deadlocking inside Atropos's loop.
|
||||
loop = asyncio.get_event_loop()
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
@@ -464,17 +461,6 @@ class HermesAgentLoop:
|
||||
tool_name, args, task_id=self.task_id
|
||||
),
|
||||
)
|
||||
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_calls_exec_error += 1
|
||||
tool_result = json.dumps(
|
||||
@@ -491,6 +477,7 @@ class HermesAgentLoop:
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
else:
|
||||
# Count tool result errors (if tool returns structured JSON error)
|
||||
tool_err = False
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
@@ -499,6 +486,7 @@ class HermesAgentLoop:
|
||||
if err:
|
||||
tool_err = True
|
||||
|
||||
# Keep existing behavior: treat negative exit_code as tool error
|
||||
exit_code = result_data.get("exit_code")
|
||||
if exit_code is not None and isinstance(exit_code, int) and exit_code < 0:
|
||||
tool_err = True
|
||||
@@ -509,6 +497,7 @@ class HermesAgentLoop:
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Non-JSON tool output — assume ok
|
||||
pass
|
||||
|
||||
if tool_err:
|
||||
@@ -525,11 +514,10 @@ class HermesAgentLoop:
|
||||
}
|
||||
)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs",
|
||||
self.task_id[:8], turn + 1, api_elapsed,
|
||||
len(assistant_msg.tool_calls), turn_elapsed,
|
||||
logger.debug(
|
||||
"Turn %d: %d tool calls executed",
|
||||
turn + 1,
|
||||
len(assistant_msg.tool_calls),
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -542,10 +530,8 @@ class HermesAgentLoop:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
messages.append(msg_dict)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, no tools (finished), turn_total=%.1fs",
|
||||
self.task_id[:8], turn + 1, api_elapsed, turn_elapsed,
|
||||
logger.debug(
|
||||
"Turn %d: model finished naturally (no tool calls)", turn + 1
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
# Terminal-Bench 2.0 Evaluation -- Default Configuration
|
||||
#
|
||||
# Eval-only environment for the TB2 benchmark (89 terminal tasks).
|
||||
# Uses Modal terminal backend for per-task cloud-isolated sandboxes
|
||||
# and OpenRouter for inference.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
# --config environments/benchmarks/terminalbench_2/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
# --config environments/benchmarks/terminalbench_2/default.yaml \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2"
|
||||
test_timeout: 600
|
||||
task_timeout: 1800 # 30 min wall-clock per task, auto-FAIL if exceeded
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "terminal-bench-2"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/terminal-bench-2"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,32 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Terminal-Bench 2.0 Evaluation
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
#
|
||||
# Run a subset:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
||||
# --env.task_filter fix-git,git-multibranch
|
||||
|
||||
mkdir -p logs evals/terminal-bench-2
|
||||
LOG_FILE="logs/terminalbench2_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "Terminal-Bench 2.0 Evaluation"
|
||||
echo "Log: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
export TERMINAL_ENV=modal
|
||||
export TERMINAL_TIMEOUT=300
|
||||
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--config environments/benchmarks/terminalbench_2/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
@@ -1,904 +0,0 @@
|
||||
"""
|
||||
TerminalBench2Env -- Terminal-Bench 2.0 Evaluation Environment
|
||||
|
||||
Evaluates agentic LLMs on challenging terminal tasks from Terminal-Bench 2.0.
|
||||
Each task provides a unique Docker environment (pre-built on Docker Hub), a natural
|
||||
language instruction, and a test suite for verification. The agent uses terminal +
|
||||
file tools to complete the task, then the test suite runs inside the same sandbox.
|
||||
|
||||
This is an eval-only environment (not a training environment). It is designed to
|
||||
be run via the `evaluate` subcommand:
|
||||
|
||||
python environments/terminalbench2_env.py evaluate \\
|
||||
--env.dataset_name NousResearch/terminal-bench-2
|
||||
|
||||
The evaluate flow:
|
||||
1. setup() -- Loads the TB2 dataset from HuggingFace
|
||||
2. evaluate() -- Iterates over all tasks, running each through:
|
||||
a. rollout_and_score_eval() -- Per-task agent loop + test verification
|
||||
- Resolves Docker image (pre-built Hub image or Dockerfile fallback)
|
||||
- Registers per-task Modal sandbox via register_task_env_overrides()
|
||||
- Runs the HermesAgentLoop (terminal + file tools)
|
||||
- Uploads test suite and runs test.sh in the same sandbox
|
||||
- Returns binary pass/fail result
|
||||
b. Aggregates per-task, per-category, and overall pass rates
|
||||
c. Logs results via evaluate_log() and wandb
|
||||
|
||||
Key features:
|
||||
- Per-task Modal sandboxes using pre-built Docker Hub images
|
||||
- Binary reward: 1.0 if all tests pass, 0.0 otherwise
|
||||
- Concurrency-controlled parallel evaluation via asyncio.Semaphore
|
||||
- Per-task, per-category, and aggregate pass rate tracking
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.terminal_tool import (
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
cleanup_vm,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
class TerminalBench2EvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the Terminal-Bench 2.0 evaluation environment.
|
||||
|
||||
Extends HermesAgentEnvConfig with TB2-specific settings for dataset loading,
|
||||
test execution, task filtering, and eval concurrency.
|
||||
"""
|
||||
|
||||
# --- Dataset ---
|
||||
dataset_name: str = Field(
|
||||
default="NousResearch/terminal-bench-2",
|
||||
description="HuggingFace dataset containing TB2 tasks.",
|
||||
)
|
||||
|
||||
# --- Test execution ---
|
||||
test_timeout: int = Field(
|
||||
default=180,
|
||||
description="Timeout in seconds for running the test suite after agent completes.",
|
||||
)
|
||||
|
||||
# --- Image strategy ---
|
||||
force_build: bool = Field(
|
||||
default=False,
|
||||
description="If True, always build from Dockerfile (ignore docker_image). "
|
||||
"Useful for testing custom Dockerfiles.",
|
||||
)
|
||||
|
||||
# --- Task filtering (comma-separated from CLI) ---
|
||||
task_filter: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Comma-separated task names to run (e.g., 'fix-git,git-multibranch'). "
|
||||
"If not set, all tasks are run.",
|
||||
)
|
||||
skip_tasks: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Comma-separated task names to skip on top of the default skip list.",
|
||||
)
|
||||
|
||||
# --- Per-task wall-clock timeout ---
|
||||
task_timeout: int = Field(
|
||||
default=1800,
|
||||
description="Maximum wall-clock seconds per task (agent loop + verification). "
|
||||
"Tasks exceeding this are scored as FAIL. Default 30 minutes.",
|
||||
)
|
||||
|
||||
|
||||
# Tasks that cannot run properly on Modal and are excluded from scoring.
|
||||
MODAL_INCOMPATIBLE_TASKS = {
|
||||
"qemu-startup", # Needs KVM/hardware virtualization
|
||||
"qemu-alpine-ssh", # Needs KVM/hardware virtualization
|
||||
"crack-7z-hash", # Password brute-force -- too slow for cloud sandbox timeouts
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tar extraction helper
|
||||
# =============================================================================
|
||||
|
||||
def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
"""Extract a base64-encoded tar.gz archive into target_dir."""
|
||||
if not b64_data:
|
||||
return
|
||||
raw = base64.b64decode(b64_data)
|
||||
buf = io.BytesIO(raw)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
tar.extractall(path=str(target_dir))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
Terminal-Bench 2.0 evaluation environment (eval-only, no training).
|
||||
|
||||
Inherits from HermesAgentBaseEnv for:
|
||||
- Terminal backend setup (os.environ["TERMINAL_ENV"])
|
||||
- Tool resolution via _resolve_tools_for_group()
|
||||
- Monkey patches for async-safe tool operation
|
||||
- Wandb trajectory formatting
|
||||
|
||||
The evaluate flow (triggered by `environment.py evaluate`):
|
||||
1. setup() -- Load dataset from HuggingFace
|
||||
2. evaluate() -- Run all tasks through rollout_and_score_eval()
|
||||
|
||||
Each task in rollout_and_score_eval():
|
||||
1. Resolve Docker image (pre-built Hub image or Dockerfile fallback)
|
||||
2. Register per-task Modal sandbox override
|
||||
3. Run HermesAgentLoop with terminal + file tools
|
||||
4. Upload test suite and execute test.sh in the same sandbox
|
||||
5. Check /logs/verifier/reward.txt for pass/fail
|
||||
6. Clean up sandbox, overrides, and temp files
|
||||
"""
|
||||
|
||||
name = "terminal-bench-2"
|
||||
env_config_cls = TerminalBench2EvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TerminalBench2EvalConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for Terminal-Bench 2.0 evaluation.
|
||||
|
||||
Uses eval-only settings:
|
||||
- eval_handling=STOP_TRAIN so the eval flow runs cleanly
|
||||
- steps_per_eval=1, total_steps=1 so eval triggers immediately
|
||||
- group_size=1 (one rollout per group, each task is expensive)
|
||||
|
||||
Uses Modal terminal backend (cloud-isolated sandbox per task) and
|
||||
OpenRouter with Claude for inference.
|
||||
"""
|
||||
env_config = TerminalBench2EvalConfig(
|
||||
# Terminal + file tools only (the agent interacts via shell commands)
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
|
||||
# Agent settings -- TB2 tasks are complex, need many turns
|
||||
max_agent_turns=60,
|
||||
max_token_length=16000,
|
||||
agent_temperature=0.6,
|
||||
system_prompt=None,
|
||||
|
||||
# Modal backend for per-task cloud-isolated sandboxes
|
||||
terminal_backend="modal",
|
||||
terminal_timeout=300, # 5 min per command (builds, pip install, etc.)
|
||||
|
||||
# Test execution timeout (TB2 test scripts can install deps like pytest)
|
||||
test_timeout=180,
|
||||
|
||||
# 89 tasks run in parallel, each needs a thread for tool calls
|
||||
tool_pool_size=128,
|
||||
|
||||
# --- Eval-only Atropos settings ---
|
||||
# These settings make the env work as an eval-only environment:
|
||||
# - STOP_TRAIN: pauses training during eval (standard for eval envs)
|
||||
# - steps_per_eval=1, total_steps=1: eval triggers immediately
|
||||
# - group_size=1: one rollout per group (each task is expensive)
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-bench-2",
|
||||
ensure_scores_are_not_same=False, # Binary rewards may all be 0 or 1
|
||||
)
|
||||
|
||||
# OpenRouter with Claude -- API key loaded from .env
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Setup -- load dataset
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Load the Terminal-Bench 2.0 dataset from HuggingFace."""
|
||||
from datasets import load_dataset
|
||||
|
||||
# Auto-set terminal_lifetime to task_timeout + 120s so sandboxes
|
||||
# never get killed during an active task, but still get cleaned up
|
||||
# promptly after the task times out.
|
||||
lifetime = self.config.task_timeout + 120
|
||||
self.config.terminal_lifetime = lifetime
|
||||
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(lifetime)
|
||||
print(f" Terminal lifetime auto-set to {lifetime}s (task_timeout + 120s)")
|
||||
|
||||
print(f"Loading TB2 dataset from: {self.config.dataset_name}")
|
||||
ds = load_dataset(self.config.dataset_name, split="train")
|
||||
|
||||
# Apply task filters (comma-separated strings from CLI)
|
||||
tasks = list(ds)
|
||||
if self.config.task_filter:
|
||||
allowed = {name.strip() for name in self.config.task_filter.split(",")}
|
||||
tasks = [t for t in tasks if t["task_name"] in allowed]
|
||||
print(f" Filtered to {len(tasks)} tasks: {sorted(allowed)}")
|
||||
|
||||
# Skip tasks incompatible with the current backend (e.g., QEMU on Modal)
|
||||
# plus any user-specified skip_tasks
|
||||
skip = set(MODAL_INCOMPATIBLE_TASKS) if self.config.terminal_backend == "modal" else set()
|
||||
if self.config.skip_tasks:
|
||||
skip |= {name.strip() for name in self.config.skip_tasks.split(",")}
|
||||
if skip:
|
||||
before = len(tasks)
|
||||
tasks = [t for t in tasks if t["task_name"] not in skip]
|
||||
skipped = before - len(tasks)
|
||||
if skipped > 0:
|
||||
print(f" Skipped {skipped} incompatible tasks: {sorted(skip & {t['task_name'] for t in ds})}")
|
||||
|
||||
self.all_eval_items = tasks
|
||||
self.iter = 0
|
||||
|
||||
# Build category index for per-category metrics
|
||||
self.category_index: Dict[str, List[int]] = defaultdict(list)
|
||||
for i, task in enumerate(self.all_eval_items):
|
||||
self.category_index[task.get("category", "unknown")].append(i)
|
||||
|
||||
# Reward tracking for wandb logging
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
# Streaming JSONL writer -- saves each task's full conversation
|
||||
# immediately on completion so data is preserved even on Ctrl+C.
|
||||
# Timestamped filename so each run produces a unique file.
|
||||
import datetime
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
self._streaming_file = open(self._streaming_path, "w")
|
||||
self._streaming_lock = __import__("threading").Lock()
|
||||
print(f" Streaming results to: {self._streaming_path}")
|
||||
|
||||
print(f"TB2 ready: {len(self.all_eval_items)} tasks across {len(self.category_index)} categories")
|
||||
for cat, indices in sorted(self.category_index.items()):
|
||||
print(f" {cat}: {len(indices)} tasks")
|
||||
|
||||
def _save_result(self, result: Dict[str, Any]):
|
||||
"""Write a single task result to the streaming JSONL file immediately."""
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
# Training pipeline stubs -- NOT used in eval-only mode
|
||||
# =========================================================================
|
||||
# These satisfy the abstract method requirements from HermesAgentBaseEnv.
|
||||
# The evaluate subcommand calls setup() -> evaluate() directly, bypassing
|
||||
# the training pipeline entirely.
|
||||
|
||||
async def get_next_item(self):
|
||||
"""Return next item (stub -- not used in eval-only mode)."""
|
||||
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
"""Return the task's instruction as the user prompt."""
|
||||
return item["instruction"]
|
||||
|
||||
async def compute_reward(self, item, result, ctx) -> float:
|
||||
"""Compute reward (stub -- actual verification is in rollout_and_score_eval)."""
|
||||
return 0.0
|
||||
|
||||
async def collect_trajectories(self, item):
|
||||
"""Collect trajectories (stub -- not used in eval-only mode)."""
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data):
|
||||
"""Score rollouts (stub -- not used in eval-only mode)."""
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Docker image resolution
|
||||
# =========================================================================
|
||||
|
||||
def _resolve_task_image(
|
||||
self, item: Dict[str, Any], task_name: str
|
||||
) -> Tuple[str, Optional[Path]]:
|
||||
"""
|
||||
Resolve the Docker image for a task, with fallback to Dockerfile.
|
||||
|
||||
Strategy (mirrors Harbor's approach):
|
||||
1. If force_build=True, always build from Dockerfile in environment_tar
|
||||
2. If docker_image is available, use the pre-built Docker Hub image (fast)
|
||||
3. Otherwise, extract Dockerfile from environment_tar and build (slow)
|
||||
|
||||
Returns:
|
||||
(modal_image, temp_dir) -- modal_image is a Docker Hub name or a
|
||||
Dockerfile path. temp_dir is set if we extracted files that need
|
||||
cleanup later.
|
||||
"""
|
||||
docker_image = item.get("docker_image", "")
|
||||
environment_tar = item.get("environment_tar", "")
|
||||
|
||||
# Fast path: use pre-built Docker Hub image
|
||||
if docker_image and not self.config.force_build:
|
||||
logger.info("Task %s: using pre-built image %s", task_name, docker_image)
|
||||
return docker_image, None
|
||||
|
||||
# Slow path: extract Dockerfile from environment_tar and build
|
||||
if environment_tar:
|
||||
task_dir = Path(tempfile.mkdtemp(prefix=f"tb2-{task_name}-"))
|
||||
_extract_base64_tar(environment_tar, task_dir)
|
||||
dockerfile_path = task_dir / "Dockerfile"
|
||||
if dockerfile_path.exists():
|
||||
logger.info(
|
||||
"Task %s: building from Dockerfile (force_build=%s, docker_image=%s)",
|
||||
task_name, self.config.force_build, bool(docker_image),
|
||||
)
|
||||
return str(dockerfile_path), task_dir
|
||||
|
||||
# Neither available -- fall back to Hub image if force_build was True
|
||||
if docker_image:
|
||||
logger.warning(
|
||||
"Task %s: force_build=True but no environment_tar, "
|
||||
"falling back to docker_image %s", task_name, docker_image,
|
||||
)
|
||||
return docker_image, None
|
||||
|
||||
return "", None
|
||||
|
||||
# =========================================================================
|
||||
# Per-task evaluation -- agent loop + test verification
|
||||
# =========================================================================
|
||||
|
||||
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Evaluate a single TB2 task: run the agent loop, then verify with tests.
|
||||
|
||||
This is the core evaluation method. For each task it:
|
||||
1. Resolves the Docker image and registers the Modal sandbox override
|
||||
2. Runs HermesAgentLoop with terminal + file tools
|
||||
3. Uploads the test suite into the sandbox
|
||||
4. Executes test.sh and checks the result
|
||||
5. Cleans up the sandbox and temp files
|
||||
|
||||
Args:
|
||||
eval_item: A single TB2 task dict from the dataset
|
||||
|
||||
Returns:
|
||||
Dict with 'passed' (bool), 'reward' (float), 'task_name' (str),
|
||||
'category' (str), and optional debug info
|
||||
"""
|
||||
task_name = eval_item.get("task_name", "unknown")
|
||||
category = eval_item.get("category", "unknown")
|
||||
task_id = str(uuid.uuid4())
|
||||
task_dir = None # Set if we extract a Dockerfile (needs cleanup)
|
||||
|
||||
from tqdm import tqdm
|
||||
tqdm.write(f" [START] {task_name} (task_id={task_id[:8]})")
|
||||
task_start = time.time()
|
||||
|
||||
try:
|
||||
# --- 1. Resolve Docker image ---
|
||||
modal_image, task_dir = self._resolve_task_image(eval_item, task_name)
|
||||
if not modal_image:
|
||||
logger.error("Task %s: no docker_image or environment_tar, skipping", task_name)
|
||||
return {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": "no_image",
|
||||
}
|
||||
|
||||
# --- 2. Register per-task Modal image override ---
|
||||
register_task_env_overrides(task_id, {"modal_image": modal_image})
|
||||
logger.info(
|
||||
"Task %s: registered image override for task_id %s",
|
||||
task_name, task_id[:8],
|
||||
)
|
||||
|
||||
# --- 3. Resolve tools and build messages ---
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(eval_item)})
|
||||
|
||||
# --- 4. Run agent loop ---
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# --- 5. Verify -- run test suite in the agent's sandbox ---
|
||||
# Skip verification if the agent produced no meaningful output
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Task %s: agent produced no output (turns=%d). Reward=0.",
|
||||
task_name, result.turns_used,
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# Run tests in a thread so the blocking ctx.terminal() calls
|
||||
# don't freeze the entire event loop (which would stall all
|
||||
# other tasks, tqdm updates, and timeout timers).
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
reward = await loop.run_in_executor(
|
||||
None, # default thread pool
|
||||
self._run_tests, eval_item, ctx, task_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Task %s: test verification failed: %s", task_name, e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
passed = reward == 1.0
|
||||
status = "PASS" if passed else "FAIL"
|
||||
elapsed = time.time() - task_start
|
||||
tqdm.write(f" [{status}] {task_name} (turns={result.turns_used}, {elapsed:.0f}s)")
|
||||
logger.info(
|
||||
"Task %s: reward=%.1f, turns=%d, finished=%s",
|
||||
task_name, reward, result.turns_used, result.finished_naturally,
|
||||
)
|
||||
|
||||
out = {
|
||||
"passed": passed,
|
||||
"reward": reward,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"messages": result.messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - task_start
|
||||
logger.error("Task %s: rollout failed: %s", task_name, e, exc_info=True)
|
||||
tqdm.write(f" [ERROR] {task_name}: {e} ({elapsed:.0f}s)")
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": str(e),
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
finally:
|
||||
# --- Cleanup: clear overrides, sandbox, and temp files ---
|
||||
clear_task_env_overrides(task_id)
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
except Exception as e:
|
||||
logger.debug("VM cleanup for %s: %s", task_id[:8], e)
|
||||
if task_dir and task_dir.exists():
|
||||
shutil.rmtree(task_dir, ignore_errors=True)
|
||||
|
||||
def _run_tests(
|
||||
self, item: Dict[str, Any], ctx: ToolContext, task_name: str
|
||||
) -> float:
|
||||
"""
|
||||
Upload and execute the test suite in the agent's sandbox, then
|
||||
download the verifier output locally to read the reward.
|
||||
|
||||
Follows Harbor's verification pattern:
|
||||
1. Upload tests/ directory into the sandbox
|
||||
2. Execute test.sh inside the sandbox
|
||||
3. Download /logs/verifier/ directory to a local temp dir
|
||||
4. Read reward.txt locally with native Python I/O
|
||||
|
||||
Downloading locally avoids issues with the file_read tool on
|
||||
the Modal VM and matches how Harbor handles verification.
|
||||
|
||||
TB2 test scripts (test.sh) typically:
|
||||
1. Install pytest via uv/pip
|
||||
2. Run pytest against the test files in /tests/
|
||||
3. Write results to /logs/verifier/reward.txt
|
||||
|
||||
Args:
|
||||
item: The TB2 task dict (contains tests_tar, test_sh)
|
||||
ctx: ToolContext scoped to this task's sandbox
|
||||
task_name: For logging
|
||||
|
||||
Returns:
|
||||
1.0 if tests pass, 0.0 otherwise
|
||||
"""
|
||||
tests_tar = item.get("tests_tar", "")
|
||||
test_sh = item.get("test_sh", "")
|
||||
|
||||
if not test_sh:
|
||||
logger.warning("Task %s: no test_sh content, reward=0", task_name)
|
||||
return 0.0
|
||||
|
||||
# Create required directories in the sandbox
|
||||
ctx.terminal("mkdir -p /tests /logs/verifier")
|
||||
|
||||
# Upload test files into the sandbox (binary-safe via base64)
|
||||
if tests_tar:
|
||||
tests_temp = Path(tempfile.mkdtemp(prefix=f"tb2-tests-{task_name}-"))
|
||||
try:
|
||||
_extract_base64_tar(tests_tar, tests_temp)
|
||||
ctx.upload_dir(str(tests_temp), "/tests")
|
||||
except Exception as e:
|
||||
logger.warning("Task %s: failed to upload test files: %s", task_name, e)
|
||||
finally:
|
||||
shutil.rmtree(tests_temp, ignore_errors=True)
|
||||
|
||||
# Write the test runner script (test.sh)
|
||||
ctx.write_file("/tests/test.sh", test_sh)
|
||||
ctx.terminal("chmod +x /tests/test.sh")
|
||||
|
||||
# Execute the test suite
|
||||
logger.info(
|
||||
"Task %s: running test suite (timeout=%ds)",
|
||||
task_name, self.config.test_timeout,
|
||||
)
|
||||
test_result = ctx.terminal(
|
||||
"bash /tests/test.sh",
|
||||
timeout=self.config.test_timeout,
|
||||
)
|
||||
|
||||
exit_code = test_result.get("exit_code", -1)
|
||||
output = test_result.get("output", "")
|
||||
|
||||
# Download the verifier output directory locally, then read reward.txt
|
||||
# with native Python I/O. This avoids issues with file_read on the
|
||||
# Modal VM and matches Harbor's verification pattern.
|
||||
reward = 0.0
|
||||
local_verifier_dir = Path(tempfile.mkdtemp(prefix=f"tb2-verifier-{task_name}-"))
|
||||
try:
|
||||
ctx.download_dir("/logs/verifier", str(local_verifier_dir))
|
||||
|
||||
reward_file = local_verifier_dir / "reward.txt"
|
||||
if reward_file.exists() and reward_file.stat().st_size > 0:
|
||||
content = reward_file.read_text().strip()
|
||||
if content == "1":
|
||||
reward = 1.0
|
||||
elif content == "0":
|
||||
reward = 0.0
|
||||
else:
|
||||
# Unexpected content -- try parsing as float
|
||||
try:
|
||||
reward = float(content)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Task %s: reward.txt content unexpected (%r), "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, content, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
else:
|
||||
# reward.txt not written -- fall back to exit code
|
||||
logger.warning(
|
||||
"Task %s: reward.txt not found after download, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Task %s: failed to download verifier dir: %s, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, e, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
finally:
|
||||
shutil.rmtree(local_verifier_dir, ignore_errors=True)
|
||||
|
||||
# Log test output for debugging failures
|
||||
if reward == 0.0:
|
||||
output_preview = output[-500:] if output else "(no output)"
|
||||
logger.info(
|
||||
"Task %s: FAIL (exit_code=%d)\n%s",
|
||||
task_name, exit_code, output_preview,
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# =========================================================================
|
||||
# Evaluate -- main entry point for the eval subcommand
|
||||
# =========================================================================
|
||||
|
||||
async def _eval_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Wrap rollout_and_score_eval with a per-task wall-clock timeout.
|
||||
|
||||
If the task exceeds task_timeout seconds, it's automatically scored
|
||||
as FAIL. This prevents any single task from hanging indefinitely.
|
||||
"""
|
||||
task_name = item.get("task_name", "unknown")
|
||||
category = item.get("category", "unknown")
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
timeout=self.config.task_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
elapsed = self.config.task_timeout
|
||||
tqdm.write(f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)")
|
||||
logger.error("Task %s: wall-clock timeout after %ds", task_name, elapsed)
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": f"timeout ({elapsed}s)",
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run Terminal-Bench 2.0 evaluation over all tasks.
|
||||
|
||||
This is the main entry point when invoked via:
|
||||
python environments/terminalbench2_env.py evaluate
|
||||
|
||||
Runs all tasks through rollout_and_score_eval() via asyncio.gather()
|
||||
(same pattern as GPQA and other Atropos eval envs). Each task is
|
||||
wrapped with a wall-clock timeout so hung tasks auto-fail.
|
||||
|
||||
Suppresses noisy Modal/terminal output (HERMES_QUIET) so the tqdm
|
||||
bar stays visible.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Route all logging through tqdm.write() so the progress bar stays
|
||||
# pinned at the bottom while log lines scroll above it.
|
||||
from tqdm import tqdm
|
||||
|
||||
class _TqdmHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
tqdm.write(self.format(record))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
))
|
||||
root = logging.getLogger()
|
||||
root.handlers = [handler] # Replace any existing handlers
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
# Silence noisy third-party loggers that flood the output
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request
|
||||
logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries
|
||||
logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment
|
||||
logging.getLogger("rex_image_builder").setLevel(logging.WARNING) # Image builds
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting Terminal-Bench 2.0 Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Total tasks: {len(self.all_eval_items)}")
|
||||
print(f" Max agent turns: {self.config.max_agent_turns}")
|
||||
print(f" Task timeout: {self.config.task_timeout}s")
|
||||
print(f" Terminal backend: {self.config.terminal_backend}")
|
||||
print(f" Tool thread pool: {self.config.tool_pool_size}")
|
||||
print(f" Terminal timeout: {self.config.terminal_timeout}s/cmd")
|
||||
print(f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Fire all tasks with wall-clock timeout, track live accuracy on the bar
|
||||
total_tasks = len(self.all_eval_items)
|
||||
eval_tasks = [
|
||||
asyncio.ensure_future(self._eval_with_timeout(item))
|
||||
for item in self.all_eval_items
|
||||
]
|
||||
|
||||
results = []
|
||||
passed_count = 0
|
||||
pbar = tqdm(total=total_tasks, desc="Evaluating TB2", dynamic_ncols=True)
|
||||
try:
|
||||
for coro in asyncio.as_completed(eval_tasks):
|
||||
result = await coro
|
||||
results.append(result)
|
||||
if result and result.get("passed"):
|
||||
passed_count += 1
|
||||
done = len(results)
|
||||
pct = (passed_count / done * 100) if done else 0
|
||||
pbar.set_postfix_str(f"pass={passed_count}/{done} ({pct:.1f}%)")
|
||||
pbar.update(1)
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
pbar.close()
|
||||
print(f"\n\nInterrupted! Cleaning up {len(eval_tasks)} tasks...")
|
||||
# Cancel all pending tasks
|
||||
for task in eval_tasks:
|
||||
task.cancel()
|
||||
# Let cancellations propagate (finally blocks run cleanup_vm)
|
||||
await asyncio.gather(*eval_tasks, return_exceptions=True)
|
||||
# Belt-and-suspenders: clean up any remaining sandboxes
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
print("All sandboxes cleaned up.")
|
||||
return
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Filter out None results (shouldn't happen, but be safe)
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return
|
||||
|
||||
# ---- Compute metrics ----
|
||||
total = len(valid_results)
|
||||
passed = sum(1 for r in valid_results if r.get("passed"))
|
||||
overall_pass_rate = passed / total if total > 0 else 0.0
|
||||
|
||||
# Per-category breakdown
|
||||
cat_results: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in valid_results:
|
||||
cat_results[r.get("category", "unknown")].append(r)
|
||||
|
||||
# Build metrics dict
|
||||
eval_metrics = {
|
||||
"eval/pass_rate": overall_pass_rate,
|
||||
"eval/total_tasks": total,
|
||||
"eval/passed_tasks": passed,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
}
|
||||
|
||||
# Per-category metrics
|
||||
for category, cat_items in sorted(cat_results.items()):
|
||||
cat_passed = sum(1 for r in cat_items if r.get("passed"))
|
||||
cat_total = len(cat_items)
|
||||
cat_pass_rate = cat_passed / cat_total if cat_total > 0 else 0.0
|
||||
cat_key = category.replace(" ", "_").replace("-", "_").lower()
|
||||
eval_metrics[f"eval/pass_rate_{cat_key}"] = cat_pass_rate
|
||||
|
||||
# Store metrics for wandb_log
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# ---- Print summary ----
|
||||
print(f"\n{'='*60}")
|
||||
print("Terminal-Bench 2.0 Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(f"Overall Pass Rate: {overall_pass_rate:.4f} ({passed}/{total})")
|
||||
print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
|
||||
|
||||
print("\nCategory Breakdown:")
|
||||
for category, cat_items in sorted(cat_results.items()):
|
||||
cat_passed = sum(1 for r in cat_items if r.get("passed"))
|
||||
cat_total = len(cat_items)
|
||||
cat_rate = cat_passed / cat_total if cat_total > 0 else 0.0
|
||||
print(f" {category}: {cat_rate:.1%} ({cat_passed}/{cat_total})")
|
||||
|
||||
# Print individual task results
|
||||
print("\nTask Results:")
|
||||
for r in sorted(valid_results, key=lambda x: x.get("task_name", "")):
|
||||
status = "PASS" if r.get("passed") else "FAIL"
|
||||
turns = r.get("turns_used", "?")
|
||||
error = r.get("error", "")
|
||||
extra = f" (error: {error})" if error else ""
|
||||
print(f" [{status}] {r['task_name']} (turns={turns}){extra}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Build sample records for evaluate_log (includes full conversations)
|
||||
samples = [
|
||||
{
|
||||
"task_name": r.get("task_name"),
|
||||
"category": r.get("category"),
|
||||
"passed": r.get("passed"),
|
||||
"reward": r.get("reward"),
|
||||
"turns_used": r.get("turns_used"),
|
||||
"error": r.get("error"),
|
||||
"messages": r.get("messages"),
|
||||
}
|
||||
for r in valid_results
|
||||
]
|
||||
|
||||
# Log evaluation results
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
"terminal_backend": self.config.terminal_backend,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error logging evaluation results: {e}")
|
||||
|
||||
# Close streaming file
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
print(f" Live results saved to: {self._streaming_path}")
|
||||
|
||||
# Kill all remaining sandboxes. Timed-out tasks leave orphaned thread
|
||||
# pool workers still executing commands -- cleanup_all stops them.
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
print("\nCleaning up all sandboxes...")
|
||||
cleanup_all_environments()
|
||||
|
||||
# Shut down the tool thread pool so orphaned workers from timed-out
|
||||
# tasks are killed immediately instead of retrying against dead
|
||||
# sandboxes and spamming the console with TimeoutError warnings.
|
||||
from environments.agent_loop import _tool_executor
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
print("Done.")
|
||||
|
||||
# =========================================================================
|
||||
# Wandb logging
|
||||
# =========================================================================
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log TB2-specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Add stored eval metrics
|
||||
for metric_name, metric_value in self.eval_metrics:
|
||||
wandb_metrics[metric_name] = metric_value
|
||||
self.eval_metrics = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TerminalBench2EvalEnv.cli()
|
||||
@@ -4,8 +4,7 @@
|
||||
# Uses terminal + file + web toolsets.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/hermes_swe_env/hermes_swe_env.py serve \
|
||||
# --config environments/hermes_swe_env/default.yaml
|
||||
# python environments/hermes_swe_env.py serve --config environments/configs/swe_default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file", "web"]
|
||||
+3
-2
@@ -6,8 +6,9 @@
|
||||
#
|
||||
# Usage:
|
||||
# run-api
|
||||
# python environments/terminal_test_env/terminal_test_env.py serve \
|
||||
# --config environments/terminal_test_env/default.yaml
|
||||
# python environments/terminal_test_env.py serve
|
||||
# # Or with config file:
|
||||
# python environments/terminal_test_env.py serve --config environments/configs/terminal_test_default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
GSM8kAgentEnv -- Math Reasoning with Tool Use (Python REPL)
|
||||
|
||||
An agentic RL environment where models solve GSM8k math problems using
|
||||
a Python interpreter tool. Uses proper OpenAI-spec tool calling via
|
||||
HermesAgentBaseEnv (not ICL).
|
||||
|
||||
The model:
|
||||
1. Receives a math problem
|
||||
2. Can call the `terminal` tool to run Python code (`python3 -c "..."`)
|
||||
3. Provides a final answer in \\boxed{} format
|
||||
4. Gets reward: 1.0 if correct, 0.0 if wrong
|
||||
|
||||
Usage:
|
||||
# Phase 1 (OpenRouter, no training):
|
||||
python environments/gsm8k_agent_env.py process \\
|
||||
--env.data_path_to_save_groups gsm8k_agent_output.jsonl
|
||||
|
||||
# Phase 2 (VLLM + Tinker training):
|
||||
run-api
|
||||
python launch_training.py --config configs/gsm8k_agent.yaml
|
||||
python environments/gsm8k_agent_env.py serve
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Math verification helpers
|
||||
# =============================================================================
|
||||
|
||||
def _verify_math_answer(model_response: str, gold_answer: str) -> bool:
|
||||
"""
|
||||
Verify if the model's response contains the correct answer.
|
||||
Uses math_verify for robust LaTeX comparison, falls back to string matching.
|
||||
"""
|
||||
try:
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
gold_parsed = parse(
|
||||
f"\\boxed{{{gold_answer}}}",
|
||||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
|
||||
# Strip <think> blocks if present
|
||||
answer_text = model_response
|
||||
if "</think>" in answer_text:
|
||||
answer_text = answer_text.split("</think>")[-1]
|
||||
|
||||
answer_parsed = parse(
|
||||
answer_text,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
|
||||
return bool(verify(answer_parsed, gold_parsed))
|
||||
|
||||
except ImportError:
|
||||
# Fallback: simple string matching for \\boxed{answer}
|
||||
import re
|
||||
pattern = r'\\boxed\{([^}]+)\}'
|
||||
matches = re.findall(pattern, model_response)
|
||||
if matches:
|
||||
model_answer = matches[-1].strip().replace(",", "")
|
||||
gold_clean = gold_answer.strip().replace(",", "")
|
||||
return model_answer == gold_clean
|
||||
return False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment Config
|
||||
# =============================================================================
|
||||
|
||||
class GSM8kAgentEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config with defaults for GSM8k agent environment."""
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment
|
||||
# =============================================================================
|
||||
|
||||
class GSM8kAgentEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
GSM8k math environment with Python REPL tool calling.
|
||||
|
||||
Models solve grade-school math problems by reasoning step by step
|
||||
and using Python (via the terminal tool) for calculations.
|
||||
|
||||
Exercises the full agentic RL training loop:
|
||||
- Model receives math problem
|
||||
- Makes tool calls to compute (python3 -c "...")
|
||||
- Provides final answer in \\boxed{}
|
||||
- Reward: binary (1.0 correct, 0.0 wrong)
|
||||
"""
|
||||
|
||||
name = "gsm8k-agent"
|
||||
env_config_cls = GSM8kAgentEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[GSM8kAgentEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default config using terminal tool.
|
||||
|
||||
Reads from environment variables (set in .env):
|
||||
ATROPOS_SERVER_BASE_URL - Inference server URL
|
||||
ATROPOS_SERVER_MODEL - Model name on the server
|
||||
ATROPOS_TOKENIZER_NAME - HuggingFace tokenizer name
|
||||
ATROPOS_SERVER_API_KEY - API key for the server
|
||||
"""
|
||||
# Resolve inference server settings from env
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "https://openrouter.ai/api/v1"
|
||||
)
|
||||
if not base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/") + "/v1"
|
||||
|
||||
model = (
|
||||
os.getenv("ATROPOS_SERVER_MODEL")
|
||||
or os.getenv("LLM_MODEL")
|
||||
or "Hermes-4.3-36B"
|
||||
)
|
||||
|
||||
api_key = (
|
||||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("NOUS_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
|
||||
tokenizer = (
|
||||
os.getenv("ATROPOS_TOKENIZER_NAME")
|
||||
or os.getenv("ATROPOS_TOKENIZER")
|
||||
or "NousResearch/Hermes-4.3-36B"
|
||||
)
|
||||
|
||||
env_config = GSM8kAgentEnvConfig(
|
||||
# Terminal + file toolsets (same as terminal_test_env.py)
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
# Agent settings
|
||||
max_agent_turns=5, # Math problems don't need many turns
|
||||
max_token_length=2048, # Room for reasoning + code
|
||||
agent_temperature=1.0,
|
||||
system_prompt=(
|
||||
"You are a helpful math assistant. You have access to a terminal "
|
||||
"where you can run Python code to help solve problems.\n\n"
|
||||
"When you need to calculate something, use the terminal tool with "
|
||||
"a command like: python3 -c \"print(2 + 2)\"\n\n"
|
||||
"When you have the final answer, write it inside \\boxed{} like: \\boxed{42}\n\n"
|
||||
"Work step by step. Use Python to verify your reasoning."
|
||||
),
|
||||
# Terminal backend (local for testing, modal for production)
|
||||
terminal_backend=os.getenv("TERMINAL_ENV", "local"),
|
||||
# Parser -- hermes format for Hermes models
|
||||
tool_call_parser="hermes",
|
||||
# Atropos settings
|
||||
group_size=4,
|
||||
tokenizer_name=tokenizer,
|
||||
steps_per_eval=5,
|
||||
total_steps=10,
|
||||
use_wandb=bool(os.getenv("WANDB_API_KEY")),
|
||||
wandb_name="gsm8k-agent",
|
||||
ensure_scores_are_not_same=False,
|
||||
# No external dataset (we load GSM8k ourselves)
|
||||
dataset_name=None,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url=base_url,
|
||||
model_name=model,
|
||||
server_type="openai",
|
||||
api_key=api_key,
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
"""Load GSM8k dataset."""
|
||||
from datasets import load_dataset
|
||||
|
||||
self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
|
||||
test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
|
||||
self.test = [
|
||||
{
|
||||
"question": item["question"],
|
||||
"gold_answer": item["answer"].split("#")[-1].strip().replace(",", ""),
|
||||
}
|
||||
for item in test_data
|
||||
]
|
||||
self.iter = 0
|
||||
self.reward_buffer: List[float] = []
|
||||
self.tool_use_buffer: List[int] = []
|
||||
print(f"[GSM8kAgentEnv] Loaded {len(self.train)} train, {len(self.test)} test examples")
|
||||
|
||||
async def get_next_item(self) -> Dict[str, str]:
|
||||
"""Cycle through training problems."""
|
||||
item = self.train[self.iter % len(self.train)]
|
||||
self.iter += 1
|
||||
return {
|
||||
"question": item["question"],
|
||||
"gold_answer": item["answer"].split("#")[-1].strip().replace(",", ""),
|
||||
}
|
||||
|
||||
def format_prompt(self, item: Dict[str, str]) -> str:
|
||||
"""Format the math problem as a user message."""
|
||||
return item["question"]
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Dict[str, str], result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Score: verify the model's \\boxed{} answer against the gold answer.
|
||||
|
||||
The agent has full access to terminal via ctx, but for GSM8k we just
|
||||
check the final answer from the conversation.
|
||||
"""
|
||||
# Get the last assistant message content
|
||||
final_text = ""
|
||||
for msg in reversed(result.messages):
|
||||
if msg.get("role") == "assistant" and msg.get("content"):
|
||||
final_text = msg["content"]
|
||||
break
|
||||
|
||||
correct = _verify_math_answer(final_text, item["gold_answer"])
|
||||
reward = 1.0 if correct else 0.0
|
||||
|
||||
self.reward_buffer.append(reward)
|
||||
# Count tool calls in this trajectory
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
self.tool_use_buffer.append(tool_call_count)
|
||||
|
||||
return reward
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Evaluate on a subset of the test set (greedy, no tools for speed)."""
|
||||
start_time = time.time()
|
||||
correct = 0
|
||||
total = 0
|
||||
samples = []
|
||||
|
||||
eval_subset = self.test[:30] # Small subset for quick eval
|
||||
|
||||
for item in eval_subset:
|
||||
try:
|
||||
completion = await self.server.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": self.config.system_prompt or ""},
|
||||
{"role": "user", "content": item["question"]},
|
||||
],
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.0,
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response = completion.choices[0].message.content or ""
|
||||
is_correct = _verify_math_answer(response, item["gold_answer"])
|
||||
|
||||
if is_correct:
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
samples.append({
|
||||
"question": item["question"],
|
||||
"gold_answer": item["gold_answer"],
|
||||
"response": response[:500],
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Eval failed: %s", e)
|
||||
total += 1
|
||||
|
||||
percent_correct = correct / total if total > 0 else 0
|
||||
end_time = time.time()
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics={"eval/percent_correct": percent_correct, "eval/total": total},
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log training metrics."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.reward_buffer:
|
||||
wandb_metrics["train/percent_correct"] = sum(self.reward_buffer) / len(self.reward_buffer)
|
||||
wandb_metrics["train/total_rollouts"] = len(self.reward_buffer)
|
||||
self.reward_buffer = []
|
||||
|
||||
if self.tool_use_buffer:
|
||||
wandb_metrics["train/avg_tool_calls"] = sum(self.tool_use_buffer) / len(self.tool_use_buffer)
|
||||
wandb_metrics["train/tool_use_rate"] = sum(1 for t in self.tool_use_buffer if t > 0) / len(self.tool_use_buffer)
|
||||
self.tool_use_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
GSM8kAgentEnv.cli()
|
||||
+435
-103
@@ -45,7 +45,7 @@ if _env_path.exists():
|
||||
# This patches SwerexModalEnvironment to use a background thread instead of
|
||||
# asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
|
||||
from environments.patches import apply_patches
|
||||
apply_patches()
|
||||
# apply_patches() # DISABLED: sglang patch breaks native vLLM /generate
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
@@ -64,7 +64,7 @@ from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
# Import hermes-agent toolset infrastructure
|
||||
from model_tools import get_tool_definitions
|
||||
from model_tools import get_tool_definitions, handle_function_call
|
||||
from toolset_distributions import sample_toolsets_from_distribution
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -117,18 +117,6 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
description="Terminal backend: 'local', 'docker', 'modal', 'ssh', 'singularity'. "
|
||||
"Modal recommended for production RL (cloud isolation per rollout).",
|
||||
)
|
||||
terminal_timeout: int = Field(
|
||||
default=120,
|
||||
description="Per-command timeout in seconds for terminal tool calls. "
|
||||
"Commands exceeding this are killed. Increase for tasks with long-running "
|
||||
"commands (compilation, pip install, etc.).",
|
||||
)
|
||||
terminal_lifetime: int = Field(
|
||||
default=3600,
|
||||
description="Sandbox inactivity lifetime in seconds. The cleanup thread kills "
|
||||
"sandboxes that have been idle longer than this. Must be longer than "
|
||||
"the longest gap between tool calls (e.g., waiting for LLM response).",
|
||||
)
|
||||
|
||||
# --- Dataset ---
|
||||
dataset_name: Optional[str] = Field(
|
||||
@@ -144,14 +132,6 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
description="Which field in the dataset contains the prompt.",
|
||||
)
|
||||
|
||||
# --- Thread pool ---
|
||||
tool_pool_size: int = Field(
|
||||
default=128,
|
||||
description="Thread pool size for tool execution. Each concurrent task needs a "
|
||||
"thread for tool calls. Must be large enough for parallel evaluation. "
|
||||
"Too small = thread pool starvation.",
|
||||
)
|
||||
|
||||
# --- Phase 2: Tool call parsing ---
|
||||
tool_call_parser: str = Field(
|
||||
default="hermes",
|
||||
@@ -160,22 +140,48 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
||||
)
|
||||
|
||||
# --- Provider-specific parameters ---
|
||||
# Passed as extra_body to the OpenAI client's chat.completions.create() call.
|
||||
# Useful for OpenRouter provider preferences, transforms, route settings, etc.
|
||||
# Example YAML:
|
||||
# extra_body:
|
||||
# provider:
|
||||
# ignore: ["DeepInfra", "Fireworks"]
|
||||
# order: ["Together"]
|
||||
# transforms: ["middle-out"]
|
||||
extra_body: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Extra body parameters passed to the OpenAI client's "
|
||||
"chat.completions.create(). Used for OpenRouter provider preferences, "
|
||||
"transforms, and other provider-specific settings.",
|
||||
# --- Sandbox pool mode (optional, for scaled environments) ---
|
||||
tool_pool_mode: str = Field(
|
||||
default="default",
|
||||
description="Tool execution mode: 'default' (terminal tool per task_id), "
|
||||
"'nomad' (slot pool via Nomad/Docker/Singularity), or 'modal' (Modal sandbox pool).",
|
||||
)
|
||||
|
||||
# Sandbox pool: shared settings
|
||||
allow_network: bool = Field(default=True, description="Whether sandbox bash commands may access the network.")
|
||||
require_sandbox: bool = Field(default=False, description="Fail closed if bubblewrap is unavailable.")
|
||||
purge_job_on_start: bool = Field(default=False, description="Purge existing sandbox job on startup.")
|
||||
purge_job_on_shutdown: bool = Field(default=True, description="Purge sandbox job on shutdown.")
|
||||
acquire_timeout_s: float = Field(default=30.0, description="Slot acquisition timeout (seconds).")
|
||||
|
||||
# Sandbox pool: Nomad settings
|
||||
nomad_address: str = Field(default="http://localhost:4646", description="Nomad API address.")
|
||||
sandbox_job_id: str = Field(default="atropos-sandbox", description="Nomad job id for sandbox containers.")
|
||||
sandbox_image: str = Field(default="atropos-sandbox:local", description="Docker image for sandbox containers.")
|
||||
slots_per_container: int = Field(default=10, description="Nomad: slots per container.")
|
||||
min_containers: int = Field(default=1, description="Nomad: minimum containers.")
|
||||
max_containers: int = Field(default=10, description="Nomad: maximum containers.")
|
||||
privileged: bool = Field(default=False, description="Nomad: run container privileged.")
|
||||
driver: str = Field(default="docker", description="Nomad task driver: 'docker' or 'singularity'.")
|
||||
singularity_image: Optional[str] = Field(default=None, description="Path to .sif file for Singularity driver.")
|
||||
|
||||
# Sandbox pool: Modal settings
|
||||
modal_app_name: str = Field(default="atropos-sandbox", description="Modal app name prefix.")
|
||||
modal_image: str = Field(default="python:3.11", description="Modal: container image.")
|
||||
modal_gpu: Optional[str] = Field(default=None, description="Modal: GPU type (None, 'T4', 'A10G', 'A100', 'H100').")
|
||||
modal_cpu: float = Field(default=1.0, description="Modal: CPU cores.")
|
||||
modal_memory: int = Field(default=2048, description="Modal: memory in MB.")
|
||||
modal_slots_per_sandbox: int = Field(default=10, description="Modal: slots per sandbox.")
|
||||
modal_min_sandboxes: int = Field(default=1, description="Modal: minimum sandboxes.")
|
||||
modal_max_sandboxes: int = Field(default=5, description="Modal: maximum sandboxes.")
|
||||
modal_idle_timeout: int = Field(default=120, description="Modal: idle timeout (seconds).")
|
||||
modal_max_lifetime: int = Field(default=3600, description="Modal: max sandbox lifetime (seconds).")
|
||||
modal_acquire_timeout: float = Field(default=60.0, description="Modal: slot acquisition timeout (seconds).")
|
||||
modal_execution_timeout: float = Field(default=30.0, description="Modal: command execution timeout (seconds).")
|
||||
modal_secrets: str = Field(default="", description="Modal: comma-separated Modal Secret names.")
|
||||
modal_env_vars: str = Field(default="", description="Modal: semicolon-separated KEY=VALUE pairs.")
|
||||
modal_workspace_base: str = Field(default="/data", description="Modal: workspace base directory.")
|
||||
|
||||
|
||||
class HermesAgentBaseEnv(BaseEnv):
|
||||
"""
|
||||
@@ -211,23 +217,10 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
|
||||
# Set terminal environment variables so hermes tools pick them up.
|
||||
# These can all be overridden per-environment via config fields instead
|
||||
# of requiring users to set shell env vars.
|
||||
# Set terminal backend environment variable so hermes tools pick it up
|
||||
if config.terminal_backend:
|
||||
os.environ["TERMINAL_ENV"] = config.terminal_backend
|
||||
os.environ["TERMINAL_TIMEOUT"] = str(config.terminal_timeout)
|
||||
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(config.terminal_lifetime)
|
||||
print(
|
||||
f"🖥️ Terminal: backend={config.terminal_backend}, "
|
||||
f"timeout={config.terminal_timeout}s, lifetime={config.terminal_lifetime}s"
|
||||
)
|
||||
|
||||
# Resize the agent loop's thread pool for tool execution.
|
||||
# This must be large enough for the number of concurrent tasks
|
||||
# (e.g., 89 parallel TB2 eval tasks each need a thread for tool calls).
|
||||
from environments.agent_loop import resize_tool_pool
|
||||
resize_tool_pool(config.tool_pool_size)
|
||||
print(f"🖥️ Terminal backend: {config.terminal_backend}")
|
||||
|
||||
# Current group's resolved tools (set in collect_trajectories)
|
||||
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
|
||||
@@ -235,6 +228,9 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
# Tool error tracking for wandb logging
|
||||
self._tool_error_buffer: List[Dict[str, Any]] = []
|
||||
|
||||
# Sandbox pool backend (only used when tool_pool_mode != "default")
|
||||
self._sandbox_backend = None
|
||||
|
||||
# =========================================================================
|
||||
# Toolset resolution (per-group)
|
||||
# =========================================================================
|
||||
@@ -274,6 +270,12 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
# =========================================================================
|
||||
|
||||
def _use_managed_server(self) -> bool:
|
||||
import sys
|
||||
result = self._use_managed_server_inner()
|
||||
print(f"HERMES_DEBUG _use_managed_server={result}, servers={len(self.server.servers) if hasattr(self.server, 'servers') else 'N/A'}, type={type(self.server.servers[0]).__name__ if hasattr(self.server, 'servers') and self.server.servers else 'N/A'}", file=sys.stderr, flush=True)
|
||||
return result
|
||||
|
||||
def _use_managed_server_inner(self) -> bool:
|
||||
"""
|
||||
Determine if we should use ManagedServer (Phase 2) or direct server (Phase 1).
|
||||
|
||||
@@ -291,6 +293,154 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
||||
return not isinstance(server, OpenAIServer)
|
||||
|
||||
# =========================================================================
|
||||
# Sandbox pool backend (tool_pool_mode != "default")
|
||||
# =========================================================================
|
||||
|
||||
async def _start_sandbox_backend(self) -> None:
|
||||
"""
|
||||
Configure the slot pool backend if tool_pool_mode is not 'default'.
|
||||
|
||||
Sets TERMINAL_ENV=slot_pool and configures env vars so that ALL hermes
|
||||
tools (terminal, file, etc.) automatically route through the sandbox
|
||||
pool via _SlotPoolEnvironment in terminal_tool.py.
|
||||
"""
|
||||
if self.config.tool_pool_mode == "default":
|
||||
return
|
||||
|
||||
mode = self.config.tool_pool_mode
|
||||
logger.info("Configuring slot pool backend (mode=%s)", mode)
|
||||
|
||||
# Set TERMINAL_ENV=slot_pool so terminal_tool.py uses _SlotPoolEnvironment
|
||||
os.environ["TERMINAL_ENV"] = "slot_pool"
|
||||
|
||||
# Set the backend type (modal or nomad)
|
||||
if mode == "modal":
|
||||
os.environ["TERMINAL_SLOT_BACKEND"] = "modal"
|
||||
# Forward modal config from env config to slot pool env vars
|
||||
os.environ.setdefault("TERMINAL_MODAL_IMAGE", self.config.modal_image)
|
||||
os.environ.setdefault("TERMINAL_MODAL_SLOTS", str(self.config.modal_slots_per_sandbox))
|
||||
os.environ.setdefault("TERMINAL_MODAL_MIN", str(self.config.modal_min_sandboxes))
|
||||
os.environ.setdefault("TERMINAL_MODAL_MAX", str(self.config.modal_max_sandboxes))
|
||||
os.environ.setdefault("TERMINAL_MODAL_IDLE_TIMEOUT", str(self.config.modal_idle_timeout))
|
||||
os.environ.setdefault("TERMINAL_MODAL_MAX_LIFETIME", str(self.config.modal_max_lifetime))
|
||||
os.environ.setdefault("TERMINAL_MODAL_ACQUIRE_TIMEOUT", str(self.config.modal_acquire_timeout))
|
||||
os.environ.setdefault("TERMINAL_MODAL_EXEC_TIMEOUT", str(self.config.modal_execution_timeout))
|
||||
os.environ.setdefault("TERMINAL_MODAL_WORKSPACE", self.config.modal_workspace_base)
|
||||
if self.config.modal_gpu:
|
||||
os.environ.setdefault("TERMINAL_MODAL_GPU", self.config.modal_gpu)
|
||||
elif mode == "nomad":
|
||||
os.environ["TERMINAL_SLOT_BACKEND"] = "nomad"
|
||||
os.environ.setdefault("TERMINAL_NOMAD_ADDRESS", self.config.nomad_address)
|
||||
os.environ.setdefault("TERMINAL_NOMAD_IMAGE", self.config.sandbox_image)
|
||||
os.environ.setdefault("TERMINAL_NOMAD_DRIVER", self.config.driver)
|
||||
os.environ.setdefault("TERMINAL_NOMAD_SLOTS", str(self.config.slots_per_container))
|
||||
os.environ.setdefault("TERMINAL_NOMAD_MIN", str(self.config.min_containers))
|
||||
os.environ.setdefault("TERMINAL_NOMAD_MAX", str(self.config.max_containers))
|
||||
|
||||
# Eagerly start the _SlotPoolManager so the backend is ready
|
||||
# before any trajectories try to use it
|
||||
from tools.terminal_tool import _SlotPoolManager
|
||||
_SlotPoolManager.get_instance() # Triggers _start() which creates sandboxes
|
||||
|
||||
self._sandbox_backend = True # Flag that sandbox mode is active
|
||||
print(f"🔧 Slot pool started: TERMINAL_ENV=slot_pool, backend={mode}")
|
||||
|
||||
async def _stop_sandbox_backend(self) -> None:
|
||||
"""Stop the slot pool backend."""
|
||||
if self._sandbox_backend:
|
||||
logger.info("Stopping slot pool backend")
|
||||
try:
|
||||
from tools.terminal_tool import _SlotPoolManager
|
||||
_SlotPoolManager.reset_instance()
|
||||
except Exception as e:
|
||||
logger.warning("Slot pool shutdown: %s", e)
|
||||
self._sandbox_backend = None
|
||||
|
||||
# =========================================================================
|
||||
# Optional hooks for sandbox environments
|
||||
# =========================================================================
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self,
|
||||
item: Item,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optional hook: prepare the sandbox workspace before the agent starts.
|
||||
|
||||
Override in subclasses for environments that need workspace setup
|
||||
(e.g., git clone, worktree creation, dependency installation).
|
||||
|
||||
Args:
|
||||
item: The dataset item being rolled out
|
||||
trajectory_id: Unique ID for this trajectory
|
||||
exec_tool: Callable to execute tool calls in the sandbox
|
||||
|
||||
Returns:
|
||||
Dict of workspace metadata (passed to verify_and_score_trajectory)
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
Optional hook: run in-sandbox verification before scoring.
|
||||
|
||||
Override in subclasses for environments that need to verify results
|
||||
inside the sandbox (e.g., run pytest, check file contents).
|
||||
|
||||
Default: calls compute_reward() with ToolContext.
|
||||
|
||||
Args:
|
||||
item: The dataset item
|
||||
result: The agent's rollout result
|
||||
trajectory_id: Unique ID for this trajectory
|
||||
exec_tool: Callable to execute tool calls in the sandbox
|
||||
workspace_meta: Metadata from setup_trajectory_workspace
|
||||
|
||||
Returns:
|
||||
Tuple of (reward, metadata_dict)
|
||||
"""
|
||||
ctx = ToolContext(trajectory_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
return reward, {}
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle hooks for env_manager/process_manager cleanup
|
||||
# =========================================================================
|
||||
|
||||
async def env_manager(self):
|
||||
"""Start sandbox backend, run env, then clean up."""
|
||||
await self._start_sandbox_backend()
|
||||
try:
|
||||
return await super().env_manager()
|
||||
finally:
|
||||
await self._stop_sandbox_backend()
|
||||
|
||||
async def process_manager(self):
|
||||
"""Start sandbox backend, run process, then clean up."""
|
||||
await self._start_sandbox_backend()
|
||||
try:
|
||||
return await super().process_manager()
|
||||
finally:
|
||||
await self._stop_sandbox_backend()
|
||||
|
||||
# =========================================================================
|
||||
# Core Atropos integration
|
||||
# =========================================================================
|
||||
@@ -434,6 +584,13 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
def _use_sandbox_backend(self) -> bool:
|
||||
"""Check if we should route tool execution through a sandbox backend."""
|
||||
return (
|
||||
self.config.tool_pool_mode != "default"
|
||||
and self._sandbox_backend is not None
|
||||
)
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
@@ -442,12 +599,19 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
|
||||
This is called group_size times in parallel by collect_trajectories().
|
||||
Each call gets its own task_id for terminal/browser session isolation.
|
||||
|
||||
When tool_pool_mode != "default", routes tool execution through the
|
||||
sandbox backend (Modal, Nomad) with slot-based multiplexing:
|
||||
1. Acquire a slot from the sandbox pool
|
||||
2. Setup workspace via subclass hook (e.g., git clone + worktree)
|
||||
3. Run agent loop with terminal calls routed through sandbox
|
||||
4. Verify and score in-sandbox via subclass hook (e.g., pytest)
|
||||
5. Release the slot
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Get group-level tools (resolved once in collect_trajectories)
|
||||
if self._current_group_tools is None:
|
||||
# Fallback: resolve per-trajectory if called outside collect_trajectories
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
else:
|
||||
tools, valid_names = self._current_group_tools
|
||||
@@ -458,11 +622,194 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
|
||||
# Run the agent loop
|
||||
result: AgentResult
|
||||
# Dispatch to the appropriate path
|
||||
if self._use_sandbox_backend():
|
||||
return await self._collect_trajectory_sandbox(
|
||||
item, task_id, tools, valid_names, messages
|
||||
)
|
||||
else:
|
||||
return await self._collect_trajectory_local(
|
||||
item, task_id, tools, valid_names, messages
|
||||
)
|
||||
|
||||
async def _collect_trajectory_local(
|
||||
self,
|
||||
item: Item,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Default (local) trajectory collection path.
|
||||
|
||||
Uses hermes-agent's handle_function_call() for tool execution.
|
||||
Reward computed via compute_reward() with ToolContext.
|
||||
"""
|
||||
result = await self._run_agent_loop(
|
||||
task_id, tools, valid_names, messages, tool_handler=None
|
||||
)
|
||||
|
||||
# Skip reward if the agent loop produced no meaningful work
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
return self._build_scored_item(item, result, reward)
|
||||
|
||||
async def _collect_trajectory_sandbox(
|
||||
self,
|
||||
item: Item,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Sandbox trajectory collection path (Modal, Nomad).
|
||||
|
||||
Uses TERMINAL_ENV=slot_pool so ALL hermes tools (terminal, file, web)
|
||||
automatically route through the sandbox pool via _SlotPoolEnvironment.
|
||||
No per-tool routing needed — the slot pool is the terminal backend.
|
||||
|
||||
Flow:
|
||||
1. Pre-warm terminal env (acquires a slot in the pool)
|
||||
2. Setup workspace via subclass hook (e.g., git clone + worktree)
|
||||
3. Run agent loop with tool_handler=None (all tools use handle_function_call)
|
||||
4. Verify and score in-sandbox via subclass hook (e.g., pytest)
|
||||
5. Release the slot via cleanup_vm()
|
||||
"""
|
||||
from tools.terminal_tool import _SlotPoolManager, cleanup_vm
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class _ExecResult:
|
||||
"""Lightweight result for exec_tool compatibility with env hooks."""
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = None
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
try:
|
||||
# 1. Pre-warm: trigger terminal env creation → acquires slot
|
||||
logger.info("Pre-warming sandbox slot for task %s", task_id)
|
||||
loop = asyncio.get_event_loop()
|
||||
warmup = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: handle_function_call(
|
||||
"terminal", {"command": "echo slot_ready"}, task_id=task_id
|
||||
),
|
||||
)
|
||||
logger.info("Sandbox slot acquired for task %s", task_id)
|
||||
|
||||
# 2. Create exec_tool for setup/verify hooks
|
||||
# Routes through handle_function_call → terminal_tool → same _SlotPoolEnvironment
|
||||
async def exec_tool(tool_name: str, args: Dict[str, Any], timeout: float = 300) -> _ExecResult:
|
||||
command = args.get("command", "")
|
||||
result_json = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: handle_function_call(
|
||||
"terminal",
|
||||
{"command": command, "timeout": int(timeout)},
|
||||
task_id=task_id,
|
||||
),
|
||||
)
|
||||
try:
|
||||
result_dict = json.loads(result_json)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
result_dict = {"output": str(result_json), "exit_code": 1}
|
||||
returncode = result_dict.get("exit_code", result_dict.get("returncode", 1))
|
||||
output = result_dict.get("output", "")
|
||||
return _ExecResult(
|
||||
success=(returncode == 0),
|
||||
output=output,
|
||||
error=result_dict.get("error", "") if returncode != 0 else "",
|
||||
metadata={"returncode": returncode},
|
||||
)
|
||||
|
||||
# 3. Setup workspace (subclass hook: git clone, worktree, etc.)
|
||||
workspace_meta = await self.setup_trajectory_workspace(
|
||||
item, trajectory_id=task_id, exec_tool=exec_tool
|
||||
)
|
||||
|
||||
# 4. Run agent loop — tool_handler=None means ALL tools go through
|
||||
# handle_function_call() → terminal_tool() → _SlotPoolEnvironment
|
||||
# → same sandbox slot. File tools also route through same env.
|
||||
result = await self._run_agent_loop(
|
||||
task_id, tools, valid_names, messages,
|
||||
tool_handler=None,
|
||||
)
|
||||
|
||||
# 5. Skip verification if no meaningful work
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# 6. Verify and score in-sandbox (subclass hook: pytest, etc.)
|
||||
reward, score_meta = await self.verify_and_score_trajectory(
|
||||
item, result,
|
||||
trajectory_id=task_id,
|
||||
exec_tool=exec_tool,
|
||||
workspace_meta=workspace_meta,
|
||||
)
|
||||
logger.info("Sandbox reward for task %s: %.2f", task_id, reward)
|
||||
|
||||
return self._build_scored_item(item, result, reward)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Sandbox trajectory failed for task %s: %s", task_id, e, exc_info=True)
|
||||
dummy_result = AgentResult(
|
||||
messages=messages, turns_used=0, finished_naturally=False
|
||||
)
|
||||
return self._build_scored_item(item, dummy_result, 0.0)
|
||||
|
||||
finally:
|
||||
# Release the slot back to the pool
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
logger.info("Released sandbox slot for task %s", task_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to release slot for task %s: %s", task_id, e)
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
task_id: str,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: Set[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
tool_handler=None,
|
||||
) -> AgentResult:
|
||||
"""
|
||||
Run the agent loop in either Phase 1 or Phase 2 mode.
|
||||
|
||||
Shared between local and sandbox paths -- the only difference is
|
||||
the tool_handler parameter (None for local, sandbox callable for sandbox).
|
||||
"""
|
||||
if self._use_managed_server():
|
||||
# Phase 2: ManagedServer with parser -- exact tokens + logprobs
|
||||
# Load the tool call parser from registry based on config
|
||||
from environments.tool_call_parsers import get_parser
|
||||
try:
|
||||
tc_parser = get_parser(self.config.tool_call_parser)
|
||||
@@ -478,7 +825,13 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
tokenizer=self.tokenizer,
|
||||
tool_call_parser=tc_parser,
|
||||
) as managed:
|
||||
_max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None
|
||||
# Calculate max prompt tokens
|
||||
# Context budget = max_token_length (prompt can be as long as generation budget)
|
||||
# This ensures prompt + generation stays under typical model context limits
|
||||
# E.g., max_token_length=16384 → 16384 prompt + 16384 gen = 32K < 40960 model limit
|
||||
_max_ctx = None
|
||||
if self.config.max_token_length and self.config.max_token_length > 0:
|
||||
_max_ctx = self.config.max_token_length
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
@@ -487,17 +840,18 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
tool_handler=tool_handler,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
return await agent.run(messages)
|
||||
except NotImplementedError:
|
||||
# DummyManagedServer not allowed -- fall back to Phase 1
|
||||
logger.warning(
|
||||
"ManagedServer not available (OpenAI server?). "
|
||||
"Falling back to direct server mode."
|
||||
)
|
||||
_max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None
|
||||
_max_ctx = None
|
||||
if self.config.max_token_length and self.config.max_token_length > 0:
|
||||
_max_ctx = self.config.max_token_length
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
@@ -506,13 +860,14 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
tool_handler=tool_handler,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
return await agent.run(messages)
|
||||
else:
|
||||
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
|
||||
_max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None
|
||||
_max_ctx = None
|
||||
if self.config.max_token_length and self.config.max_token_length > 0:
|
||||
_max_ctx = self.config.max_token_length
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
@@ -521,34 +876,22 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
tool_handler=tool_handler,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
return await agent.run(messages)
|
||||
|
||||
# Skip reward computation if the agent loop produced no meaningful work
|
||||
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
|
||||
# just to verify files that were never created.
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
||||
result.turns_used, len(result.messages),
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# Compute reward using ToolContext (gives verifier full tool access)
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.error("compute_reward failed: %s", e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
def _build_scored_item(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
reward: float,
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
"""
|
||||
Build a ScoredDataItem from an AgentResult and reward.
|
||||
|
||||
Shared between local and sandbox paths.
|
||||
"""
|
||||
# Track tool errors for wandb logging
|
||||
if result.tool_errors:
|
||||
for err in result.tool_errors:
|
||||
@@ -561,28 +904,19 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
})
|
||||
|
||||
# Build ScoredDataItem from ManagedServer state
|
||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
||||
# Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
|
||||
nodes = (result.managed_state or {}).get("nodes", [])
|
||||
|
||||
if nodes:
|
||||
# Phase 2 (or DummyManagedServer): use actual node data
|
||||
node = nodes[-1] # Final sequence node = full trajectory
|
||||
node = nodes[-1]
|
||||
scored_item: Dict[str, Any] = {
|
||||
"tokens": node.tokens,
|
||||
"masks": node.masked_tokens,
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Include logprobs if available (Phase 2)
|
||||
if hasattr(node, "logprobs") and node.logprobs:
|
||||
scored_item["advantages"] = None # Computed by trainer
|
||||
scored_item["advantages"] = None
|
||||
scored_item["ref_logprobs"] = None
|
||||
else:
|
||||
# Phase 1 with no managed state: create placeholder tokens
|
||||
# so the data pipeline doesn't break. These are NOT suitable
|
||||
# for training but allow process mode (SFT data gen) to work.
|
||||
# Tokenize the full conversation to get approximate tokens.
|
||||
full_text = "\n".join(
|
||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
||||
)
|
||||
@@ -593,13 +927,11 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
|
||||
scored_item = {
|
||||
"tokens": tokens,
|
||||
"masks": [-100] + tokens[1:], # Mask first token as prompt
|
||||
"masks": [-100] + tokens[1:],
|
||||
"scores": reward,
|
||||
}
|
||||
|
||||
# Always include messages for wandb rollout display and data logging
|
||||
scored_item["messages"] = result.messages
|
||||
|
||||
return scored_item, []
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -36,7 +36,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
@@ -171,6 +171,126 @@ def _patch_swerex_modal():
|
||||
logger.debug("Patched SwerexModalEnvironment for async-safe operation")
|
||||
|
||||
|
||||
def _patch_vllm_server_for_sglang():
|
||||
"""
|
||||
(Mainly for Runpod serverless compat)
|
||||
|
||||
Monkey patch VLLMServer._tokens_and_logprobs_completion_wrapper to handle
|
||||
SGLang's /generate response format.
|
||||
|
||||
VLLMServer expects:
|
||||
Request: {"prompt": {"prompt_token_ids": [...]}, "logprobs": 0}
|
||||
Response: {"logprobs": [[{token_id: logprob}]], "finish_reasons": [...]}
|
||||
|
||||
SGLang returns:
|
||||
Request: {"input_ids": [...], "sampling_params": {...}, "return_logprob": true}
|
||||
Response: {"text": "...", "meta_info": {"output_token_logprobs": [[logprob, token_id, text], ...]}}
|
||||
|
||||
This patch makes VLLMServer work with SGLang endpoints (e.g., RunPod SGLang workers).
|
||||
"""
|
||||
try:
|
||||
import aiohttp
|
||||
from atroposlib.envs.server_handling.vllm_server import VLLMServer
|
||||
except ImportError:
|
||||
logger.debug("atroposlib VLLMServer not available, skipping SGLang patch")
|
||||
return
|
||||
|
||||
# Save the original method
|
||||
_original_wrapper = VLLMServer._tokens_and_logprobs_completion_wrapper
|
||||
|
||||
async def _sglang_compatible_wrapper(self, **kwargs):
|
||||
"""
|
||||
Patched wrapper that tries the original VLLMServer format first,
|
||||
then falls back to SGLang format if that fails.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("prompt") is not None or kwargs.get("input_ids") is not None, "Prompt or input_ids required!"
|
||||
|
||||
# Get prompt tokens
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt_tokens = self.tokenizer.encode(kwargs.pop("prompt"))
|
||||
|
||||
# Check for double BOS
|
||||
if (len(prompt_tokens) >= 2
|
||||
and prompt_tokens[0] == self.tokenizer.bos_token_id == prompt_tokens[1]):
|
||||
prompt_tokens = prompt_tokens[1:]
|
||||
|
||||
# Normalize kwargs
|
||||
max_tokens = kwargs.pop("max_new_tokens", kwargs.pop("max_completion_tokens", kwargs.pop("max_tokens", 2048)))
|
||||
n = kwargs.pop("n", 1)
|
||||
temperature = kwargs.pop("temperature", 1.0)
|
||||
kwargs.pop("model", None)
|
||||
|
||||
# Build SGLang-compatible request
|
||||
request_data = {
|
||||
"input_ids": prompt_tokens,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
"n": n,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"top_logprobs_num": 0,
|
||||
}
|
||||
|
||||
generate_url = f"{self.config.base_url.replace('/v1', '')}/generate"
|
||||
|
||||
headers = {}
|
||||
if self.config.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
generate_url,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
raw_text = await response.text()
|
||||
|
||||
# RunPod wraps JSON responses in quotes — may need double-parse
|
||||
import json
|
||||
results = json.loads(raw_text)
|
||||
if isinstance(results, str):
|
||||
results = json.loads(results)
|
||||
|
||||
# Parse SGLang response format
|
||||
meta = results.get("meta_info", {})
|
||||
output_token_logprobs_raw = meta.get("output_token_logprobs", [])
|
||||
|
||||
# SGLang format: [[logprob, token_id, token_text], ...]
|
||||
output_tokens = []
|
||||
output_logprobs = []
|
||||
for entry in output_token_logprobs_raw:
|
||||
if isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
||||
logprob, token_id = entry[0], entry[1]
|
||||
output_tokens.append(int(token_id))
|
||||
output_logprobs.append(float(logprob))
|
||||
|
||||
# Get finish reason
|
||||
finish_reason_raw = meta.get("finish_reason", "stop")
|
||||
if isinstance(finish_reason_raw, dict):
|
||||
finish_reason = finish_reason_raw.get("type", "stop")
|
||||
else:
|
||||
finish_reason = str(finish_reason_raw)
|
||||
|
||||
return (
|
||||
prompt_tokens,
|
||||
[output_tokens],
|
||||
[output_logprobs],
|
||||
[finish_reason],
|
||||
)
|
||||
|
||||
# Apply the patch
|
||||
VLLMServer._tokens_and_logprobs_completion_wrapper = _sglang_compatible_wrapper
|
||||
logger.info("Patched VLLMServer for SGLang /generate compatibility")
|
||||
|
||||
|
||||
def apply_patches():
|
||||
"""
|
||||
Apply all monkey patches needed for Atropos compatibility.
|
||||
@@ -184,5 +304,6 @@ def apply_patches():
|
||||
return
|
||||
|
||||
_patch_swerex_modal()
|
||||
# _patch_vllm_server_for_sglang()
|
||||
|
||||
_patches_applied = True
|
||||
|
||||
@@ -0,0 +1,620 @@
|
||||
"""
|
||||
SWE-smith-oracle environment (ported to HermesAgentBaseEnv).
|
||||
|
||||
Trains models to fix real GitHub repositories:
|
||||
- Clones a public GitHub repo at a specific commit
|
||||
- Runs an agent loop with terminal tool to apply a fix
|
||||
- Verifies by running pytest with nodeids from the dataset
|
||||
- Reward: 1.0 if all tests pass, 0.0 otherwise
|
||||
|
||||
Dataset: NousResearch/SWE-smith-oracle (train split; does NOT use SWE-bench eval set).
|
||||
|
||||
Usage:
|
||||
# Process mode (OpenAI server, no training):
|
||||
python environments/swe_smith_oracle_env.py process \\
|
||||
--env.data_path_to_save_groups data/swe_oracle_output.jsonl
|
||||
|
||||
# With Modal sandbox backend:
|
||||
python environments/swe_smith_oracle_env.py process \\
|
||||
--env.tool_pool_mode modal \\
|
||||
--env.modal_image python:3.11
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import ScoredDataGroup
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config
|
||||
# =============================================================================
|
||||
|
||||
class SweSmithOracleEnvConfig(HermesAgentEnvConfig):
|
||||
"""Config for SWE-smith-oracle environment."""
|
||||
|
||||
dataset_name: str = Field(default="NousResearch/SWE-smith-oracle")
|
||||
dataset_split: str = Field(default="train")
|
||||
max_items: int = Field(default=0, description="0 = no limit")
|
||||
shuffle: bool = Field(default=True)
|
||||
seed: int = Field(default=0)
|
||||
|
||||
python_only: bool = Field(default=True, description="Filter to Python-evaluable rows")
|
||||
score_include_fail_to_pass: bool = Field(
|
||||
default=True,
|
||||
description="Score tests on PASS_TO_PASS ∪ FAIL_TO_PASS. "
|
||||
"Disable to only run PASS_TO_PASS (faster but weaker signal).",
|
||||
)
|
||||
|
||||
prompt_mode: str = Field(
|
||||
default="problem_statement",
|
||||
description="'problem_statement' (fast) or 'problem_statement+text' (includes dataset 'text').",
|
||||
)
|
||||
|
||||
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
||||
install_timeout_s: float = Field(default=600.0)
|
||||
test_timeout_s: float = Field(default=600.0)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Environment
|
||||
# =============================================================================
|
||||
|
||||
class SweSmithOracleEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
SWE-smith-oracle environment for training models to fix real GitHub repos.
|
||||
|
||||
Uses proper OpenAI-spec tool calling via HermesAgentBaseEnv.
|
||||
The model gets terminal access to inspect, edit, and test the repository.
|
||||
"""
|
||||
|
||||
name = "swe-smith-oracle"
|
||||
env_config_cls = SweSmithOracleEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SweSmithOracleEnvConfig,
|
||||
server_configs,
|
||||
slurm=False,
|
||||
testing=False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._dataset = None
|
||||
self._indices: List[int] = []
|
||||
self._cursor = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
||||
"""Default config — reads from ATROPOS_SERVER_* env vars."""
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
if not base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/") + "/v1"
|
||||
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "Hermes-4.3-36B"
|
||||
api_key = (
|
||||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("NOUS_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or "local"
|
||||
)
|
||||
|
||||
env_config = SweSmithOracleEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=1,
|
||||
max_token_length=8192,
|
||||
wandb_name="swe_smith_oracle",
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
terminal_backend=os.getenv("TERMINAL_ENV", "local"),
|
||||
# Longer agent turns for SWE tasks
|
||||
max_agent_turns=50,
|
||||
agent_temperature=0.7,
|
||||
system_prompt=(
|
||||
"You are a senior software engineer. You have access to a terminal "
|
||||
"to inspect and fix repositories. Use non-interactive commands only. "
|
||||
"Each terminal command runs in a fresh shell."
|
||||
),
|
||||
tool_call_parser="hermes",
|
||||
# Sandbox settings (used when tool_pool_mode != "default")
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
server_type="vllm",
|
||||
health_check=False,
|
||||
timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"),
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Dataset loading
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Load SWE-smith-oracle dataset."""
|
||||
from datasets import load_dataset
|
||||
|
||||
t0 = time.perf_counter()
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})",
|
||||
flush=True,
|
||||
)
|
||||
ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self._dataset = ds
|
||||
|
||||
indices: List[int] = []
|
||||
for idx in range(len(ds)):
|
||||
row = ds[idx]
|
||||
if self.config.python_only and not self._is_python_row(row):
|
||||
continue
|
||||
indices.append(idx)
|
||||
|
||||
if self.config.shuffle:
|
||||
rnd = random.Random(self.config.seed)
|
||||
rnd.shuffle(indices)
|
||||
|
||||
if self.config.max_items and self.config.max_items > 0:
|
||||
indices = indices[: self.config.max_items]
|
||||
|
||||
self._indices = indices
|
||||
self._cursor = 0
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loaded {len(self._indices)} items "
|
||||
f"in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _is_python_row(self, row: Dict[str, Any]) -> bool:
|
||||
nodeids = row.get("PASS_TO_PASS")
|
||||
if not isinstance(nodeids, list) or not nodeids:
|
||||
return False
|
||||
return all(isinstance(nid, str) and ".py::" in nid for nid in nodeids)
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
if not self._dataset or not self._indices:
|
||||
raise RuntimeError("Dataset not initialized")
|
||||
if self._cursor >= len(self._indices):
|
||||
self._cursor = 0
|
||||
idx = self._indices[self._cursor]
|
||||
self._cursor += 1
|
||||
return dict(self._dataset[idx])
|
||||
|
||||
# =========================================================================
|
||||
# Prompt formatting
|
||||
# =========================================================================
|
||||
|
||||
def _repo_name(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
if isinstance(repo, str) and "/" in repo:
|
||||
return repo.split("/")[-1]
|
||||
return "repo"
|
||||
|
||||
def format_prompt(self, item: Item) -> str:
|
||||
"""Build the SWE task prompt."""
|
||||
repo = item.get("repo") or ""
|
||||
base_commit = item.get("base_commit") or ""
|
||||
problem = str(item.get("problem_statement") or "")
|
||||
context = str(item.get("text") or "")
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
tests_list = "\n".join(f"- {t}" for t in nodeids)
|
||||
|
||||
context_block = ""
|
||||
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
|
||||
if prompt_mode == "problem_statement+text" and context:
|
||||
context_block = f"\nAdditional context:\n{context}\n"
|
||||
|
||||
return (
|
||||
f"Fix the repository so the specified tests pass.\n\n"
|
||||
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
||||
f"Workspace path: ./{repo_dir}\n\n"
|
||||
"Constraints:\n"
|
||||
"- Use the terminal tool to inspect, edit, and verify the repository.\n"
|
||||
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
|
||||
"- Use a workspace-local virtualenv (.venv) to avoid cross-run contamination.\n"
|
||||
"- Use non-interactive commands only.\n"
|
||||
"- Prefer `. .venv/bin/activate` or `.venv/bin/python ...` (POSIX compatible).\n\n"
|
||||
f"Problem statement:\n{problem}\n\n"
|
||||
f"{context_block}"
|
||||
f"Run these tests to verify:\n{tests_list}\n\n"
|
||||
"When done, briefly describe what you changed and confirm tests pass."
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Test helpers
|
||||
# =========================================================================
|
||||
|
||||
def _tests_for_item(self, item: Item) -> List[str]:
|
||||
tests: List[str] = []
|
||||
if self.config.score_include_fail_to_pass:
|
||||
for key in ("PASS_TO_PASS", "FAIL_TO_PASS"):
|
||||
nodeids = item.get(key)
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
else:
|
||||
nodeids = item.get("PASS_TO_PASS")
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
return sorted(dict.fromkeys(tests))
|
||||
|
||||
def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]:
|
||||
return [nodeids[i : i + max_per_chunk] for i in range(0, len(nodeids), max_per_chunk)]
|
||||
|
||||
# =========================================================================
|
||||
# Sandbox hooks: setup_trajectory_workspace + verify_and_score_trajectory
|
||||
# =========================================================================
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self, item: Item, *, trajectory_id: str, exec_tool
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare a sandbox workspace: bare repo cache + git worktree.
|
||||
|
||||
Uses flock-serialized bare repo cache under /data/repo_cache so
|
||||
multiple trajectories sharing a sandbox don't clone the same repo
|
||||
in parallel. Each trajectory gets an isolated worktree at the
|
||||
specified base_commit.
|
||||
|
||||
Args:
|
||||
item: Dataset row with repo, base_commit, etc.
|
||||
trajectory_id: Unique trajectory ID
|
||||
exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult
|
||||
|
||||
Returns:
|
||||
Dict with repo_dir, base_commit metadata
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
t0 = _time.perf_counter()
|
||||
repo = item.get("repo")
|
||||
base_commit = item.get("base_commit")
|
||||
instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id")
|
||||
if not isinstance(repo, str) or not isinstance(base_commit, str):
|
||||
raise RuntimeError("Invalid dataset row: missing repo/base_commit")
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git"
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Bare repo cache + worktree strategy (same as atropos/envs/swe_smith_oracle_env.py)
|
||||
repo_slug = repo.replace("/", "__")
|
||||
cache_root = "/data/repo_cache"
|
||||
bare_repo = f"{cache_root}/{repo_slug}.git"
|
||||
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
|
||||
|
||||
worktree_cmd = (
|
||||
"set -e; "
|
||||
f"rm -rf {repo_dir}; "
|
||||
f"mkdir -p {cache_root}/.locks; "
|
||||
f": > {lock_file}; "
|
||||
f"flock -x {lock_file} sh -lc '"
|
||||
f"set -e; "
|
||||
"export GIT_TERMINAL_PROMPT=0; "
|
||||
"export GIT_LFS_SKIP_SMUDGE=1; "
|
||||
f"if [ ! -d \"{bare_repo}\" ]; then "
|
||||
f" git init --bare \"{bare_repo}\"; "
|
||||
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
|
||||
"fi; "
|
||||
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
|
||||
f"git -C \"{bare_repo}\" worktree prune || true; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
|
||||
"fi; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --prune origin; "
|
||||
"fi; "
|
||||
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
|
||||
"'"
|
||||
)
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
|
||||
res = await exec_tool(
|
||||
"bash",
|
||||
{"command": worktree_cmd},
|
||||
timeout=self.config.install_timeout_s,
|
||||
)
|
||||
if not res.success:
|
||||
raise RuntimeError(
|
||||
f"git worktree setup failed "
|
||||
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): "
|
||||
f"{res.error}\n{res.output}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"worktree ready in {_time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
result: AgentResult,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
In-sandbox verification: install deps + run pytest with dataset nodeids.
|
||||
|
||||
Args:
|
||||
item: Dataset row
|
||||
result: Agent's rollout result
|
||||
trajectory_id: Unique trajectory ID
|
||||
exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult
|
||||
workspace_meta: From setup_trajectory_workspace (has repo_dir)
|
||||
|
||||
Returns:
|
||||
(reward, metadata) tuple
|
||||
"""
|
||||
repo_dir = (workspace_meta or {}).get("repo_dir") or self._repo_name(item)
|
||||
|
||||
# Don't reward trajectories that never used tools
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
if tool_call_count == 0:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify: no tool calls; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {"error": "No tool calls were made by the agent"}
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
return 0.0, {"error": "No tests provided"}
|
||||
|
||||
# Install dependencies
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify: installing deps + running tests",
|
||||
flush=True,
|
||||
)
|
||||
setup_cmd = (
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest"
|
||||
)
|
||||
setup_res = await exec_tool(
|
||||
"bash", {"command": setup_cmd}, timeout=self.config.install_timeout_s
|
||||
)
|
||||
if not setup_res.success:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} install failed; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"phase": "install",
|
||||
"error": setup_res.error,
|
||||
"output": setup_res.output,
|
||||
}
|
||||
|
||||
# Run test chunks
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}"
|
||||
res = await exec_tool(
|
||||
"bash", {"command": cmd}, timeout=self.config.test_timeout_s
|
||||
)
|
||||
if not res.success:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} tests failed (chunk {chunk_idx}); score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"phase": "pytest",
|
||||
"failed_chunk": chunk_idx,
|
||||
"error": res.error,
|
||||
"output": res.output,
|
||||
}
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} all tests passed; score=1.0",
|
||||
flush=True,
|
||||
)
|
||||
return 1.0, {"passed": True}
|
||||
|
||||
# =========================================================================
|
||||
# Reward: run pytest in the terminal (local / non-sandbox path)
|
||||
# =========================================================================
|
||||
|
||||
async def compute_reward(
|
||||
self, item: Item, result: AgentResult, ctx: ToolContext
|
||||
) -> float:
|
||||
"""
|
||||
Verify by running pytest with the dataset's nodeids.
|
||||
|
||||
Reward structure (shaped to give training signal even when model can't solve tasks):
|
||||
- 0.0: No tool calls at all
|
||||
- 0.05: Per valid tool call (up to 0.3 max for tool-call shaping)
|
||||
- 0.4: Successfully installed deps
|
||||
- 1.0: All tests pass
|
||||
|
||||
The partial rewards for tool calls help the model learn to USE tools
|
||||
before it can learn to use them CORRECTLY. This is critical for cold-start
|
||||
training where the base model barely makes any tool calls.
|
||||
"""
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
# Count tool calls (assistant messages that have tool_calls).
|
||||
# NOTE: we keep scoring policy here intentionally simple and env-specific.
|
||||
# The agent loop exposes additional tool-call metrics (attempted/schema_valid/
|
||||
# executed_ok/exec_error) that other environments may choose to use for
|
||||
# reward shaping, but we don't hard-require any particular calling format here.
|
||||
tool_call_count = sum(
|
||||
len(msg.get("tool_calls", []))
|
||||
for msg in result.messages
|
||||
if msg.get("role") == "assistant"
|
||||
)
|
||||
|
||||
if tool_call_count == 0:
|
||||
print(f"[SweSmithOracleEnv] No tool calls made; score=0.0", flush=True)
|
||||
return 0.0
|
||||
|
||||
# Partial reward: 0.05 per tool call, capped at 0.3
|
||||
tool_call_reward = min(tool_call_count * 0.05, 0.3)
|
||||
|
||||
# Debug: log tool-call quality metrics if present
|
||||
attempted = getattr(result, "tool_calls_attempted", None)
|
||||
schema_valid = getattr(result, "tool_calls_schema_valid", None)
|
||||
executed_ok = getattr(result, "tool_calls_executed_ok", None)
|
||||
exec_error = getattr(result, "tool_calls_exec_error", None)
|
||||
if attempted is not None:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] Tool calls: total={tool_call_count}, attempted={attempted}, schema_valid={schema_valid}, ok={executed_ok}, err={exec_error}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
# No tests defined — just reward tool usage
|
||||
print(f"[SweSmithOracleEnv] No tests defined; score={tool_call_reward:.2f} (tool calls)", flush=True)
|
||||
return tool_call_reward
|
||||
|
||||
# Install deps + run tests
|
||||
print(f"[SweSmithOracleEnv] Verifying: installing deps + running tests", flush=True)
|
||||
setup_result = ctx.terminal(
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest",
|
||||
timeout=int(self.config.install_timeout_s),
|
||||
)
|
||||
if setup_result.get("exit_code", 1) != 0:
|
||||
print(f"[SweSmithOracleEnv] Install failed; score={tool_call_reward:.2f} (tool calls only)", flush=True)
|
||||
return tool_call_reward
|
||||
|
||||
# Partial reward for successful install
|
||||
install_reward = 0.4
|
||||
|
||||
# Run test chunks
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
test_result = ctx.terminal(
|
||||
f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}",
|
||||
timeout=int(self.config.test_timeout_s),
|
||||
)
|
||||
if test_result.get("exit_code", 1) != 0:
|
||||
print(f"[SweSmithOracleEnv] Tests failed (chunk {chunk_idx}); score={install_reward:.2f} (install ok)", flush=True)
|
||||
return install_reward
|
||||
|
||||
print(f"[SweSmithOracleEnv] All tests passed; score=1.0", flush=True)
|
||||
return 1.0
|
||||
|
||||
# =========================================================================
|
||||
# Token truncation — keep start of trajectory, truncate from end
|
||||
# =========================================================================
|
||||
|
||||
def _build_scored_item(self, item, result, reward):
|
||||
"""
|
||||
Override to truncate tokens/masks from the END to fit within max_token_len.
|
||||
|
||||
Intuition (from NeurIPS finding): the start of the trajectory is most important
|
||||
for shifting the model distribution. Truncating from the end only costs ~2-3%
|
||||
vs handling the full sequence, but avoids the "Token length is too long" discard
|
||||
that throws away entire groups including valid training signal.
|
||||
"""
|
||||
scored_item, remaining = super()._build_scored_item(item, result, reward)
|
||||
if scored_item is None:
|
||||
return scored_item, remaining
|
||||
|
||||
# Use config.max_token_length as the truncation limit.
|
||||
# self.max_token_len comes from the trainer via /info, but may be -1
|
||||
# if the trainer hasn't registered yet (race condition).
|
||||
max_len = self.max_token_len
|
||||
if max_len <= 0:
|
||||
# Fallback to config value
|
||||
max_len = getattr(self.config, 'max_token_length', 0)
|
||||
if max_len <= 0:
|
||||
return scored_item, remaining
|
||||
|
||||
# Leave some margin (64 tokens) to avoid edge cases with padding alignment
|
||||
truncate_to = max_len - 64
|
||||
|
||||
tokens = scored_item.get("tokens")
|
||||
masks = scored_item.get("masks")
|
||||
|
||||
if tokens is not None and len(tokens) >= max_len:
|
||||
orig_len = len(tokens)
|
||||
scored_item["tokens"] = tokens[:truncate_to]
|
||||
if masks is not None and len(masks) >= max_len:
|
||||
scored_item["masks"] = masks[:truncate_to]
|
||||
logger.info(
|
||||
"Truncated trajectory from %d to %d tokens (max_token_len=%d)",
|
||||
orig_len, truncate_to, max_len,
|
||||
)
|
||||
|
||||
return scored_item, remaining
|
||||
|
||||
# =========================================================================
|
||||
# Evaluation (minimal for now)
|
||||
# =========================================================================
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""Placeholder evaluation — SWE tasks are too expensive for frequent eval."""
|
||||
start_time = time.time()
|
||||
await self.evaluate_log(
|
||||
metrics={"eval/placeholder": 0.0},
|
||||
samples=[],
|
||||
start_time=start_time,
|
||||
end_time=time.time(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SweSmithOracleEnv.cli()
|
||||
+1
-1
@@ -36,7 +36,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
@@ -129,14 +129,11 @@ class ToolContext:
|
||||
|
||||
def write_file(self, path: str, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Write a TEXT file in the rollout's filesystem.
|
||||
|
||||
Uses a shell heredoc under the hood, so this is only safe for text content.
|
||||
For binary files (images, compiled artifacts, etc.), use upload_file() instead.
|
||||
Write a file in the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
path: File path to write
|
||||
content: Text content to write
|
||||
content: Content to write
|
||||
|
||||
Returns:
|
||||
Dict with success status or error
|
||||
@@ -149,177 +146,6 @@ class ToolContext:
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def upload_file(self, local_path: str, remote_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Upload a local file to the rollout's sandbox (binary-safe).
|
||||
|
||||
Unlike write_file() which passes content through a shell heredoc (text-only),
|
||||
this method base64-encodes the file and decodes it inside the sandbox.
|
||||
Safe for any file type: binaries, images, archives, etc.
|
||||
|
||||
For large files (>1MB), the content is split into chunks to avoid
|
||||
hitting shell command-length limits.
|
||||
|
||||
Args:
|
||||
local_path: Path to a local file on the host
|
||||
remote_path: Destination path inside the sandbox
|
||||
|
||||
Returns:
|
||||
Dict with 'exit_code' and 'output'
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path as _Path
|
||||
|
||||
local = _Path(local_path)
|
||||
if not local.exists():
|
||||
return {"exit_code": -1, "output": f"Local file not found: {local_path}"}
|
||||
|
||||
raw = local.read_bytes()
|
||||
b64 = base64.b64encode(raw).decode("ascii")
|
||||
|
||||
# Ensure parent directory exists in the sandbox
|
||||
parent = str(_Path(remote_path).parent)
|
||||
if parent not in (".", "/"):
|
||||
self.terminal(f"mkdir -p {parent}", timeout=10)
|
||||
|
||||
# For small files, single command is fine
|
||||
chunk_size = 60_000 # ~60KB per chunk (well within shell limits)
|
||||
if len(b64) <= chunk_size:
|
||||
result = self.terminal(
|
||||
f"printf '%s' '{b64}' | base64 -d > {remote_path}",
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
# For larger files, write base64 in chunks then decode
|
||||
tmp_b64 = "/tmp/_hermes_upload.b64"
|
||||
self.terminal(f": > {tmp_b64}", timeout=5) # truncate
|
||||
for i in range(0, len(b64), chunk_size):
|
||||
chunk = b64[i : i + chunk_size]
|
||||
self.terminal(f"printf '%s' '{chunk}' >> {tmp_b64}", timeout=15)
|
||||
result = self.terminal(
|
||||
f"base64 -d {tmp_b64} > {remote_path} && rm -f {tmp_b64}",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def upload_dir(self, local_dir: str, remote_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Upload an entire local directory to the rollout's sandbox (binary-safe).
|
||||
|
||||
Recursively uploads all files, preserving directory structure.
|
||||
|
||||
Args:
|
||||
local_dir: Path to a local directory on the host
|
||||
remote_dir: Destination directory inside the sandbox
|
||||
|
||||
Returns:
|
||||
List of results, one per file uploaded
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
local = _Path(local_dir)
|
||||
if not local.exists() or not local.is_dir():
|
||||
return [{"exit_code": -1, "output": f"Local directory not found: {local_dir}"}]
|
||||
|
||||
results = []
|
||||
for file_path in sorted(local.rglob("*")):
|
||||
if file_path.is_file():
|
||||
relative = file_path.relative_to(local)
|
||||
target = f"{remote_dir}/{relative}"
|
||||
results.append(self.upload_file(str(file_path), target))
|
||||
return results
|
||||
|
||||
def download_file(self, remote_path: str, local_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Download a file from the rollout's sandbox to the host (binary-safe).
|
||||
|
||||
The inverse of upload_file(). Base64-encodes the file inside the sandbox,
|
||||
reads the encoded data through the terminal, and decodes it locally.
|
||||
Safe for any file type.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file inside the sandbox
|
||||
local_path: Destination path on the host
|
||||
|
||||
Returns:
|
||||
Dict with 'success' (bool) and 'bytes' (int) or 'error' (str)
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# Base64-encode the file inside the sandbox and capture output
|
||||
result = self.terminal(
|
||||
f"base64 {remote_path} 2>/dev/null",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if result.get("exit_code", -1) != 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to read remote file: {result.get('output', '')}",
|
||||
}
|
||||
|
||||
b64_data = result.get("output", "").strip()
|
||||
if not b64_data:
|
||||
return {"success": False, "error": f"Remote file is empty or missing: {remote_path}"}
|
||||
|
||||
try:
|
||||
raw = base64.b64decode(b64_data)
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"Base64 decode failed: {e}"}
|
||||
|
||||
# Write to local host filesystem
|
||||
local = _Path(local_path)
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_bytes(raw)
|
||||
|
||||
return {"success": True, "bytes": len(raw)}
|
||||
|
||||
def download_dir(self, remote_dir: str, local_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Download a directory from the rollout's sandbox to the host (binary-safe).
|
||||
|
||||
Lists all files in the remote directory, then downloads each one.
|
||||
Preserves directory structure.
|
||||
|
||||
Args:
|
||||
remote_dir: Path to the directory inside the sandbox
|
||||
local_dir: Destination directory on the host
|
||||
|
||||
Returns:
|
||||
List of results, one per file downloaded
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# List files in the remote directory
|
||||
ls_result = self.terminal(
|
||||
f"find {remote_dir} -type f 2>/dev/null",
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if ls_result.get("exit_code", -1) != 0:
|
||||
return [{"success": False, "error": f"Failed to list remote dir: {remote_dir}"}]
|
||||
|
||||
file_list = ls_result.get("output", "").strip()
|
||||
if not file_list:
|
||||
return [{"success": False, "error": f"Remote directory is empty or missing: {remote_dir}"}]
|
||||
|
||||
results = []
|
||||
for remote_file in file_list.splitlines():
|
||||
remote_file = remote_file.strip()
|
||||
if not remote_file:
|
||||
continue
|
||||
# Compute the relative path to preserve directory structure
|
||||
if remote_file.startswith(remote_dir):
|
||||
relative = remote_file[len(remote_dir):].lstrip("/")
|
||||
else:
|
||||
relative = _Path(remote_file).name
|
||||
local_file = str(_Path(local_dir) / relative)
|
||||
results.append(self.download_file(remote_file, local_file))
|
||||
|
||||
return results
|
||||
|
||||
def search(self, query: str, path: str = ".") -> Dict[str, Any]:
|
||||
"""
|
||||
Search for text in the rollout's filesystem.
|
||||
|
||||
+15
-166
@@ -6,11 +6,10 @@ and implement the required methods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
@@ -178,123 +177,6 @@ class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an image natively via the platform API.
|
||||
|
||||
Override in subclasses to send images as proper attachments
|
||||
instead of plain-text URLs. Default falls back to sending the
|
||||
URL as a text message.
|
||||
"""
|
||||
# Fallback: send URL as text (subclasses override for native images)
|
||||
text = f"{caption}\n{image_url}" if caption else image_url
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
@staticmethod
|
||||
def extract_images(content: str) -> Tuple[List[Tuple[str, str]], str]:
|
||||
"""
|
||||
Extract image URLs from markdown and HTML image tags in a response.
|
||||
|
||||
Finds patterns like:
|
||||
- 
|
||||
- <img src="https://example.com/image.png">
|
||||
- <img src="https://example.com/image.png"></img>
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (url, alt_text) pairs, cleaned content with image tags removed).
|
||||
"""
|
||||
images = []
|
||||
cleaned = content
|
||||
|
||||
# Match markdown images: 
|
||||
md_pattern = r'!\[([^\]]*)\]\((https?://[^\s\)]+)\)'
|
||||
for match in re.finditer(md_pattern, content):
|
||||
alt_text = match.group(1)
|
||||
url = match.group(2)
|
||||
# Only extract URLs that look like actual images
|
||||
if any(url.lower().endswith(ext) or ext in url.lower() for ext in
|
||||
['.png', '.jpg', '.jpeg', '.gif', '.webp', 'fal.media', 'fal-cdn', 'replicate.delivery']):
|
||||
images.append((url, alt_text))
|
||||
|
||||
# Match HTML img tags: <img src="url"> or <img src="url"></img> or <img src="url"/>
|
||||
html_pattern = r'<img\s+src=["\']?(https?://[^\s"\'<>]+)["\']?\s*/?>\s*(?:</img>)?'
|
||||
for match in re.finditer(html_pattern, content):
|
||||
url = match.group(1)
|
||||
images.append((url, ""))
|
||||
|
||||
# Remove matched image tags from content if we found images
|
||||
if images:
|
||||
cleaned = re.sub(md_pattern, '', cleaned)
|
||||
cleaned = re.sub(html_pattern, '', cleaned)
|
||||
# Clean up leftover blank lines
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
return images, cleaned
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an audio file as a native voice message via the platform API.
|
||||
|
||||
Override in subclasses to send audio as voice bubbles (Telegram)
|
||||
or file attachments (Discord). Default falls back to sending the
|
||||
file path as text.
|
||||
"""
|
||||
text = f"🔊 Audio: {audio_path}"
|
||||
if caption:
|
||||
text = f"{caption}\n{text}"
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
@staticmethod
|
||||
def extract_media(content: str) -> Tuple[List[Tuple[str, bool]], str]:
|
||||
"""
|
||||
Extract MEDIA:<path> tags and [[audio_as_voice]] directives from response text.
|
||||
|
||||
The TTS tool returns responses like:
|
||||
[[audio_as_voice]]
|
||||
MEDIA:/path/to/audio.ogg
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (path, is_voice) pairs, cleaned content with tags removed).
|
||||
"""
|
||||
media = []
|
||||
cleaned = content
|
||||
|
||||
# Check for [[audio_as_voice]] directive
|
||||
has_voice_tag = "[[audio_as_voice]]" in content
|
||||
cleaned = cleaned.replace("[[audio_as_voice]]", "")
|
||||
|
||||
# Extract MEDIA:<path> tags (path may contain spaces)
|
||||
media_pattern = r'MEDIA:(\S+)'
|
||||
for match in re.finditer(media_pattern, content):
|
||||
path = match.group(1).strip()
|
||||
if path:
|
||||
media.append((path, has_voice_tag))
|
||||
|
||||
# Remove MEDIA tags from content
|
||||
if media:
|
||||
cleaned = re.sub(media_pattern, '', cleaned)
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
return media, cleaned
|
||||
|
||||
async def _keep_typing(self, chat_id: str, interval: float = 2.0) -> None:
|
||||
"""
|
||||
Continuously send typing indicator until cancelled.
|
||||
@@ -349,56 +231,23 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
# Send response if any
|
||||
if response:
|
||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||
media_files, response = self.extract_media(response)
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=response,
|
||||
reply_to=event.message_id
|
||||
)
|
||||
|
||||
# Extract image URLs and send them as native platform attachments
|
||||
images, text_content = self.extract_images(response)
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
result = await self.send(
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=text_content,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{response[:3500]}",
|
||||
reply_to=event.message_id
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Send extracted images as native attachments
|
||||
for image_url, alt_text in images:
|
||||
try:
|
||||
img_result = await self.send_image(
|
||||
chat_id=event.source.chat_id,
|
||||
image_url=image_url,
|
||||
caption=alt_text if alt_text else None,
|
||||
)
|
||||
if not img_result.success:
|
||||
print(f"[{self.name}] Failed to send image: {img_result.error}")
|
||||
except Exception as img_err:
|
||||
print(f"[{self.name}] Error sending image: {img_err}")
|
||||
|
||||
# Send extracted audio/voice files as native attachments
|
||||
for audio_path, is_voice in media_files:
|
||||
try:
|
||||
voice_result = await self.send_voice(
|
||||
chat_id=event.source.chat_id,
|
||||
audio_path=audio_path,
|
||||
)
|
||||
if not voice_result.success:
|
||||
print(f"[{self.name}] Failed to send voice: {voice_result.error}")
|
||||
except Exception as voice_err:
|
||||
print(f"[{self.name}] Error sending voice: {voice_err}")
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
@@ -437,7 +286,7 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
|
||||
"""Get and clear any pending message for a session."""
|
||||
return self._pending_messages.pop(session_key, None)
|
||||
return self._pending_messages.get(session_key)
|
||||
|
||||
def build_source(
|
||||
self,
|
||||
|
||||
@@ -8,7 +8,6 @@ Uses discord.py library for:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
try:
|
||||
@@ -174,99 +173,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
# Determine filename from path
|
||||
filename = os.path.basename(audio_path)
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Download the image and send as a Discord file attachment
|
||||
# (Discord renders attachments inline, unlike plain URLs)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to download image: HTTP {resp.status}")
|
||||
|
||||
image_data = await resp.read()
|
||||
|
||||
# Determine filename from URL or content type
|
||||
content_type = resp.headers.get("content-type", "image/png")
|
||||
ext = "png"
|
||||
if "jpeg" in content_type or "jpg" in content_type:
|
||||
ext = "jpg"
|
||||
elif "gif" in content_type:
|
||||
ext = "gif"
|
||||
elif "webp" in content_type:
|
||||
ext = "webp"
|
||||
|
||||
import io
|
||||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, falling back to URL. Run: pip install aiohttp")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send image attachment, falling back to URL: {e}")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._client:
|
||||
@@ -326,36 +232,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
# UNLESS the channel is in the free-response list.
|
||||
#
|
||||
# Config:
|
||||
# DISCORD_FREE_RESPONSE_CHANNELS: Comma-separated channel IDs where the
|
||||
# bot responds to every message without needing a mention.
|
||||
# DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement
|
||||
# globally (all channels become free-response). Default: "true".
|
||||
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
# Check if this channel is in the free-response list
|
||||
free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
|
||||
channel_id = str(message.channel.id)
|
||||
|
||||
# Global override: if DISCORD_REQUIRE_MENTION=false, all channels are free
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
|
||||
is_free_channel = channel_id in free_channels
|
||||
|
||||
if require_mention and not is_free_channel:
|
||||
# Must be @mentioned to respond
|
||||
if self._client.user not in message.mentions:
|
||||
return # Silently ignore messages that don't mention the bot
|
||||
|
||||
# Strip the bot mention from the message text so the agent sees clean input
|
||||
if self._client.user and self._client.user in message.mentions:
|
||||
message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip()
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if message.content.startswith("/"):
|
||||
|
||||
@@ -174,69 +174,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a native Telegram voice message or audio file."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
# .ogg files -> send as voice (round playable bubble)
|
||||
if audio_path.endswith(".ogg") or audio_path.endswith(".opus"):
|
||||
msg = await self._bot.send_voice(
|
||||
chat_id=int(chat_id),
|
||||
voice=audio_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
else:
|
||||
# .mp3 and others -> send as audio file
|
||||
msg = await self._bot.send_audio(
|
||||
chat_id=int(chat_id),
|
||||
audio=audio_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send voice/audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_url,
|
||||
caption=caption[:1024] if caption else None, # Telegram caption limit
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send photo, falling back to URL: {e}")
|
||||
# Fallback: send as text link
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._bot:
|
||||
|
||||
+18
-103
@@ -35,9 +35,6 @@ load_dotenv()
|
||||
# Gateway runs in quiet mode - suppress debug output and use cwd directly (no temp dirs)
|
||||
os.environ["HERMES_QUIET"] = "1"
|
||||
|
||||
# Enable interactive exec approval for dangerous commands on messaging platforms
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
|
||||
# Set terminal working directory for messaging platforms
|
||||
# Uses MESSAGING_CWD if set, otherwise defaults to home directory
|
||||
# This is separate from CLI which uses the directory where `hermes` is run
|
||||
@@ -80,10 +77,6 @@ class GatewayRunner:
|
||||
# Key: session_key, Value: AIAgent instance
|
||||
self._running_agents: Dict[str, Any] = {}
|
||||
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||
|
||||
# Track pending exec approvals per session
|
||||
# Key: session_key, Value: {"command": str, "pattern_key": str}
|
||||
self._pending_approvals: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""
|
||||
@@ -253,25 +246,6 @@ class GatewayRunner:
|
||||
if command == "stop":
|
||||
return await self._handle_stop_command(event)
|
||||
|
||||
# Check for pending exec approval responses
|
||||
session_key_preview = f"agent:main:{source.platform.value}:{source.chat_type}:{source.chat_id}" if source.chat_type != "dm" else f"agent:main:{source.platform.value}:dm"
|
||||
if session_key_preview in self._pending_approvals:
|
||||
user_text = event.text.strip().lower()
|
||||
if user_text in ("yes", "y", "approve", "ok", "go", "do it"):
|
||||
approval = self._pending_approvals.pop(session_key_preview)
|
||||
cmd = approval["command"]
|
||||
pattern_key = approval.get("pattern_key", "")
|
||||
print(f"[gateway] ✅ User approved dangerous command: {cmd[:60]}...")
|
||||
# Approve for session and re-run via terminal_tool with force=True
|
||||
from tools.terminal_tool import terminal_tool, _session_approved_patterns
|
||||
_session_approved_patterns.add(pattern_key)
|
||||
result = terminal_tool(command=cmd, force=True)
|
||||
return f"✅ Command approved and executed.\n\n```\n{result[:3500]}\n```"
|
||||
elif user_text in ("no", "n", "deny", "cancel", "nope"):
|
||||
self._pending_approvals.pop(session_key_preview)
|
||||
return "❌ Command denied."
|
||||
# If it's not clearly an approval/denial, fall through to normal processing
|
||||
|
||||
# Get or create session
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
@@ -308,17 +282,6 @@ class GatewayRunner:
|
||||
session_key=session_key
|
||||
)
|
||||
|
||||
# Check if the agent encountered a dangerous command needing approval
|
||||
# The terminal tool stores the last pending approval globally
|
||||
try:
|
||||
from tools.terminal_tool import _last_pending_approval
|
||||
if _last_pending_approval:
|
||||
self._pending_approvals[session_key] = _last_pending_approval.copy()
|
||||
# Clear the global so it doesn't leak to other sessions
|
||||
_last_pending_approval.clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Append to transcript
|
||||
self.session_store.append_to_transcript(
|
||||
session_entry.session_id,
|
||||
@@ -455,35 +418,23 @@ class GatewayRunner:
|
||||
return
|
||||
last_tool[0] = tool_name
|
||||
|
||||
# Build progress message with primary argument preview
|
||||
# Build progress message
|
||||
tool_emojis = {
|
||||
"terminal": "💻",
|
||||
"web_search": "🔍",
|
||||
"web_extract": "📄",
|
||||
"read_file": "📖",
|
||||
"write_file": "✍️",
|
||||
"patch": "🔧",
|
||||
"search": "🔎",
|
||||
"list_directory": "📂",
|
||||
"image_generate": "🎨",
|
||||
"text_to_speech": "🔊",
|
||||
"browser_navigate": "🌐",
|
||||
"browser_click": "👆",
|
||||
"browser_type": "⌨️",
|
||||
"browser_snapshot": "📸",
|
||||
"moa_query": "🧠",
|
||||
"mixture_of_agents": "🧠",
|
||||
"vision_analyze": "👁️",
|
||||
"skill_view": "📚",
|
||||
"skills_list": "📋",
|
||||
}
|
||||
emoji = tool_emojis.get(tool_name, "⚙️")
|
||||
|
||||
if preview:
|
||||
# Truncate preview to keep messages clean
|
||||
if len(preview) > 40:
|
||||
preview = preview[:37] + "..."
|
||||
msg = f"{emoji} {tool_name}... \"{preview}\""
|
||||
if tool_name == "terminal" and preview:
|
||||
msg = f"{emoji} `{preview}`..."
|
||||
else:
|
||||
msg = f"{emoji} {tool_name}..."
|
||||
|
||||
@@ -529,10 +480,6 @@ class GatewayRunner:
|
||||
# Read from env var or use default (same as CLI)
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
||||
|
||||
# Map platform enum to the platform hint key the agent understands.
|
||||
# Platform.LOCAL ("local") maps to "cli"; others pass through as-is.
|
||||
platform_key = "cli" if source.platform == Platform.LOCAL else source.platform.value
|
||||
|
||||
agent = AIAgent(
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-opus-4.6"),
|
||||
max_iterations=max_iterations,
|
||||
@@ -541,42 +488,19 @@ class GatewayRunner:
|
||||
ephemeral_system_prompt=context_prompt,
|
||||
session_id=session_id,
|
||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||
platform=platform_key, # Tells the agent which interface to format for
|
||||
)
|
||||
|
||||
# Store agent reference for interrupt support
|
||||
agent_holder[0] = agent
|
||||
|
||||
# Convert history to agent format.
|
||||
# Two cases:
|
||||
# 1. Normal path (from transcript): simple {role, content, timestamp} dicts
|
||||
# - Strip timestamps, keep role+content
|
||||
# 2. Interrupt path (from agent result["messages"]): full agent messages
|
||||
# that may include tool_calls, tool_call_id, reasoning, etc.
|
||||
# - These must be passed through intact so the API sees valid
|
||||
# assistant→tool sequences (dropping tool_calls causes 500 errors)
|
||||
# Convert transcript history to agent format
|
||||
# Transcript has timestamps; agent expects {"role": ..., "content": ...}
|
||||
agent_history = []
|
||||
for msg in history:
|
||||
role = msg.get("role")
|
||||
if not role:
|
||||
continue
|
||||
|
||||
# Check if this is a rich agent message (has tool_calls or tool_call_id)
|
||||
# If so, pass it through with full structure intact
|
||||
has_tool_calls = "tool_calls" in msg
|
||||
has_tool_call_id = "tool_call_id" in msg
|
||||
is_tool_message = role == "tool"
|
||||
|
||||
if has_tool_calls or has_tool_call_id or is_tool_message:
|
||||
# Preserve full message structure (tool_calls, tool_call_id, etc.)
|
||||
# Only strip fields that are purely internal (e.g. timestamp)
|
||||
clean_msg = {k: v for k, v in msg.items() if k != "timestamp"}
|
||||
agent_history.append(clean_msg)
|
||||
else:
|
||||
# Simple text message - just need role and content
|
||||
content = msg.get("content")
|
||||
if content:
|
||||
agent_history.append({"role": role, "content": content})
|
||||
content = msg.get("content")
|
||||
if role and content:
|
||||
agent_history.append({"role": role, "content": content})
|
||||
|
||||
result = agent.run_conversation(message, conversation_history=agent_history)
|
||||
result_holder[0] = result
|
||||
@@ -648,16 +572,13 @@ class GatewayRunner:
|
||||
|
||||
if pending:
|
||||
print(f"[gateway] 📨 Processing interrupted message: '{pending[:40]}...'")
|
||||
# Add an indicator to the response
|
||||
if response:
|
||||
response = response + "\n\n---\n_[Interrupted - processing your new message]_"
|
||||
|
||||
# Clear the adapter's interrupt event so the next _run_agent call
|
||||
# doesn't immediately re-trigger the interrupt before the new agent
|
||||
# even makes its first API call (this was causing an infinite loop).
|
||||
if adapter and hasattr(adapter, '_active_sessions') and source.chat_id in adapter._active_sessions:
|
||||
adapter._active_sessions[source.chat_id].clear()
|
||||
|
||||
# Don't send the interrupted response to the user — it's just noise
|
||||
# like "Operation interrupted." They already know they sent a new
|
||||
# message, so go straight to processing it.
|
||||
# Send the interrupted response first
|
||||
if adapter and response:
|
||||
await adapter.send(chat_id=source.chat_id, content=response)
|
||||
|
||||
# Now process the pending message with updated history
|
||||
updated_history = result.get("messages", history)
|
||||
@@ -691,13 +612,11 @@ class GatewayRunner:
|
||||
return response
|
||||
|
||||
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None) -> bool:
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None) -> None:
|
||||
"""
|
||||
Start the gateway and run until interrupted.
|
||||
|
||||
This is the main entry point for running the gateway.
|
||||
Returns True if the gateway ran successfully, False if it failed to start.
|
||||
A False return causes a non-zero exit code so systemd can auto-restart.
|
||||
"""
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
@@ -716,11 +635,10 @@ async def start_gateway(config: Optional[GatewayConfig] = None) -> bool:
|
||||
# Start the gateway
|
||||
success = await runner.start()
|
||||
if not success:
|
||||
return False
|
||||
return
|
||||
|
||||
# Wait for shutdown
|
||||
await runner.wait_for_shutdown()
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
@@ -740,11 +658,8 @@ def main():
|
||||
data = json.load(f)
|
||||
config = GatewayConfig.from_dict(data)
|
||||
|
||||
# Run the gateway - exit with code 1 if no platforms connected,
|
||||
# so systemd Restart=on-failure will retry on transient errors (e.g. DNS)
|
||||
success = asyncio.run(start_gateway(config))
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
# Run the gateway
|
||||
asyncio.run(start_gateway(config))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,6 +7,40 @@ Usage: ./hermes [options]
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Fire (google/python-fire) does not support POSIX-style short flags like `-p`.
|
||||
We translate the most common shorthands to their long equivalents so wrapper
|
||||
scripts can reliably use:
|
||||
- `-p "..."` -> `--prompt "..."` (no TUI/banner; print result and exit)
|
||||
- `-q "..."` -> `--query "..."` (single-shot with banner UX)
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
def _rewrite_short_flags(argv: list[str]) -> list[str]:
|
||||
rewritten: list[str] = []
|
||||
i = 0
|
||||
while i < len(argv):
|
||||
arg = argv[i]
|
||||
if arg == "-p":
|
||||
rewritten.append("--prompt")
|
||||
if i + 1 < len(argv):
|
||||
rewritten.append(argv[i + 1])
|
||||
i += 2
|
||||
continue
|
||||
if arg == "-q":
|
||||
rewritten.append("--query")
|
||||
if i + 1 < len(argv):
|
||||
rewritten.append(argv[i + 1])
|
||||
i += 2
|
||||
continue
|
||||
rewritten.append(arg)
|
||||
i += 1
|
||||
return rewritten
|
||||
|
||||
sys.argv = [sys.argv[0]] + _rewrite_short_flags(sys.argv[1:])
|
||||
|
||||
from cli import main
|
||||
import fire
|
||||
|
||||
fire.Fire(main)
|
||||
|
||||
@@ -0,0 +1,659 @@
|
||||
Metadata-Version: 2.4
|
||||
Name: hermes-agent
|
||||
Version: 0.1.0
|
||||
Summary: AI agent with advanced tool-calling and toolsets
|
||||
Author: Nous Research
|
||||
License: MIT
|
||||
Requires-Python: >=3.10
|
||||
Description-Content-Type: text/markdown
|
||||
Requires-Dist: openai
|
||||
Requires-Dist: python-dotenv
|
||||
Requires-Dist: fire
|
||||
Requires-Dist: httpx
|
||||
Requires-Dist: rich
|
||||
Requires-Dist: tenacity
|
||||
Requires-Dist: pyyaml
|
||||
Requires-Dist: prompt_toolkit
|
||||
Requires-Dist: requests
|
||||
Requires-Dist: jinja2
|
||||
Requires-Dist: pydantic>=2.0
|
||||
Requires-Dist: firecrawl-py
|
||||
Requires-Dist: fal-client
|
||||
Requires-Dist: litellm>=1.75.5
|
||||
Requires-Dist: typer
|
||||
Requires-Dist: platformdirs
|
||||
Provides-Extra: modal
|
||||
Requires-Dist: modal; extra == "modal"
|
||||
Requires-Dist: boto3; extra == "modal"
|
||||
Provides-Extra: dev
|
||||
Requires-Dist: pytest; extra == "dev"
|
||||
Requires-Dist: pytest-asyncio; extra == "dev"
|
||||
Provides-Extra: atropos
|
||||
Requires-Dist: atroposlib @ git+https://github.com/NousResearch/atropos.git ; extra == "atropos"
|
||||
Requires-Dist: aiohttp; extra == "atropos"
|
||||
Requires-Dist: fastapi; extra == "atropos"
|
||||
Requires-Dist: uvicorn; extra == "atropos"
|
||||
Requires-Dist: pyte; extra == "atropos"
|
||||
|
||||
# Hermes Agent
|
||||
|
||||
An AI agent with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools.
|
||||
|
||||
## Features
|
||||
|
||||
- **Interactive CLI**: Beautiful terminal interface with animated feedback, personalities, and session management
|
||||
- **Web Tools**: Search, extract content, and crawl websites
|
||||
- **Terminal Tools**: Execute commands via local, Docker, Singularity, Modal, or SSH backends
|
||||
- **Browser Tools**: Automate web browsers to navigate, click, type, and extract content
|
||||
- **Vision Tools**: Analyze images from URLs
|
||||
- **Reasoning Tools**: Advanced multi-model reasoning (Mixture of Agents)
|
||||
- **Creative Tools**: Generate images from text prompts
|
||||
- **Skills Tools**: On-demand knowledge documents with progressive disclosure
|
||||
- **Toolsets System**: Organize tools into logical groups for different scenarios
|
||||
- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking
|
||||
- **Ephemeral System Prompts**: Guide model behavior without polluting training datasets
|
||||
|
||||
## Quick Start (CLI)
|
||||
|
||||
```bash
|
||||
# After setup (see below), just run:
|
||||
./hermes
|
||||
|
||||
# Or with options:
|
||||
./hermes --model "anthropic/claude-sonnet-4" --toolsets "web,terminal"
|
||||
```
|
||||
|
||||
The CLI provides:
|
||||
- Animated spinners during thinking and tool execution
|
||||
- Kawaii-style feedback messages
|
||||
- `/commands` for configuration, history, and session management
|
||||
- Customizable personalities (`/personality kawaii`, `/personality pirate`, etc.)
|
||||
- Persistent configuration via `cli-config.yaml`
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Clone the Repository
|
||||
```bash
|
||||
# Clone with submodules (recommended)
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Or if already cloned without submodules:
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
### 2. Install Dependencies
|
||||
```bash
|
||||
# Create and activate virtual environment (recommended)
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
|
||||
# Install Python packages
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install mini-swe-agent for terminal tools
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Install Node.js dependencies for browser tools (requires Node.js)
|
||||
npm install
|
||||
```
|
||||
|
||||
### 3. Configure Environment Variables
|
||||
```bash
|
||||
# Copy the example environment file
|
||||
cp .env.example .env
|
||||
|
||||
# Edit .env and add your API keys
|
||||
nano .env # or use your preferred editor
|
||||
```
|
||||
|
||||
**Required API Keys:**
|
||||
- `OPENROUTER_API_KEY` - LLM access via OpenRouter (get at: https://openrouter.ai/keys)
|
||||
- `FIRECRAWL_API_KEY` - Web tools (get at: https://firecrawl.dev/)
|
||||
- `NOUS_API_KEY` - Vision & reasoning tools (get at: https://inference-api.nousresearch.com/)
|
||||
- `FAL_KEY` - Image generation (get at: https://fal.ai/)
|
||||
|
||||
**Optional API Keys (for specific features):**
|
||||
- `BROWSERBASE_API_KEY` - Browser automation (get at: https://browserbase.com/)
|
||||
- `BROWSERBASE_PROJECT_ID` - From Browserbase dashboard
|
||||
- `MORPH_API_KEY` - For legacy Hecate terminal backend (get at: https://morph.so/)
|
||||
|
||||
### 4. Configure Terminal Backend
|
||||
|
||||
The terminal tool uses **mini-swe-agent** environments. Configure in `.env` or `cli-config.yaml`:
|
||||
|
||||
```bash
|
||||
# Backend: "local", "docker", "singularity", "modal", or "ssh"
|
||||
TERMINAL_ENV=local # Default: runs on host machine (no isolation)
|
||||
TERMINAL_ENV=ssh # Remote execution via SSH (agent code stays local)
|
||||
TERMINAL_ENV=singularity # Recommended for HPC: Apptainer/Singularity containers
|
||||
TERMINAL_ENV=docker # Isolated Docker containers
|
||||
TERMINAL_ENV=modal # Cloud execution via Modal
|
||||
|
||||
# Container image (for docker/singularity/modal backends)
|
||||
TERMINAL_DOCKER_IMAGE=python:3.11-slim
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://python:3.11-slim
|
||||
TERMINAL_TIMEOUT=60
|
||||
|
||||
# SSH backend (for ssh)
|
||||
TERMINAL_SSH_HOST=my-server.example.com
|
||||
TERMINAL_SSH_USER=myuser
|
||||
TERMINAL_SSH_KEY=~/.ssh/id_rsa # Optional, uses ssh-agent if not set
|
||||
```
|
||||
|
||||
**Backend Requirements:**
|
||||
- **local**: No extra setup (runs directly on your machine, no isolation)
|
||||
- **ssh**: SSH access to remote machine (great for sandboxing - agent can't touch its own code)
|
||||
- **singularity**: Requires Apptainer or Singularity installed (common on HPC clusters, no root needed)
|
||||
- **docker**: Requires Docker installed and user in `docker` group
|
||||
- **modal**: Requires Modal account (see setup below)
|
||||
|
||||
### Singularity/Apptainer Setup (Recommended for HPC)
|
||||
|
||||
Singularity/Apptainer provides rootless container execution, ideal for HPC clusters:
|
||||
|
||||
```bash
|
||||
# 1. Verify Apptainer is installed
|
||||
apptainer --version # or: singularity --version
|
||||
|
||||
# 2. Set up cache directories (important for parallel workers)
|
||||
# Use /scratch if available (HPC), otherwise /tmp
|
||||
export APPTAINER_CACHEDIR=/scratch/$USER/.apptainer
|
||||
export APPTAINER_TMPDIR=/scratch/$USER/.apptainer/tmp
|
||||
mkdir -p "$APPTAINER_CACHEDIR" "$APPTAINER_TMPDIR"
|
||||
|
||||
# 3. Pre-build SIF image (recommended for parallel batch processing)
|
||||
# This avoids race conditions when multiple workers start simultaneously
|
||||
apptainer build $APPTAINER_CACHEDIR/python-nodejs.sif docker://nikolaik/python-nodejs:python3.11-nodejs20
|
||||
|
||||
# 4. Configure .env to use the local SIF
|
||||
TERMINAL_ENV=singularity
|
||||
TERMINAL_SINGULARITY_IMAGE=/scratch/$USER/.apptainer/python-nodejs.sif
|
||||
```
|
||||
|
||||
**Tip:** The batch scripts in `configs/` automatically handle SIF pre-building if `/scratch` is available.
|
||||
|
||||
### Modal Cloud Backend Setup
|
||||
|
||||
[Modal](https://modal.com) provides serverless cloud compute for running sandboxed environments at scale.
|
||||
|
||||
```bash
|
||||
# 1. Install Modal and dependencies
|
||||
pip install modal boto3
|
||||
|
||||
# 2. Authenticate with Modal (opens browser)
|
||||
modal setup
|
||||
|
||||
# 3. Set terminal backend to modal in .env
|
||||
TERMINAL_ENV=modal
|
||||
```
|
||||
|
||||
Modal uses CLI-based authentication (stored in `~/.modal/`), so no API key is needed in `.env`. After running `modal setup`, commands will automatically execute in Modal's cloud sandboxes.
|
||||
|
||||
### Browser Tools Setup
|
||||
|
||||
Browser tools enable the agent to navigate websites, fill forms, click buttons, and extract content. They use [agent-browser](https://github.com/vercel-labs/agent-browser) CLI with [Browserbase](https://browserbase.com) cloud execution.
|
||||
|
||||
```bash
|
||||
# 1. Install Node.js (if not already installed)
|
||||
# Use nvm (recommended) or your package manager
|
||||
|
||||
# 2. Install agent-browser CLI (choose one option):
|
||||
npm install -g agent-browser # Option A: Global install (recommended)
|
||||
npm install # Option B: Local install (uses npx fallback)
|
||||
|
||||
# 3. Get Browserbase credentials
|
||||
# Sign up at https://browserbase.com/ and get your:
|
||||
# - API Key (from Settings → API Keys)
|
||||
# - Project ID (from your project dashboard)
|
||||
|
||||
# 4. Add to your .env file:
|
||||
BROWSERBASE_API_KEY=your_api_key_here
|
||||
BROWSERBASE_PROJECT_ID=your_project_id_here
|
||||
```
|
||||
|
||||
**Available Browser Tools:**
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `browser_navigate` | Navigate to a URL |
|
||||
| `browser_snapshot` | Get text-based page snapshot with element refs |
|
||||
| `browser_click` | Click an element by ref (e.g., `@e5`) |
|
||||
| `browser_type` | Type text into an input field |
|
||||
| `browser_scroll` | Scroll up or down |
|
||||
| `browser_back` | Go back in browser history |
|
||||
| `browser_press` | Press a keyboard key (Enter, Tab, etc.) |
|
||||
| `browser_close` | Close the browser session |
|
||||
| `browser_get_images` | Get list of images on the page |
|
||||
|
||||
**Example Usage:**
|
||||
```bash
|
||||
# Use browser tools with web search and vision
|
||||
python run_agent.py \
|
||||
--query "Go to amazon.com and find the price of the latest Kindle" \
|
||||
--enabled_toolsets=browser,web,vision
|
||||
|
||||
# Use browser-focused distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=browser_tasks.jsonl \
|
||||
--distribution=browser_use \
|
||||
--run_name=browser_run
|
||||
```
|
||||
|
||||
See `.env.example` for all available configuration options including debug settings.
|
||||
|
||||
### Skills Tools
|
||||
|
||||
Skills are on-demand knowledge documents the agent can load when needed. They follow a **progressive disclosure** pattern to minimize token usage:
|
||||
|
||||
```
|
||||
skills/
|
||||
├── mlops/ # Category folder
|
||||
│ ├── axolotl/ # Skill folder
|
||||
│ │ ├── SKILL.md # Main instructions (required)
|
||||
│ │ ├── references/ # Additional docs, API specs
|
||||
│ │ └── templates/ # Output formats, configs
|
||||
│ └── vllm/
|
||||
│ └── SKILL.md
|
||||
```
|
||||
|
||||
**Available Skills Tools:**
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `skills_categories` | List available skill categories (~50 tokens) |
|
||||
| `skills_list` | List skills with name + description (~3k tokens for 40 skills) |
|
||||
| `skill_view` | Load full skill content, tags, and linked files |
|
||||
|
||||
**Example Usage:**
|
||||
```bash
|
||||
# Use skills tools
|
||||
python run_agent.py \
|
||||
--query "What skills do you have for fine-tuning? Show me the axolotl skill." \
|
||||
--enabled_toolsets=skills
|
||||
```
|
||||
|
||||
**Creating Skills:**
|
||||
|
||||
Skills use YAML frontmatter for metadata:
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
description: Brief description shown in skills_list
|
||||
tags: [tag1, tag2]
|
||||
related_skills: [other-skill]
|
||||
version: 1.0.0
|
||||
---
|
||||
# Skill Content
|
||||
|
||||
Instructions, examples, and guidelines here...
|
||||
```
|
||||
|
||||
Skills can include:
|
||||
- `references/` - Additional documentation, API specs, examples
|
||||
- `templates/` - Output formats, config files, boilerplate code
|
||||
- `scripts/` - Executable helpers (Python, shell scripts)
|
||||
|
||||
## Session Logging
|
||||
|
||||
Every conversation is automatically logged to `logs/` for debugging and inspection:
|
||||
|
||||
```
|
||||
logs/
|
||||
├── session_20260201_143052_a1b2c3.json
|
||||
├── session_20260201_150217_d4e5f6.json
|
||||
└── ...
|
||||
```
|
||||
|
||||
**Log Format:**
|
||||
```json
|
||||
{
|
||||
"session_id": "20260201_143052_a1b2c3",
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"session_start": "2026-02-01T14:30:52.123456",
|
||||
"last_updated": "2026-02-01T14:35:12.789012",
|
||||
"message_count": 8,
|
||||
"conversations": [
|
||||
{"from": "system", "value": "..."},
|
||||
{"from": "human", "value": "..."},
|
||||
{"from": "gpt", "value": "..."},
|
||||
{"from": "tool", "value": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
- **Automatic**: Logs are created and updated automatically after each conversation turn
|
||||
- **Session ID in Banner**: The CLI displays the session ID in the welcome banner
|
||||
- **Trajectory Format**: Uses the same format as batch processing for consistency
|
||||
- **Git Ignored**: `logs/` is in `.gitignore` so logs aren't committed
|
||||
|
||||
## Interactive CLI
|
||||
|
||||
The CLI provides a rich interactive experience for working with the agent.
|
||||
|
||||
### Running the CLI
|
||||
|
||||
```bash
|
||||
# Basic usage
|
||||
./hermes
|
||||
|
||||
# With specific model
|
||||
./hermes --model "anthropic/claude-sonnet-4"
|
||||
|
||||
# With specific toolsets
|
||||
./hermes --toolsets "web,terminal,skills"
|
||||
```
|
||||
|
||||
### CLI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show available commands |
|
||||
| `/tools` | List available tools by toolset |
|
||||
| `/toolsets` | List available toolsets |
|
||||
| `/model [name]` | Show or change the current model |
|
||||
| `/prompt [text]` | View/set custom system prompt |
|
||||
| `/personality [name]` | Set a predefined personality |
|
||||
| `/clear` | Clear screen and reset conversation |
|
||||
| `/reset` | Reset conversation only |
|
||||
| `/history` | Show conversation history |
|
||||
| `/save` | Save current conversation to file |
|
||||
| `/config` | Show current configuration |
|
||||
| `/quit` | Exit the CLI |
|
||||
|
||||
### Configuration
|
||||
|
||||
Copy `cli-config.yaml.example` to `cli-config.yaml` and customize:
|
||||
|
||||
```yaml
|
||||
# Model settings
|
||||
model:
|
||||
default: "anthropic/claude-sonnet-4"
|
||||
|
||||
# Terminal backend (local, docker, singularity, modal, or ssh)
|
||||
terminal:
|
||||
env_type: "local"
|
||||
cwd: "." # Use current directory
|
||||
|
||||
# Or use SSH for remote execution (keeps agent code isolated)
|
||||
# terminal:
|
||||
# env_type: "ssh"
|
||||
# ssh_host: "my-server.example.com"
|
||||
# ssh_user: "myuser"
|
||||
# ssh_key: "~/.ssh/id_rsa"
|
||||
# cwd: "/home/myuser/project"
|
||||
|
||||
# Enable specific toolsets
|
||||
toolsets:
|
||||
- all # or: web, terminal, browser, vision, etc.
|
||||
|
||||
# Custom personalities (use with /personality command)
|
||||
agent:
|
||||
personalities:
|
||||
helpful: "You are a helpful assistant."
|
||||
kawaii: "You are a kawaii assistant! Use cute expressions..."
|
||||
```
|
||||
|
||||
### Personalities
|
||||
|
||||
Built-in personalities available via `/personality`:
|
||||
- `helpful`, `concise`, `technical`, `creative`, `teacher`
|
||||
- `kawaii`, `catgirl`, `pirate`, `shakespeare`, `surfer`
|
||||
- `noir`, `uwu`, `philosopher`, `hype`
|
||||
|
||||
## Toolsets System
|
||||
|
||||
The agent uses a toolsets system for organizing and managing tools. All tools must be part of a toolset to be accessible - individual tool selection is not supported. This ensures consistent and logical grouping of capabilities.
|
||||
|
||||
### Key Concepts
|
||||
|
||||
- **Toolsets**: Logical groups of tools for specific use cases (e.g., "research", "development", "debugging")
|
||||
- **Composition**: Toolsets can include other toolsets for powerful combinations
|
||||
- **Custom Toolsets**: Create your own toolsets at runtime or by editing `toolsets.py`
|
||||
- **Toolset-Only Access**: Tools are only accessible through toolsets, not individually
|
||||
|
||||
### Available Toolsets
|
||||
|
||||
See `toolsets.py` for the complete list of predefined toolsets including:
|
||||
- Basic toolsets (web, terminal, vision, creative, reasoning)
|
||||
- Composite toolsets (research, development, analysis, etc.)
|
||||
- Scenario-specific toolsets (debugging, documentation, API testing, etc.)
|
||||
- Special toolsets (safe mode without terminal, minimal, offline)
|
||||
|
||||
### Using Toolsets
|
||||
|
||||
```bash
|
||||
# Use a predefined toolset
|
||||
python run_agent.py --enabled_toolsets=research --query "Find latest AI papers"
|
||||
|
||||
# Combine multiple toolsets
|
||||
python run_agent.py --enabled_toolsets=web,vision --query "Analyze this website"
|
||||
|
||||
# Enable all toolsets explicitly (same as omitting the flag)
|
||||
python run_agent.py --enabled_toolsets=all --query "Do web research and run commands if helpful"
|
||||
|
||||
# Safe mode (no terminal access)
|
||||
python run_agent.py --enabled_toolsets=safe --query "Help without running commands"
|
||||
|
||||
# List all available toolsets and tools
|
||||
python run_agent.py --list_tools
|
||||
```
|
||||
|
||||
See `toolsets.py` for the complete list of available toolsets and how to create custom ones.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Default (all tools enabled)
|
||||
```bash
|
||||
# Uses OpenRouter by default - just set OPENROUTER_API_KEY in .env
|
||||
python run_agent.py \
|
||||
--query "search up the latest docs on jit in python 3.13 and write me basic example that's not in their docs. profile its perf" \
|
||||
--max_turns 20 \
|
||||
--model anthropic/claude-sonnet-4-20250514
|
||||
```
|
||||
|
||||
### With specific toolset
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--query "Debug this Python error" \
|
||||
--enabled_toolsets=debugging \
|
||||
--model anthropic/claude-sonnet-4-20250514
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Uses OpenRouter by default (reads OPENROUTER_API_KEY from .env)
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4-20250514",
|
||||
enabled_toolsets=["research"]
|
||||
)
|
||||
response = agent.chat("Find information about quantum computing")
|
||||
|
||||
# Create custom toolset at runtime
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
create_custom_toolset(
|
||||
name="my_tools",
|
||||
description="My custom toolkit",
|
||||
tools=["web_search"],
|
||||
includes=["terminal", "vision"]
|
||||
)
|
||||
|
||||
agent = AIAgent(enabled_toolsets=["my_tools"])
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
Process multiple prompts from a dataset in parallel with automatic checkpointing and statistics tracking:
|
||||
|
||||
```bash
|
||||
# Basic batch processing
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=my_run
|
||||
|
||||
# With specific distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=image_run \
|
||||
--distribution=image_gen \
|
||||
--num_workers=4
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Parallel processing with configurable workers
|
||||
- Toolset distributions for varied data generation
|
||||
- Automatic checkpointing and resume capability
|
||||
- Combined output in `data/<run_name>/trajectories.jsonl`
|
||||
- Tool usage statistics and success rates
|
||||
|
||||
Use `--list_distributions` to see available toolset distributions for varied data generation.
|
||||
|
||||
### Trajectory Compression
|
||||
|
||||
Post-process trajectories to fit within token budgets for training:
|
||||
|
||||
```bash
|
||||
# Compress a directory of JSONL files
|
||||
python trajectory_compressor.py --input=data/my_run
|
||||
|
||||
# Compress a single JSONL file
|
||||
python trajectory_compressor.py --input=data/trajectories.jsonl
|
||||
|
||||
# Compress a 15% sample (useful for creating smaller training sets)
|
||||
python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15
|
||||
|
||||
# Custom output and token target
|
||||
python trajectory_compressor.py \
|
||||
--input=data/trajectories.jsonl \
|
||||
--output=data/compressed.jsonl \
|
||||
--target_max_tokens=16000
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Protects first turns (system, human, first GPT response, first tool call)
|
||||
- Protects last N turns (configurable)
|
||||
- Summarizes middle turns using LLM to fit target token budget
|
||||
- Supports both directory and single file input
|
||||
- Optional random sampling with `--sample_percent`
|
||||
- Configurable via `configs/trajectory_compression.yaml`
|
||||
|
||||
### Ephemeral System Prompts
|
||||
|
||||
The ephemeral system prompt feature allows you to guide the model's behavior during batch processing **without** saving that prompt to the training dataset trajectories. This is useful for:
|
||||
|
||||
- Guiding model behavior during data collection
|
||||
- Adding task-specific instructions
|
||||
- Keeping saved trajectories clean and focused on tool-calling format
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=10 \
|
||||
--run_name=my_run \
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
```
|
||||
|
||||
The ephemeral prompt will influence the model's behavior during execution, but **only the standard tool-calling system prompt** will be saved in the trajectory files.
|
||||
|
||||
The ephemeral prompt influences model behavior during execution, but **only the standard tool-calling system prompt** is saved in trajectory files.
|
||||
|
||||
## Command Line Arguments
|
||||
|
||||
**Single Agent (`run_agent.py`):**
|
||||
- `--query`: The question or task for the agent
|
||||
- `--model`: Model to use (default: claude-opus-4-20250514)
|
||||
- `--api_key`: API key for authentication
|
||||
- `--base_url`: API endpoint URL
|
||||
- `--max_turns`: Maximum number of tool-calling iterations
|
||||
- `--enabled_toolsets`: Comma-separated list of toolsets to enable. Use `all` (or `*`) to enable everything. If omitted, all toolsets are enabled by default.
|
||||
- `--disabled_toolsets`: Comma-separated list of toolsets to disable
|
||||
- `--list_tools`: List all available toolsets and tools
|
||||
- `--save_trajectories`: Save conversation trajectories to JSONL files
|
||||
|
||||
**Batch Processing (`batch_runner.py`):**
|
||||
- `--dataset_file`: Path to JSONL file with prompts
|
||||
- `--batch_size`: Number of prompts per batch
|
||||
- `--run_name`: Name for this run (for output/checkpointing)
|
||||
- `--distribution`: Toolset distribution to use (default: "default")
|
||||
- `--num_workers`: Number of parallel workers (default: 4)
|
||||
- `--resume`: Resume from checkpoint if interrupted
|
||||
- `--ephemeral_system_prompt`: System prompt used during execution but NOT saved to trajectories
|
||||
- `--list_distributions`: List available toolset distributions
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All environment variables can be configured in the `.env` file (copy from `.env.example`).
|
||||
|
||||
**LLM Provider (OpenRouter):**
|
||||
- `OPENROUTER_API_KEY`: Primary LLM access via OpenRouter (supports Claude, GPT-4, Gemini, etc.)
|
||||
- `LLM_MODEL`: Default model (e.g., `anthropic/claude-sonnet-4`, `openai/gpt-4o`)
|
||||
|
||||
**Tool API Keys:**
|
||||
- `FIRECRAWL_API_KEY`: Web tools (search, extract, crawl)
|
||||
- `NOUS_API_KEY`: Vision and reasoning tools
|
||||
- `FAL_KEY`: Image generation tools
|
||||
|
||||
**Terminal Tool Configuration (mini-swe-agent backend):**
|
||||
- `TERMINAL_ENV`: Backend type - `local`, `docker`, `singularity`, `modal`, or `ssh` (default: `local`)
|
||||
- `TERMINAL_DOCKER_IMAGE`: Docker image for docker backend (default: `python:3.11-slim`)
|
||||
- `TERMINAL_SINGULARITY_IMAGE`: Singularity/Apptainer image (can be `docker://...` URL or local `.sif` path)
|
||||
- `TERMINAL_TIMEOUT`: Command timeout in seconds (default: `60`)
|
||||
- `TERMINAL_LIFETIME_SECONDS`: Cleanup inactive environments after this time (default: `300`)
|
||||
- `TERMINAL_CWD`: Working directory inside containers (default: `/tmp`)
|
||||
- `TERMINAL_SCRATCH_DIR`: Custom scratch directory for sandbox storage (optional, auto-detects `/scratch`)
|
||||
- `SUDO_PASSWORD`: Enable sudo commands by piping password via `sudo -S` (works with all backends)
|
||||
- If unset in CLI mode, you'll be prompted interactively when sudo is needed (45s timeout)
|
||||
|
||||
**SSH Backend Configuration (for remote execution):**
|
||||
- `TERMINAL_SSH_HOST`: Remote server hostname or IP
|
||||
- `TERMINAL_SSH_USER`: SSH username
|
||||
- `TERMINAL_SSH_PORT`: SSH port (default: `22`)
|
||||
- `TERMINAL_SSH_KEY`: Path to SSH private key (optional, uses ssh-agent if not set)
|
||||
|
||||
**Browser Tool Configuration (agent-browser + Browserbase):**
|
||||
- `BROWSERBASE_API_KEY`: Browserbase API key for cloud browser execution
|
||||
- `BROWSERBASE_PROJECT_ID`: Browserbase project ID
|
||||
- `BROWSER_SESSION_TIMEOUT`: Session timeout in seconds (default: `300`)
|
||||
|
||||
**Legacy Hecate Terminal Backend (optional):**
|
||||
- `MORPH_API_KEY`: For Hecate/MorphCloud terminal backend
|
||||
- `HECATE_VM_LIFETIME_SECONDS`: VM lifetime (default: 300)
|
||||
- `HECATE_DEFAULT_SNAPSHOT_ID`: Default snapshot (default: snapshot_p5294qxt)
|
||||
|
||||
**Debug Options:**
|
||||
- `WEB_TOOLS_DEBUG`, `VISION_TOOLS_DEBUG`, `MOA_TOOLS_DEBUG`, `IMAGE_TOOLS_DEBUG`: Enable debug logging
|
||||
|
||||
## Key Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `hermes` | CLI launcher script (run with `./hermes`) |
|
||||
| `cli.py` | Interactive CLI implementation |
|
||||
| `cli-config.yaml` | CLI configuration (copy from `.example`) |
|
||||
| `run_agent.py` | Main agent runner - single query execution |
|
||||
| `batch_runner.py` | Parallel batch processing with checkpointing |
|
||||
| `model_tools.py` | Core tool definitions and handlers |
|
||||
| `toolsets.py` | Toolset definitions and composition |
|
||||
| `toolset_distributions.py` | Probability distributions for data generation |
|
||||
| `trajectory_compressor.py` | Post-process trajectories for training |
|
||||
| `tools/` | Individual tool implementations |
|
||||
| `tools/skills_tool.py` | Skills system with progressive disclosure |
|
||||
| `skills/` | On-demand knowledge documents |
|
||||
| `docs/` | Documentation |
|
||||
| `configs/` | Example batch run scripts |
|
||||
|
||||
# Atropos Integrations & RL Training
|
||||
|
||||
## Nomad Setup
|
||||
Follow this: https://developer.hashicorp.com/nomad/docs/deploy
|
||||
|
||||
## Atropos dependencies
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -e '.[atropos]'
|
||||
@@ -0,0 +1,70 @@
|
||||
README.md
|
||||
atropos_compatible_agent.py
|
||||
batch_runner.py
|
||||
local_server.py
|
||||
model_tools.py
|
||||
pyproject.toml
|
||||
run_agent.py
|
||||
toolset_distributions.py
|
||||
toolsets.py
|
||||
trajectory_compressor.py
|
||||
atropos/__init__.py
|
||||
atropos/sandbox_server.py
|
||||
atropos/agent/__init__.py
|
||||
atropos/agent/atropos_agent.py
|
||||
atropos/api/__init__.py
|
||||
atropos/api/tool_executor_server.py
|
||||
atropos/api/tool_server.py
|
||||
atropos/backends/__init__.py
|
||||
atropos/backends/base.py
|
||||
atropos/backends/modal_backend.py
|
||||
atropos/backends/nomad_backend.py
|
||||
atropos/envs/__init__.py
|
||||
atropos/envs/agent_env.py
|
||||
atropos/envs/hermes_compat_test_env.py
|
||||
atropos/envs/sandbox_terminal_smoke_env.py
|
||||
atropos/envs/swe_smith_oracle_env.py
|
||||
atropos/envs/test_env.py
|
||||
atropos/envs/toolserver_smoke_env.py
|
||||
atropos/nomad/__init__.py
|
||||
atropos/nomad/client.py
|
||||
atropos/slots/__init__.py
|
||||
atropos/slots/executor.py
|
||||
atropos/slots/pool.py
|
||||
atropos/slots/slot.py
|
||||
atropos/terminal/__init__.py
|
||||
atropos/terminal/asciinema_stream.py
|
||||
atropos/tools/__init__.py
|
||||
atropos/tools/base.py
|
||||
atropos/tools/build_registry.py
|
||||
atropos/tools/hermes_external_tools.py
|
||||
atropos/tools/sandbox_stubs.py
|
||||
atropos/tools/terminal_stateful_tool.py
|
||||
atropos/tools/tmux_tool.py
|
||||
atropos/tools/tool_executor.py
|
||||
atropos/tools/toolset_resolver.py
|
||||
hermes_agent.egg-info/PKG-INFO
|
||||
hermes_agent.egg-info/SOURCES.txt
|
||||
hermes_agent.egg-info/dependency_links.txt
|
||||
hermes_agent.egg-info/entry_points.txt
|
||||
hermes_agent.egg-info/requires.txt
|
||||
hermes_agent.egg-info/top_level.txt
|
||||
tests/test_batch_runner.py
|
||||
tests/test_checkpoint_resumption.py
|
||||
tests/test_modal_integration.py
|
||||
tests/test_modal_stress.py
|
||||
tests/test_modal_terminal.py
|
||||
tests/test_nous_api_limits.py
|
||||
tests/test_nous_api_pattern.py
|
||||
tests/test_temperature_fix.py
|
||||
tests/test_tool_call_parsing.py
|
||||
tests/test_web_tools.py
|
||||
tools/__init__.py
|
||||
tools/browser_tool.py
|
||||
tools/image_generation_tool.py
|
||||
tools/mixture_of_agents_tool.py
|
||||
tools/skills_tool.py
|
||||
tools/terminal_hecate.py
|
||||
tools/terminal_tool.py
|
||||
tools/vision_tools.py
|
||||
tools/web_tools.py
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
[console_scripts]
|
||||
hermes-agent = run_agent:main
|
||||
hermes-atropos-sandbox-smoke = atropos.envs.sandbox_terminal_smoke_env:SandboxTerminalSmokeEnv.cli
|
||||
hermes-atropos-toolserver-smoke = atropos.envs.toolserver_smoke_env:ToolServerSmokeEnv.cli
|
||||
@@ -0,0 +1,31 @@
|
||||
openai
|
||||
python-dotenv
|
||||
fire
|
||||
httpx
|
||||
rich
|
||||
tenacity
|
||||
pyyaml
|
||||
prompt_toolkit
|
||||
requests
|
||||
jinja2
|
||||
pydantic>=2.0
|
||||
firecrawl-py
|
||||
fal-client
|
||||
litellm>=1.75.5
|
||||
typer
|
||||
platformdirs
|
||||
|
||||
[atropos]
|
||||
atroposlib @ git+https://github.com/NousResearch/atropos.git
|
||||
aiohttp
|
||||
fastapi
|
||||
uvicorn
|
||||
pyte
|
||||
|
||||
[dev]
|
||||
pytest
|
||||
pytest-asyncio
|
||||
|
||||
[modal]
|
||||
modal
|
||||
boto3
|
||||
@@ -0,0 +1,10 @@
|
||||
atropos
|
||||
atropos_compatible_agent
|
||||
batch_runner
|
||||
local_server
|
||||
model_tools
|
||||
run_agent
|
||||
tools
|
||||
toolset_distributions
|
||||
toolsets
|
||||
trajectory_compressor
|
||||
@@ -99,24 +99,6 @@ DEFAULT_CONFIG = {
|
||||
"personality": "kawaii",
|
||||
},
|
||||
|
||||
# Text-to-speech configuration
|
||||
"tts": {
|
||||
"provider": "edge", # "edge" (free) | "elevenlabs" (premium) | "openai"
|
||||
"edge": {
|
||||
"voice": "en-US-AriaNeural",
|
||||
# Popular: AriaNeural, JennyNeural, AndrewNeural, BrianNeural, SoniaNeural
|
||||
},
|
||||
"elevenlabs": {
|
||||
"voice_id": "pNInz6obpgDQGcFmaJgB", # Adam
|
||||
"model_id": "eleven_multilingual_v2",
|
||||
},
|
||||
"openai": {
|
||||
"model": "gpt-4o-mini-tts",
|
||||
"voice": "alloy",
|
||||
# Voices: alloy, echo, fable, onyx, nova, shimmer
|
||||
},
|
||||
},
|
||||
|
||||
# Permanently allowed dangerous command patterns (added via "always" approval)
|
||||
"command_allowlist": [],
|
||||
|
||||
@@ -220,13 +202,6 @@ OPTIONAL_ENV_VARS = {
|
||||
"url": None,
|
||||
"password": False,
|
||||
},
|
||||
# Text-to-speech (premium providers)
|
||||
"ELEVENLABS_API_KEY": {
|
||||
"description": "ElevenLabs API key for premium text-to-speech voices",
|
||||
"prompt": "ElevenLabs API key",
|
||||
"url": "https://elevenlabs.io/",
|
||||
"password": True,
|
||||
},
|
||||
# Terminal configuration
|
||||
"MESSAGING_CWD": {
|
||||
"description": "Working directory for terminal commands via messaging (Telegram/Discord/etc). CLI always uses current directory.",
|
||||
|
||||
@@ -360,11 +360,7 @@ def run_gateway(verbose: bool = False):
|
||||
print("└─────────────────────────────────────────────────────────┘")
|
||||
print()
|
||||
|
||||
# Exit with code 1 if gateway fails to connect any platform,
|
||||
# so systemd Restart=on-failure will retry on transient errors
|
||||
success = asyncio.run(start_gateway())
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
asyncio.run(start_gateway())
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -186,11 +186,6 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
else:
|
||||
tool_status.append(("Image Generation", False, "FAL_KEY"))
|
||||
|
||||
# TTS (always available via Edge TTS; ElevenLabs/OpenAI are optional)
|
||||
tool_status.append(("Text-to-Speech (Edge TTS)", True, None))
|
||||
if get_env_value('ELEVENLABS_API_KEY'):
|
||||
tool_status.append(("Text-to-Speech (ElevenLabs)", True, None))
|
||||
|
||||
# Tinker + WandB (RL training)
|
||||
if get_env_value('TINKER_API_KEY') and get_env_value('WANDB_API_KEY'):
|
||||
tool_status.append(("RL Training (Tinker)", True, None))
|
||||
@@ -996,28 +991,6 @@ def run_setup_wizard(args):
|
||||
print_success(" Configured ✓")
|
||||
print()
|
||||
|
||||
# ElevenLabs - Premium TTS
|
||||
print_info("─" * 50)
|
||||
print(color(" Text-to-Speech - ElevenLabs (Premium)", Colors.CYAN))
|
||||
print_info(" Enables: Premium TTS voices (Edge TTS is free and works without a key)")
|
||||
print_info(" Use case: High-quality, customizable voice synthesis")
|
||||
if get_env_value('ELEVENLABS_API_KEY'):
|
||||
print_success(" Status: Configured ✓")
|
||||
if prompt_yes_no(" Update ElevenLabs API key?", False):
|
||||
api_key = prompt(" API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("ELEVENLABS_API_KEY", api_key)
|
||||
print_success(" Updated")
|
||||
else:
|
||||
print_warning(" Status: Not configured (free Edge TTS will be used by default)")
|
||||
if prompt_yes_no(" Set up ElevenLabs?", False):
|
||||
print_info(" Get your API key at: https://elevenlabs.io/")
|
||||
api_key = prompt(" API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("ELEVENLABS_API_KEY", api_key)
|
||||
print_success(" Configured ✓")
|
||||
print()
|
||||
|
||||
# Tinker + WandB - RL Training
|
||||
print_info("─" * 50)
|
||||
print(color(" RL Training (Tinker + WandB)", Colors.CYAN))
|
||||
|
||||
@@ -76,7 +76,6 @@ def show_status(args):
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
"WandB": "WANDB_API_KEY",
|
||||
"ElevenLabs": "ELEVENLABS_API_KEY",
|
||||
}
|
||||
|
||||
for name, env_var in keys.items():
|
||||
|
||||
+353
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Local OpenAI-compatible server implementation for Hermes-Agent (Atropos integration).
|
||||
|
||||
Extends the Atropos APIServer to work with local OpenAI-compatible APIs (e.g. vLLM, SGLang),
|
||||
providing tokens_and_logprobs_completion support via client-side tokenization.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ReasoningConfig,
|
||||
)
|
||||
|
||||
|
||||
class LocalServer(APIServer):
|
||||
"""
|
||||
OpenAI-compatible local server with tokens_and_logprobs support.
|
||||
|
||||
Uses an OpenAI-compatible API (typically at a /v1 endpoint) and handles
|
||||
token extraction via client-side tokenization.
|
||||
|
||||
Note: Many local servers don't return per-token logprobs in the standard API,
|
||||
so this implementation uses placeholder logprobs (0.0) for PoC purposes.
|
||||
For production training, use vLLM/SGLang servers that return real logprobs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: APIServerConfig,
|
||||
tokenizer: Optional[Any] = None,
|
||||
tokenizer_name: str = "gpt2",
|
||||
reasoning_config: Optional[ReasoningConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the local server.
|
||||
|
||||
Args:
|
||||
config: Server configuration
|
||||
tokenizer: Pre-initialized tokenizer (optional)
|
||||
tokenizer_name: Name of tokenizer to load if tokenizer not provided
|
||||
reasoning_config: Optional reasoning configuration
|
||||
"""
|
||||
# Build the OpenAI client pointing to the server's /v1 endpoint
|
||||
base_url = config.base_url
|
||||
if base_url and not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url.rstrip('/')}/v1"
|
||||
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key or "local", # Local servers often ignore auth
|
||||
base_url=base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
if tokenizer is not None:
|
||||
self.tokenizer = tokenizer
|
||||
else:
|
||||
try:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"Missing optional dependency 'transformers'. Pass a tokenizer instance to LocalServer, "
|
||||
"or install transformers to enable `tokenizer_name` auto-loading."
|
||||
) from exc
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
|
||||
# Add a simple chat template if the tokenizer doesn't have one
|
||||
# This is needed for ManagedServer's chat_completion to work
|
||||
if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None:
|
||||
# Simple ChatML-style template
|
||||
self.tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
)
|
||||
|
||||
super().__init__(config, reasoning_config=reasoning_config)
|
||||
# Local servers are treated as always-healthy unless a status task is enabled.
|
||||
self.server_healthy = True
|
||||
|
||||
@classmethod
|
||||
def from_env(
|
||||
cls,
|
||||
base_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
tokenizer_name: str = "gpt2",
|
||||
**kwargs,
|
||||
) -> "LocalServer":
|
||||
"""
|
||||
Create a LocalServer from environment variables (or explicit overrides).
|
||||
|
||||
Env vars (checked in order):
|
||||
- base URL: ATROPOS_SERVER_BASE_URL, OPENAI_BASE_URL, LOCAL_LLM_BASE_URL, LLM_BASE_URL
|
||||
- model: ATROPOS_SERVER_MODEL, LLM_MODEL, LOCAL_LLM_MODEL
|
||||
- api key: ATROPOS_SERVER_API_KEY, OPENAI_API_KEY, LOCAL_LLM_API_KEY, LLM_API_KEY
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
base_url = (
|
||||
base_url
|
||||
or os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LOCAL_LLM_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://localhost:11434"
|
||||
)
|
||||
model = (
|
||||
model
|
||||
or os.getenv("ATROPOS_SERVER_MODEL")
|
||||
or os.getenv("LLM_MODEL")
|
||||
or os.getenv("LOCAL_LLM_MODEL")
|
||||
or "hermes3:8b"
|
||||
)
|
||||
api_key = (
|
||||
api_key
|
||||
or os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("LOCAL_LLM_API_KEY")
|
||||
or os.getenv("LLM_API_KEY")
|
||||
)
|
||||
|
||||
config = APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key or "local",
|
||||
timeout=kwargs.get("timeout", 120),
|
||||
num_max_requests_at_once=kwargs.get("num_max_requests_at_once", 4),
|
||||
num_requests_for_eval=kwargs.get("num_requests_for_eval", 4),
|
||||
health_check=False, # Local dev servers often lack /health
|
||||
)
|
||||
|
||||
return cls(config, tokenizer_name=tokenizer_name)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
"""
|
||||
Check if the server is healthy.
|
||||
|
||||
For local development, we generally assume the server is healthy.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Simple health check via a minimal completion
|
||||
if chat_completion:
|
||||
await self.openai.chat.completions.create(
|
||||
model=self.config.model_name,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
else:
|
||||
await self.openai.completions.create(
|
||||
model=self.config.model_name,
|
||||
prompt="hi",
|
||||
max_tokens=1,
|
||||
)
|
||||
self.server_healthy = True
|
||||
except Exception:
|
||||
self.server_healthy = False
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for chat completion using an OpenAI-compatible API.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("messages") is not None, "Messages are required!"
|
||||
|
||||
n = kwargs.get("n", 1)
|
||||
|
||||
# Some OpenAI-compatible servers don't support n > 1, so we make multiple requests.
|
||||
if n > 1:
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.chat.completions.create(**{**kwargs, "n": 1}) for _ in range(n)]
|
||||
)
|
||||
# Merge completions
|
||||
completions = completion_list[0]
|
||||
for c in completion_list[1:]:
|
||||
for choice in c.choices:
|
||||
choice.index = len(completions.choices)
|
||||
completions.choices.append(choice)
|
||||
return completions
|
||||
else:
|
||||
return await self.openai.chat.completions.create(**kwargs)
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for completion using an OpenAI-compatible API.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("prompt") is not None, "Prompt is required!"
|
||||
|
||||
n = kwargs.get("n", 1)
|
||||
|
||||
# Some OpenAI-compatible servers don't support n > 1.
|
||||
if n > 1:
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**{**kwargs, "n": 1}) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
for c in completion_list[1:]:
|
||||
for choice in c.choices:
|
||||
choice.index = len(completions.choices)
|
||||
completions.choices.append(choice)
|
||||
return completions
|
||||
else:
|
||||
return await self.openai.completions.create(**kwargs)
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[List[int], List[List[int]], List[List[float]], List[str]]:
|
||||
"""
|
||||
Wrapper for tokens and logprobs completion.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons)
|
||||
|
||||
Note: Many OpenAI-compatible local servers don't return per-token logprobs,
|
||||
so we use placeholder logprobs (0.0). For real training, use vLLM/SGLang.
|
||||
"""
|
||||
model = kwargs.get("model")
|
||||
assert model is not None, "Model is required!"
|
||||
|
||||
# Handle input_ids (from ManagedServer) or prompt
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
prompt = self.tokenizer.decode(prompt_tokens)
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt = kwargs.pop("prompt", "")
|
||||
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
|
||||
|
||||
n = kwargs.pop("n", 1)
|
||||
max_tokens = kwargs.pop("max_tokens", 256)
|
||||
temperature = kwargs.pop("temperature", 0.7)
|
||||
stop = kwargs.pop("stop", None)
|
||||
|
||||
# Make completion requests
|
||||
completions = []
|
||||
for _ in range(n):
|
||||
try:
|
||||
response = await self.openai.completions.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
)
|
||||
completions.append(response)
|
||||
except Exception as e:
|
||||
# Fallback to chat completion if completion endpoint not supported
|
||||
warnings.warn(f"Completion API failed, trying chat: {e}")
|
||||
response = await self.openai.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
)
|
||||
# Convert to completion-like response
|
||||
completions.append(response)
|
||||
|
||||
output_tokens_list = []
|
||||
output_logprobs_list = []
|
||||
finish_reasons = []
|
||||
|
||||
for completion in completions:
|
||||
# Extract text from response
|
||||
if hasattr(completion.choices[0], "text"):
|
||||
# Completion API response
|
||||
text = completion.choices[0].text
|
||||
finish_reason = completion.choices[0].finish_reason or "stop"
|
||||
else:
|
||||
# Chat completion API response
|
||||
text = completion.choices[0].message.content or ""
|
||||
finish_reason = completion.choices[0].finish_reason or "stop"
|
||||
|
||||
# Tokenize output
|
||||
output_tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Placeholder logprobs (many local servers don't provide per-token logprobs).
|
||||
# In production, use vLLM/SGLang which return real logprobs
|
||||
output_logprobs = [0.0] * len(output_tokens)
|
||||
|
||||
output_tokens_list.append(output_tokens)
|
||||
output_logprobs_list.append(output_logprobs)
|
||||
finish_reasons.append(finish_reason)
|
||||
|
||||
return prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons
|
||||
|
||||
def managed_server(self, tokenizer=None, track_tree: bool = False):
|
||||
"""
|
||||
Create a ManagedServer context manager for this server.
|
||||
|
||||
Args:
|
||||
tokenizer: Optional tokenizer override
|
||||
track_tree: Whether to maintain tree structure for multi-turn
|
||||
|
||||
Returns:
|
||||
ManagedServer context manager
|
||||
"""
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
return ManagedServerContext(
|
||||
self,
|
||||
tokenizer=tokenizer or self.tokenizer,
|
||||
track_tree=track_tree,
|
||||
)
|
||||
|
||||
|
||||
class ManagedServerContext:
|
||||
"""
|
||||
Context manager wrapper for ManagedServer.
|
||||
|
||||
Usage:
|
||||
async with server.managed_server(tokenizer=tokenizer) as managed:
|
||||
response = await managed.chat_completion(...)
|
||||
state = managed.get_state()
|
||||
"""
|
||||
|
||||
def __init__(self, server: LocalServer, tokenizer, track_tree: bool = False):
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.track_tree = track_tree
|
||||
self.managed = None
|
||||
|
||||
async def __aenter__(self):
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
self.managed = ManagedServer(
|
||||
self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
track_tree=self.track_tree,
|
||||
)
|
||||
return self.managed
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.managed:
|
||||
self.managed.reset()
|
||||
return False
|
||||
@@ -0,0 +1,136 @@
|
||||
# Active Context
|
||||
|
||||
## Current Task: SWE Smith Oracle Env with Modal Backend
|
||||
|
||||
### Goal
|
||||
Run this command:
|
||||
```bash
|
||||
python environments/swe_smith_oracle_env.py process \
|
||||
--env.use_wandb false \
|
||||
--env.total_steps 2 \
|
||||
--env.group_size 1 \
|
||||
--env.max_items 2 \
|
||||
--env.tool_pool_mode modal \
|
||||
--env.modal_image python:3.11 \
|
||||
--env.modal_slots_per_sandbox 10 \
|
||||
--env.modal_min_sandboxes 1
|
||||
```
|
||||
|
||||
### What's Done
|
||||
1. ✅ **agent_loop.py** - Added `tool_handler` parameter
|
||||
- New param: `tool_handler=None` in `__init__`
|
||||
- When `self.tool_handler` is set, it's called INSTEAD of `handle_function_call()`
|
||||
- Signature: `async tool_handler(tool_name, args, task_id) -> str`
|
||||
- Shows `[sandbox]` instead of backend name in terminal preview
|
||||
|
||||
2. ✅ **Phase 2 ManagedServer + SGLang** - Fully working (previous session)
|
||||
|
||||
3. ✅ **hermes_base_env.py** - Sandbox routing in collect_trajectory() (THIS SESSION)
|
||||
- Refactored `collect_trajectory()` into:
|
||||
- `_use_sandbox_backend()` - checks if sandbox should be used
|
||||
- `_collect_trajectory_local()` - existing path (ToolContext + handle_function_call)
|
||||
- `_collect_trajectory_sandbox()` - NEW sandbox path with slot lifecycle
|
||||
- `_run_agent_loop()` - shared agent loop for Phase 1/2, accepts tool_handler
|
||||
- `_build_scored_item()` - shared scored item construction
|
||||
- Sandbox path:
|
||||
1. `backend.acquire(task_id)` → Slot
|
||||
2. `exec_tool` callable wrapping `backend.execute_batch([(slot, tool_name, args)])`
|
||||
3. `setup_trajectory_workspace(item, exec_tool=exec_tool)` → workspace_meta
|
||||
4. `sandbox_tool_handler` routes terminal→sandbox, other→local
|
||||
5. `_run_agent_loop(tool_handler=sandbox_tool_handler)`
|
||||
6. `verify_and_score_trajectory(item, result, exec_tool=exec_tool)`
|
||||
7. `backend.release(slot, reset_workspace=True)` in finally
|
||||
- Added `handle_function_call` import for non-terminal tool fallback
|
||||
|
||||
4. ✅ **swe_smith_oracle_env.py** - Sandbox hooks (THIS SESSION)
|
||||
- `setup_trajectory_workspace()` - bare repo cache + git worktree (ported from atropos/envs/swe_smith_oracle_env.py)
|
||||
- `verify_and_score_trajectory()` - install deps + run pytest in sandbox
|
||||
- `compute_reward()` retained for local (non-sandbox) path
|
||||
- Uses `exec_tool("bash", {"command": cmd}, timeout=600)` → `ExecutionResult`
|
||||
|
||||
5. ✅ **All tests pass**:
|
||||
- Syntax checks (ast.parse) on both files
|
||||
- Import checks (both modules import cleanly)
|
||||
- Method existence checks (all new methods present)
|
||||
- Signature checks (exec_tool, trajectory_id, workspace_meta params)
|
||||
- Backend integration (ModalSandboxConfig.from_agent_env_config, create_tool_backend)
|
||||
- `_use_sandbox_backend()` logic (True when modal+backend set, False otherwise)
|
||||
|
||||
6. ✅ **End-to-end test with Qwen 3 8B + Modal sandbox** (THIS SESSION)
|
||||
- RunPod endpoint: `0tx0ruuuo4f10c` (Qwen/Qwen3-8B via SGLang)
|
||||
- 5 terminal tool calls executed IN sandbox: `ls`, `git status`, `git log`, `cat parse.py`, `cat tests/`
|
||||
- In-sandbox verification: install deps + pytest → score=0.0 (model inspected but didn't fix)
|
||||
- Full token tracking with logprobs via Phase 2 ManagedServer
|
||||
- Key finding: Llama-3-8B template silently drops `tools=` param, Qwen 3 has full Hermes format support
|
||||
|
||||
### Current Task: Integrate Slot Pool Backend into tools/terminal_tool.py
|
||||
|
||||
#### Step 1: Add `_SlotPoolEnvironment` to `tools/terminal_tool.py`
|
||||
- New class alongside existing `_LocalEnvironment`, `_DockerEnvironment`, etc.
|
||||
- Routes through `atropos/backends/` (ModalToolBackend or NomadToolBackend)
|
||||
- N:M slot multiplexing: 5-10 sandboxes × 10 slots each = 50-100 concurrent
|
||||
- Singleton `_SlotPoolManager` (like `_ModalPoolManager`) manages backend lifecycle
|
||||
- `execute()` acquires slot → `backend.execute_batch([(slot, "bash", ...)])` → returns `{"output": ..., "returncode": ...}`
|
||||
- `cleanup()` releases slot back to pool
|
||||
|
||||
#### Step 2: Wire into `_create_environment()`
|
||||
- `TERMINAL_ENV=slot_pool` → `_SlotPoolEnvironment(...)`
|
||||
- Sub-config: `TERMINAL_SLOT_BACKEND=modal` or `TERMINAL_SLOT_BACKEND=nomad`
|
||||
- Reuse existing `TERMINAL_MODAL_*` and Nomad env vars for configuration
|
||||
|
||||
#### Step 3: Remove redundant `atropos/tools/` files
|
||||
- DELETE: `hermes_external_tools.py`, `build_registry.py`, `sandbox_stubs.py`, `toolset_resolver.py`
|
||||
- KEEP: `base.py` (ToolCall/ToolResult types), `tool_executor.py` (batched queue), `terminal_stateful_tool.py`, `tmux_tool.py`
|
||||
|
||||
#### Step 4: Clean up `atropos/envs/` and `atropos/agent/` (defer)
|
||||
- Remove `atropos/envs/agent_env.py` → replaced by `environments/hermes_base_env.py`
|
||||
- Remove `atropos/agent/atropos_agent.py` → replaced by `environments/agent_loop.py`
|
||||
|
||||
#### Later
|
||||
- Test with Tinker trainer (blocked on billing)
|
||||
- Add more environments (endless-terminals, terminalbench 2)
|
||||
|
||||
### Key Architecture Insight
|
||||
Two separate sandbox integration points:
|
||||
1. **`tools/terminal_tool.py` with `TERMINAL_ENV=slot_pool`** — for hermes CLI, batch_runner, any code using `handle_function_call("terminal", ...)`. Uses `_SlotPoolEnvironment` which wraps `atropos/backends/`.
|
||||
2. **`environments/hermes_base_env.py` with `tool_pool_mode=modal/nomad`** — for RL environments. Uses `_collect_trajectory_sandbox()` which directly acquires slots and creates `sandbox_tool_handler`.
|
||||
|
||||
Both use the same underlying `atropos/backends/` (ModalToolBackend, NomadToolBackend) with the same slot pool.
|
||||
|
||||
### Architecture Summary
|
||||
|
||||
```
|
||||
environments/hermes_base_env.py (HermesAgentBaseEnv)
|
||||
│
|
||||
├── tool_pool_mode="default" (existing path)
|
||||
│ └── collect_trajectory() → HermesAgentLoop(tool_handler=None)
|
||||
│ → handle_function_call() → hermes terminal tool (local)
|
||||
│
|
||||
└── tool_pool_mode="modal" or "nomad" (new path)
|
||||
└── collect_trajectory():
|
||||
1. slot = backend.acquire(task_id)
|
||||
2. exec_tool = lambda routing through backend.execute_batch
|
||||
3. setup_trajectory_workspace(item, exec_tool=exec_tool) [subclass hook]
|
||||
4. HermesAgentLoop(tool_handler=sandbox_tool_handler)
|
||||
→ terminal calls → backend.execute_batch(slot, "bash", ...)
|
||||
5. verify_and_score_trajectory(item, result, exec_tool=exec_tool) [subclass hook]
|
||||
6. backend.release(slot, reset_workspace=True)
|
||||
|
||||
atropos/backends/modal_backend.py (ModalToolBackend)
|
||||
└── acquire(trajectory_id) → Slot
|
||||
└── execute_batch([(slot, "bash", {"command": "..."})]) → [ExecutionResult]
|
||||
└── release(slot, reset_workspace=True)
|
||||
```
|
||||
|
||||
### Key Files to Modify
|
||||
1. `environments/hermes_base_env.py` - Add sandbox path in `collect_trajectory()`
|
||||
2. `environments/swe_smith_oracle_env.py` - Override `setup_trajectory_workspace()` and `verify_and_score_trajectory()` to use exec_tool
|
||||
|
||||
### Important Notes
|
||||
- `exec_tool` returns `ExecutionResult` (from `atropos/slots/executor.py`) with `.success`, `.output`, `.error`, `.metadata`
|
||||
- `tool_handler` returns JSON string (for agent loop message format)
|
||||
- These are DIFFERENT interfaces for different purposes:
|
||||
- `exec_tool`: used by env hooks (setup/verify) - returns structured result
|
||||
- `tool_handler`: used by agent loop - returns JSON string like hermes tools do
|
||||
- The ModalToolBackend.execute_batch calls _ModalSandboxWithSlots.execute which runs `sandbox.exec("bash", "-c", command)` on Modal
|
||||
- For the SWE env, the worktree setup pattern from `atropos/envs/swe_smith_oracle_env.py` should be reused (bare repo cache + worktree add)
|
||||
@@ -0,0 +1,55 @@
|
||||
# Product Context: Hermes-Agent
|
||||
|
||||
## Why This Project Exists
|
||||
|
||||
Hermes-Agent addresses several key challenges in the AI agent space:
|
||||
|
||||
1. **Unified Tool Interface** - Provides a clean, consistent interface for LLMs to use various tools (web, terminal, browser, vision, etc.) without requiring custom integration for each model provider.
|
||||
|
||||
2. **Training Data Generation** - Enables efficient generation of high-quality tool-calling trajectories for fine-tuning LLMs, with features like batch processing, checkpointing, and trajectory compression.
|
||||
|
||||
3. **Flexible Deployment** - Supports multiple execution environments (local, Docker, Singularity, Modal, SSH) to accommodate different security and isolation requirements.
|
||||
|
||||
4. **Developer Experience** - Offers a beautiful, interactive CLI with kawaii-style feedback that makes working with AI agents enjoyable.
|
||||
|
||||
## Problems It Solves
|
||||
|
||||
### For AI Researchers
|
||||
- **Data Generation at Scale**: Parallel batch processing with content-based checkpointing for fault tolerance
|
||||
- **Clean Trajectories**: Trajectory compression to fit token budgets while preserving important information
|
||||
- **Toolset Distributions**: Probability-based tool selection for varied training data
|
||||
|
||||
### For Developers
|
||||
- **Tool Orchestration**: Logical grouping of tools into toolsets (research, development, debugging, etc.)
|
||||
- **Session Persistence**: Conversation history and session logging for debugging
|
||||
- **Multi-Model Support**: Works with any OpenAI-compatible API (OpenRouter, local models, etc.)
|
||||
|
||||
### For MLOps
|
||||
- **Skills System**: On-demand knowledge documents for specific tools/frameworks (Axolotl, vLLM, TRL, etc.)
|
||||
- **Sandboxed Execution**: Terminal commands can run in isolated environments (Docker, Singularity, Modal)
|
||||
- **Configurable Backends**: Easy switching between local and cloud execution
|
||||
|
||||
## How It Should Work
|
||||
|
||||
### User Flow (CLI)
|
||||
1. User launches `./hermes`
|
||||
2. Beautiful welcome banner displays with caduceus logo, model info, and available tools
|
||||
3. User types a natural language request
|
||||
4. Agent processes request, potentially calling tools with animated feedback
|
||||
5. Agent responds with results, conversation continues
|
||||
6. Session is automatically logged for debugging
|
||||
|
||||
### User Flow (Batch Processing)
|
||||
1. User prepares JSONL file with prompts
|
||||
2. Runs `batch_runner.py` with distribution and worker count
|
||||
3. System processes prompts in parallel, saves checkpoints
|
||||
4. Completed trajectories saved to `data/<run_name>/trajectories.jsonl`
|
||||
5. Optional: compress trajectories with `trajectory_compressor.py`
|
||||
|
||||
## User Experience Goals
|
||||
|
||||
- **Delightful Interaction**: Kawaii ASCII faces, animated spinners, cute messages
|
||||
- **Informative Feedback**: Clear progress indication during tool execution
|
||||
- **Configurable Personalities**: From "helpful" to "pirate" to "Shakespeare"
|
||||
- **Easy Configuration**: YAML config file + environment variables + CLI flags
|
||||
- **Graceful Degradation**: Missing tools/APIs don't break the system, just disable features
|
||||
@@ -0,0 +1,134 @@
|
||||
# Progress
|
||||
|
||||
## Current Sprint: Phase 2 ManagedServer + SGLang Working (Feb 10, 2026)
|
||||
|
||||
### ✅ Phase 2 End-to-End Pipeline VERIFIED
|
||||
Full pipeline working: GSM8k env → collect_trajectory → ManagedServer → VLLMServer (SGLang patched) → tokens + logprobs + masks.
|
||||
|
||||
Test results:
|
||||
- 212 tokens with logprobs and masks from single trajectory
|
||||
- Reward: 1.0 (correct answer)
|
||||
- ScoredDataItem has all required fields: tokens, masks, scores, advantages, ref_logprobs, messages
|
||||
- RunPod SGLang endpoint (b9zmuyn1carwya) with Llama-3-8B-Instruct
|
||||
|
||||
### Consolidation Checklist
|
||||
- [x] Install atropos `tool_call_support` branch (PR #366)
|
||||
- [x] Create `environments/gsm8k_agent_env.py` using `HermesAgentBaseEnv`
|
||||
- [x] Create `environments/agent_loop.py` with proper OpenAI-spec tool calling
|
||||
- [x] Create `environments/tool_call_parsers/` with 13 parsers
|
||||
- [x] Create `environments/patches.py` for SGLang compatibility
|
||||
- [x] Add sandbox pool support to `HermesAgentBaseEnv`
|
||||
- [x] Test Phase 1 (OpenAI server type) with Nous API — WORKS
|
||||
- [x] Test Phase 2 (ManagedServer) with RunPod SGLang — WORKS
|
||||
- [x] Port SWE env to `HermesAgentBaseEnv` with multiplexed sandboxing
|
||||
- [x] End-to-end test: Qwen 3 8B + Modal sandbox + tool calls in sandbox + pytest verification
|
||||
- [x] Add `_SlotPoolEnvironment` to `tools/terminal_tool.py` (TERMINAL_ENV=slot_pool)
|
||||
- [x] Remove redundant `atropos/tools/` files (4 of 8)
|
||||
- [ ] Remove redundant `atropos/agent/` and `atropos/envs/agent_env.py` (deferred)
|
||||
- [ ] Test end-to-end with Tinker trainer (blocked on billing)
|
||||
|
||||
### ✅ End-to-End SWE + Modal Sandbox Verified (Feb 10, 2026)
|
||||
- Qwen 3 8B on RunPod SGLang (endpoint `0tx0ruuuo4f10c`)
|
||||
- Phase 2 ManagedServer with hermes tool call parser
|
||||
- 5 terminal commands executed in Modal sandbox: ls, git status, git log, cat parse.py, cat tests/
|
||||
- In-sandbox verification: install deps + pytest → score 0.0 (model inspected but didn't fix)
|
||||
- Full token tracking with logprobs via /generate endpoint
|
||||
- Key finding: Llama-3-8B template drops tools= silently; Qwen 3 has full Hermes tool format
|
||||
|
||||
## Completed Features
|
||||
|
||||
### ✅ Phase 2 ManagedServer + SGLang (Feb 10, 2026)
|
||||
- SGLang patch in `environments/patches.py` monkey-patches VLLMServer
|
||||
- Handles SGLang's different request/response format vs VLLM
|
||||
- Handles RunPod's double-JSON wrapping
|
||||
- Full chain verified: ManagedServer → VLLMServer → _tokens_and_logprobs_comp (retry) → patched wrapper → /generate endpoint
|
||||
- SequenceNode tracking: tokens, logprobs, masked_tokens all populated
|
||||
- **Key discovery**: The AttributeError from earlier was NOT in our current code — likely from a prior code state
|
||||
|
||||
### ✅ Phase 1 OpenAI Server Mode (Feb 9-10, 2026)
|
||||
- GSM8k env works with Nous API (OpenRouter-style endpoint)
|
||||
- Terminal tool calls properly dispatched
|
||||
- Tool call parsing handled natively by server (VLLM/SGLang /v1/chat/completions)
|
||||
- Reward computation verified (math_verify for robust LaTeX comparison)
|
||||
|
||||
### ✅ Sandbox Pool Integration (Feb 10, 2026)
|
||||
- Config fields added to `HermesAgentEnvConfig` for Nomad and Modal
|
||||
- `_start_sandbox_backend()` / `_stop_sandbox_backend()` lifecycle methods
|
||||
- Optional hooks: `setup_trajectory_workspace()`, `verify_and_score_trajectory()`
|
||||
- Integrated into `env_manager()` and `process_manager()` cleanup
|
||||
|
||||
### ✅ Tool Call Parsers (Feb 9-10, 2026)
|
||||
- 13 parsers: hermes, llama3_json, llama4_json, qwen, qwen3_coder, deepseek_v3, deepseek_v31, glm45, glm47, mistral, kimi_k2, longcat
|
||||
- Registry pattern: `get_parser("hermes")` returns parser instance
|
||||
- Each parser: `.parse(text) → (content, tool_calls)`
|
||||
- Used by ManagedServer in Phase 2 to extract structured tool_calls from raw completion
|
||||
|
||||
### ✅ Modal Backend Integration (Feb 8, 2026)
|
||||
- `ModalToolBackend` with slot-based multiplexing
|
||||
- Multi-profile support (CPU, GPU, high-memory)
|
||||
- Auto-scaling sandbox pool via Modal Sandboxes
|
||||
|
||||
### ✅ Main Branch Merge (Feb 9, 2026)
|
||||
- Merged 22,560 lines, 79 files, 5 conflicts resolved
|
||||
- New: hermes_cli/, file_operations, RL training tools, gateway, cron
|
||||
|
||||
### ✅ Tinker RL Training Setup (Feb 9, 2026)
|
||||
- tinker 0.12.0 + tinker-atropos installed
|
||||
- GSM8k agent config created
|
||||
- Pipeline verified: Tinker API connection works, all imports pass
|
||||
- **Blocked on billing** (Tinker 402 error)
|
||||
|
||||
### ✅ Singularity/Apptainer Sandbox (Feb 6, 2026)
|
||||
- Nomad raw_exec driver for HPC clusters
|
||||
- All sandbox operations tested and working
|
||||
|
||||
### ✅ Memory Bank (Feb 5, 2026)
|
||||
- Project documentation structure initialized
|
||||
|
||||
## What to KEEP vs REMOVE
|
||||
|
||||
### KEEP (valuable infrastructure):
|
||||
| Component | Location | Purpose |
|
||||
|-----------|----------|---------|
|
||||
| Modal backend | `atropos/backends/modal_backend.py` | Cloud sandbox pool |
|
||||
| Nomad backend | `atropos/backends/nomad_backend.py` | Docker/Singularity sandboxes |
|
||||
| Slot pool | `atropos/slots/` | Container multiplexing |
|
||||
| Nomad client | `atropos/nomad/` | Nomad API |
|
||||
| Sandbox server | `atropos/sandbox_server.py` | HTTP server in containers |
|
||||
| Dockerfile | `atropos/Dockerfile` | Container image |
|
||||
| Agent loop | `environments/agent_loop.py` | Proper OpenAI-spec tool calling |
|
||||
| Base env | `environments/hermes_base_env.py` | Phase 1/2 with parsers |
|
||||
| Tool parsers | `environments/tool_call_parsers/` | 13 model parsers |
|
||||
| SGLang patch | `environments/patches.py` | SGLang compatibility |
|
||||
|
||||
### REMOVE (redundant with environments/):
|
||||
| Component | Location | Replaced By |
|
||||
|-----------|----------|-------------|
|
||||
| ICL agent | `atropos/agent/atropos_agent.py` | `environments/agent_loop.py` |
|
||||
| AgentEnv | `atropos/envs/agent_env.py` | `environments/hermes_base_env.py` |
|
||||
| Tool registry | `atropos/tools/` | `model_tools.py` + `tools/` |
|
||||
| GSM8k ICL env | `tinker-atropos/.../gsm8k_agent.py` | `environments/gsm8k_agent_env.py` |
|
||||
|
||||
## Known Issues
|
||||
- Tinker billing (402 error) - user's payment didn't process
|
||||
- `bwrap_available: false` in Singularity containers
|
||||
- Llama-3-8B-Instruct doesn't reliably produce tool calls via Phase 2 (needs Hermes-format model)
|
||||
- Model answered GSM8k correctly but didn't actually USE the terminal tool (computed mentally)
|
||||
|
||||
## Evolution of Decisions
|
||||
|
||||
### Agent Architecture
|
||||
- **v1 (our branch)**: ICL-based agent with `<tool_call>` XML tags in system prompt
|
||||
- **v2 (Teknium's)**: Proper OpenAI-spec tool calling with `tools=` parameter
|
||||
- **Decision**: Adopt v2, consolidate into `environments/`, keep sandbox backends from v1
|
||||
|
||||
### Environment Organization
|
||||
- **Before**: Two parallel systems (`atropos/envs/` and `environments/`)
|
||||
- **After**: Single system in `environments/`, using `HermesAgentBaseEnv` as base class
|
||||
- Sandbox backends remain in `atropos/backends/` but integrate via terminal backend config
|
||||
|
||||
### Phase 2 SGLang Support
|
||||
- **Problem**: VLLMServer hardcoded for VLLM's /generate format, SGLang is different
|
||||
- **Solution**: Monkey-patch `_tokens_and_logprobs_completion_wrapper` in `environments/patches.py`
|
||||
- **Applied**: Automatically at import time via `apply_patches()` in `hermes_base_env.py`
|
||||
- **Handles**: SGLang format differences AND RunPod's double-JSON wrapping
|
||||
@@ -0,0 +1,44 @@
|
||||
# Project Brief: Hermes-Agent
|
||||
|
||||
## Overview
|
||||
Hermes-Agent is an AI agent harness for LLMs with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools. Named after Hermes, the Greek messenger god, it serves as a bridge between human intent and AI-powered task execution.
|
||||
|
||||
## Core Requirements
|
||||
|
||||
### Primary Goals
|
||||
1. **Interactive CLI Experience** - Beautiful terminal interface with animated feedback, personalities, and session management
|
||||
2. **Flexible Tool System** - Modular tools organized into logical toolsets for different use cases
|
||||
3. **Batch Processing** - Process multiple prompts in parallel with checkpointing and statistics
|
||||
4. **Multi-Backend Support** - Support for local, Docker, Singularity, Modal, and SSH terminal backends
|
||||
5. **Training Data Generation** - Save conversation trajectories in formats suitable for LLM fine-tuning
|
||||
|
||||
### Target Users
|
||||
- AI researchers generating training data
|
||||
- Developers needing an AI assistant with tool access
|
||||
- MLOps practitioners automating workflows
|
||||
- Anyone needing a powerful CLI-based AI agent
|
||||
|
||||
## Scope
|
||||
|
||||
### In Scope
|
||||
- Interactive CLI with rich formatting and kawaii-style feedback
|
||||
- Web tools (search, extract, crawl via Firecrawl)
|
||||
- Terminal tools (command execution across multiple backends)
|
||||
- Browser automation (via agent-browser + Browserbase)
|
||||
- Vision tools (image analysis)
|
||||
- Image generation (FLUX via FAL.ai)
|
||||
- Mixture-of-Agents reasoning
|
||||
- Skills system for on-demand knowledge
|
||||
- Batch processing with parallel workers
|
||||
- Trajectory compression for training
|
||||
|
||||
### Out of Scope (Current)
|
||||
- Proactive suggestions (agent only runs on request)
|
||||
- Clipboard integration (no local system access)
|
||||
- Real-time streaming of thinking/reasoning (deferred)
|
||||
|
||||
## Success Metrics
|
||||
- Clean, maintainable tool architecture
|
||||
- Reliable tool execution with proper error handling
|
||||
- Efficient context management for long conversations
|
||||
- High-quality trajectory data for training
|
||||
@@ -0,0 +1,267 @@
|
||||
# System Patterns: Hermes-Agent
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ CLI (cli.py) │
|
||||
│ - Rich welcome banner with caduceus │
|
||||
│ - prompt_toolkit for input with history │
|
||||
│ - Kawaii-style feedback and personalities │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ AIAgent (run_agent.py) │
|
||||
│ - Conversation loop with tool calling │
|
||||
│ - KawaiiSpinner for animated feedback │
|
||||
│ - Retry logic with exponential backoff │
|
||||
│ - Session logging to logs/ directory │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Tool Routing (model_tools.py) │
|
||||
│ - get_tool_definitions() - returns tools for API calls │
|
||||
│ - handle_function_call() - dispatches to tool handlers │
|
||||
│ - Toolset filtering (enabled/disabled) │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
┌─────────────────┼─────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌───────────┐ ┌───────────┐ ┌───────────┐
|
||||
│ Web Tools │ │ Terminal │ │ Browser │
|
||||
│ (Firecrawl)│ │ (mini-swe)│ │(agent-brw)│
|
||||
└───────────┘ └───────────┘ └───────────┘
|
||||
│ │ │
|
||||
└─────────────────┼─────────────────┘
|
||||
▼
|
||||
┌───────────────┐
|
||||
│ Toolsets │
|
||||
│ (toolsets.py)│
|
||||
│ Composition │
|
||||
└───────────────┘
|
||||
```
|
||||
|
||||
## Key Design Patterns
|
||||
|
||||
### 1. Toolset Composition Pattern
|
||||
Toolsets can include other toolsets, allowing flexible composition:
|
||||
|
||||
```python
|
||||
TOOLSETS = {
|
||||
"web": {"tools": ["web_search", "web_extract"], "includes": []},
|
||||
"debugging": {"tools": ["terminal"], "includes": ["web"]},
|
||||
"full_stack": {"tools": [], "includes": ["web", "terminal", "vision", "browser"]}
|
||||
}
|
||||
```
|
||||
|
||||
Resolution is recursive with cycle detection.
|
||||
|
||||
### 2. Graceful Degradation Pattern
|
||||
Each tool module has a `check_*_requirements()` function:
|
||||
- Tools are only loaded if requirements are met
|
||||
- Missing API keys disable tools, not crash the system
|
||||
- Import errors are caught and tools marked unavailable
|
||||
|
||||
```python
|
||||
try:
|
||||
from tools.web_tools import web_search_tool, check_firecrawl_api_key
|
||||
except ModuleNotFoundError:
|
||||
web_search_tool = None
|
||||
def check_firecrawl_api_key(): return False
|
||||
```
|
||||
|
||||
### 3. Session Isolation Pattern (task_id)
|
||||
Stateful tools (terminal, browser) use `task_id` to isolate concurrent sessions:
|
||||
- Each batch worker gets unique task_id
|
||||
- VMs and browser sessions are tracked per task_id
|
||||
- Cleanup functions release resources: `cleanup_vm(task_id)`, `cleanup_browser(task_id)`
|
||||
|
||||
### 4. Trajectory Format Pattern
|
||||
Conversations are saved in ShareGPT format for training:
|
||||
|
||||
```json
|
||||
{"from": "system", "value": "System prompt with <tools>...</tools>"}
|
||||
{"from": "human", "value": "User message"}
|
||||
{"from": "gpt", "value": "<think>reasoning</think>\n<tool_call>{...}</tool_call>"}
|
||||
{"from": "tool", "value": "<tool_response>{...}</tool_response>"}
|
||||
{"from": "gpt", "value": "Final response"}
|
||||
```
|
||||
|
||||
### 5. Ephemeral System Prompt Pattern
|
||||
Guide model behavior during data collection without saving to trajectories:
|
||||
- `ephemeral_system_prompt` influences execution
|
||||
- Only standard tool-calling system prompt saved to trajectories
|
||||
- Keeps training data clean
|
||||
|
||||
### 6. Retry with Validation Pattern
|
||||
The agent validates responses before accepting:
|
||||
- Check tool names against `valid_tool_names` set
|
||||
- Validate JSON arguments can be parsed
|
||||
- Check for content after `<think>` blocks
|
||||
- Roll back to last valid state on persistent failures
|
||||
|
||||
## Component Relationships
|
||||
|
||||
### AIAgent Class
|
||||
- Central orchestrator for conversations
|
||||
- Manages conversation history
|
||||
- Calls OpenAI-compatible API
|
||||
- Routes tool calls to handlers
|
||||
- Provides animated feedback (KawaiiSpinner)
|
||||
|
||||
### Tool Modules (tools/*.py)
|
||||
- Self-contained tool implementations
|
||||
- Export: handler function + check function + schema
|
||||
- Return JSON strings (never raw dicts)
|
||||
- Accept optional `task_id` for stateful tools
|
||||
|
||||
### Toolsets System (toolsets.py)
|
||||
- Defines logical groupings of tools
|
||||
- Supports composition via `includes`
|
||||
- `resolve_toolset()` recursively resolves all tools
|
||||
- `validate_toolset()` checks if name is valid
|
||||
|
||||
### Model Tools (model_tools.py)
|
||||
- Aggregates all tool definitions
|
||||
- Routes function calls to correct handlers
|
||||
- Filters tools based on enabled/disabled toolsets
|
||||
- Bridge between agent and tool implementations
|
||||
|
||||
## Critical Implementation Paths
|
||||
|
||||
### Tool Execution Flow
|
||||
1. AIAgent receives tool_calls from API response
|
||||
2. Validates tool names against `valid_tool_names`
|
||||
3. Validates JSON arguments can be parsed
|
||||
4. Calls `handle_function_call()` with tool name, args, task_id
|
||||
5. `handle_function_call()` routes to appropriate handler
|
||||
6. Tool executes, returns JSON string
|
||||
7. Result added to conversation as tool message
|
||||
8. Loop continues until natural language response
|
||||
|
||||
### Configuration Loading Flow
|
||||
1. `cli.py` calls `load_cli_config()`
|
||||
2. Loads `cli-config.yaml`, merges with defaults
|
||||
3. Sets environment variables for terminal config
|
||||
4. `AIAgent` reads env vars when initializing terminal tool
|
||||
5. Terminal tool creates appropriate backend based on `TERMINAL_ENV`
|
||||
|
||||
## RL Training Architecture (Consolidated)
|
||||
|
||||
### Environment System (`environments/`)
|
||||
|
||||
The canonical way to build agentic RL environments in Hermes-Agent:
|
||||
|
||||
```
|
||||
environments/
|
||||
├── agent_loop.py ← HermesAgentLoop: OpenAI-spec tool calling
|
||||
├── hermes_base_env.py ← HermesAgentBaseEnv: base class for all envs
|
||||
├── tool_context.py ← ToolContext: reward function tool access
|
||||
├── tool_call_parsers/ ← 11+ model parsers (hermes, qwen, deepseek, etc.)
|
||||
├── terminal_test_env.py ← Example: file creation tasks
|
||||
├── hermes_swe_env.py ← SWE environment
|
||||
└── gsm8k_agent_env.py ← GSM8k with Python REPL (TODO)
|
||||
```
|
||||
|
||||
### Two-Phase Operation
|
||||
- **Phase 1 (OpenAI server)**: Native tool_calls from VLLM/SGLang/OpenRouter
|
||||
- Good for: SFT data gen, testing, evaluation
|
||||
- Server handles tool call parsing via `/v1/chat/completions`
|
||||
- **Phase 2 (ManagedServer)**: Client-side tool call parser + logprob tracking
|
||||
- Required for: RL training (exact token IDs + logprobs for GRPO/PPO)
|
||||
- Uses `/generate` endpoint for raw token output
|
||||
- Parser registry selects per-model parser (hermes, qwen, llama, etc.)
|
||||
- **Verified working** with RunPod SGLang endpoint (Feb 10, 2026)
|
||||
|
||||
### Phase 2 Call Chain (Verified)
|
||||
```
|
||||
collect_trajectory()
|
||||
→ ServerManager.managed_server(tokenizer, tool_call_parser)
|
||||
→ ManagedServer(server=VLLMServer)
|
||||
→ ManagedServer.chat_completion(messages, tools, n, max_tokens, temp)
|
||||
→ _convert_messages_to_prompt(messages, tools=tools) [apply_chat_template]
|
||||
→ _compute_input_ids(prompt, extending_node)
|
||||
→ VLLMServer.tokens_and_logprobs_completion(**kwargs) [public method]
|
||||
→ _tokens_and_logprobs_comp(stat_dict, **kwargs) [retry decorator, semaphore]
|
||||
→ _tokens_and_logprobs_completion_wrapper(**kwargs) [patched for SGLang]
|
||||
→ aiohttp POST to /generate
|
||||
→ Returns (prompt_tokens, [output_tokens], [output_logprobs], [finish_reasons])
|
||||
→ _create_sequence_node(...) [stores in current_nodes]
|
||||
→ tool_call_parser.parse(completion_text) [if parser configured]
|
||||
→ Returns ChatCompletion with tool_calls
|
||||
```
|
||||
|
||||
### SGLang Compatibility Patch (`environments/patches.py`)
|
||||
VLLMServer's `_tokens_and_logprobs_completion_wrapper` is monkey-patched to handle SGLang's
|
||||
different request/response format. Applied automatically at import time via `apply_patches()`.
|
||||
|
||||
```
|
||||
SGLang request: {"input_ids": [...], "sampling_params": {...}, "return_logprob": true}
|
||||
SGLang response: {"meta_info": {"output_token_logprobs": [[logprob, token_id, text], ...]}}
|
||||
|
||||
VLLM request: {"prompt": {"prompt_token_ids": [...]}, "logprobs": 0}
|
||||
VLLM response: {"logprobs": [[{token_id: logprob}]], "finish_reasons": [...]}
|
||||
```
|
||||
|
||||
Also handles RunPod serverless double-JSON wrapping (response body wrapped in quotes).
|
||||
|
||||
### Key Design: Proper Tool Calling (NOT ICL)
|
||||
```python
|
||||
# CORRECT: pass tools= to chat_completion()
|
||||
response = await server.chat_completion(
|
||||
messages=messages,
|
||||
tools=tool_schemas, # ← tokenizer.apply_chat_template(tools=...) formats these
|
||||
temperature=1.0,
|
||||
)
|
||||
# Response has response.choices[0].message.tool_calls (structured objects)
|
||||
|
||||
# WRONG (old approach): embed tools in system prompt as XML
|
||||
system_prompt = f"<tools>{json.dumps(tools)}</tools>" # ← ICL, not proper training format
|
||||
```
|
||||
|
||||
### Sandbox Backends (`atropos/backends/`)
|
||||
|
||||
Infrastructure for scaled sandbox execution, integrated into HermesAgentBaseEnv:
|
||||
|
||||
```
|
||||
ToolBackend (Protocol)
|
||||
├── NomadToolBackend → SlotPool → NomadClient + SandboxExecutor (HTTP)
|
||||
│ ├── Docker driver (default)
|
||||
│ └── Singularity driver (HPC)
|
||||
└── ModalToolBackend → _ModalSandboxPool → modal.Sandbox.exec() (direct)
|
||||
└── _ModalMultiProfileManager (multi-profile support)
|
||||
```
|
||||
|
||||
Two execution modes in HermesAgentBaseEnv (controlled by `tool_pool_mode` config):
|
||||
- `default` - Local tool execution via handle_function_call() + ToolContext
|
||||
- `modal` / `nomad` - Sandbox routing: slot acquire → setup workspace → agent loop → verify → release
|
||||
|
||||
Sandbox routing architecture:
|
||||
```
|
||||
collect_trajectory()
|
||||
├── tool_pool_mode="default" → _collect_trajectory_local()
|
||||
│ └── _run_agent_loop(tool_handler=None) → compute_reward(ctx)
|
||||
│
|
||||
└── tool_pool_mode="modal"/"nomad" → _collect_trajectory_sandbox()
|
||||
├── backend.acquire(task_id) → Slot
|
||||
├── exec_tool = backend.execute_batch wrapper → ExecutionResult
|
||||
├── setup_trajectory_workspace(item, exec_tool) [subclass hook]
|
||||
├── _run_agent_loop(tool_handler=sandbox_tool_handler)
|
||||
│ └── terminal → backend.execute_batch → JSON string
|
||||
│ └── other tools → handle_function_call (local)
|
||||
├── verify_and_score_trajectory(item, result, exec_tool) [subclass hook]
|
||||
└── backend.release(slot, reset_workspace=True) [finally]
|
||||
```
|
||||
|
||||
Key interfaces:
|
||||
- `exec_tool(tool_name, args, timeout)` → `ExecutionResult` (for env hooks)
|
||||
- `tool_handler(tool_name, args, task_id)` → JSON string (for agent loop)
|
||||
|
||||
### Training Pipeline (Tinker + Atropos)
|
||||
```
|
||||
Terminal 1: run-api (port 8000) ← Atropos Rollout API
|
||||
Terminal 2: launch_training.py (port 8001) ← Tinker Trainer + inference
|
||||
Terminal 3: environment.py serve ← Environment (rollouts)
|
||||
```
|
||||
@@ -0,0 +1,113 @@
|
||||
# Technical Context: Hermes-Agent
|
||||
|
||||
## Technologies Used
|
||||
|
||||
### Core Stack
|
||||
- **Python 3.11+** - Primary language
|
||||
- **OpenAI SDK** - For LLM API interactions (OpenAI-compatible)
|
||||
- **OpenRouter** - Default LLM provider (supports multiple models)
|
||||
- **Rich** - Terminal formatting and panels
|
||||
- **prompt_toolkit** - Interactive input with history
|
||||
- **Fire** - CLI argument parsing
|
||||
- **PyYAML** - Configuration files
|
||||
- **python-dotenv** - Environment variable management
|
||||
|
||||
### Tool Dependencies
|
||||
- **Firecrawl** - Web search and extraction (`FIRECRAWL_API_KEY`)
|
||||
- **mini-swe-agent** - Terminal tool backend (local/docker/singularity/modal/ssh)
|
||||
- **agent-browser** - Browser automation (npm package)
|
||||
- **Browserbase** - Cloud browser execution (`BROWSERBASE_API_KEY`)
|
||||
- **FAL.ai** - Image generation with FLUX (`FAL_KEY`)
|
||||
- **Nous API** - Vision and MoA tools (`NOUS_API_KEY`)
|
||||
|
||||
### Optional Dependencies
|
||||
- **Modal** - Cloud compute for sandboxed environments
|
||||
- **Singularity/Apptainer** - Rootless containers (HPC environments)
|
||||
- **Docker** - Container isolation
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Quick Start
|
||||
```bash
|
||||
# Clone with submodules
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Create virtual environment
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Install browser tools (optional)
|
||||
npm install
|
||||
|
||||
# Configure environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your API keys
|
||||
```
|
||||
|
||||
### Key Configuration Files
|
||||
- `.env` - API keys and secrets
|
||||
- `cli-config.yaml` - CLI configuration (model, terminal, toolsets, personalities)
|
||||
- `configs/` - Batch run scripts and configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
**Required for Full Functionality:**
|
||||
- `OPENROUTER_API_KEY` - Primary LLM access
|
||||
- `FIRECRAWL_API_KEY` - Web tools
|
||||
- `NOUS_API_KEY` - Vision and reasoning tools
|
||||
- `FAL_KEY` - Image generation
|
||||
|
||||
**Terminal Backend:**
|
||||
- `TERMINAL_ENV` - Backend type: `local`, `docker`, `singularity`, `modal`, `ssh`
|
||||
- `TERMINAL_CWD` - Working directory
|
||||
- `TERMINAL_DOCKER_IMAGE` / `TERMINAL_SINGULARITY_IMAGE` - Container images
|
||||
- `TERMINAL_SSH_HOST/USER/KEY` - SSH backend config
|
||||
- `SUDO_PASSWORD` - Optional sudo support
|
||||
|
||||
**Browser:**
|
||||
- `BROWSERBASE_API_KEY` - Browser automation
|
||||
- `BROWSERBASE_PROJECT_ID` - Browserbase project
|
||||
|
||||
## Technical Constraints
|
||||
|
||||
1. **Context Window Limits** - Long tool outputs can exhaust context; trajectory compression helps
|
||||
2. **API Rate Limits** - OpenRouter and tool APIs have rate limits; exponential backoff implemented
|
||||
3. **Tool Availability** - Tools gracefully degrade if dependencies/keys missing
|
||||
4. **Async Compatibility** - Some tools are async, handled via `asyncio.run()` in sync context
|
||||
|
||||
## Dependency Graph
|
||||
|
||||
```
|
||||
tools/*.py → tools/__init__.py → model_tools.py → toolsets.py → toolset_distributions.py
|
||||
↑
|
||||
run_agent.py ──────────────────────────┘
|
||||
cli.py → run_agent.py (uses AIAgent with quiet_mode=True)
|
||||
batch_runner.py → run_agent.py + toolset_distributions.py
|
||||
```
|
||||
|
||||
## Tool Usage Patterns
|
||||
|
||||
### Adding a New Tool
|
||||
1. Create `tools/your_tool.py` with handler + requirements check
|
||||
2. Export in `tools/__init__.py`
|
||||
3. Register in `model_tools.py` (definitions + handler routing)
|
||||
4. Add to toolset in `toolsets.py`
|
||||
5. Optionally add to `toolset_distributions.py` for batch processing
|
||||
|
||||
### Tool Handler Pattern
|
||||
```python
|
||||
def your_tool(param: str, task_id: str = None) -> str:
|
||||
"""Execute tool and return JSON string result."""
|
||||
try:
|
||||
result = {"success": True, "data": "..."}
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
```
|
||||
|
||||
All tool handlers MUST return a JSON string, never raw dicts.
|
||||
+1
-1
Submodule mini-swe-agent updated: 07aa6a7385...ee36b3d4e5
@@ -0,0 +1,134 @@
|
||||
# Modal Sandbox Profiles Configuration
|
||||
# =====================================
|
||||
# This file defines different sandbox profiles for heterogeneous workloads.
|
||||
# Copy to modal_profiles.yaml and customize as needed.
|
||||
#
|
||||
# Usage:
|
||||
# terminal_tool("python train.py", profile="pytorch-gpu")
|
||||
# terminal_tool("npm test", profile="node")
|
||||
#
|
||||
# Each profile can specify:
|
||||
# - image: Docker image to use
|
||||
# - gpu: GPU type (null, "T4", "A10G", "A100", "H100")
|
||||
# - cpu: CPU cores (float)
|
||||
# - memory: Memory in MB
|
||||
# - min_pool: Minimum warm sandboxes (cost vs latency tradeoff)
|
||||
# - max_pool: Maximum sandboxes (hard cost cap)
|
||||
# - idle_timeout: Server-side auto-cleanup in seconds
|
||||
# - max_lifetime: Maximum sandbox lifetime in seconds
|
||||
# - scale_down_idle: Client-side scale-down threshold in seconds
|
||||
# - workdir: Working directory inside container
|
||||
# - secrets: List of Modal Secret names to inject (created via dashboard/CLI)
|
||||
# - env_vars: Dict of environment variables to pass directly
|
||||
# - use_dotenv: If true, loads local .env file into sandbox
|
||||
#
|
||||
# SECRETS SETUP:
|
||||
# Create secrets via Modal dashboard or CLI:
|
||||
# modal secret create huggingface-token HF_TOKEN=hf_xxx
|
||||
# modal secret create openai-key OPENAI_API_KEY=sk-xxx
|
||||
# Then reference by name in profile's secrets list.
|
||||
|
||||
# Default profile used when no profile specified
|
||||
default_profile: default
|
||||
|
||||
profiles:
|
||||
# Default Python environment - good for most tasks
|
||||
default:
|
||||
image: python:3.11
|
||||
gpu: null
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 1 # Keep 1 warm for fast response
|
||||
max_pool: 5
|
||||
idle_timeout: 120 # Modal terminates if idle 2 min
|
||||
max_lifetime: 3600 # Max 1 hour
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
secrets: [] # Add secret names here: ["my-api-keys"]
|
||||
env_vars: {} # Add env vars here: {DEBUG: "1"}
|
||||
use_dotenv: false # Set to true to load local .env
|
||||
|
||||
# PyTorch with GPU for ML training/inference
|
||||
pytorch-gpu:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: T4 # Options: T4, A10G, A100, H100
|
||||
cpu: 4.0
|
||||
memory: 16384 # 16GB
|
||||
min_pool: 0 # Don't keep GPU sandboxes warm (expensive!)
|
||||
max_pool: 2
|
||||
idle_timeout: 60 # Shorter idle timeout for GPU (cost)
|
||||
max_lifetime: 1800 # 30 min max for GPU tasks
|
||||
scale_down_idle: 60
|
||||
workdir: /workspace
|
||||
# ML-specific secrets
|
||||
secrets:
|
||||
- huggingface-token # HF_TOKEN env var
|
||||
- wandb-key # WANDB_API_KEY env var
|
||||
env_vars:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
|
||||
|
||||
# High-end GPU for large models
|
||||
pytorch-a100:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: A100
|
||||
cpu: 8.0
|
||||
memory: 65536 # 64GB
|
||||
min_pool: 0
|
||||
max_pool: 1 # Only 1 at a time (very expensive)
|
||||
idle_timeout: 30
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 30
|
||||
workdir: /workspace
|
||||
|
||||
# Node.js for JavaScript/TypeScript tasks
|
||||
node:
|
||||
image: node:18
|
||||
gpu: null
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 0 # Create on-demand
|
||||
max_pool: 3
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# High memory for data processing
|
||||
high-memory:
|
||||
image: python:3.11
|
||||
gpu: null
|
||||
cpu: 4.0
|
||||
memory: 32768 # 32GB
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# Rust development environment
|
||||
rust:
|
||||
image: rust:1.75
|
||||
gpu: null
|
||||
cpu: 2.0
|
||||
memory: 4096
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# Go development environment
|
||||
golang:
|
||||
image: golang:1.21
|
||||
gpu: null
|
||||
cpu: 2.0
|
||||
memory: 4096
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
+35
-100
@@ -41,7 +41,7 @@ from tools.terminal_hecate import terminal_hecate_tool, check_hecate_requirement
|
||||
from tools.vision_tools import vision_analyze_tool, check_vision_requirements
|
||||
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
|
||||
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
|
||||
from tools.skills_tool import skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION
|
||||
from tools.skills_tool import skills_categories, skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from tools.rl_training_tool import (
|
||||
rl_list_environments,
|
||||
@@ -83,8 +83,6 @@ from tools.browser_tool import (
|
||||
check_browser_requirements,
|
||||
BROWSER_TOOL_SCHEMAS
|
||||
)
|
||||
# Text-to-speech tool (Edge TTS / ElevenLabs / OpenAI)
|
||||
from tools.tts_tool import text_to_speech_tool, check_tts_requirements
|
||||
from toolsets import (
|
||||
get_toolset, resolve_toolset, resolve_multiple_toolsets,
|
||||
get_all_toolsets, get_toolset_names, validate_toolset,
|
||||
@@ -145,7 +143,7 @@ TOOLSET_REQUIREMENTS = {
|
||||
"env_vars": [], # Just needs skills directory
|
||||
"check_fn": check_skills_requirements,
|
||||
"setup_url": None,
|
||||
"tools": ["skills_list", "skill_view"],
|
||||
"tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
},
|
||||
"rl": {
|
||||
"name": "RL Training (Tinker-Atropos)",
|
||||
@@ -167,13 +165,6 @@ TOOLSET_REQUIREMENTS = {
|
||||
"setup_url": None,
|
||||
"tools": ["read_file", "write_file", "patch", "search"],
|
||||
},
|
||||
"tts": {
|
||||
"name": "Text-to-Speech",
|
||||
"env_vars": [], # Edge TTS needs no key; premium providers checked at runtime
|
||||
"check_fn": check_tts_requirements,
|
||||
"setup_url": None,
|
||||
"tools": ["text_to_speech"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -401,7 +392,7 @@ def get_image_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "image_generate",
|
||||
"description": "Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL. Display it using markdown: ",
|
||||
"description": "Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL that can be displayed using <img src=\"{URL}\"></img> tags.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -441,7 +432,24 @@ def get_skills_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"properties": {
|
||||
"category": {
|
||||
"type": "string",
|
||||
"description": "Optional category filter to narrow results"
|
||||
"description": "Optional category filter (from skills_categories)"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "skills_categories",
|
||||
"description": "List available skill categories. Call this first to discover what skill categories exist, then use skills_list(category) to see skills in a category.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"verbose": {
|
||||
"type": "boolean",
|
||||
"description": "If true, include skill counts per category. Default: false."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
@@ -871,38 +879,6 @@ def get_file_tool_definitions() -> List[Dict[str, Any]]:
|
||||
]
|
||||
|
||||
|
||||
def get_tts_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for text-to-speech tools in OpenAI's expected format.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of TTS tool definitions compatible with OpenAI API
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "text_to_speech",
|
||||
"description": "Convert text to speech audio. Returns a MEDIA: path that the platform delivers as a voice message. On Telegram it plays as a voice bubble, on Discord/WhatsApp as an audio attachment. In CLI mode, saves to ~/voice-memos/. Voice and provider are user-configured, not model-selected.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to convert to speech. Keep under 4000 characters."
|
||||
},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "Optional custom file path to save the audio. Defaults to ~/voice-memos/<timestamp>.mp3"
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def get_all_tool_names() -> List[str]:
|
||||
"""
|
||||
Get the names of all available tools across all toolsets.
|
||||
@@ -934,7 +910,7 @@ def get_all_tool_names() -> List[str]:
|
||||
|
||||
# Skills tools
|
||||
if check_skills_requirements():
|
||||
tool_names.extend(["skills_list", "skill_view"])
|
||||
tool_names.extend(["skills_categories", "skills_list", "skill_view"])
|
||||
|
||||
# Browser automation tools
|
||||
if check_browser_requirements():
|
||||
@@ -967,10 +943,6 @@ def get_all_tool_names() -> List[str]:
|
||||
"read_file", "write_file", "patch", "search"
|
||||
])
|
||||
|
||||
# Text-to-speech tools
|
||||
if check_tts_requirements():
|
||||
tool_names.extend(["text_to_speech"])
|
||||
|
||||
return tool_names
|
||||
|
||||
|
||||
@@ -985,6 +957,7 @@ TOOL_TO_TOOLSET_MAP = {
|
||||
"mixture_of_agents": "moa_tools",
|
||||
"image_generate": "image_tools",
|
||||
# Skills tools
|
||||
"skills_categories": "skills_tools",
|
||||
"skills_list": "skills_tools",
|
||||
"skill_view": "skills_tools",
|
||||
# Browser automation tools
|
||||
@@ -1012,8 +985,6 @@ TOOL_TO_TOOLSET_MAP = {
|
||||
"rl_stop_training": "rl_tools",
|
||||
"rl_get_results": "rl_tools",
|
||||
"rl_list_runs": "rl_tools",
|
||||
# Text-to-speech tools
|
||||
"text_to_speech": "tts_tools",
|
||||
# File manipulation tools
|
||||
"read_file": "file_tools",
|
||||
"write_file": "file_tools",
|
||||
@@ -1117,11 +1088,6 @@ def get_tool_definitions(
|
||||
for tool in get_file_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# Text-to-speech tools
|
||||
if check_tts_requirements():
|
||||
for tool in get_tts_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# Determine which tools to include based on toolsets
|
||||
tools_to_include = set()
|
||||
|
||||
@@ -1143,7 +1109,7 @@ def get_tool_definitions(
|
||||
"vision_tools": ["vision_analyze"],
|
||||
"moa_tools": ["mixture_of_agents"],
|
||||
"image_tools": ["image_generate"],
|
||||
"skills_tools": ["skills_list", "skill_view"],
|
||||
"skills_tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
"browser_tools": [
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
@@ -1158,8 +1124,7 @@ def get_tool_definitions(
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
],
|
||||
"file_tools": ["read_file", "write_file", "patch", "search"],
|
||||
"tts_tools": ["text_to_speech"]
|
||||
"file_tools": ["read_file", "write_file", "patch", "search"]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.update(legacy_tools)
|
||||
@@ -1197,7 +1162,7 @@ def get_tool_definitions(
|
||||
"vision_tools": ["vision_analyze"],
|
||||
"moa_tools": ["mixture_of_agents"],
|
||||
"image_tools": ["image_generate"],
|
||||
"skills_tools": ["skills_list", "skill_view"],
|
||||
"skills_tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
"browser_tools": [
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
@@ -1212,8 +1177,7 @@ def get_tool_definitions(
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
],
|
||||
"file_tools": ["read_file", "write_file", "patch", "search"],
|
||||
"tts_tools": ["text_to_speech"]
|
||||
"file_tools": ["read_file", "write_file", "patch", "search"]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.difference_update(legacy_tools)
|
||||
@@ -1427,7 +1391,11 @@ def handle_skills_function_call(function_name: str, function_args: Dict[str, Any
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
if function_name == "skills_list":
|
||||
if function_name == "skills_categories":
|
||||
verbose = function_args.get("verbose", False)
|
||||
return skills_categories(verbose=verbose)
|
||||
|
||||
elif function_name == "skills_list":
|
||||
category = function_args.get("category")
|
||||
return skills_list(category=category)
|
||||
|
||||
@@ -1671,28 +1639,6 @@ def handle_file_function_call(
|
||||
return json.dumps({"error": f"Unknown file function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_tts_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
Handle function calls for text-to-speech tools.
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the TTS function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
if function_name == "text_to_speech":
|
||||
text = function_args.get("text", "")
|
||||
output_path = function_args.get("output_path")
|
||||
return text_to_speech_tool(text=text, output_path=output_path)
|
||||
|
||||
return json.dumps({"error": f"Unknown TTS function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
@@ -1740,7 +1686,7 @@ def handle_function_call(
|
||||
return handle_image_function_call(function_name, function_args)
|
||||
|
||||
# Route skills tools
|
||||
elif function_name in ["skills_list", "skill_view"]:
|
||||
elif function_name in ["skills_categories", "skills_list", "skill_view"]:
|
||||
return handle_skills_function_call(function_name, function_args)
|
||||
|
||||
# Route browser automation tools
|
||||
@@ -1770,10 +1716,6 @@ def handle_function_call(
|
||||
elif function_name in ["read_file", "write_file", "patch", "search"]:
|
||||
return handle_file_function_call(function_name, function_args, task_id)
|
||||
|
||||
# Route text-to-speech tools
|
||||
elif function_name in ["text_to_speech"]:
|
||||
return handle_tts_function_call(function_name, function_args)
|
||||
|
||||
else:
|
||||
error_msg = f"Unknown function: {function_name}"
|
||||
print(f"❌ {error_msg}")
|
||||
@@ -1825,7 +1767,7 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
},
|
||||
"skills_tools": {
|
||||
"available": check_skills_requirements(),
|
||||
"tools": ["skills_list", "skill_view"],
|
||||
"tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
"description": "Access skill documents that provide specialized instructions, guidelines, or knowledge the agent can load on demand",
|
||||
"requirements": ["skills/ directory in repo root"]
|
||||
},
|
||||
@@ -1851,12 +1793,6 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
"tools": ["read_file", "write_file", "patch", "search"],
|
||||
"description": "File manipulation tools: read/write files, search content/files, patch with fuzzy matching",
|
||||
"requirements": ["Terminal backend available (local/docker/ssh/singularity/modal)"]
|
||||
},
|
||||
"tts_tools": {
|
||||
"available": check_tts_requirements(),
|
||||
"tools": ["text_to_speech"],
|
||||
"description": "Text-to-speech: convert text to audio (Edge TTS free, ElevenLabs, OpenAI)",
|
||||
"requirements": ["edge-tts package (free) or ELEVENLABS_API_KEY or OPENAI_API_KEY"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1878,8 +1814,7 @@ def check_toolset_requirements() -> Dict[str, bool]:
|
||||
"skills_tools": check_skills_requirements(),
|
||||
"browser_tools": check_browser_requirements(),
|
||||
"cronjob_tools": check_cronjob_requirements(),
|
||||
"file_tools": check_file_requirements(),
|
||||
"tts_tools": check_tts_requirements()
|
||||
"file_tools": check_file_requirements()
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
# Nomad Development Configuration (Hermes-Agent)
|
||||
# Run with: nomad agent -dev -config=nomad-dev.hcl
|
||||
#
|
||||
# This is intended for local development only.
|
||||
|
||||
client {
|
||||
enabled = true
|
||||
|
||||
options {
|
||||
# Enable Docker volume mounts for persistent slot workspaces
|
||||
"docker.volumes.enabled" = "true"
|
||||
}
|
||||
}
|
||||
|
||||
# Docker driver plugin configuration
|
||||
plugin "docker" {
|
||||
config {
|
||||
# CRITICAL: Enable volume mounts
|
||||
volumes {
|
||||
enabled = true
|
||||
}
|
||||
|
||||
# Allow privileged containers if needed
|
||||
allow_privileged = false
|
||||
|
||||
# Garbage collection settings
|
||||
gc {
|
||||
image = true
|
||||
# NOTE: For local dev we often rely on locally built images like `atropos-sandbox:local`.
|
||||
# A short image GC delay can delete these between runs, causing confusing "Failed to pull"
|
||||
# crash loops. Keep this comfortably long; tighten it for CI/production if needed.
|
||||
image_delay = "24h"
|
||||
container = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
# Nomad Configuration for Singularity/Apptainer Sandbox
|
||||
# Run with: nomad agent -dev -config=nomad-singularity.hcl
|
||||
#
|
||||
# This uses the raw_exec driver to run Apptainer containers.
|
||||
# Suitable for HPC environments where Docker cannot run without sudo.
|
||||
|
||||
client {
|
||||
enabled = true
|
||||
|
||||
options {
|
||||
# Enable raw_exec driver for Singularity/Apptainer
|
||||
"driver.raw_exec.enable" = "1"
|
||||
}
|
||||
}
|
||||
|
||||
# raw_exec driver plugin configuration
|
||||
plugin "raw_exec" {
|
||||
config {
|
||||
enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
# Optional: If you have the nomad-driver-singularity plugin installed,
|
||||
# uncomment the following instead of using raw_exec:
|
||||
# plugin "singularity" {
|
||||
# config {
|
||||
# enabled = true
|
||||
# # Allow bind mounts
|
||||
# bind_paths = ["/tmp", "/var/tmp"]
|
||||
# }
|
||||
# }
|
||||
+29
-2
@@ -19,6 +19,7 @@ dependencies = [
|
||||
"rich",
|
||||
"tenacity",
|
||||
"pyyaml",
|
||||
"prompt_toolkit",
|
||||
"requests",
|
||||
"jinja2",
|
||||
"pydantic>=2.0",
|
||||
@@ -39,6 +40,20 @@ dev = ["pytest", "pytest-asyncio"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0"]
|
||||
cron = ["croniter"]
|
||||
cli = ["simple-term-menu"]
|
||||
# Install Atropos + Tinker training integration from source.
|
||||
# Uses tool_call_support branch for ManagedServer tool calling (PR #366).
|
||||
atropos = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git@tool_call_support",
|
||||
"tinker @ git+https://github.com/thinking-machines-lab/tinker.git",
|
||||
# Atropos integration runtime deps (kept optional for Hermes-only users)
|
||||
"aiohttp",
|
||||
"fastapi",
|
||||
"uvicorn",
|
||||
"pyte",
|
||||
"torch",
|
||||
"wandb",
|
||||
"math-verify",
|
||||
]
|
||||
all = [
|
||||
"hermes-agent[modal]",
|
||||
"hermes-agent[messaging]",
|
||||
@@ -50,9 +65,21 @@ all = [
|
||||
[project.scripts]
|
||||
hermes = "hermes_cli.main:main"
|
||||
hermes-agent = "run_agent:main"
|
||||
hermes-atropos-sandbox-smoke = "atropos.envs.sandbox_terminal_smoke_env:SandboxTerminalSmokeEnv.cli"
|
||||
hermes-atropos-toolserver-smoke = "atropos.envs.toolserver_smoke_env:ToolServerSmokeEnv.cli"
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions", "cli"]
|
||||
py-modules = [
|
||||
"run_agent",
|
||||
"model_tools",
|
||||
"toolsets",
|
||||
"batch_runner",
|
||||
"trajectory_compressor",
|
||||
"toolset_distributions",
|
||||
"atropos_compatible_agent",
|
||||
"local_server",
|
||||
"cli",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron"]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron", "atropos", "atropos.*"]
|
||||
|
||||
@@ -29,12 +29,6 @@ platformdirs
|
||||
# Optional: For Modal backend (cloud execution)
|
||||
# swe-rex[modal]>=1.4.0 # Includes modal + boto3 + swe-rex runtime
|
||||
|
||||
# Text-to-speech (Edge TTS is free, no API key needed)
|
||||
edge-tts
|
||||
|
||||
# Optional: Premium TTS providers
|
||||
# elevenlabs # Uncomment if using ElevenLabs TTS (needs ELEVENLABS_API_KEY)
|
||||
|
||||
# Optional: For cron expression parsing (cronjob scheduling)
|
||||
croniter
|
||||
|
||||
|
||||
+73
-571
@@ -20,7 +20,6 @@ Usage:
|
||||
response = agent.run_conversation("Tell me about the latest Python updates")
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -31,7 +30,6 @@ import threading
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI
|
||||
import fire
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -49,46 +47,11 @@ elif not os.getenv("HERMES_QUIET"):
|
||||
|
||||
# Import our tool system
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
from tools.terminal_tool import cleanup_vm, set_interrupt_event as _set_terminal_interrupt
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.browser_tool import cleanup_browser
|
||||
|
||||
import requests
|
||||
|
||||
# =============================================================================
|
||||
# Default Agent Identity & Platform Hints
|
||||
# =============================================================================
|
||||
|
||||
# The default identity prompt is prepended to every conversation so the agent
|
||||
# knows who it is and behaves consistently across platforms.
|
||||
DEFAULT_AGENT_IDENTITY = (
|
||||
"You are Hermes Agent, an intelligent AI assistant created by Nous Research. "
|
||||
"You are helpful, knowledgeable, and direct. You assist users with a wide "
|
||||
"range of tasks including answering questions, writing and editing code, "
|
||||
"analyzing information, creative work, and executing actions via your tools. "
|
||||
"You communicate clearly, admit uncertainty when appropriate, and prioritize "
|
||||
"being genuinely useful over being verbose unless otherwise directed below."
|
||||
)
|
||||
|
||||
# Platform-specific formatting hints appended to the system prompt.
|
||||
# These tell the agent how to format its output for the current interface.
|
||||
PLATFORM_HINTS = {
|
||||
"whatsapp": (
|
||||
"You are on a text messaging communication platform, WhatsApp. "
|
||||
"Please do not use markdown as it does not render."
|
||||
),
|
||||
"telegram": (
|
||||
"You are on a text messaging communication platform, Telegram. "
|
||||
"Please do not use markdown as it does not render."
|
||||
),
|
||||
"discord": (
|
||||
"You are in a Discord server or group chat communicating with your user."
|
||||
),
|
||||
"cli": (
|
||||
"You are a CLI AI Agent. Try not to use markdown but simple text "
|
||||
"renderable inside a terminal."
|
||||
),
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Model Context Management
|
||||
# =============================================================================
|
||||
@@ -493,389 +456,18 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
return compressed
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Anthropic Prompt Caching (system_and_3 strategy)
|
||||
# =============================================================================
|
||||
# Reduces input token costs by ~75% on multi-turn conversations by caching
|
||||
# the conversation prefix. Uses 4 cache_control breakpoints (Anthropic max):
|
||||
# 1. System prompt (stable across all turns)
|
||||
# 2-4. Last 3 non-system messages (rolling window)
|
||||
#
|
||||
# Cached tokens are read at 0.1x input price. Cache writes cost 1.25x (5m TTL)
|
||||
# or 2x (1h TTL). Only applied to Claude models via OpenRouter.
|
||||
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
"""
|
||||
Add cache_control to a single message, handling all format variations.
|
||||
|
||||
- tool messages: cache_control at message level (Anthropic API quirk)
|
||||
- string content: converted to multipart content array
|
||||
- list content: marker added to last item
|
||||
- None content (assistant with tool_calls): message level
|
||||
"""
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "tool":
|
||||
msg["cache_control"] = cache_marker
|
||||
return
|
||||
|
||||
if content is None:
|
||||
msg["cache_control"] = cache_marker
|
||||
return
|
||||
|
||||
if isinstance(content, str):
|
||||
msg["content"] = [{"type": "text", "text": content, "cache_control": cache_marker}]
|
||||
return
|
||||
|
||||
if isinstance(content, list) and content:
|
||||
last = content[-1]
|
||||
if isinstance(last, dict):
|
||||
last["cache_control"] = cache_marker
|
||||
|
||||
|
||||
def apply_anthropic_cache_control(
|
||||
api_messages: List[Dict[str, Any]],
|
||||
cache_ttl: str = "5m",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply system_and_3 caching strategy to messages for Anthropic models.
|
||||
|
||||
Places up to 4 cache_control breakpoints:
|
||||
1. System prompt (index 0, stable across all turns)
|
||||
2-4. Last 3 non-system messages (rolling cache frontier)
|
||||
|
||||
Each breakpoint tells Anthropic "cache everything from the start up to here."
|
||||
Multiple breakpoints create a ladder of cached prefixes at different depths,
|
||||
which provides robust cache hits even when the most recent cache entry hasn't
|
||||
propagated yet.
|
||||
|
||||
Args:
|
||||
api_messages: Fully assembled message list (system prompt first).
|
||||
cache_ttl: "5m" (default, 1.25x write cost) or "1h" (2x write cost).
|
||||
|
||||
Returns:
|
||||
Deep copy of messages with cache_control breakpoints injected.
|
||||
"""
|
||||
messages = copy.deepcopy(api_messages)
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
marker = {"type": "ephemeral"}
|
||||
if cache_ttl == "1h":
|
||||
marker["ttl"] = "1h"
|
||||
|
||||
breakpoints_used = 0
|
||||
|
||||
# Breakpoint 1: System prompt (always stable, gives a guaranteed minimum hit)
|
||||
if messages[0].get("role") == "system":
|
||||
_apply_cache_marker(messages[0], marker)
|
||||
breakpoints_used += 1
|
||||
|
||||
# Breakpoints 2-4: Last 3 non-system messages (rolling window)
|
||||
remaining = 4 - breakpoints_used
|
||||
non_sys = [i for i in range(len(messages)) if messages[i].get("role") != "system"]
|
||||
for idx in non_sys[-remaining:]:
|
||||
_apply_cache_marker(messages[idx], marker)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Default System Prompt Components
|
||||
# =============================================================================
|
||||
|
||||
# Skills guidance - embeds a compact skill index in the system prompt so
|
||||
# the model can match skills at a glance without extra tool calls.
|
||||
def build_skills_system_prompt() -> str:
|
||||
"""
|
||||
Build a dynamic skills system prompt by scanning the skills/ directory.
|
||||
|
||||
Returns a prompt section that lists all skill categories (with descriptions
|
||||
from DESCRIPTION.md) and their skill names inline, so the model can
|
||||
immediately see if a relevant skill exists and load it with a single
|
||||
skill_view(name) call -- no discovery tool calls needed.
|
||||
|
||||
Returns:
|
||||
str: The skills system prompt section, or empty string if no skills found.
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
skills_dir = Path(__file__).parent / "skills"
|
||||
if not skills_dir.exists():
|
||||
return ""
|
||||
|
||||
# Scan for SKILL.md files grouped by category
|
||||
skills_by_category = {}
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
category = parts[0]
|
||||
skill_name = parts[-2] # Folder containing SKILL.md
|
||||
else:
|
||||
category = "general"
|
||||
skill_name = skill_file.parent.name
|
||||
skills_by_category.setdefault(category, []).append(skill_name)
|
||||
|
||||
if not skills_by_category:
|
||||
return ""
|
||||
|
||||
# Load category descriptions from DESCRIPTION.md files (YAML frontmatter)
|
||||
category_descriptions = {}
|
||||
for category in skills_by_category:
|
||||
desc_file = skills_dir / category / "DESCRIPTION.md"
|
||||
if desc_file.exists():
|
||||
try:
|
||||
content = desc_file.read_text(encoding="utf-8")
|
||||
# Parse description from YAML frontmatter: ---\ndescription: ...\n---
|
||||
match = re.search(r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---", content, re.MULTILINE | re.DOTALL)
|
||||
if match:
|
||||
category_descriptions[category] = match.group(1).strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build compact index: category with description + skill names
|
||||
index_lines = []
|
||||
for category in sorted(skills_by_category.keys()):
|
||||
desc = category_descriptions.get(category, "")
|
||||
names = ", ".join(sorted(skills_by_category[category]))
|
||||
if desc:
|
||||
index_lines.append(f" {category}: {desc}")
|
||||
else:
|
||||
index_lines.append(f" {category}:")
|
||||
index_lines.append(f" skills: {names}")
|
||||
|
||||
return (
|
||||
"## Skills (mandatory)\n"
|
||||
"Before replying, scan the skills below. If one clearly matches your task, "
|
||||
"load it with skill_view(name) and follow its instructions.\n"
|
||||
"\n"
|
||||
"<available_skills>\n"
|
||||
+ "\n".join(index_lines) + "\n"
|
||||
"</available_skills>\n"
|
||||
"\n"
|
||||
"If none match, proceed normally without loading a skill."
|
||||
)
|
||||
# Skills guidance - instructs the model to check skills before technical tasks
|
||||
SKILLS_SYSTEM_PROMPT = """## Skills
|
||||
Before answering technical questions about tools, frameworks, or workflows:
|
||||
1. Check skills_categories to see if a relevant category exists
|
||||
2. If a category matches your task, use skills_list with that category
|
||||
3. If a skill matches, load it with skill_view and follow its instructions
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Context File Injection (SOUL.md, AGENTS.md, .cursorrules)
|
||||
# =============================================================================
|
||||
|
||||
# Maximum characters per context file before truncation
|
||||
CONTEXT_FILE_MAX_CHARS = 20_000
|
||||
# Truncation strategy: keep 70% from the head, 20% from the tail
|
||||
CONTEXT_TRUNCATE_HEAD_RATIO = 0.7
|
||||
CONTEXT_TRUNCATE_TAIL_RATIO = 0.2
|
||||
|
||||
|
||||
def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE_MAX_CHARS) -> str:
|
||||
"""
|
||||
Truncate content if it exceeds max_chars using a head/tail strategy.
|
||||
|
||||
Keeps 70% from the start and 20% from the end, with a truncation
|
||||
marker in the middle so the model knows content was cut.
|
||||
"""
|
||||
if len(content) <= max_chars:
|
||||
return content
|
||||
|
||||
head_chars = int(max_chars * CONTEXT_TRUNCATE_HEAD_RATIO)
|
||||
tail_chars = int(max_chars * CONTEXT_TRUNCATE_TAIL_RATIO)
|
||||
head = content[:head_chars]
|
||||
tail = content[-tail_chars:]
|
||||
|
||||
marker = f"\n\n[...truncated {filename}: kept {head_chars}+{tail_chars} of {len(content)} chars. Use file tools to read the full file.]\n\n"
|
||||
return head + marker + tail
|
||||
|
||||
|
||||
def build_context_files_prompt(cwd: str = None) -> str:
|
||||
"""
|
||||
Discover and load context files (SOUL.md, AGENTS.md, .cursorrules)
|
||||
for injection into the system prompt.
|
||||
|
||||
Discovery rules:
|
||||
- AGENTS.md: Recursively search from cwd (only if top-level exists).
|
||||
Each file becomes a ## section with its relative path.
|
||||
- .cursorrules: Check cwd for .cursorrules file and .cursor/rules/*.mdc
|
||||
- SOUL.md: Check cwd first, then ~/.hermes/SOUL.md as global fallback
|
||||
|
||||
Args:
|
||||
cwd: Working directory to search from. Defaults to os.getcwd().
|
||||
|
||||
Returns:
|
||||
str: The context files prompt section, or empty string if none found.
|
||||
"""
|
||||
import os
|
||||
import glob as glob_mod
|
||||
from pathlib import Path
|
||||
|
||||
if cwd is None:
|
||||
cwd = os.getcwd()
|
||||
|
||||
cwd_path = Path(cwd).resolve()
|
||||
sections = []
|
||||
|
||||
# ----- AGENTS.md (hierarchical, recursive) -----
|
||||
top_level_agents = None
|
||||
for name in ["AGENTS.md", "agents.md"]:
|
||||
candidate = cwd_path / name
|
||||
if candidate.exists():
|
||||
top_level_agents = candidate
|
||||
break
|
||||
|
||||
if top_level_agents:
|
||||
# Recursively find all AGENTS.md files (case-insensitive)
|
||||
agents_files = []
|
||||
for root, dirs, files in os.walk(cwd_path):
|
||||
# Skip hidden directories and common non-project dirs
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')]
|
||||
for f in files:
|
||||
if f.lower() == "agents.md":
|
||||
agents_files.append(Path(root) / f)
|
||||
|
||||
# Sort by path depth (top-level first, then deeper)
|
||||
agents_files.sort(key=lambda p: len(p.parts))
|
||||
|
||||
total_agents_content = ""
|
||||
for agents_path in agents_files:
|
||||
try:
|
||||
content = agents_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
rel_path = agents_path.relative_to(cwd_path)
|
||||
total_agents_content += f"## {rel_path}\n\n{content}\n\n"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if total_agents_content:
|
||||
total_agents_content = _truncate_content(total_agents_content, "AGENTS.md")
|
||||
sections.append(total_agents_content)
|
||||
|
||||
# ----- .cursorrules -----
|
||||
cursorrules_content = ""
|
||||
|
||||
# Check for .cursorrules file
|
||||
cursorrules_file = cwd_path / ".cursorrules"
|
||||
if cursorrules_file.exists():
|
||||
try:
|
||||
content = cursorrules_file.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
cursorrules_content += f"## .cursorrules\n\n{content}\n\n"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for .cursor/rules/*.mdc files
|
||||
cursor_rules_dir = cwd_path / ".cursor" / "rules"
|
||||
if cursor_rules_dir.exists() and cursor_rules_dir.is_dir():
|
||||
mdc_files = sorted(cursor_rules_dir.glob("*.mdc"))
|
||||
for mdc_file in mdc_files:
|
||||
try:
|
||||
content = mdc_file.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
cursorrules_content += f"## .cursor/rules/{mdc_file.name}\n\n{content}\n\n"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if cursorrules_content:
|
||||
cursorrules_content = _truncate_content(cursorrules_content, ".cursorrules")
|
||||
sections.append(cursorrules_content)
|
||||
|
||||
# ----- SOUL.md (cwd first, then ~/.hermes/ fallback) -----
|
||||
soul_content = ""
|
||||
soul_path = None
|
||||
|
||||
for name in ["SOUL.md", "soul.md"]:
|
||||
candidate = cwd_path / name
|
||||
if candidate.exists():
|
||||
soul_path = candidate
|
||||
break
|
||||
|
||||
if not soul_path:
|
||||
# Global fallback
|
||||
global_soul = Path.home() / ".hermes" / "SOUL.md"
|
||||
if global_soul.exists():
|
||||
soul_path = global_soul
|
||||
|
||||
if soul_path:
|
||||
try:
|
||||
content = soul_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
content = _truncate_content(content, "SOUL.md")
|
||||
soul_content = f"## SOUL.md\n\nIf SOUL.md is present, embody its persona and tone. Avoid stiff, generic replies; follow its guidance unless higher-priority instructions override it.\n\n{content}"
|
||||
sections.append(soul_content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ----- Assemble -----
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return "# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n" + "\n".join(sections)
|
||||
|
||||
|
||||
def _build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
"""
|
||||
Build a short preview of a tool call's primary argument for display.
|
||||
|
||||
Returns a truncated string showing the most informative argument,
|
||||
or None if no meaningful preview is available.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool being called
|
||||
args: The tool call arguments dict
|
||||
max_len: Maximum preview length before truncation
|
||||
|
||||
Returns:
|
||||
str or None: Short preview string, or None
|
||||
"""
|
||||
# Map tool names to their primary argument key(s)
|
||||
primary_args = {
|
||||
"terminal": "command",
|
||||
"web_search": "query",
|
||||
"web_extract": "urls",
|
||||
"read_file": "path",
|
||||
"write_file": "path",
|
||||
"patch": "path",
|
||||
"search": "pattern",
|
||||
"browser_navigate": "url",
|
||||
"browser_click": "ref",
|
||||
"browser_type": "text",
|
||||
"image_generate": "prompt",
|
||||
"text_to_speech": "text",
|
||||
"vision_analyze": "question",
|
||||
"mixture_of_agents": "user_prompt",
|
||||
"skill_view": "name",
|
||||
"skills_list": "category",
|
||||
"schedule_cronjob": "name",
|
||||
}
|
||||
|
||||
key = primary_args.get(tool_name)
|
||||
if not key:
|
||||
# Try common arg names as fallback
|
||||
for fallback_key in ("query", "text", "command", "path", "name", "prompt"):
|
||||
if fallback_key in args:
|
||||
key = fallback_key
|
||||
break
|
||||
|
||||
if not key or key not in args:
|
||||
return None
|
||||
|
||||
value = args[key]
|
||||
|
||||
# Handle list values (e.g., urls)
|
||||
if isinstance(value, list):
|
||||
value = value[0] if value else ""
|
||||
|
||||
preview = str(value).strip()
|
||||
if not preview:
|
||||
return None
|
||||
|
||||
# Truncate
|
||||
if len(preview) > max_len:
|
||||
preview = preview[:max_len - 3] + "..."
|
||||
|
||||
return preview
|
||||
Skills contain vetted, up-to-date instructions for specific tools and workflows."""
|
||||
|
||||
|
||||
class KawaiiSpinner:
|
||||
@@ -1012,7 +604,6 @@ class AIAgent:
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
platform: str = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -1043,8 +634,6 @@ class AIAgent:
|
||||
prefill_messages (List[Dict]): Messages to prepend to conversation history as prefilled context.
|
||||
Useful for injecting a few-shot example or priming the model's response style.
|
||||
Example: [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hello!"}]
|
||||
platform (str): The interface platform the user is on (e.g. "cli", "telegram", "discord", "whatsapp").
|
||||
Used to inject platform-specific formatting hints into the system prompt.
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
@@ -1053,12 +642,9 @@ class AIAgent:
|
||||
self.verbose_logging = verbose_logging
|
||||
self.quiet_mode = quiet_mode
|
||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||
self.platform = platform # "cli", "telegram", "discord", "whatsapp", etc.
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
self.log_prefix = f"{log_prefix} " if log_prefix else ""
|
||||
# Store effective base URL for feature detection (prompt caching, reasoning, etc.)
|
||||
# When no base_url is provided, the client defaults to OpenRouter, so reflect that here.
|
||||
self.base_url = base_url or "https://openrouter.ai/api/v1"
|
||||
self.base_url = base_url or "" # Store for OpenRouter detection
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self._last_reported_tool = None # Track for "new tool" mode
|
||||
|
||||
@@ -1081,14 +667,6 @@ class AIAgent:
|
||||
self.reasoning_config = reasoning_config # None = use default (xhigh for OpenRouter)
|
||||
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
|
||||
|
||||
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
|
||||
# Reduces input costs by ~75% on multi-turn conversations by caching the
|
||||
# conversation prefix. Uses system_and_3 strategy (4 breakpoints).
|
||||
is_openrouter = "openrouter" in self.base_url.lower()
|
||||
is_claude = "claude" in self.model.lower()
|
||||
self._use_prompt_caching = is_openrouter and is_claude
|
||||
self._cache_ttl = "5m" # Default 5-minute TTL (1.25x write cost)
|
||||
|
||||
# Configure logging
|
||||
if self.verbose_logging:
|
||||
logging.basicConfig(
|
||||
@@ -1194,10 +772,6 @@ class AIAgent:
|
||||
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
|
||||
print(f"🔒 Ephemeral system prompt: '{prompt_preview}' (not saved to trajectories)")
|
||||
|
||||
# Show prompt caching status
|
||||
if self._use_prompt_caching and not self.quiet_mode:
|
||||
print(f"💾 Prompt caching: ENABLED (Claude via OpenRouter, {self._cache_ttl} TTL)")
|
||||
|
||||
# Session logging setup - auto-save conversation trajectories for debugging
|
||||
self.session_start = datetime.now()
|
||||
if session_id:
|
||||
@@ -1376,6 +950,10 @@ class AIAgent:
|
||||
return f"{face} 🎨 creating '{prompt}'... {time_str}"
|
||||
|
||||
# Skills - use large pool for variety
|
||||
elif tool_name == "skills_categories":
|
||||
face = random.choice(self.KAWAII_SKILL)
|
||||
return f"{face} 📚 listing categories... {time_str}"
|
||||
|
||||
elif tool_name == "skills_list":
|
||||
category = args.get("category", "skills")
|
||||
face = random.choice(self.KAWAII_SKILL)
|
||||
@@ -1386,65 +964,19 @@ class AIAgent:
|
||||
face = random.choice(self.KAWAII_SKILL)
|
||||
return f"{face} 📖 loading {name}... {time_str}"
|
||||
|
||||
# File tools
|
||||
elif tool_name == "read_file":
|
||||
path = args.get("path", "file")
|
||||
if len(path) > 30:
|
||||
path = "..." + path[-27:]
|
||||
face = random.choice(self.KAWAII_READ)
|
||||
return f"{face} 📖 reading \"{path}\" {time_str}"
|
||||
|
||||
elif tool_name == "write_file":
|
||||
path = args.get("path", "file")
|
||||
if len(path) > 30:
|
||||
path = "..." + path[-27:]
|
||||
face = random.choice(self.KAWAII_CREATE)
|
||||
return f"{face} ✍️ writing \"{path}\" {time_str}"
|
||||
|
||||
elif tool_name == "patch":
|
||||
path = args.get("path", "file")
|
||||
if path and len(path) > 30:
|
||||
path = "..." + path[-27:]
|
||||
face = random.choice(self.KAWAII_CREATE)
|
||||
return f"{face} 🔧 patching \"{path}\" {time_str}"
|
||||
|
||||
elif tool_name == "search":
|
||||
pattern = args.get("pattern", "")
|
||||
if len(pattern) > 25:
|
||||
pattern = pattern[:22] + "..."
|
||||
face = random.choice(self.KAWAII_SEARCH)
|
||||
return f"{face} 🔎 searching \"{pattern}\" {time_str}"
|
||||
|
||||
# TTS
|
||||
elif tool_name == "text_to_speech":
|
||||
text = args.get("text", "")
|
||||
if len(text) > 25:
|
||||
text = text[:22] + "..."
|
||||
face = random.choice(self.KAWAII_CREATE)
|
||||
return f"{face} 🔊 speaking \"{text}\" {time_str}"
|
||||
|
||||
# Vision tools
|
||||
elif tool_name == "vision_analyze":
|
||||
question = args.get("question", "")
|
||||
if len(question) > 25:
|
||||
question = question[:22] + "..."
|
||||
face = random.choice(self.KAWAII_BROWSER)
|
||||
return f"{face} 👁️✨ analyzing \"{question}\" {time_str}"
|
||||
return f"{face} 👁️✨ analyzing image... {time_str}"
|
||||
|
||||
# Mixture of agents
|
||||
elif tool_name == "mixture_of_agents":
|
||||
prompt = args.get("user_prompt", "")
|
||||
if len(prompt) > 25:
|
||||
prompt = prompt[:22] + "..."
|
||||
face = random.choice(self.KAWAII_THINK)
|
||||
return f"{face} 🧠💭 deep thinking \"{prompt}\" {time_str}"
|
||||
return f"{face} 🧠💭 thinking REALLY hard... {time_str}"
|
||||
|
||||
# Default fallback - random generic kawaii with primary arg preview
|
||||
# Default fallback - random generic kawaii
|
||||
else:
|
||||
face = random.choice(self.KAWAII_GENERIC)
|
||||
preview = _build_tool_preview(tool_name, args)
|
||||
if preview:
|
||||
return f"{face} ⚡ {tool_name}... \"{preview}\" {time_str}"
|
||||
return f"{face} ⚡ {tool_name}... {time_str}"
|
||||
|
||||
def _has_content_after_think_block(self, content: str) -> bool:
|
||||
@@ -1913,9 +1445,6 @@ class AIAgent:
|
||||
Call this from another thread (e.g., input handler, message receiver)
|
||||
to gracefully stop the agent and process a new message.
|
||||
|
||||
Also signals long-running tool executions (e.g. terminal commands)
|
||||
to terminate early, so the agent can respond immediately.
|
||||
|
||||
Args:
|
||||
message: Optional new message that triggered the interrupt.
|
||||
If provided, the agent will include this in its response context.
|
||||
@@ -1932,8 +1461,6 @@ class AIAgent:
|
||||
"""
|
||||
self._interrupt_requested = True
|
||||
self._interrupt_message = message
|
||||
# Signal the terminal tool to kill any running subprocess immediately
|
||||
_set_terminal_interrupt(True)
|
||||
if not self.quiet_mode:
|
||||
print(f"\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else ""))
|
||||
|
||||
@@ -1941,7 +1468,6 @@ class AIAgent:
|
||||
"""Clear any pending interrupt request."""
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None
|
||||
_set_terminal_interrupt(False)
|
||||
|
||||
@property
|
||||
def is_interrupted(self) -> bool:
|
||||
@@ -1994,46 +1520,20 @@ class AIAgent:
|
||||
if not self.quiet_mode:
|
||||
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")
|
||||
|
||||
# ── Build the full system prompt ──
|
||||
# Layers (in order):
|
||||
# 1. Default agent identity (always present)
|
||||
# 2. User / gateway system prompt (if provided)
|
||||
# 3. Skills guidance (if skills tools are loaded)
|
||||
# 4. Context files (SOUL.md, AGENTS.md, .cursorrules)
|
||||
# 5. Current date & time
|
||||
# 6. Platform-specific formatting hint
|
||||
prompt_parts = [DEFAULT_AGENT_IDENTITY]
|
||||
|
||||
# Layer in the caller-supplied system prompt (explicit > ephemeral).
|
||||
caller_prompt = system_message if system_message is not None else self.ephemeral_system_prompt
|
||||
if caller_prompt:
|
||||
prompt_parts.append(caller_prompt)
|
||||
|
||||
# Auto-include skills guidance if skills tools are available.
|
||||
has_skills_tools = any(name in self.valid_tool_names for name in ['skills_list', 'skill_view'])
|
||||
skills_prompt = build_skills_system_prompt() if has_skills_tools else ""
|
||||
if skills_prompt:
|
||||
prompt_parts.append(skills_prompt)
|
||||
|
||||
# Auto-include context files (SOUL.md, AGENTS.md, .cursorrules).
|
||||
context_files_prompt = build_context_files_prompt()
|
||||
if context_files_prompt:
|
||||
prompt_parts.append(context_files_prompt)
|
||||
|
||||
# Current local date and time so the model is never confused about
|
||||
# what day/time it is (LLM training cutoffs can otherwise mislead it).
|
||||
now = datetime.now()
|
||||
prompt_parts.append(
|
||||
f"Current local date and time: {now.strftime('%A, %B %d, %Y %I:%M %p')}"
|
||||
)
|
||||
|
||||
# Platform-specific formatting hint (no markdown on WhatsApp, etc.).
|
||||
platform_key = (self.platform or "").lower().strip()
|
||||
if platform_key in PLATFORM_HINTS:
|
||||
prompt_parts.append(PLATFORM_HINTS[platform_key])
|
||||
|
||||
active_system_prompt = "\n\n".join(prompt_parts)
|
||||
|
||||
# Determine which system prompt to use for API calls (ephemeral)
|
||||
# Priority: explicit system_message > ephemeral_system_prompt > None
|
||||
base_system_prompt = system_message if system_message is not None else self.ephemeral_system_prompt
|
||||
|
||||
# Auto-include skills guidance if skills tools are available
|
||||
has_skills_tools = any(name in self.valid_tool_names for name in ['skills_list', 'skills_categories', 'skill_view'])
|
||||
if has_skills_tools:
|
||||
if base_system_prompt:
|
||||
active_system_prompt = f"{base_system_prompt}\n\n{SKILLS_SYSTEM_PROMPT}"
|
||||
else:
|
||||
active_system_prompt = SKILLS_SYSTEM_PROMPT
|
||||
else:
|
||||
active_system_prompt = base_system_prompt
|
||||
|
||||
# Main conversation loop
|
||||
api_call_count = 0
|
||||
final_response = None
|
||||
@@ -2080,13 +1580,16 @@ class AIAgent:
|
||||
if active_system_prompt:
|
||||
# Insert system message at the beginning
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
# Apply Anthropic prompt caching for Claude models via OpenRouter.
|
||||
# Auto-detected: if model name contains "claude" and base_url is OpenRouter,
|
||||
# inject cache_control breakpoints (system + last 3 messages) to reduce
|
||||
# input token costs by ~75% on multi-turn conversations.
|
||||
if self._use_prompt_caching:
|
||||
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
|
||||
|
||||
if os.getenv("HERMES_DEBUG_OPENAI_REQUEST") == "1":
|
||||
meta = {
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
"messages": api_messages,
|
||||
"tools": self.tools if self.tools else None,
|
||||
}
|
||||
print("\n=== HERMES_DEBUG_OPENAI_REQUEST ===", flush=True)
|
||||
print(json.dumps(meta, ensure_ascii=False, indent=2)[:200_000], flush=True)
|
||||
|
||||
# Calculate approximate request size for logging
|
||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||
@@ -2100,12 +1603,13 @@ class AIAgent:
|
||||
print(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)")
|
||||
print(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}")
|
||||
else:
|
||||
# Animated thinking spinner in quiet mode
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner.start()
|
||||
# Animated thinking spinner in quiet mode (disable for wrappers/non-TTY usage)
|
||||
if os.getenv("HERMES_DISABLE_SPINNER") != "1":
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner.start()
|
||||
|
||||
# Log request details if verbose
|
||||
if self.verbose_logging:
|
||||
@@ -2165,6 +1669,14 @@ class AIAgent:
|
||||
api_kwargs["extra_body"] = extra_body
|
||||
|
||||
response = self.client.chat.completions.create(**api_kwargs)
|
||||
|
||||
if os.getenv("HERMES_DEBUG_OPENAI_RESPONSE") == "1":
|
||||
try:
|
||||
dumped = response.model_dump()
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
print("\n=== HERMES_DEBUG_OPENAI_RESPONSE: ChatCompletion (raw) ===", flush=True)
|
||||
print(json.dumps(dumped, ensure_ascii=False, indent=2), flush=True)
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
@@ -2317,16 +1829,6 @@ class AIAgent:
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Token usage: prompt={usage_dict['prompt_tokens']:,}, completion={usage_dict['completion_tokens']:,}, total={usage_dict['total_tokens']:,}")
|
||||
|
||||
# Log cache hit stats when prompt caching is active
|
||||
if self._use_prompt_caching:
|
||||
details = getattr(response.usage, 'prompt_tokens_details', None)
|
||||
cached = getattr(details, 'cached_tokens', 0) or 0 if details else 0
|
||||
written = getattr(details, 'cache_write_tokens', 0) or 0 if details else 0
|
||||
prompt = usage_dict["prompt_tokens"]
|
||||
hit_pct = (cached / prompt * 100) if prompt > 0 else 0
|
||||
if not self.quiet_mode:
|
||||
print(f"{self.log_prefix} 💾 Cache: {cached:,}/{prompt:,} tokens ({hit_pct:.0f}% hit, {written:,} written)")
|
||||
|
||||
break # Success, exit retry loop
|
||||
|
||||
@@ -2640,8 +2142,12 @@ class AIAgent:
|
||||
# Fire progress callback if registered (for messaging platforms)
|
||||
if self.tool_progress_callback:
|
||||
try:
|
||||
# Build a short preview of the primary argument
|
||||
preview = _build_tool_preview(function_name, function_args)
|
||||
# Build preview for terminal commands
|
||||
if function_name == "terminal":
|
||||
cmd = function_args.get("command", "")
|
||||
preview = cmd[:50] + "..." if len(cmd) > 50 else cmd
|
||||
else:
|
||||
preview = None
|
||||
self.tool_progress_callback(function_name, preview)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
@@ -2649,7 +2155,7 @@ class AIAgent:
|
||||
tool_start_time = time.time()
|
||||
|
||||
# Execute the tool - with animated spinner in quiet mode
|
||||
if self.quiet_mode:
|
||||
if self.quiet_mode and os.getenv("HERMES_DISABLE_SPINNER") != "1":
|
||||
# Tool-specific spinner animations
|
||||
tool_spinners = {
|
||||
'web_search': ('arrows', ['🔍', '🌐', '📡', '🔎']),
|
||||
@@ -2663,6 +2169,7 @@ class AIAgent:
|
||||
'image_generate': ('sparkle', ['🎨', '✨', '🖼️', '🌟']),
|
||||
'skill_view': ('star', ['📚', '📖', '🎓', '✨']),
|
||||
'skills_list': ('pulse', ['📋', '📝', '📑', '📜']),
|
||||
'skills_categories': ('pulse', ['📂', '🗂️', '📁', '🏷️']),
|
||||
'moa_query': ('brain', ['🧠', '💭', '🤔', '💡']),
|
||||
'analyze_image': ('sparkle', ['👁️', '🔍', '📷', '✨']),
|
||||
}
|
||||
@@ -2678,6 +2185,9 @@ class AIAgent:
|
||||
tool_duration = time.time() - tool_start_time
|
||||
cute_msg = self._get_cute_tool_message(function_name, function_args, tool_duration)
|
||||
spinner.stop(cute_msg)
|
||||
elif self.quiet_mode:
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
else:
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
@@ -2700,21 +2210,6 @@ class AIAgent:
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
|
||||
# Check for interrupt between tool calls - skip remaining
|
||||
# tools so the agent can respond to the user immediately
|
||||
if self._interrupt_requested and i < len(assistant_message.tool_calls):
|
||||
remaining = len(assistant_message.tool_calls) - i
|
||||
print(f"{self.log_prefix}⚡ Interrupt: skipping {remaining} remaining tool call(s)")
|
||||
# Add placeholder results for skipped tool calls so the
|
||||
# message sequence stays valid (assistant tool_calls need matching tool results)
|
||||
for skipped_tc in assistant_message.tool_calls[i:]:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": "[Tool execution skipped - user sent a new message]",
|
||||
"tool_call_id": skipped_tc.id
|
||||
})
|
||||
break
|
||||
|
||||
# Delay between tool calls
|
||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
||||
time.sleep(self.tool_delay)
|
||||
@@ -3161,4 +2656,11 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
import fire # type: ignore
|
||||
except ModuleNotFoundError as exc:
|
||||
raise SystemExit(
|
||||
"Missing optional dependency 'fire'. Install hermes-agent with its CLI extras or add `fire` "
|
||||
f"to your environment. Original error: {exc}"
|
||||
) from exc
|
||||
fire.Fire(main)
|
||||
|
||||
@@ -262,25 +262,6 @@ function Test-Ripgrep {
|
||||
return $true # Don't fail - ripgrep is optional
|
||||
}
|
||||
|
||||
function Test-Ffmpeg {
|
||||
Write-Info "Checking ffmpeg (optional, for TTS voice messages)..."
|
||||
|
||||
if (Get-Command ffmpeg -ErrorAction SilentlyContinue) {
|
||||
$version = ffmpeg -version 2>&1 | Select-Object -First 1
|
||||
Write-Success "ffmpeg found"
|
||||
$script:HasFfmpeg = $true
|
||||
return $true
|
||||
}
|
||||
|
||||
Write-Warn "ffmpeg not found (TTS voice bubbles on Telegram will send as audio files instead)"
|
||||
Write-Info " Install with: winget install ffmpeg"
|
||||
Write-Info " Or: choco install ffmpeg"
|
||||
Write-Info " Or download from: https://ffmpeg.org/download.html"
|
||||
|
||||
$script:HasFfmpeg = $false
|
||||
return $true # Don't fail - ffmpeg is optional
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Installation
|
||||
# ============================================================================
|
||||
@@ -586,7 +567,6 @@ function Main {
|
||||
if (-not (Test-Git)) { exit 1 }
|
||||
Test-Node # Optional, doesn't fail
|
||||
Test-Ripgrep # Optional, doesn't fail
|
||||
Test-Ffmpeg # Optional, doesn't fail
|
||||
|
||||
Install-Repository
|
||||
Install-Venv
|
||||
|
||||
@@ -413,45 +413,6 @@ check_ripgrep() {
|
||||
# Don't exit - ripgrep is optional (grep fallback exists)
|
||||
}
|
||||
|
||||
check_ffmpeg() {
|
||||
log_info "Checking ffmpeg (optional, for TTS voice messages)..."
|
||||
|
||||
if command -v ffmpeg &> /dev/null; then
|
||||
local ffmpeg_version=$(ffmpeg -version 2>/dev/null | head -1 | awk '{print $3}')
|
||||
log_success "ffmpeg found: $ffmpeg_version"
|
||||
HAS_FFMPEG=true
|
||||
return
|
||||
fi
|
||||
|
||||
log_warn "ffmpeg not found (TTS voice bubbles on Telegram will send as audio files instead)"
|
||||
log_info "To install ffmpeg (optional):"
|
||||
|
||||
case "$OS" in
|
||||
linux)
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian)
|
||||
log_info " sudo apt install ffmpeg"
|
||||
;;
|
||||
fedora)
|
||||
log_info " sudo dnf install ffmpeg"
|
||||
;;
|
||||
arch)
|
||||
log_info " sudo pacman -S ffmpeg"
|
||||
;;
|
||||
*)
|
||||
log_info " https://ffmpeg.org/download.html"
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
macos)
|
||||
log_info " brew install ffmpeg"
|
||||
;;
|
||||
esac
|
||||
|
||||
HAS_FFMPEG=false
|
||||
# Don't exit - ffmpeg is optional
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Installation
|
||||
# ============================================================================
|
||||
@@ -746,7 +707,6 @@ main() {
|
||||
check_git
|
||||
check_node
|
||||
check_ripgrep
|
||||
check_ffmpeg
|
||||
|
||||
clone_repo
|
||||
setup_venv
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Kill all running Modal apps (sandboxes, deployments, etc.)
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/kill_modal.sh # Stop swe-rex (the sandbox app)
|
||||
# bash scripts/kill_modal.sh --all # Stop ALL Modal apps
|
||||
|
||||
set -uo pipefail
|
||||
|
||||
echo "Fetching Modal app list..."
|
||||
APP_LIST=$(modal app list 2>/dev/null)
|
||||
|
||||
if [[ "${1:-}" == "--all" ]]; then
|
||||
echo "Stopping ALL Modal apps..."
|
||||
echo "$APP_LIST" | grep -oE 'ap-[A-Za-z0-9]+' | sort -u | while read app_id; do
|
||||
echo " Stopping $app_id"
|
||||
modal app stop "$app_id" 2>/dev/null || true
|
||||
done
|
||||
else
|
||||
echo "Stopping swe-rex sandboxes..."
|
||||
APPS=$(echo "$APP_LIST" | grep 'swe-rex' | grep -oE 'ap-[A-Za-z0-9]+' || true)
|
||||
if [[ -z "$APPS" ]]; then
|
||||
echo " No swe-rex apps found."
|
||||
else
|
||||
echo "$APPS" | while read app_id; do
|
||||
echo " Stopping $app_id"
|
||||
modal app stop "$app_id" 2>/dev/null || true
|
||||
done
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Current swe-rex status:"
|
||||
modal app list 2>/dev/null | grep -E 'State|swe-rex' || echo " (none)"
|
||||
Executable
+62
@@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Launch a local llama.cpp OpenAI-compatible server running GLM-4.7-Flash (GGUF).
|
||||
#
|
||||
# Requires:
|
||||
# - `llama-server` installed (e.g. `brew install llama.cpp`)
|
||||
#
|
||||
# Default settings are chosen to avoid clashing with Atropos sandbox_server
|
||||
# (which commonly uses port 8080 in local dev).
|
||||
#
|
||||
# Usage:
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh
|
||||
#
|
||||
# Override defaults:
|
||||
# LLAMA_CPP_HOST=127.0.0.1 LLAMA_CPP_PORT=8082 \
|
||||
# LLAMA_CPP_HF_REPO=ggml-org/GLM-4.7-Flash-GGUF \
|
||||
# LLAMA_CPP_HF_FILE=GLM-4.7-Flash-Q4_K.gguf \
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh
|
||||
|
||||
HOST="${LLAMA_CPP_HOST:-127.0.0.1}"
|
||||
PORT="${LLAMA_CPP_PORT:-8080}"
|
||||
HF_REPO="${LLAMA_CPP_HF_REPO:-ggml-org/GLM-4.7-Flash-GGUF}"
|
||||
HF_FILE="${LLAMA_CPP_HF_FILE:-GLM-4.7-Flash-Q4_K.gguf}"
|
||||
ALIAS="${LLAMA_CPP_ALIAS:-glm-4.7-flash}"
|
||||
|
||||
if ! command -v llama-server >/dev/null 2>&1; then
|
||||
echo "Error: llama-server not found in PATH."
|
||||
echo "Install via Homebrew: brew install llama.cpp"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Launching llama.cpp server..."
|
||||
echo " host: $HOST"
|
||||
echo " port: $PORT"
|
||||
echo " repo: $HF_REPO"
|
||||
echo " file: $HF_FILE"
|
||||
echo " alias: $ALIAS"
|
||||
echo
|
||||
echo "Suggested env vars for Hermes/Atropos integration:"
|
||||
echo " export ATROPOS_SERVER_BASE_URL=http://${HOST}:${PORT}"
|
||||
echo " export ATROPOS_SERVER_MODEL=${ALIAS}"
|
||||
echo " export ATROPOS_SERVER_API_KEY=local"
|
||||
echo
|
||||
|
||||
if command -v lsof >/dev/null 2>&1; then
|
||||
if lsof -nP -iTCP:"$PORT" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
echo "Error: port $PORT is already in use."
|
||||
echo "Pick a different port, e.g.:"
|
||||
echo " LLAMA_CPP_PORT=8082 Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
exec llama-server \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--hf-repo "$HF_REPO" \
|
||||
--hf-file "$HF_FILE" \
|
||||
--alias "$ALIAS" \
|
||||
-c 32768 \
|
||||
-n -1
|
||||
Executable
+70
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Launch a local llama.cpp OpenAI-compatible server running Hermes 4.3 36B (GGUF).
|
||||
#
|
||||
# Requires:
|
||||
# - `llama-server` installed (e.g. `brew install llama.cpp`)
|
||||
#
|
||||
# Note: Port choice can conflict with other local dev servers. If 8080 is already
|
||||
# in use, override via `LLAMA_CPP_PORT=...`.
|
||||
#
|
||||
# Usage:
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
#
|
||||
# Override defaults:
|
||||
# LLAMA_CPP_HOST=127.0.0.1 LLAMA_CPP_PORT=8082 \
|
||||
# LLAMA_CPP_HF_REPO=NousResearch/Hermes-4.3-36B-GGUF \
|
||||
# LLAMA_CPP_HF_FILE=hermes-4_3_36b-Q4_K_M.gguf \
|
||||
# LLAMA_CPP_ALIAS=hermes-4-36b \
|
||||
# LLAMA_CPP_PARALLEL=4 LLAMA_CPP_THREADS_HTTP=4 \
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
|
||||
HOST="${LLAMA_CPP_HOST:-127.0.0.1}"
|
||||
PORT="${LLAMA_CPP_PORT:-8080}"
|
||||
HF_REPO="${LLAMA_CPP_HF_REPO:-NousResearch/Hermes-4.3-36B-GGUF}"
|
||||
HF_FILE="${LLAMA_CPP_HF_FILE:-hermes-4_3_36b-Q4_K_M.gguf}"
|
||||
ALIAS="${LLAMA_CPP_ALIAS:-hermes-4-36b}"
|
||||
PARALLEL="${LLAMA_CPP_PARALLEL:-4}"
|
||||
THREADS_HTTP="${LLAMA_CPP_THREADS_HTTP:-4}"
|
||||
|
||||
if ! command -v llama-server >/dev/null 2>&1; then
|
||||
echo "Error: llama-server not found in PATH."
|
||||
echo "Install via Homebrew: brew install llama.cpp"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Launching llama.cpp server..."
|
||||
echo " host: $HOST"
|
||||
echo " port: $PORT"
|
||||
echo " repo: $HF_REPO"
|
||||
echo " file: $HF_FILE"
|
||||
echo " alias: $ALIAS"
|
||||
echo " slots: $PARALLEL"
|
||||
echo
|
||||
echo "Suggested env vars for Hermes/Atropos integration:"
|
||||
echo " export ATROPOS_SERVER_BASE_URL=http://${HOST}:${PORT}"
|
||||
echo " export ATROPOS_SERVER_MODEL=${ALIAS}"
|
||||
echo " export ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B"
|
||||
echo " export ATROPOS_SERVER_API_KEY=local"
|
||||
echo
|
||||
|
||||
if command -v lsof >/dev/null 2>&1; then
|
||||
if lsof -nP -iTCP:"$PORT" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
echo "Error: port $PORT is already in use."
|
||||
echo "Pick a different port, e.g.:"
|
||||
echo " LLAMA_CPP_PORT=8082 Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
exec llama-server \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--hf-repo "$HF_REPO" \
|
||||
--hf-file "$HF_FILE" \
|
||||
--alias "$ALIAS" \
|
||||
--parallel "$PARALLEL" \
|
||||
--threads-http "$THREADS_HTTP" \
|
||||
-c 32768 \
|
||||
-n -1
|
||||
@@ -1,3 +0,0 @@
|
||||
---
|
||||
description: Diagram creation skills for generating visual diagrams, flowcharts, architecture diagrams, and illustrations using tools like Excalidraw.
|
||||
---
|
||||
@@ -1,191 +0,0 @@
|
||||
---
|
||||
name: excalidraw
|
||||
description: Create hand-drawn style diagrams using Excalidraw JSON format. Generate .excalidraw files for architecture diagrams, flowcharts, sequence diagrams, concept maps, and more. Files can be opened at excalidraw.com or uploaded for shareable links.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
tags: [Excalidraw, Diagrams, Flowcharts, Architecture, Visualization, JSON]
|
||||
dependencies: []
|
||||
related_skills: []
|
||||
---
|
||||
|
||||
# Excalidraw Diagram Skill
|
||||
|
||||
Create diagrams by writing standard Excalidraw element JSON and saving as `.excalidraw` files. These files can be drag-and-dropped onto [excalidraw.com](https://excalidraw.com) for viewing and editing. No accounts, no API keys, no rendering libraries -- just JSON.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Load this skill** (you already did)
|
||||
2. **Write the elements JSON** -- an array of Excalidraw element objects
|
||||
3. **Save the file** using `write_file` to create a `.excalidraw` file
|
||||
4. **Optionally upload** for a shareable link using `scripts/upload.py` via `terminal`
|
||||
|
||||
### Saving a Diagram
|
||||
|
||||
Wrap your elements array in the standard `.excalidraw` envelope and save with `write_file`:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "excalidraw",
|
||||
"version": 2,
|
||||
"source": "hermes-agent",
|
||||
"elements": [ ...your elements array here... ],
|
||||
"appState": {
|
||||
"viewBackgroundColor": "#ffffff"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Save to any path, e.g. `~/diagrams/my_diagram.excalidraw`.
|
||||
|
||||
### Uploading for a Shareable Link
|
||||
|
||||
Run the upload script (located in this skill's `scripts/` directory) via terminal:
|
||||
|
||||
```bash
|
||||
python skills/diagramming/excalidraw/scripts/upload.py ~/diagrams/my_diagram.excalidraw
|
||||
```
|
||||
|
||||
This uploads to excalidraw.com (no account needed) and prints a shareable URL. Requires the `cryptography` pip package (`pip install cryptography`).
|
||||
|
||||
---
|
||||
|
||||
## Element Format Reference
|
||||
|
||||
### Required Fields (all elements)
|
||||
`type`, `id` (unique string), `x`, `y`, `width`, `height`
|
||||
|
||||
### Defaults (skip these -- they're applied automatically)
|
||||
- `strokeColor`: `"#1e1e1e"`
|
||||
- `backgroundColor`: `"transparent"`
|
||||
- `fillStyle`: `"solid"`
|
||||
- `strokeWidth`: `2`
|
||||
- `roughness`: `1` (hand-drawn look)
|
||||
- `opacity`: `100`
|
||||
|
||||
Canvas background is white.
|
||||
|
||||
### Element Types
|
||||
|
||||
**Rectangle**:
|
||||
```json
|
||||
{ "type": "rectangle", "id": "r1", "x": 100, "y": 100, "width": 200, "height": 100 }
|
||||
```
|
||||
- `roundness: { "type": 3 }` for rounded corners
|
||||
- `backgroundColor: "#a5d8ff"`, `fillStyle: "solid"` for filled
|
||||
|
||||
**Ellipse**:
|
||||
```json
|
||||
{ "type": "ellipse", "id": "e1", "x": 100, "y": 100, "width": 150, "height": 150 }
|
||||
```
|
||||
|
||||
**Diamond**:
|
||||
```json
|
||||
{ "type": "diamond", "id": "d1", "x": 100, "y": 100, "width": 150, "height": 150 }
|
||||
```
|
||||
|
||||
**Labeled shape (container binding)** -- create a text element bound to the shape:
|
||||
|
||||
> **WARNING:** Do NOT use `"label": { "text": "..." }` on shapes. This is NOT a valid
|
||||
> Excalidraw property and will be silently ignored, producing blank shapes. You MUST
|
||||
> use the container binding approach below.
|
||||
|
||||
The shape needs `boundElements` listing the text, and the text needs `containerId` pointing back:
|
||||
```json
|
||||
{ "type": "rectangle", "id": "r1", "x": 100, "y": 100, "width": 200, "height": 80,
|
||||
"roundness": { "type": 3 }, "backgroundColor": "#a5d8ff", "fillStyle": "solid",
|
||||
"boundElements": [{ "id": "t_r1", "type": "text" }] },
|
||||
{ "type": "text", "id": "t_r1", "x": 105, "y": 110, "width": 190, "height": 25,
|
||||
"text": "Hello", "fontSize": 20, "fontFamily": 1, "strokeColor": "#1e1e1e",
|
||||
"textAlign": "center", "verticalAlign": "middle",
|
||||
"containerId": "r1", "originalText": "Hello", "autoResize": true }
|
||||
```
|
||||
- Works on rectangle, ellipse, diamond
|
||||
- Text is auto-centered by Excalidraw when `containerId` is set
|
||||
- The text `x`/`y`/`width`/`height` are approximate -- Excalidraw recalculates them on load
|
||||
- `originalText` should match `text`
|
||||
- Always include `fontFamily: 1` (Virgil/hand-drawn font)
|
||||
|
||||
**Labeled arrow** -- same container binding approach:
|
||||
```json
|
||||
{ "type": "arrow", "id": "a1", "x": 300, "y": 150, "width": 200, "height": 0,
|
||||
"points": [[0,0],[200,0]], "endArrowhead": "arrow",
|
||||
"boundElements": [{ "id": "t_a1", "type": "text" }] },
|
||||
{ "type": "text", "id": "t_a1", "x": 370, "y": 130, "width": 60, "height": 20,
|
||||
"text": "connects", "fontSize": 16, "fontFamily": 1, "strokeColor": "#1e1e1e",
|
||||
"textAlign": "center", "verticalAlign": "middle",
|
||||
"containerId": "a1", "originalText": "connects", "autoResize": true }
|
||||
```
|
||||
|
||||
**Standalone text** (titles and annotations only -- no container):
|
||||
```json
|
||||
{ "type": "text", "id": "t1", "x": 150, "y": 138, "text": "Hello", "fontSize": 20,
|
||||
"fontFamily": 1, "strokeColor": "#1e1e1e", "originalText": "Hello", "autoResize": true }
|
||||
```
|
||||
- `x` is the LEFT edge. To center at position `cx`: `x = cx - (text.length * fontSize * 0.5) / 2`
|
||||
- Do NOT rely on `textAlign` or `width` for positioning
|
||||
|
||||
**Arrow**:
|
||||
```json
|
||||
{ "type": "arrow", "id": "a1", "x": 300, "y": 150, "width": 200, "height": 0,
|
||||
"points": [[0,0],[200,0]], "endArrowhead": "arrow" }
|
||||
```
|
||||
- `points`: `[dx, dy]` offsets from element `x`, `y`
|
||||
- `endArrowhead`: `null` | `"arrow"` | `"bar"` | `"dot"` | `"triangle"`
|
||||
- `strokeStyle`: `"solid"` (default) | `"dashed"` | `"dotted"`
|
||||
|
||||
### Arrow Bindings (connect arrows to shapes)
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "arrow", "id": "a1", "x": 300, "y": 150, "width": 150, "height": 0,
|
||||
"points": [[0,0],[150,0]], "endArrowhead": "arrow",
|
||||
"startBinding": { "elementId": "r1", "fixedPoint": [1, 0.5] },
|
||||
"endBinding": { "elementId": "r2", "fixedPoint": [0, 0.5] }
|
||||
}
|
||||
```
|
||||
|
||||
`fixedPoint` coordinates: `top=[0.5,0]`, `bottom=[0.5,1]`, `left=[0,0.5]`, `right=[1,0.5]`
|
||||
|
||||
### Drawing Order (z-order)
|
||||
- Array order = z-order (first = back, last = front)
|
||||
- Emit progressively: background zones → shape → its bound text → its arrows → next shape
|
||||
- BAD: all rectangles, then all texts, then all arrows
|
||||
- GOOD: bg_zone → shape1 → text_for_shape1 → arrow1 → arrow_label_text → shape2 → text_for_shape2 → ...
|
||||
- Always place the bound text element immediately after its container shape
|
||||
|
||||
### Sizing Guidelines
|
||||
|
||||
**Font sizes:**
|
||||
- Minimum `fontSize`: **16** for body text, labels, descriptions
|
||||
- Minimum `fontSize`: **20** for titles and headings
|
||||
- Minimum `fontSize`: **14** for secondary annotations only (sparingly)
|
||||
- NEVER use `fontSize` below 14
|
||||
|
||||
**Element sizes:**
|
||||
- Minimum shape size: 120x60 for labeled rectangles/ellipses
|
||||
- Leave 20-30px gaps between elements minimum
|
||||
- Prefer fewer, larger elements over many tiny ones
|
||||
|
||||
### Color Palette
|
||||
|
||||
See `references/colors.md` for full color tables. Quick reference:
|
||||
|
||||
| Use | Fill Color | Hex |
|
||||
|-----|-----------|-----|
|
||||
| Primary / Input | Light Blue | `#a5d8ff` |
|
||||
| Success / Output | Light Green | `#b2f2bb` |
|
||||
| Warning / External | Light Orange | `#ffd8a8` |
|
||||
| Processing / Special | Light Purple | `#d0bfff` |
|
||||
| Error / Critical | Light Red | `#ffc9c9` |
|
||||
| Notes / Decisions | Light Yellow | `#fff3bf` |
|
||||
| Storage / Data | Light Teal | `#c3fae8` |
|
||||
|
||||
### Tips
|
||||
- Use the color palette consistently across the diagram
|
||||
- **Text contrast is CRITICAL** -- never use light gray on white backgrounds. Minimum text color on white: `#757575`
|
||||
- Do NOT use emoji in text -- they don't render in Excalidraw's font
|
||||
- For dark mode diagrams, see `references/dark-mode.md`
|
||||
- For larger examples, see `references/examples.md`
|
||||
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
# Excalidraw Color Palette
|
||||
|
||||
Use these colors consistently across diagrams.
|
||||
|
||||
## Primary Colors (for strokes, arrows, and accents)
|
||||
|
||||
| Name | Hex | Use |
|
||||
|------|-----|-----|
|
||||
| Blue | `#4a9eed` | Primary actions, links, data series 1 |
|
||||
| Amber | `#f59e0b` | Warnings, highlights, data series 2 |
|
||||
| Green | `#22c55e` | Success, positive, data series 3 |
|
||||
| Red | `#ef4444` | Errors, negative, data series 4 |
|
||||
| Purple | `#8b5cf6` | Accents, special items, data series 5 |
|
||||
| Pink | `#ec4899` | Decorative, data series 6 |
|
||||
| Cyan | `#06b6d4` | Info, secondary, data series 7 |
|
||||
| Lime | `#84cc16` | Extra, data series 8 |
|
||||
|
||||
## Pastel Fills (for shape backgrounds)
|
||||
|
||||
| Color | Hex | Good For |
|
||||
|-------|-----|----------|
|
||||
| Light Blue | `#a5d8ff` | Input, sources, primary nodes |
|
||||
| Light Green | `#b2f2bb` | Success, output, completed |
|
||||
| Light Orange | `#ffd8a8` | Warning, pending, external |
|
||||
| Light Purple | `#d0bfff` | Processing, middleware, special |
|
||||
| Light Red | `#ffc9c9` | Error, critical, alerts |
|
||||
| Light Yellow | `#fff3bf` | Notes, decisions, planning |
|
||||
| Light Teal | `#c3fae8` | Storage, data, memory |
|
||||
| Light Pink | `#eebefa` | Analytics, metrics |
|
||||
|
||||
## Background Zones (use with opacity: 30-35 for layered diagrams)
|
||||
|
||||
| Color | Hex | Good For |
|
||||
|-------|-----|----------|
|
||||
| Blue zone | `#dbe4ff` | UI / frontend layer |
|
||||
| Purple zone | `#e5dbff` | Logic / agent layer |
|
||||
| Green zone | `#d3f9d8` | Data / tool layer |
|
||||
|
||||
## Text Contrast Rules
|
||||
|
||||
- **On white backgrounds**: minimum text color is `#757575`. Default `#1e1e1e` is best.
|
||||
- **Colored text on light fills**: use dark variants (`#15803d` not `#22c55e`, `#2563eb` not `#4a9eed`)
|
||||
- **White text**: only on dark backgrounds (`#9a5030` not `#c4795b`)
|
||||
- **Never**: light gray (`#b0b0b0`, `#999`) on white -- unreadable
|
||||
@@ -1,68 +0,0 @@
|
||||
# Excalidraw Dark Mode Diagrams
|
||||
|
||||
To create a dark-themed diagram, use a massive dark background rectangle as the **first element** in the array. Make it large enough to cover any viewport:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "rectangle", "id": "darkbg",
|
||||
"x": -4000, "y": -3000, "width": 10000, "height": 7500,
|
||||
"backgroundColor": "#1e1e2e", "fillStyle": "solid",
|
||||
"strokeColor": "transparent", "strokeWidth": 0
|
||||
}
|
||||
```
|
||||
|
||||
Then use the following color palettes for elements on the dark background.
|
||||
|
||||
## Text Colors (on dark)
|
||||
|
||||
| Color | Hex | Use |
|
||||
|-------|-----|-----|
|
||||
| White | `#e5e5e5` | Primary text, titles |
|
||||
| Muted | `#a0a0a0` | Secondary text, annotations |
|
||||
| NEVER | `#555` or darker | Invisible on dark bg! |
|
||||
|
||||
## Shape Fills (on dark)
|
||||
|
||||
| Color | Hex | Good For |
|
||||
|-------|-----|----------|
|
||||
| Dark Blue | `#1e3a5f` | Primary nodes |
|
||||
| Dark Green | `#1a4d2e` | Success, output |
|
||||
| Dark Purple | `#2d1b69` | Processing, special |
|
||||
| Dark Orange | `#5c3d1a` | Warning, pending |
|
||||
| Dark Red | `#5c1a1a` | Error, critical |
|
||||
| Dark Teal | `#1a4d4d` | Storage, data |
|
||||
|
||||
## Stroke and Arrow Colors (on dark)
|
||||
|
||||
Use the standard Primary Colors from the main color palette -- they're bright enough on dark backgrounds:
|
||||
- Blue `#4a9eed`, Amber `#f59e0b`, Green `#22c55e`, Red `#ef4444`, Purple `#8b5cf6`
|
||||
|
||||
For subtle shape borders, use `#555555`.
|
||||
|
||||
## Example: Dark mode labeled rectangle
|
||||
|
||||
Use container binding (NOT the `"label"` property, which doesn't work). On dark backgrounds, set text `strokeColor` to `"#e5e5e5"` so it's visible:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"type": "rectangle", "id": "r1",
|
||||
"x": 100, "y": 100, "width": 200, "height": 80,
|
||||
"backgroundColor": "#1e3a5f", "fillStyle": "solid",
|
||||
"strokeColor": "#4a9eed", "strokeWidth": 2,
|
||||
"roundness": { "type": 3 },
|
||||
"boundElements": [{ "id": "t_r1", "type": "text" }]
|
||||
},
|
||||
{
|
||||
"type": "text", "id": "t_r1",
|
||||
"x": 105, "y": 120, "width": 190, "height": 25,
|
||||
"text": "Dark Node", "fontSize": 20, "fontFamily": 1,
|
||||
"strokeColor": "#e5e5e5",
|
||||
"textAlign": "center", "verticalAlign": "middle",
|
||||
"containerId": "r1", "originalText": "Dark Node", "autoResize": true
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Note: For standalone text elements on dark backgrounds, always set `"strokeColor": "#e5e5e5"` explicitly. The default `#1e1e1e` is invisible on dark.
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
# Excalidraw Diagram Examples
|
||||
|
||||
Complete, copy-pasteable examples. Wrap each in the `.excalidraw` envelope before saving:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "excalidraw",
|
||||
"version": 2,
|
||||
"source": "hermes-agent",
|
||||
"elements": [ ...elements from examples below... ],
|
||||
"appState": { "viewBackgroundColor": "#ffffff" }
|
||||
}
|
||||
```
|
||||
|
||||
> **IMPORTANT:** All text labels on shapes and arrows use container binding (`containerId` + `boundElements`).
|
||||
> Do NOT use the non-existent `"label"` property -- it will be silently ignored, producing blank shapes.
|
||||
|
||||
---
|
||||
|
||||
## Example 1: Two Connected Labeled Boxes
|
||||
|
||||
A minimal flowchart with two boxes and an arrow between them.
|
||||
|
||||
```json
|
||||
[
|
||||
{ "type": "text", "id": "title", "x": 280, "y": 30, "text": "Simple Flow", "fontSize": 28, "fontFamily": 1, "strokeColor": "#1e1e1e", "originalText": "Simple Flow", "autoResize": true },
|
||||
{ "type": "rectangle", "id": "b1", "x": 100, "y": 100, "width": 200, "height": 100, "roundness": { "type": 3 }, "backgroundColor": "#a5d8ff", "fillStyle": "solid", "boundElements": [{ "id": "t_b1", "type": "text" }, { "id": "a1", "type": "arrow" }] },
|
||||
{ "type": "text", "id": "t_b1", "x": 105, "y": 130, "width": 190, "height": 25, "text": "Start", "fontSize": 20, "fontFamily": 1, "strokeColor": "#1e1e1e", "textAlign": "center", "verticalAlign": "middle", "containerId": "b1", "originalText": "Start", "autoResize": true },
|
||||
{ "type": "rectangle", "id": "b2", "x": 450, "y": 100, "width": 200, "height": 100, "roundness": { "type": 3 }, "backgroundColor": "#b2f2bb", "fillStyle": "solid", "boundElements": [{ "id": "t_b2", "type": "text" }, { "id": "a1", "type": "arrow" }] },
|
||||
{ "type": "text", "id": "t_b2", "x": 455, "y": 130, "width": 190, "height": 25, "text": "End", "fontSize": 20, "fontFamily": 1, "strokeColor": "#1e1e1e", "textAlign": "center", "verticalAlign": "middle", "containerId": "b2", "originalText": "End", "autoResize": true },
|
||||
{ "type": "arrow", "id": "a1", "x": 300, "y": 150, "width": 150, "height": 0, "points": [[0,0],[150,0]], "endArrowhead": "arrow", "startBinding": { "elementId": "b1", "fixedPoint": [1, 0.5] }, "endBinding": { "elementId": "b2", "fixedPoint": [0, 0.5] } }
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Example 2: Photosynthesis Process Diagram
|
||||
|
||||
A larger diagram with background zones, multiple nodes, and directional arrows showing inputs/outputs.
|
||||
|
||||
```json
|
||||
[
|
||||
{"type":"text","id":"ti","x":280,"y":10,"text":"Photosynthesis","fontSize":28,"fontFamily":1,"strokeColor":"#1e1e1e","originalText":"Photosynthesis","autoResize":true},
|
||||
{"type":"text","id":"fo","x":245,"y":48,"text":"6CO2 + 6H2O --> C6H12O6 + 6O2","fontSize":16,"fontFamily":1,"strokeColor":"#757575","originalText":"6CO2 + 6H2O --> C6H12O6 + 6O2","autoResize":true},
|
||||
{"type":"rectangle","id":"lf","x":150,"y":90,"width":520,"height":380,"backgroundColor":"#d3f9d8","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#22c55e","strokeWidth":1,"opacity":35},
|
||||
{"type":"text","id":"lfl","x":170,"y":96,"text":"Inside the Leaf","fontSize":16,"fontFamily":1,"strokeColor":"#15803d","originalText":"Inside the Leaf","autoResize":true},
|
||||
|
||||
{"type":"rectangle","id":"lr","x":190,"y":190,"width":160,"height":70,"backgroundColor":"#fff3bf","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#f59e0b","boundElements":[{"id":"t_lr","type":"text"},{"id":"a1","type":"arrow"},{"id":"a2","type":"arrow"},{"id":"a3","type":"arrow"},{"id":"a5","type":"arrow"}]},
|
||||
{"type":"text","id":"t_lr","x":195,"y":205,"width":150,"height":20,"text":"Light Reactions","fontSize":16,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"lr","originalText":"Light Reactions","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"a1","x":350,"y":225,"width":120,"height":0,"points":[[0,0],[120,0]],"strokeColor":"#1e1e1e","strokeWidth":2,"endArrowhead":"arrow","boundElements":[{"id":"t_a1","type":"text"}]},
|
||||
{"type":"text","id":"t_a1","x":390,"y":205,"width":40,"height":20,"text":"ATP","fontSize":14,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"a1","originalText":"ATP","autoResize":true},
|
||||
|
||||
{"type":"rectangle","id":"cc","x":470,"y":190,"width":160,"height":70,"backgroundColor":"#d0bfff","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#8b5cf6","boundElements":[{"id":"t_cc","type":"text"},{"id":"a1","type":"arrow"},{"id":"a4","type":"arrow"},{"id":"a6","type":"arrow"}]},
|
||||
{"type":"text","id":"t_cc","x":475,"y":205,"width":150,"height":20,"text":"Calvin Cycle","fontSize":16,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"cc","originalText":"Calvin Cycle","autoResize":true},
|
||||
|
||||
{"type":"rectangle","id":"sl","x":10,"y":200,"width":120,"height":50,"backgroundColor":"#fff3bf","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#f59e0b","boundElements":[{"id":"t_sl","type":"text"},{"id":"a2","type":"arrow"}]},
|
||||
{"type":"text","id":"t_sl","x":15,"y":210,"width":110,"height":20,"text":"Sunlight","fontSize":16,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"sl","originalText":"Sunlight","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"a2","x":130,"y":225,"width":60,"height":0,"points":[[0,0],[60,0]],"strokeColor":"#f59e0b","strokeWidth":2,"endArrowhead":"arrow"},
|
||||
|
||||
{"type":"rectangle","id":"wa","x":200,"y":360,"width":140,"height":50,"backgroundColor":"#a5d8ff","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#4a9eed","boundElements":[{"id":"t_wa","type":"text"},{"id":"a3","type":"arrow"}]},
|
||||
{"type":"text","id":"t_wa","x":205,"y":370,"width":130,"height":20,"text":"Water (H2O)","fontSize":16,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"wa","originalText":"Water (H2O)","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"a3","x":270,"y":360,"width":0,"height":-100,"points":[[0,0],[0,-100]],"strokeColor":"#4a9eed","strokeWidth":2,"endArrowhead":"arrow"},
|
||||
|
||||
{"type":"rectangle","id":"co","x":480,"y":360,"width":130,"height":50,"backgroundColor":"#ffd8a8","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#f59e0b","boundElements":[{"id":"t_co","type":"text"},{"id":"a4","type":"arrow"}]},
|
||||
{"type":"text","id":"t_co","x":485,"y":370,"width":120,"height":20,"text":"CO2","fontSize":16,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"co","originalText":"CO2","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"a4","x":545,"y":360,"width":0,"height":-100,"points":[[0,0],[0,-100]],"strokeColor":"#f59e0b","strokeWidth":2,"endArrowhead":"arrow"},
|
||||
|
||||
{"type":"rectangle","id":"ox","x":540,"y":100,"width":100,"height":40,"backgroundColor":"#ffc9c9","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#ef4444","boundElements":[{"id":"t_ox","type":"text"},{"id":"a5","type":"arrow"}]},
|
||||
{"type":"text","id":"t_ox","x":545,"y":105,"width":90,"height":20,"text":"O2","fontSize":16,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"ox","originalText":"O2","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"a5","x":310,"y":190,"width":230,"height":-50,"points":[[0,0],[230,-50]],"strokeColor":"#ef4444","strokeWidth":2,"endArrowhead":"arrow"},
|
||||
|
||||
{"type":"rectangle","id":"gl","x":690,"y":195,"width":120,"height":60,"backgroundColor":"#c3fae8","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#22c55e","boundElements":[{"id":"t_gl","type":"text"},{"id":"a6","type":"arrow"}]},
|
||||
{"type":"text","id":"t_gl","x":695,"y":210,"width":110,"height":25,"text":"Glucose","fontSize":18,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"gl","originalText":"Glucose","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"a6","x":630,"y":225,"width":60,"height":0,"points":[[0,0],[60,0]],"strokeColor":"#22c55e","strokeWidth":2,"endArrowhead":"arrow"},
|
||||
|
||||
{"type":"ellipse","id":"sun","x":30,"y":110,"width":50,"height":50,"backgroundColor":"#fff3bf","fillStyle":"solid","strokeColor":"#f59e0b","strokeWidth":2},
|
||||
{"type":"arrow","id":"r1","x":55,"y":108,"width":0,"height":-14,"points":[[0,0],[0,-14]],"strokeColor":"#f59e0b","strokeWidth":2,"endArrowhead":null,"startArrowhead":null},
|
||||
{"type":"arrow","id":"r2","x":55,"y":162,"width":0,"height":14,"points":[[0,0],[0,14]],"strokeColor":"#f59e0b","strokeWidth":2,"endArrowhead":null,"startArrowhead":null},
|
||||
{"type":"arrow","id":"r3","x":28,"y":135,"width":-14,"height":0,"points":[[0,0],[-14,0]],"strokeColor":"#f59e0b","strokeWidth":2,"endArrowhead":null,"startArrowhead":null},
|
||||
{"type":"arrow","id":"r4","x":82,"y":135,"width":14,"height":0,"points":[[0,0],[14,0]],"strokeColor":"#f59e0b","strokeWidth":2,"endArrowhead":null,"startArrowhead":null}
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Example 3: Sequence Diagram (UML-style)
|
||||
|
||||
Demonstrates a sequence diagram with actors, dashed lifelines, and message arrows.
|
||||
|
||||
```json
|
||||
[
|
||||
{"type":"text","id":"title","x":200,"y":15,"text":"MCP Apps -- Sequence Flow","fontSize":24,"fontFamily":1,"strokeColor":"#1e1e1e","originalText":"MCP Apps -- Sequence Flow","autoResize":true},
|
||||
|
||||
{"type":"rectangle","id":"uHead","x":60,"y":60,"width":100,"height":40,"backgroundColor":"#a5d8ff","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#4a9eed","strokeWidth":2,"boundElements":[{"id":"t_uHead","type":"text"}]},
|
||||
{"type":"text","id":"t_uHead","x":65,"y":65,"width":90,"height":20,"text":"User","fontSize":16,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"uHead","originalText":"User","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"uLine","x":110,"y":100,"width":0,"height":400,"points":[[0,0],[0,400]],"strokeColor":"#b0b0b0","strokeWidth":1,"strokeStyle":"dashed","endArrowhead":null},
|
||||
|
||||
{"type":"rectangle","id":"aHead","x":230,"y":60,"width":100,"height":40,"backgroundColor":"#d0bfff","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#8b5cf6","strokeWidth":2,"boundElements":[{"id":"t_aHead","type":"text"}]},
|
||||
{"type":"text","id":"t_aHead","x":235,"y":65,"width":90,"height":20,"text":"Agent","fontSize":16,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"aHead","originalText":"Agent","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"aLine","x":280,"y":100,"width":0,"height":400,"points":[[0,0],[0,400]],"strokeColor":"#b0b0b0","strokeWidth":1,"strokeStyle":"dashed","endArrowhead":null},
|
||||
|
||||
{"type":"rectangle","id":"sHead","x":420,"y":60,"width":130,"height":40,"backgroundColor":"#ffd8a8","fillStyle":"solid","roundness":{"type":3},"strokeColor":"#f59e0b","strokeWidth":2,"boundElements":[{"id":"t_sHead","type":"text"}]},
|
||||
{"type":"text","id":"t_sHead","x":425,"y":65,"width":120,"height":20,"text":"Server","fontSize":16,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"sHead","originalText":"Server","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"sLine","x":485,"y":100,"width":0,"height":400,"points":[[0,0],[0,400]],"strokeColor":"#b0b0b0","strokeWidth":1,"strokeStyle":"dashed","endArrowhead":null},
|
||||
|
||||
{"type":"arrow","id":"m1","x":110,"y":150,"width":170,"height":0,"points":[[0,0],[170,0]],"strokeColor":"#1e1e1e","strokeWidth":2,"endArrowhead":"arrow","boundElements":[{"id":"t_m1","type":"text"}]},
|
||||
{"type":"text","id":"t_m1","x":165,"y":130,"width":60,"height":20,"text":"request","fontSize":14,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"m1","originalText":"request","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"m2","x":280,"y":200,"width":205,"height":0,"points":[[0,0],[205,0]],"strokeColor":"#8b5cf6","strokeWidth":2,"endArrowhead":"arrow","boundElements":[{"id":"t_m2","type":"text"}]},
|
||||
{"type":"text","id":"t_m2","x":352,"y":180,"width":60,"height":20,"text":"tools/call","fontSize":14,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"m2","originalText":"tools/call","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"m3","x":485,"y":260,"width":-205,"height":0,"points":[[0,0],[-205,0]],"strokeColor":"#f59e0b","strokeWidth":2,"endArrowhead":"arrow","strokeStyle":"dashed","boundElements":[{"id":"t_m3","type":"text"}]},
|
||||
{"type":"text","id":"t_m3","x":352,"y":240,"width":60,"height":20,"text":"result","fontSize":14,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"m3","originalText":"result","autoResize":true},
|
||||
|
||||
{"type":"arrow","id":"m4","x":280,"y":320,"width":-170,"height":0,"points":[[0,0],[-170,0]],"strokeColor":"#8b5cf6","strokeWidth":2,"endArrowhead":"arrow","strokeStyle":"dashed","boundElements":[{"id":"t_m4","type":"text"}]},
|
||||
{"type":"text","id":"t_m4","x":165,"y":300,"width":60,"height":20,"text":"response","fontSize":14,"fontFamily":1,"strokeColor":"#1e1e1e","textAlign":"center","verticalAlign":"middle","containerId":"m4","originalText":"response","autoResize":true}
|
||||
]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Common Mistakes to Avoid
|
||||
|
||||
- **Do NOT use `"label"` property** -- this is the #1 mistake. It is NOT part of the Excalidraw file format and will be silently ignored, producing blank shapes with no visible text. Always use container binding (`containerId` + `boundElements`) as shown in the examples above.
|
||||
- **Every bound text needs both sides linked** -- the shape needs `boundElements: [{"id": "t_xxx", "type": "text"}]` AND the text needs `containerId: "shape_id"`. If either is missing, the binding won't work.
|
||||
- **Include `originalText` and `autoResize: true`** on all text elements -- Excalidraw uses these for proper text reflow.
|
||||
- **Include `fontFamily: 1`** on all text elements -- without it, text may not render with the expected hand-drawn font.
|
||||
- **Elements overlap when y-coordinates are close** -- always check that text, boxes, and labels don't stack on top of each other
|
||||
- **Arrow labels need space** -- long labels like "ATP + NADPH" overflow short arrows. Keep labels short or make arrows wider
|
||||
- **Center titles relative to the diagram** -- estimate total width and center the title text over it
|
||||
- **Draw decorations LAST** -- cute illustrations (sun, stars, icons) should appear at the end of the array so they're drawn on top
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user