Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 735723803f | |||
| 1472cc302d | |||
| 9c200abdb1 | |||
| 9dc27880cd | |||
| 3b9c53e6db | |||
| 05dd31131f | |||
| 36ea883d45 | |||
| 6be8cdeeca | |||
| 0bc914b00c | |||
| 411e7f8ff4 | |||
| eb2e6b73fe | |||
| 664acf7426 | |||
| fd1c3da305 | |||
| 4d619bcd21 | |||
| beac2ee06a | |||
| 487487406d | |||
| 87464821d8 | |||
| 661d8f4d6c | |||
| bf13a848ef | |||
| 88286f6da3 | |||
| 5b82190460 | |||
| ea7aa0b0d4 | |||
| 7130fa50cb | |||
| 5a9c98a771 | |||
| 6cb4fe948a | |||
| 30221d8c20 | |||
| b5b1fef20a | |||
| 16fb41f9cc | |||
| 4939130485 | |||
| 8dccd6569e | |||
| db348dc467 | |||
| 88722e230d | |||
| 68fb0efe0e | |||
| e38c274f8d |
+115
@@ -0,0 +1,115 @@
|
||||
# Cline's Memory Bank
|
||||
|
||||
I am Cline, an expert software engineer with a unique characteristic: my memory resets completely between sessions. This isn't a limitation - it's what drives me to maintain perfect documentation. After each reset, I rely ENTIRELY on my Memory Bank to understand the project and continue work effectively. I MUST read ALL memory bank files at the start of EVERY task - this is not optional.
|
||||
|
||||
## Memory Bank Structure
|
||||
|
||||
The Memory Bank consists of core files and optional context files, all in Markdown format. Files build upon each other in a clear hierarchy:
|
||||
|
||||
flowchart TD
|
||||
PB[projectbrief.md] --> PC[productContext.md]
|
||||
PB --> SP[systemPatterns.md]
|
||||
PB --> TC[techContext.md]
|
||||
|
||||
PC --> AC[activeContext.md]
|
||||
SP --> AC
|
||||
TC --> AC
|
||||
|
||||
AC --> P[progress.md]
|
||||
|
||||
### Core Files (Required)
|
||||
1. `projectbrief.md`
|
||||
- Foundation document that shapes all other files
|
||||
- Created at project start if it doesn't exist
|
||||
- Defines core requirements and goals
|
||||
- Source of truth for project scope
|
||||
|
||||
2. `productContext.md`
|
||||
- Why this project exists
|
||||
- Problems it solves
|
||||
- How it should work
|
||||
- User experience goals
|
||||
|
||||
3. `activeContext.md`
|
||||
- Current work focus
|
||||
- Recent changes
|
||||
- Next steps
|
||||
- Active decisions and considerations
|
||||
- Important patterns and preferences
|
||||
- Learnings and project insights
|
||||
|
||||
4. `systemPatterns.md`
|
||||
- System architecture
|
||||
- Key technical decisions
|
||||
- Design patterns in use
|
||||
- Component relationships
|
||||
- Critical implementation paths
|
||||
|
||||
5. `techContext.md`
|
||||
- Technologies used
|
||||
- Development setup
|
||||
- Technical constraints
|
||||
- Dependencies
|
||||
- Tool usage patterns
|
||||
|
||||
6. `progress.md`
|
||||
- What works
|
||||
- What's left to build
|
||||
- Current status
|
||||
- Known issues
|
||||
- Evolution of project decisions
|
||||
|
||||
### Additional Context
|
||||
Create additional files/folders within memory-bank/ when they help organize:
|
||||
- Complex feature documentation
|
||||
- Integration specifications
|
||||
- API documentation
|
||||
- Testing strategies
|
||||
- Deployment procedures
|
||||
|
||||
## Core Workflows
|
||||
|
||||
### Plan Mode
|
||||
flowchart TD
|
||||
Start[Start] --> ReadFiles[Read Memory Bank]
|
||||
ReadFiles --> CheckFiles{Files Complete?}
|
||||
|
||||
CheckFiles -->|No| Plan[Create Plan]
|
||||
Plan --> Document[Document in Chat]
|
||||
|
||||
CheckFiles -->|Yes| Verify[Verify Context]
|
||||
Verify --> Strategy[Develop Strategy]
|
||||
Strategy --> Present[Present Approach]
|
||||
|
||||
### Act Mode
|
||||
flowchart TD
|
||||
Start[Start] --> Context[Check Memory Bank]
|
||||
Context --> Update[Update Documentation]
|
||||
Update --> Execute[Execute Task]
|
||||
Execute --> Document[Document Changes]
|
||||
|
||||
## Documentation Updates
|
||||
|
||||
Memory Bank updates occur when:
|
||||
1. Discovering new project patterns
|
||||
2. After implementing significant changes
|
||||
3. When user requests with **update memory bank** (MUST review ALL files)
|
||||
4. When context needs clarification
|
||||
|
||||
flowchart TD
|
||||
Start[Update Process]
|
||||
|
||||
subgraph Process
|
||||
P1[Review ALL Files]
|
||||
P2[Document Current State]
|
||||
P3[Clarify Next Steps]
|
||||
P4[Document Insights & Patterns]
|
||||
|
||||
P1 --> P2 --> P3 --> P4
|
||||
end
|
||||
|
||||
Start --> Process
|
||||
|
||||
Note: When triggered by **update memory bank**, I MUST review every memory bank file, even if some don't require updates. Focus particularly on activeContext.md and progress.md as they track current state.
|
||||
|
||||
REMEMBER: After every memory reset, I begin completely fresh. The Memory Bank is my only link to previous work. It must be maintained with precision and clarity, as my effectiveness depends entirely on its accuracy.
|
||||
+139
-9
@@ -1,12 +1,68 @@
|
||||
# Hermes Agent Environment Configuration
|
||||
# Copy this file to .env and fill in your API keys
|
||||
|
||||
# =============================================================================
|
||||
# CORE SETTINGS
|
||||
# =============================================================================
|
||||
# Agent backend:
|
||||
# - openai : default Hermes-Agent loop (OpenAI function-calling via OpenAI SDK)
|
||||
# - atropos : Atroposlib ServerManager/ManagedServer-backed loop (training/env integration)
|
||||
HERMES_BACKEND=openai
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LOCAL / SELF-HOSTED OPENAI-COMPATIBLE ENDPOINTS (vLLM, SGLang, llama.cpp, etc.)
|
||||
# =============================================================================
|
||||
# For local development (matches the Atropos test env defaults):
|
||||
# ATROPOS_SERVER_BASE_URL=http://127.0.0.1:8080
|
||||
# ATROPOS_SERVER_MODEL=hermes-4-36b
|
||||
# For hosted inference (Nous Research inference API):
|
||||
ATROPOS_SERVER_BASE_URL=
|
||||
ATROPOS_SERVER_MODEL=
|
||||
ATROPOS_TOKENIZER_NAME=
|
||||
# Set this to your Nous API key (Bearer token).
|
||||
ATROPOS_SERVER_API_KEY=
|
||||
|
||||
# Debugging (prints to stdout; use with care)
|
||||
# HERMES_DEBUG_ATROPOS_REQUEST=1
|
||||
# HERMES_DEBUG_ATROPOS_RESPONSE=1
|
||||
# HERMES_DEBUG_OPENAI_REQUEST=1
|
||||
# HERMES_DEBUG_OPENAI_RESPONSE=1
|
||||
|
||||
# =============================================================================
|
||||
# LOCAL / SELF-HOSTED OPENAI-COMPATIBLE ENDPOINTS (vLLM, SGLang, llama.cpp, etc.)
|
||||
# =============================================================================
|
||||
# If you set ATROPOS_SERVER_BASE_URL or OPENAI_BASE_URL, Hermes will use it instead
|
||||
# of OpenRouter.
|
||||
#
|
||||
# Local server convenience (base URL without /v1):
|
||||
# llama.cpp example (see `Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh`):
|
||||
# ATROPOS_SERVER_BASE_URL=http://127.0.0.1:8080
|
||||
# ATROPOS_SERVER_MODEL=hermes-4-36b
|
||||
# ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B
|
||||
# ATROPOS_SERVER_API_KEY=local
|
||||
#
|
||||
# Hosted Nous inference API:
|
||||
# ATROPOS_SERVER_BASE_URL=https://inference-api.nousresearch.com
|
||||
# ATROPOS_SERVER_MODEL=Hermes-4.3-36B
|
||||
# ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B
|
||||
# ATROPOS_SERVER_API_KEY=sk-... (Bearer token)
|
||||
#
|
||||
# If you plan to run GRPO-style group sampling (e.g. `--env.group_size 4`) against
|
||||
# llama.cpp, start the server with at least that many slots, e.g.:
|
||||
# LLAMA_CPP_PARALLEL=4 Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
#
|
||||
# Generic OpenAI-compatible (base URL should include /v1):
|
||||
# OPENAI_BASE_URL=http://127.0.0.1:8080/v1
|
||||
# OPENAI_API_KEY=local
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (OpenRouter)
|
||||
# =============================================================================
|
||||
# OpenRouter provides access to many models through one API
|
||||
# All LLM calls go through OpenRouter - no direct provider keys needed
|
||||
# Get your key at: https://openrouter.ai/keys
|
||||
OPENROUTER_BASE_URL=https://openrouter.ai/api/v1
|
||||
OPENROUTER_API_KEY=
|
||||
|
||||
# Default model to use (OpenRouter format: provider/model)
|
||||
@@ -42,10 +98,9 @@ TERMINAL_ENV=local
|
||||
|
||||
|
||||
# Container images (for singularity/docker/modal backends)
|
||||
TERMINAL_DOCKER_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://nikolaik/python-nodejs:python3.11-nodejs20
|
||||
TERMINAL_MODAL_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
|
||||
TERMINAL_DOCKER_IMAGE=python:3.11
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://python:3.11
|
||||
TERMINAL_MODAL_IMAGE=python:3.11
|
||||
|
||||
# Working directory for terminal commands
|
||||
# For CLI: "." means current directory (resolved automatically from config.yaml)
|
||||
@@ -93,12 +148,87 @@ TERMINAL_LIFETIME_SECONDS=300
|
||||
# SUDO_PASSWORD=your_password_here
|
||||
|
||||
# =============================================================================
|
||||
# MODAL CLOUD BACKEND (Optional - for TERMINAL_ENV=modal)
|
||||
# MODAL CLOUD BACKEND (for TERMINAL_ENV=modal)
|
||||
# =============================================================================
|
||||
# Modal uses CLI authentication, not environment variables.
|
||||
# Run: pip install modal && modal setup
|
||||
# This will authenticate via browser and store credentials locally.
|
||||
# No API key needed in .env - Modal handles auth automatically.
|
||||
# Modal provides cloud sandboxes with per-second billing and auto-scaling.
|
||||
# This implementation uses a warm pool of sandboxes for cost efficiency.
|
||||
#
|
||||
# SETUP:
|
||||
# pip install modal && modal setup
|
||||
# (Authenticates via browser, stores credentials locally)
|
||||
#
|
||||
# FEATURES:
|
||||
# - Auto-scaling warm sandbox pool (no cold start after first use)
|
||||
# - Named sandbox recovery (reconnects after restart)
|
||||
# - Profile-based heterogeneous environments (CPU, GPU, different images)
|
||||
# - Server-side idle_timeout protection against orphaned sandboxes
|
||||
|
||||
# Modal app name (groups all sandboxes, used for recovery)
|
||||
TERMINAL_MODAL_APP_NAME=hermes-sandbox
|
||||
|
||||
# Default profile when none specified
|
||||
TERMINAL_MODAL_DEFAULT_PROFILE=default
|
||||
|
||||
# Profile config file (optional - YAML format, see modal_profiles.yaml)
|
||||
# TERMINAL_MODAL_PROFILES_FILE=modal_profiles.yaml
|
||||
|
||||
# --- Default Profile Settings (used if no YAML file) ---
|
||||
# These apply when no profile is specified or for the "default" profile
|
||||
TERMINAL_MODAL_IMAGE=python:3.11
|
||||
TERMINAL_MODAL_MIN_POOL=1
|
||||
TERMINAL_MODAL_MAX_POOL=5
|
||||
TERMINAL_MODAL_IDLE_TIMEOUT=120
|
||||
TERMINAL_MODAL_MAX_LIFETIME=3600
|
||||
TERMINAL_MODAL_SCALE_DOWN_IDLE=180
|
||||
|
||||
# --- Custom Profile Example: pytorch-gpu ---
|
||||
# Uncomment to enable a GPU profile for ML tasks
|
||||
# Usage: terminal_tool("python train.py", profile="pytorch-gpu")
|
||||
#
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_IMAGE=pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_GPU=T4
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MEMORY=16384
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MIN_POOL=0
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_MAX_POOL=2
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_IDLE_TIMEOUT=60
|
||||
|
||||
# --- Custom Profile Example: node ---
|
||||
# Uncomment to enable a Node.js profile
|
||||
# Usage: terminal_tool("npm test", profile="node")
|
||||
#
|
||||
# TERMINAL_MODAL_PROFILE_node_IMAGE=node:18
|
||||
# TERMINAL_MODAL_PROFILE_node_MIN_POOL=0
|
||||
# TERMINAL_MODAL_PROFILE_node_MAX_POOL=3
|
||||
|
||||
# =============================================================================
|
||||
# MODAL SECRETS (Secure credential injection)
|
||||
# =============================================================================
|
||||
# Modal Secrets allow you to securely pass API keys, passwords, and other
|
||||
# sensitive data to your sandboxes without exposing them in code or logs.
|
||||
#
|
||||
# SETUP SECRETS:
|
||||
# 1. Via Dashboard: https://modal.com/secrets
|
||||
# 2. Via CLI: modal secret create my-secret KEY1=value1 KEY2=value2
|
||||
# 3. Via CLI with env: modal secret create my-secret API_KEY="$API_KEY"
|
||||
#
|
||||
# LIST SECRETS:
|
||||
# modal secret list
|
||||
#
|
||||
# DELETE SECRETS:
|
||||
# modal secret delete my-secret
|
||||
|
||||
# Global secrets applied to ALL profiles (comma-separated secret names)
|
||||
# These secrets must be created on Modal dashboard or via CLI first
|
||||
# TERMINAL_MODAL_SECRETS=my-api-keys,database-creds
|
||||
|
||||
# Per-profile secrets (comma-separated secret names)
|
||||
# TERMINAL_MODAL_PROFILE_pytorch_gpu_SECRETS=huggingface-token,wandb-key
|
||||
|
||||
# Per-profile environment variables (semicolon-separated KEY=VALUE pairs)
|
||||
# TERMINAL_MODAL_PROFILE_default_ENV_VARS=DEBUG=1;LOG_LEVEL=info
|
||||
|
||||
# Load local .env file into sandbox (useful for development)
|
||||
# TERMINAL_MODAL_PROFILE_default_USE_DOTENV=true
|
||||
|
||||
# =============================================================================
|
||||
# BROWSER TOOL CONFIGURATION (agent-browser + Browserbase)
|
||||
|
||||
+20
@@ -46,3 +46,23 @@ testlogs
|
||||
|
||||
# CLI config (may contain sensitive SSH paths)
|
||||
cli-config.yaml
|
||||
|
||||
.DS_Store
|
||||
|
||||
# artifacts
|
||||
*.jsonl
|
||||
*.html
|
||||
*.json
|
||||
*.log
|
||||
*.csv
|
||||
|
||||
# Singularity/Apptainer images (large binary files)
|
||||
*.sif
|
||||
|
||||
# Test files
|
||||
test_singularity_*.py
|
||||
test_*.py
|
||||
!tests/test_*.py
|
||||
|
||||
# Nomad data
|
||||
/tmp/NomadClient*/
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
# Project Notes
|
||||
|
||||
*Maintained by Hermes — last updated June 2025*
|
||||
|
||||
---
|
||||
|
||||
## 1. Kandinsky (Multimodal Transformer)
|
||||
- **Repo:** https://github.com/samherring99/kandinsky
|
||||
- **Local path:** `~/Desktop/Projects/kandinsky`
|
||||
- **Description:** An anything-to-anything transformer combining text, image, and audio modalities. Trains on Pokemon BLIP captions paired with Gen 1 Pokemon audio cries. Uses audio tokenization adapted from nanoGPT.
|
||||
- **Status:** Early POC. Training code exists (`model.py`) and dataset creation (`create_dataset.py`) works. Audio heads are producing the same sound — unclear if it's a training issue or data issue.
|
||||
- **TODO:**
|
||||
- Debug why audio heads produce identical output
|
||||
- Investigate if model needs more training time
|
||||
- Design a data pipeline for better/more training data
|
||||
- General repo cleanup (requirements.txt, proper CLI, etc.)
|
||||
|
||||
---
|
||||
|
||||
## 2. NightwingGameSim (LLM → GameBoy ROM Generator)
|
||||
- **Repo:** https://github.com/samherring99/NightwingGameSim
|
||||
- **Local path:** `~/Desktop/Projects/NightwingGameSim`
|
||||
- **Description:** AI-powered pipeline that turns natural language prompts into playable GameBoy ROM files. Generates C code, compiles with GBDK, outputs `.gb` files. Supports Claude API, local Llama, and RAG backends.
|
||||
- **Status:** Functional — generation pipeline works end-to-end with Claude 4 system prompt. Has tests, docs, examples, and retry logic.
|
||||
- **TODO:**
|
||||
- Harden the repo, clean up structure
|
||||
- Build a better testing pipeline
|
||||
- Come up with better prompt ideas / examples
|
||||
|
||||
---
|
||||
|
||||
## 3. ContentBasedMIR (Music Information Retrieval)
|
||||
- **Repo:** https://github.com/samherring99/ContentBasedMIR
|
||||
- **Local path:** `~/Desktop/Projects/ContentBasedMIR`
|
||||
- **Description:** Music similarity analysis using Spotify API track data. Extracts 54 audio features per song and visualizes similarity matrices for music recommendation.
|
||||
- **Status:** Early stage. Can download Spotify track analysis data and plot similarity matrices. Needs significant expansion.
|
||||
- **TODO:**
|
||||
- Expand analysis pipeline with more features
|
||||
- Integrate with text message data for personalized recommendations
|
||||
- Build out visualization and exploration tools
|
||||
- General modernization (dependencies, structure)
|
||||
|
||||
---
|
||||
|
||||
## 4. MessageRetrieval (iMessage RAG/SQL)
|
||||
- **Repo:** https://github.com/samherring99/MessageRetrieval
|
||||
- **Local path:** `~/Desktop/Projects/MessageRetrieval`
|
||||
- **Description:** Natural language querying over iMessage data using SQL generation (text2SQL) instead of vector embeddings. Uses LLM-as-Judge pattern for scoring and ranking retrieved messages.
|
||||
- **Status:** Has initial text2SQL pipeline and summarization tool. Recently worked on with Claude Code. Needs testing.
|
||||
- **TODO:**
|
||||
- Test out the recent Claude Code work
|
||||
- Build "iMessage Jarvis" — answer questions about texts
|
||||
- Improve SQL generation prompts and accuracy
|
||||
- Better error handling and UX
|
||||
|
||||
---
|
||||
|
||||
## 5. Grailed Embedding Search
|
||||
- **Repo:** https://github.com/samherring99/grailed-embedding-search
|
||||
- **Local path:** `~/Desktop/Projects/grailed-embedding-search`
|
||||
- **Description:** Semantic similarity search over Grailed fashion listings using CLIP embeddings and FAISS. Search by image URL or text description to find visually similar products.
|
||||
- **Status:** Functional core pipeline. CLIP ViT-B/32 embeds product cover photos into 512-dim vectors, indexed with FAISS cosine similarity. Has CLI, batch embedding, persistent index save/load, and logging.
|
||||
- **Recent work (June 2025):**
|
||||
- PR #1 — Initial cleanup: docstrings, type hints, `.gitignore`, `requirements.txt`, README rewrite
|
||||
- PR #2 — Feature improvements: persistent FAISS save/load, batch embedding, CLI (`cli.py`), proper logging throughout, lazy Grailed client, `fetch_details` toggle
|
||||
- **TODO:**
|
||||
- Embedding cache (avoid re-embedding known product URLs)
|
||||
- Async/threaded image downloads for faster batch indexing
|
||||
- Search result visualization (matplotlib grid of cover photos)
|
||||
- Filter by category, designer, price range before search
|
||||
- Web UI (Gradio or Streamlit)
|
||||
|
||||
---
|
||||
|
||||
## 6. NightwingNBA (Sports Analytics)
|
||||
- **Repo:** https://github.com/samherring99/NightwingNBA
|
||||
- **Local path:** `~/Desktop/Projects/NightwingNBA`
|
||||
- **Description:** NBA game prediction system. Builds a database of game data, trains a PyTorch model, and makes daily predictions. Has full pipeline: build DB → write data → train → predict.
|
||||
- **Status:** Functional pipeline exists. Has database building, training, prediction, and daily update scripts.
|
||||
- **TODO:**
|
||||
- Explore and potentially revive
|
||||
- Update data sources if stale
|
||||
- Improve model accuracy
|
||||
- Add visualization/reporting
|
||||
|
||||
---
|
||||
|
||||
## 7. Stable Audio Sample Explorer
|
||||
- **Repo:** https://github.com/samherring99/stable-audio-sample-explorer
|
||||
- **Local path:** `~/Desktop/Projects/stable-audio-sample-explorer`
|
||||
- **Description:** Tool for exploring audio samples generated by Stable Audio.
|
||||
- **Status:** 🪦 **Dead** — no active work needed per Sam.
|
||||
|
||||
---
|
||||
|
||||
## 8. NightwingArt (Art Tools)
|
||||
- **Repo:** https://github.com/samherring99/NightwingArt
|
||||
- **Local path:** `~/Desktop/Projects/NightwingArt`
|
||||
- **Description:** Collection of art tooling scripts — video editing, clip splicing with beat matching, damage effects, and general image manipulation.
|
||||
- **Status:** Maintenance mode. Tools exist for various effects. Work happens as-needed.
|
||||
- **TODO:**
|
||||
- Add tools as needed for new art projects
|
||||
|
||||
---
|
||||
|
||||
## 9. Claude-based VST Building ⚠️ *Needs new repo*
|
||||
- **Description:** Generate VST audio plugins for DAWs from English language prompts. LLM-powered audio plugin creation.
|
||||
- **Status:** Concept only — no repo exists yet.
|
||||
- **TODO:**
|
||||
- Create repo
|
||||
- Research VST SDK / JUCE framework
|
||||
- Design prompt → code → compile pipeline
|
||||
|
||||
---
|
||||
|
||||
## 10. Government Auction Site Scraper ⚠️ *Needs new repo*
|
||||
- **Description:** Tool that monitors and scrapes government auction sites in San Francisco for deals.
|
||||
- **Status:** Concept only — no repo exists yet.
|
||||
- **TODO:**
|
||||
- Create repo
|
||||
- Research SF government auction sites and their structure
|
||||
- Build scraper + notification system
|
||||
|
||||
---
|
||||
|
||||
## Priority Assessment
|
||||
|
||||
| Project | Activity Level | Suggested Priority |
|
||||
|---------|---------------|-------------------|
|
||||
| NightwingGameSim | Active | 🔴 High |
|
||||
| MessageRetrieval | Active | 🔴 High |
|
||||
| Kandinsky | Active | 🟡 Medium |
|
||||
| ContentBasedMIR | Exploratory | 🟡 Medium |
|
||||
| Grailed Embedding Search | Early | 🟡 Medium |
|
||||
| NightwingNBA | Dormant | 🟢 Low |
|
||||
| NightwingArt | As-needed | 🟢 Low |
|
||||
| VST Builder | Concept | 🔵 Future |
|
||||
| Gov Auction Scraper | Concept | 🔵 Future |
|
||||
| Stable Audio Explorer | Dead | ⚫ None |
|
||||
|
||||
|
||||
|
||||
@@ -37,9 +37,8 @@ All your settings are stored in `~/.hermes/` for easy access:
|
||||
|
||||
```
|
||||
~/.hermes/
|
||||
├── config.yaml # Settings (model, terminal, TTS, compression, etc.)
|
||||
├── config.yaml # Settings (model, terminal, compression, etc.)
|
||||
├── .env # API keys and secrets
|
||||
├── SOUL.md # Optional: global persona (agent embodies this personality)
|
||||
├── cron/ # Scheduled jobs
|
||||
├── sessions/ # Gateway sessions
|
||||
└── logs/ # Logs
|
||||
@@ -77,8 +76,6 @@ You need at least one LLM provider:
|
||||
| Web scraping | [Firecrawl](https://firecrawl.dev/) | `FIRECRAWL_API_KEY` |
|
||||
| Browser automation | [Browserbase](https://browserbase.com/) | `BROWSERBASE_API_KEY`, `BROWSERBASE_PROJECT_ID` |
|
||||
| Image generation | [FAL](https://fal.ai/) | `FAL_KEY` |
|
||||
| Premium TTS voices | [ElevenLabs](https://elevenlabs.io/) | `ELEVENLABS_API_KEY` |
|
||||
| OpenAI TTS voices | [OpenAI](https://platform.openai.com/api-keys) | `OPENAI_API_KEY` |
|
||||
| RL Training | [Tinker](https://tinker-console.thinkingmachines.ai/) + [WandB](https://wandb.ai/) | `TINKER_API_KEY`, `WANDB_API_KEY` |
|
||||
| Messaging | Telegram, Discord | `TELEGRAM_BOT_TOKEN`, `DISCORD_BOT_TOKEN` |
|
||||
|
||||
@@ -131,58 +128,7 @@ hermes --toolsets "web,terminal"
|
||||
hermes --list-tools
|
||||
```
|
||||
|
||||
**Available toolsets:** `web`, `terminal`, `browser`, `vision`, `creative`, `reasoning`, `skills`, `tts`, `cronjob`, and more.
|
||||
|
||||
### 🔊 Text-to-Speech
|
||||
|
||||
Convert text to speech with three providers:
|
||||
|
||||
| Provider | Quality | Cost | API Key |
|
||||
|----------|---------|------|---------|
|
||||
| **Edge TTS** (default) | Good | Free | None needed |
|
||||
| **ElevenLabs** | Excellent | Paid | `ELEVENLABS_API_KEY` |
|
||||
| **OpenAI TTS** | Good | Paid | `OPENAI_API_KEY` |
|
||||
|
||||
On Telegram, audio plays as native voice bubbles. On Discord/WhatsApp, sent as audio files. In CLI mode, saved to `~/voice-memos/`.
|
||||
|
||||
**Configure in `~/.hermes/config.yaml`:**
|
||||
```yaml
|
||||
tts:
|
||||
provider: "edge" # "edge" | "elevenlabs" | "openai"
|
||||
edge:
|
||||
voice: "en-US-AriaNeural" # 322 voices, 74 languages
|
||||
elevenlabs:
|
||||
voice_id: "pNInz6obpgDQGcFmaJgB" # Adam
|
||||
model_id: "eleven_multilingual_v2"
|
||||
openai:
|
||||
model: "gpt-4o-mini-tts"
|
||||
voice: "alloy" # alloy, echo, fable, onyx, nova, shimmer
|
||||
```
|
||||
|
||||
> **Note:** Telegram voice bubbles require `ffmpeg` for Opus conversion (Edge TTS only outputs MP3). Install with `apt install ffmpeg` or `brew install ffmpeg`. Without ffmpeg, audio is sent as a file instead of a voice bubble.
|
||||
|
||||
### 📄 Context Files (SOUL.md, AGENTS.md, .cursorrules)
|
||||
|
||||
Drop these files in your project directory and the agent automatically picks them up:
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `AGENTS.md` | Project-specific instructions, coding conventions, tool usage guidelines |
|
||||
| `SOUL.md` | Persona definition -- the agent embodies this personality and tone |
|
||||
| `.cursorrules` | Cursor IDE rules (also detected) |
|
||||
| `.cursor/rules/*.mdc` | Cursor rule files (also detected) |
|
||||
|
||||
- **AGENTS.md** is hierarchical: if subdirectories also have `AGENTS.md`, all are combined (like Codex/Cline).
|
||||
- **SOUL.md** checks cwd first, then `~/.hermes/SOUL.md` as a global fallback.
|
||||
- All context files are capped at 20,000 characters with smart truncation.
|
||||
|
||||
### 🛡️ Exec Approval (Messaging Platforms)
|
||||
|
||||
When the agent tries to run a potentially dangerous command (rm -rf, chmod 777, etc.) on Telegram/Discord/WhatsApp, instead of blocking it silently, it asks the user for approval:
|
||||
|
||||
> ⚠️ This command is potentially dangerous (recursive delete). Reply "yes" to approve.
|
||||
|
||||
Reply "yes"/"y" to approve or "no"/"n" to deny. In CLI mode, the existing interactive approval prompt (once/session/always/deny) is preserved.
|
||||
**Available toolsets:** `web`, `terminal`, `browser`, `vision`, `creative`, `reasoning`, `skills`, `cronjob`, and more.
|
||||
|
||||
### 🖥️ Terminal Backend
|
||||
|
||||
@@ -1049,6 +995,137 @@ All variables go in `~/.hermes/.env`. Run `hermes config set VAR value` to set t
|
||||
|
||||
---
|
||||
|
||||
## RL Training with Tinker
|
||||
|
||||
Hermes-Agent includes an RL training integration with [Tinker](https://thinkingmachines.ai/tinker/) (Thinking Machines) and [Atropos](https://github.com/NousResearch/atropos) for training language models with reinforcement learning from agent trajectories.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
1. **Install with Atropos extras** (includes Tinker SDK, atroposlib, torch, wandb):
|
||||
```bash
|
||||
pip install -e ".[atropos]"
|
||||
```
|
||||
|
||||
2. **Initialize the tinker-atropos submodule**:
|
||||
```bash
|
||||
git submodule update --init
|
||||
pip install -e ./tinker-atropos
|
||||
```
|
||||
|
||||
3. **Get API keys**:
|
||||
- `TINKER_API_KEY` from [Tinker Console](https://tinker-console.thinkingmachines.ai/keys) (requires billing setup)
|
||||
- `WANDB_API_KEY` from [Weights & Biases](https://wandb.ai/settings) (for metrics tracking)
|
||||
|
||||
4. **Add keys to your `.env` file**:
|
||||
```bash
|
||||
# Add to .env or ~/.hermes/.env
|
||||
TINKER_API_KEY=your_tinker_key
|
||||
WANDB_API_KEY=your_wandb_key
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
The RL training pipeline uses three processes that communicate over HTTP:
|
||||
|
||||
```
|
||||
┌──────────────────────┐ ┌─────────────────────┐ ┌────────────────────────┐
|
||||
│ Atropos Rollout API │ │ Tinker Trainer │ │ Environment │
|
||||
│ (port 8000) │◄──│ (port 8001) │◄──│ (worker) │
|
||||
│ │ │ │ │ │
|
||||
│ • Collects batches │ │ • LoRA training │ │ • Generates prompts │
|
||||
│ • Coordinates env │ │ • Inference server │ │ • Calls inference API │
|
||||
│ and trainer │ │ • Weight updates │ │ • Scores responses │
|
||||
│ │ │ • WandB logging │ │ • Sends scored batches │
|
||||
└──────────────────────┘ └─────────────────────┘ └────────────────────────┘
|
||||
```
|
||||
|
||||
### Quick Start: GSM8k Agent Training
|
||||
|
||||
This example trains a model on math problems using a Python REPL tool — the model learns to write and execute Python code to solve math:
|
||||
|
||||
```bash
|
||||
# Terminal 1: Start Atropos Rollout API
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
run-api
|
||||
|
||||
# Terminal 2: Start Tinker Trainer + Inference Server
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
python launch_training.py --config configs/gsm8k_agent.yaml
|
||||
|
||||
# Terminal 3: Start GSM8k Agent Environment
|
||||
cd tinker-atropos
|
||||
source ../.venv/bin/activate
|
||||
set -a && source ../.env && set +a
|
||||
python tinker_atropos/environments/gsm8k_agent.py serve --config configs/gsm8k_agent.yaml
|
||||
```
|
||||
|
||||
### Available Environments
|
||||
|
||||
| Environment | File | Description |
|
||||
|------------|------|-------------|
|
||||
| `gsm8k` | `gsm8k_tinker.py` | Standard GSM8k math (no tools) |
|
||||
| `gsm8k_agent` | `gsm8k_agent.py` | GSM8k with Python REPL tool calling |
|
||||
|
||||
### Configuration
|
||||
|
||||
Configs are YAML files in `tinker-atropos/configs/` with three sections:
|
||||
|
||||
```yaml
|
||||
env: # Atropos environment settings
|
||||
group_size: 4 # Parallel rollouts per problem
|
||||
batch_size: 16 # Training batch size
|
||||
tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
max_token_length: 2048 # Max generation length
|
||||
total_steps: 20 # Training steps
|
||||
|
||||
openai: # Inference server (served by Tinker trainer)
|
||||
- model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
||||
base_url: "http://localhost:8001/v1"
|
||||
|
||||
tinker: # Tinker training parameters
|
||||
lora_rank: 16 # LoRA rank (lower = faster, less capacity)
|
||||
learning_rate: 0.00005 # Learning rate
|
||||
max_token_trainer_length: 4096 # Max tokens for training
|
||||
wandb_project: "hermes-agent-rl"
|
||||
```
|
||||
|
||||
### RL CLI (Agent-Driven Training)
|
||||
|
||||
For interactive training management via the Hermes agent:
|
||||
|
||||
```bash
|
||||
# Interactive mode - let the agent manage training
|
||||
python rl_cli.py --interactive
|
||||
|
||||
# List available environments
|
||||
python rl_cli.py --list-environments
|
||||
|
||||
# Direct task
|
||||
python rl_cli.py "Train a model on GSM8k with tool use"
|
||||
```
|
||||
|
||||
### Sandbox Backends for Agent Environments
|
||||
|
||||
For agent environments that need isolated tool execution (e.g., SWE tasks), Hermes-Agent supports multiple sandbox backends:
|
||||
|
||||
| Backend | Use Case | Command |
|
||||
|---------|----------|---------|
|
||||
| **Nomad + Docker** | Default, local development | `--env.tool_pool_mode nomad` |
|
||||
| **Nomad + Singularity** | HPC clusters without Docker | `--env.tool_pool_mode nomad --env.driver singularity` |
|
||||
| **Modal** | Cloud-based, auto-scaling | `--env.tool_pool_mode modal` |
|
||||
|
||||
See [docs/MODAL_BACKEND.md](docs/MODAL_BACKEND.md) for Modal backend details.
|
||||
|
||||
### Cost
|
||||
|
||||
Check the [Tinker Rate Card](https://tinker-console.thinkingmachines.ai/rate-card) for available models and pricing.
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
```bash
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
# Dockerfile for atropos-agent sandbox server
|
||||
# Runs inside Nomad containers to handle tool execution
|
||||
# Includes bubblewrap for namespace-based slot isolation
|
||||
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
# Bubblewrap for namespace isolation
|
||||
bubblewrap \
|
||||
# `script` for PTY allocation (used for stable tmux+asciinema startup)
|
||||
util-linux \
|
||||
# Git for SWE-style tasks (cloning repos)
|
||||
git \
|
||||
# tmux for stateful terminal sessions (Phase 4.7+)
|
||||
tmux \
|
||||
# Common tools agents might need
|
||||
curl \
|
||||
wget \
|
||||
jq \
|
||||
# Cleanup
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies (sandbox server + optional terminal recording)
|
||||
RUN pip install --no-cache-dir aiohttp asciinema
|
||||
|
||||
# Copy the sandbox server
|
||||
COPY sandbox_server.py /app/sandbox_server.py
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Create data directory for slot workspaces
|
||||
RUN mkdir -p /data
|
||||
|
||||
# Verify bubblewrap is installed and working
|
||||
RUN bwrap --version
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
# Default command - can be overridden by Nomad job spec
|
||||
CMD ["python", "sandbox_server.py", "--port", "8080", "--slots", "10", "--data-dir", "/data"]
|
||||
@@ -0,0 +1,46 @@
|
||||
"""
|
||||
Atropos integration for Hermes-Agent.
|
||||
|
||||
This package is intentionally optional: Hermes-Agent should work without Atropos.
|
||||
If you import anything from `atropos.*` without having `atroposlib` installed,
|
||||
we raise a clear error with install instructions.
|
||||
|
||||
Install (recommended, from repo checkout):
|
||||
uv sync --extra atropos
|
||||
|
||||
Or (pip / editable):
|
||||
pip install -e '.[atropos]'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _require_atroposlib() -> None:
|
||||
try:
|
||||
import atroposlib # noqa: F401
|
||||
except ModuleNotFoundError as exc: # pragma: no cover
|
||||
raise ModuleNotFoundError(
|
||||
"Hermes-Agent Atropos integration requires `atroposlib`, but it is not installed.\n"
|
||||
"Install it with:\n"
|
||||
" uv sync --extra atropos\n"
|
||||
"or:\n"
|
||||
" pip install -e '.[atropos]'\n"
|
||||
) from exc
|
||||
|
||||
|
||||
_require_atroposlib()
|
||||
|
||||
# Re-export the most commonly used pieces for convenience.
|
||||
from .agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData # noqa: E402
|
||||
from .envs import AgentEnv, AgentEnvConfig # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
"AtroposAgent",
|
||||
"AgentConfig",
|
||||
"AgentResult",
|
||||
"AgentStep",
|
||||
"SequenceData",
|
||||
"AgentEnv",
|
||||
"AgentEnvConfig",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Agent abstractions for atropos-agent.
|
||||
|
||||
Provides the core AtroposAgent class for running ReACT-style agent loops.
|
||||
"""
|
||||
|
||||
from .atropos_agent import AgentConfig, AgentResult, AgentStep, AtroposAgent, SequenceData
|
||||
|
||||
__all__ = [
|
||||
"AtroposAgent",
|
||||
"AgentConfig",
|
||||
"AgentResult",
|
||||
"AgentStep",
|
||||
"SequenceData",
|
||||
]
|
||||
@@ -0,0 +1,850 @@
|
||||
"""
|
||||
ReACT-style agent implementation for atropos-agent.
|
||||
|
||||
This module provides the core AtroposAgent class that implements a basic
|
||||
Reason-Act-Observe loop with tool calling capabilities.
|
||||
|
||||
Uses ManagedServer from atroposlib for automatic token/logprob tracking,
|
||||
making trajectories ready for RL training.
|
||||
|
||||
The agent uses Hermes-style XML tags for tool calls:
|
||||
- <think>...</think> for reasoning
|
||||
- <tool_call>{"name": "...", "arguments": {...}}</tool_call> for actions
|
||||
- <tool_response>...</tool_response> for observations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import uuid4
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Union
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
|
||||
from ..tools import ToolCall, ToolRegistry, ToolResult
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Default system prompt with tool calling instructions.
|
||||
AGENT_SYSTEM_PROMPT = """You are a deep thinking AI. You MUST enclose your internal reasoning inside <think>...</think> tags.
|
||||
|
||||
You are a function calling AI model.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags.
|
||||
You must call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.
|
||||
You can ONLY respond without a tool call if you are totally certain you have the final answer to the user's question or task
|
||||
After calling & executing a function, you will be provided with function results within <tool_response></tool_response> XML tags.
|
||||
|
||||
Here are the available tools:
|
||||
<tools>
|
||||
{tools_json}
|
||||
</tools>
|
||||
|
||||
Use the following JSON schema for each tool call you will make:
|
||||
{"title": "FunctionCall", "type": "object", "properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"]}
|
||||
|
||||
## REQUIRED TOOL FORMAT
|
||||
|
||||
When you decide to call a tool, your assistant message MUST be:
|
||||
1) exactly one <think>...</think> block, followed by
|
||||
2) one or more <tool_call>...</tool_call> blocks,
|
||||
and NOTHING else in that message.
|
||||
|
||||
If you need to explain anything, put it inside <think>. Do NOT write natural language outside <think> or <tool_call>.
|
||||
|
||||
For each function call return a JSON object with function name and arguments within <tool_call></tool_call> XML tags as follows:
|
||||
<tool_call>
|
||||
{"name": "<function-name>", "arguments": {"arg1": "value1"}}
|
||||
</tool_call>
|
||||
|
||||
Each <tool_call> must be on its own and contain ONLY the JSON object (no extra text).
|
||||
The JSON inside <tool_call> MUST be valid JSON with double quotes.
|
||||
|
||||
Do NOT output <tool_response> in an assistant message.
|
||||
|
||||
After you receive tool results, you may either call more tools (same required format) or provide the final answer.
|
||||
When providing the final answer, do NOT include any <tool_call> blocks.
|
||||
|
||||
## TERMINAL TOOL NOTES
|
||||
|
||||
- Commands execute under POSIX `/bin/sh` (not bash).
|
||||
- Each tool call runs in a fresh shell: environment changes (like `cd` or venv activation) do not persist across tool calls.
|
||||
- Avoid bash-only features like `source`, `[[ ... ]]`, or process substitution.
|
||||
- Prefer explicit venv usage:
|
||||
- `python -m venv .venv && . .venv/bin/activate && python -m pip install -e .` (POSIX `.` activation), or
|
||||
- `.venv/bin/python -m pip install -e .` (no activation required).
|
||||
|
||||
## ICL (examples)
|
||||
|
||||
User: Show the current directory.
|
||||
Assistant:
|
||||
<think>I should run pwd.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "pwd"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "/tmp\\n"}</tool_response>
|
||||
Assistant: /tmp
|
||||
|
||||
User: List files, then count them.
|
||||
Assistant:
|
||||
<think>I should count files.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "ls -1 | wc -l"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "3\\n"}</tool_response>
|
||||
Assistant: 3
|
||||
|
||||
User: Run pwd, then print ok (two tool calls).
|
||||
Assistant:
|
||||
<think>I should run two commands.</think>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "pwd"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "terminal", "arguments": {"command": "echo ok"}}
|
||||
</tool_call>
|
||||
User: <tool_response>{"success": true, "output": "/tmp\\n"}</tool_response>
|
||||
User: <tool_response>{"success": true, "output": "ok\\n"}</tool_response>
|
||||
Assistant: ok
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""Configuration for the AtroposAgent."""
|
||||
|
||||
# Generation parameters
|
||||
temperature: Optional[float] = 0.7
|
||||
# Default to "let the backend decide" (important for tool-tag completions that may be longer).
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
# Agent behavior
|
||||
max_steps: int = 50
|
||||
system_prompt: Optional[str] = None
|
||||
tool_delay_s: float = 0.0
|
||||
|
||||
# Working directory for tools
|
||||
working_dir: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceData:
|
||||
"""Token/logprob data from a single completion."""
|
||||
|
||||
full_text: str
|
||||
tokens: List[int]
|
||||
masked_tokens: List[int] # -100 for prompt, actual IDs for completion
|
||||
logprobs: List[float] # 1.0 for prompt, actual values for completion
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_sequence_node(cls, node) -> "SequenceData":
|
||||
"""Create from a ManagedServer SequenceNode."""
|
||||
return cls(
|
||||
full_text=node.full_text,
|
||||
tokens=node.tokens,
|
||||
masked_tokens=node.masked_tokens,
|
||||
logprobs=node.logprobs,
|
||||
metadata=getattr(node, "metadata", None),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStep:
|
||||
"""A single step in the agent's trajectory."""
|
||||
|
||||
step_number: int
|
||||
assistant_message: str
|
||||
tool_calls: List[ToolCall] = field(default_factory=list)
|
||||
tool_results: List[ToolResult] = field(default_factory=list)
|
||||
sequence_data: Optional[SequenceData] = None # Token data from this step
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
return len(self.tool_calls) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResult:
|
||||
"""Result of running an agent trajectory."""
|
||||
|
||||
success: bool
|
||||
final_response: str
|
||||
steps: List[AgentStep] = field(default_factory=list)
|
||||
total_tokens: int = 0
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Full trajectory token data for RL training
|
||||
trajectory_data: Optional[SequenceData] = None
|
||||
|
||||
@property
|
||||
def num_steps(self) -> int:
|
||||
return len(self.steps)
|
||||
|
||||
@property
|
||||
def total_tool_calls(self) -> int:
|
||||
return sum(len(step.tool_calls) for step in self.steps)
|
||||
|
||||
def to_messages(self) -> List[Dict[str, str]]:
|
||||
"""Convert trajectory to messages format for logging."""
|
||||
messages = []
|
||||
for step in self.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
# Combine all tool responses
|
||||
responses = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": responses})
|
||||
return messages
|
||||
|
||||
def to_scored_data(self, score: float) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Convert to format suitable for ScoredDataGroup.
|
||||
|
||||
Args:
|
||||
score: The score for this trajectory
|
||||
|
||||
Returns:
|
||||
Dict with tokens, masks, scores suitable for training, or None if no data
|
||||
"""
|
||||
if self.trajectory_data is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"tokens": self.trajectory_data.tokens,
|
||||
"masks": self.trajectory_data.masked_tokens,
|
||||
"scores": score,
|
||||
"logprobs": self.trajectory_data.logprobs,
|
||||
}
|
||||
|
||||
|
||||
class AtroposAgent:
|
||||
"""
|
||||
A ReACT-style agent that uses LLMs with tool calling.
|
||||
|
||||
This implementation wraps ManagedServer for automatic token/logprob tracking,
|
||||
making trajectories ready for RL training.
|
||||
|
||||
Example:
|
||||
# `server` may be an Atropos `ServerManager` (recommended) or a single `APIServer`.
|
||||
# In practice, environments usually construct this via `BaseEnv`.
|
||||
server = ...
|
||||
tools = ToolRegistry()
|
||||
tools.register(BashTool())
|
||||
|
||||
agent = AtroposAgent(server=server, tools=tools)
|
||||
result = await agent.run("List the files in the current directory")
|
||||
|
||||
# Access token data for training
|
||||
if result.trajectory_data:
|
||||
print(f"Tokens: {result.trajectory_data.tokens}")
|
||||
print(f"Masked: {result.trajectory_data.masked_tokens}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server, # ServerManager or APIServer
|
||||
tools: Optional[ToolRegistry] = None,
|
||||
config: Optional[AgentConfig] = None,
|
||||
tokenizer: Optional[Any] = None,
|
||||
execute_tool: Optional[Callable[[ToolCall], Awaitable[ToolResult]]] = None,
|
||||
):
|
||||
self.server = server
|
||||
self.tools = tools or ToolRegistry()
|
||||
self.config = config or AgentConfig()
|
||||
self.tokenizer = tokenizer or getattr(server, "tokenizer", None)
|
||||
self.execute_tool = execute_tool or self.tools.execute
|
||||
|
||||
@asynccontextmanager
|
||||
async def _managed(self) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
Yield a ManagedServer-like object.
|
||||
|
||||
- If `self.server` is a ServerManager, use its `managed_server()` context manager.
|
||||
- If `self.server` is a single APIServer, wrap it in `ManagedServer` directly.
|
||||
"""
|
||||
if os.getenv("ATROPOS_BYPASS_MANAGED_SERVER") == "1":
|
||||
yield _DirectChatCompletionClient(server=self.server)
|
||||
return
|
||||
if hasattr(self.server, "managed_server"):
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
yield managed
|
||||
else:
|
||||
managed = ManagedServer(server=self.server, tokenizer=self.tokenizer)
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
managed.reset()
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build the system prompt with tool descriptions."""
|
||||
if self.config.system_prompt:
|
||||
return self.config.system_prompt
|
||||
|
||||
tools_json = self.tools.get_prompt_tool_definitions_json()
|
||||
# Avoid `str.format()` here because the prompt contains many literal `{}` braces
|
||||
# in JSON examples; we only want to substitute the single `{tools_json}` token.
|
||||
return AGENT_SYSTEM_PROMPT.replace("{tools_json}", tools_json)
|
||||
|
||||
def _infer_server_model_for_debug(self) -> Optional[str]:
|
||||
"""
|
||||
Best-effort inference of the configured model name for debug payload saving.
|
||||
|
||||
ManagedServer/server_manager typically injects `model` internally, so `chat_kwargs`
|
||||
may not contain it. For replaying saved payloads via curl, it's useful to persist it.
|
||||
"""
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
model = getattr(cfg, "model_name", None) or getattr(s0, "model_name", None)
|
||||
if isinstance(model, str) and model:
|
||||
return model
|
||||
model = getattr(self.server, "model_name", None) or getattr(self.server, "model", None)
|
||||
if isinstance(model, str) and model:
|
||||
return model
|
||||
return None
|
||||
|
||||
def _infer_server_base_url_for_debug(self) -> Optional[str]:
|
||||
"""
|
||||
Best-effort inference of the configured base_url for debug logging.
|
||||
|
||||
This is helpful when diagnosing hangs / retries at the transport layer.
|
||||
"""
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
|
||||
if isinstance(base_url, str) and base_url:
|
||||
return base_url
|
||||
base_url = getattr(self.server, "base_url", None)
|
||||
if isinstance(base_url, str) and base_url:
|
||||
return base_url
|
||||
return None
|
||||
|
||||
def _extract_response_metadata(self, response: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract lightweight, JSON-serializable metadata from an OpenAI-style response.
|
||||
|
||||
This is useful for debugging training runs, especially when ManagedServer state
|
||||
tracking is unavailable (e.g. OpenAI-compatible chat endpoints).
|
||||
"""
|
||||
meta: Dict[str, Any] = {}
|
||||
try:
|
||||
rid = getattr(response, "id", None)
|
||||
if isinstance(rid, str) and rid:
|
||||
meta["id"] = rid
|
||||
model = getattr(response, "model", None)
|
||||
if isinstance(model, str) and model:
|
||||
meta["model"] = model
|
||||
created = getattr(response, "created", None)
|
||||
if isinstance(created, int):
|
||||
meta["created"] = created
|
||||
system_fingerprint = getattr(response, "system_fingerprint", None)
|
||||
if isinstance(system_fingerprint, str) and system_fingerprint:
|
||||
meta["system_fingerprint"] = system_fingerprint
|
||||
|
||||
choices = getattr(response, "choices", None)
|
||||
if isinstance(choices, list) and choices:
|
||||
fr = getattr(choices[0], "finish_reason", None)
|
||||
if isinstance(fr, str) and fr:
|
||||
meta["finish_reason"] = fr
|
||||
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage is not None:
|
||||
if hasattr(usage, "model_dump"):
|
||||
meta["usage"] = usage.model_dump()
|
||||
elif isinstance(usage, dict):
|
||||
meta["usage"] = usage
|
||||
except Exception:
|
||||
pass
|
||||
return meta
|
||||
|
||||
def _debug_dump_request(self, *, step_num: int, chat_kwargs: Dict[str, Any]) -> None:
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST") != "1":
|
||||
return
|
||||
try:
|
||||
# Avoid dumping megabytes by default; messages can be huge.
|
||||
meta = {
|
||||
"step": step_num,
|
||||
"base_url": self._infer_server_base_url_for_debug(),
|
||||
"model": chat_kwargs.get("model") or self._infer_server_model_for_debug(),
|
||||
"chat_kwargs_keys": sorted(list(chat_kwargs.keys())),
|
||||
"n": chat_kwargs.get("n"),
|
||||
"max_tokens": chat_kwargs.get("max_tokens"),
|
||||
"temperature": chat_kwargs.get("temperature"),
|
||||
"num_messages": len(chat_kwargs.get("messages") or []),
|
||||
}
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_REQUEST ===", flush=True)
|
||||
print(meta, flush=True)
|
||||
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST_FULL") == "1":
|
||||
payload = dict(chat_kwargs)
|
||||
# Make the payload more legible and less huge.
|
||||
try:
|
||||
dumped = json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
dumped = repr(payload)
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_REQUEST_FULL ===", flush=True)
|
||||
print(dumped[:200_000], flush=True)
|
||||
|
||||
# Optional: save the FULL request payload to disk (no truncation).
|
||||
save_dir = os.getenv("ATROPOS_DEBUG_AGENT_REQUEST_SAVE_DIR")
|
||||
if save_dir:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
payload: Dict[str, Any] = dict(chat_kwargs)
|
||||
if "model" not in payload:
|
||||
model = self._infer_server_model_for_debug()
|
||||
if model:
|
||||
payload["model"] = model
|
||||
# Use a unique filename so parallel trajectories don't clobber each other.
|
||||
fname = os.path.join(
|
||||
save_dir,
|
||||
f"atropos_agent_request_step{step_num}_{int(time.time()*1000)}_{os.getpid()}_{uuid4().hex}.json",
|
||||
)
|
||||
with open(fname, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
print(f"[AtroposAgent] saved request payload: {fname}", flush=True)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _debug_dump_response(self, *, step_num: int, response: Any) -> None:
|
||||
if os.getenv("ATROPOS_DEBUG_AGENT_RESPONSE") != "1":
|
||||
return
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_RESPONSE ===", flush=True)
|
||||
print({"step": step_num, "type": type(response).__name__}, flush=True)
|
||||
try:
|
||||
dumped = response.model_dump() # openai pydantic model
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
# Keep the dump bounded; we only need enough to see the assistant message content.
|
||||
text = str(dumped)
|
||||
print(text[:200_000], flush=True)
|
||||
|
||||
async def _chat_completion_with_debug(
|
||||
self, *, managed: Any, step_num: int, chat_kwargs: Dict[str, Any]
|
||||
) -> Any:
|
||||
"""
|
||||
Call `managed.chat_completion()` with optional timeout + richer failure logging.
|
||||
|
||||
Debug env vars:
|
||||
- `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`.
|
||||
- `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting.
|
||||
"""
|
||||
# Hard guardrail: never allow a single chat completion to block for too long.
|
||||
# This is essential for RL data-gen stability; long hangs should be treated as failures (score=0).
|
||||
timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S")
|
||||
timeout_s_default = 240.0
|
||||
timeout_s = float(timeout_s_raw) if timeout_s_raw else timeout_s_default
|
||||
timeout_s = min(timeout_s, 240.0)
|
||||
|
||||
wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S")
|
||||
wait_every_s = float(wait_every_raw) if wait_every_raw else None
|
||||
|
||||
async def _await_call() -> Any:
|
||||
if not wait_every_s or wait_every_s <= 0:
|
||||
return await managed.chat_completion(**chat_kwargs)
|
||||
|
||||
# Heartbeat mode: wait in chunks without cancelling the underlying request.
|
||||
# NOTE: do NOT use `asyncio.wait_for(task, timeout=...)` here, because a timeout
|
||||
# will cancel the task and surface as `CancelledError` on the next loop.
|
||||
task = asyncio.create_task(managed.chat_completion(**chat_kwargs))
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
while True:
|
||||
done, _pending = await asyncio.wait({task}, timeout=wait_every_s)
|
||||
if task in done:
|
||||
return task.result()
|
||||
|
||||
waited = time.perf_counter() - t0
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num} still waiting for chat_completion... ({waited:.1f}s)",
|
||||
flush=True,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
task.cancel()
|
||||
raise
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(_await_call(), timeout=timeout_s)
|
||||
except asyncio.TimeoutError as e:
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_TIMEOUT ===", flush=True)
|
||||
print({"step": step_num, "timeout_s": timeout_s}, flush=True)
|
||||
raise RuntimeError(f"chat_completion timed out after {timeout_s:.1f}s") from e
|
||||
except asyncio.CancelledError:
|
||||
# Treat cancellation as a hard failure rather than crashing the whole env run.
|
||||
# (Atropos/BaseEnv may cancel tasks during shutdown or retries.)
|
||||
raise RuntimeError("chat_completion cancelled") from None
|
||||
except Exception as e:
|
||||
detail: Dict[str, Any] = {
|
||||
"step": step_num,
|
||||
"exc_type": type(e).__name__,
|
||||
"exc_str": str(e),
|
||||
}
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
try:
|
||||
detail["status_code"] = e.response.status_code
|
||||
detail["response_text"] = e.response.text[:20_000]
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(e, httpx.RequestError):
|
||||
detail["request"] = repr(getattr(e, "request", None))
|
||||
|
||||
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_FAILURE ===", flush=True)
|
||||
print(detail, flush=True)
|
||||
raise
|
||||
|
||||
async def run(
|
||||
self,
|
||||
task: str,
|
||||
initial_messages: Optional[List[Dict[str, str]]] = None,
|
||||
) -> AgentResult:
|
||||
"""
|
||||
Run the agent on a task using ManagedServer for token tracking.
|
||||
|
||||
Args:
|
||||
task: The task/prompt for the agent
|
||||
initial_messages: Optional additional context messages
|
||||
|
||||
Returns:
|
||||
AgentResult with the trajectory, final response, and token data
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": self._build_system_prompt()},
|
||||
]
|
||||
|
||||
if initial_messages:
|
||||
messages.extend(initial_messages)
|
||||
|
||||
messages.append({"role": "user", "content": task})
|
||||
|
||||
steps = []
|
||||
final_response = ""
|
||||
final_node = None
|
||||
final_prompt_messages: Optional[List[Dict[str, str]]] = None
|
||||
last_node = None
|
||||
last_prompt_messages: Optional[List[Dict[str, str]]] = None
|
||||
last_response_text: str = ""
|
||||
|
||||
# Use ManagedServer for automatic token tracking
|
||||
async with self._managed() as managed:
|
||||
for step_num in range(self.config.max_steps):
|
||||
# ReACT loop iteration here, just call -> tools -> observe until done (no tools called)
|
||||
try:
|
||||
# Keep a copy of the prompt messages used for this completion.
|
||||
# Useful for reconstructing tokens/masks when state tracking is unavailable.
|
||||
prompt_messages = list(messages)
|
||||
chat_kwargs: Dict[str, Any] = {"messages": messages, "n": 1}
|
||||
if self.config.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.config.max_tokens
|
||||
if self.config.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.config.temperature
|
||||
|
||||
t_req = time.perf_counter()
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num+1} chat_completion start "
|
||||
f"(messages={len(messages)}, max_tokens={self.config.max_tokens}, temp={self.config.temperature})",
|
||||
flush=True,
|
||||
)
|
||||
self._debug_dump_request(step_num=step_num + 1, chat_kwargs=chat_kwargs)
|
||||
response = await self._chat_completion_with_debug(
|
||||
managed=managed, step_num=step_num + 1, chat_kwargs=chat_kwargs
|
||||
)
|
||||
self._debug_dump_response(step_num=step_num + 1, response=response)
|
||||
response_meta = self._extract_response_metadata(response)
|
||||
print(
|
||||
f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
current_node = None
|
||||
if hasattr(managed, "get_state"):
|
||||
state = managed.get_state()
|
||||
nodes = state.get("nodes", [])
|
||||
current_node = nodes[-1] if nodes else None
|
||||
|
||||
except Exception as e:
|
||||
return AgentResult(
|
||||
success=False,
|
||||
final_response="",
|
||||
steps=steps,
|
||||
error=f"Generation error: {str(e)}",
|
||||
)
|
||||
|
||||
msg = response.choices[0].message
|
||||
# Some OpenAI-compatible servers populate `message.reasoning` and leave `content=""`.
|
||||
response_text = (msg.content or "") or (getattr(msg, "reasoning", None) or "")
|
||||
tool_calls = ToolCall.parse_from_text(response_text)
|
||||
last_node = current_node
|
||||
last_prompt_messages = prompt_messages
|
||||
last_response_text = response_text
|
||||
|
||||
step_sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
|
||||
if step_sequence_data is None:
|
||||
if response_meta:
|
||||
# We still want metadata for debugging even if token/logprob state tracking is unavailable.
|
||||
step_sequence_data = SequenceData(
|
||||
full_text=response_text,
|
||||
tokens=[],
|
||||
masked_tokens=[],
|
||||
logprobs=[],
|
||||
metadata=response_meta,
|
||||
)
|
||||
else:
|
||||
merged = dict(response_meta)
|
||||
node_meta = step_sequence_data.metadata
|
||||
if isinstance(node_meta, dict):
|
||||
merged.update(node_meta)
|
||||
step_sequence_data.metadata = merged or step_sequence_data.metadata
|
||||
|
||||
step = AgentStep(
|
||||
step_number=step_num + 1,
|
||||
assistant_message=response_text,
|
||||
tool_calls=tool_calls,
|
||||
sequence_data=step_sequence_data,
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
steps.append(step)
|
||||
final_response = response_text
|
||||
final_node = current_node
|
||||
final_prompt_messages = prompt_messages
|
||||
break
|
||||
|
||||
messages.append({"role": "assistant", "content": response_text})
|
||||
|
||||
tool_responses = []
|
||||
for call in tool_calls:
|
||||
result = await self.execute_tool(call)
|
||||
step.tool_results.append(result)
|
||||
tool_responses.append(result.to_xml())
|
||||
if self.config.tool_delay_s > 0:
|
||||
await asyncio.sleep(self.config.tool_delay_s)
|
||||
|
||||
steps.append(step)
|
||||
|
||||
responses_text = "\n".join(tool_responses)
|
||||
# Tool observations are represented as user content with Hermes-style tags.
|
||||
# This is compatible with most OpenAI-compatible chat APIs and ensures
|
||||
# tokenizers/chat templates include tool outputs during training.
|
||||
messages.append({"role": "user", "content": responses_text})
|
||||
|
||||
else:
|
||||
# Reached max steps without completing
|
||||
# Return a failure result but include the last observed completion so callers can
|
||||
# record the trajectory (score=0) without triggering retries.
|
||||
final_response = last_response_text or final_response
|
||||
final_node = last_node
|
||||
final_prompt_messages = last_prompt_messages
|
||||
trajectory_data = None
|
||||
if final_node:
|
||||
trajectory_data = SequenceData.from_sequence_node(final_node)
|
||||
elif final_prompt_messages is not None and self.tokenizer is not None:
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
prompt_text = self.tokenizer.apply_chat_template(
|
||||
final_prompt_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
else:
|
||||
prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in final_prompt_messages])
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True)
|
||||
output_tokens = self.tokenizer.encode(final_response, add_special_tokens=False)
|
||||
tokens = prompt_tokens + output_tokens
|
||||
masked_tokens = ([-100] * len(prompt_tokens)) + output_tokens
|
||||
logprobs = ([1.0] * len(prompt_tokens)) + ([0.0] * len(output_tokens))
|
||||
trajectory_data = SequenceData(
|
||||
full_text=f"{prompt_text}{final_response}",
|
||||
tokens=tokens,
|
||||
masked_tokens=masked_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
# Preserve response metadata (if any) even on failure trajectories.
|
||||
try:
|
||||
if trajectory_data is not None and steps:
|
||||
last_step = steps[-1]
|
||||
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
|
||||
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
|
||||
except Exception:
|
||||
pass
|
||||
return AgentResult(
|
||||
success=False,
|
||||
final_response=final_response,
|
||||
steps=steps,
|
||||
error=f"Reached maximum steps ({self.config.max_steps})",
|
||||
trajectory_data=trajectory_data,
|
||||
)
|
||||
|
||||
# Build result with trajectory data
|
||||
trajectory_data = None
|
||||
if final_node:
|
||||
trajectory_data = SequenceData.from_sequence_node(final_node)
|
||||
elif final_prompt_messages is not None and self.tokenizer is not None:
|
||||
if hasattr(self.tokenizer, "apply_chat_template"):
|
||||
prompt_text = self.tokenizer.apply_chat_template(
|
||||
final_prompt_messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=False)
|
||||
else:
|
||||
prompt_text = "\n".join([f"{m['role']}: {m['content']}" for m in final_prompt_messages])
|
||||
prompt_tokens = self.tokenizer.encode(prompt_text, add_special_tokens=True)
|
||||
output_tokens = self.tokenizer.encode(final_response, add_special_tokens=False)
|
||||
tokens = prompt_tokens + output_tokens
|
||||
masked_tokens = ([-100] * len(prompt_tokens)) + output_tokens
|
||||
logprobs = ([1.0] * len(prompt_tokens)) + ([0.0] * len(output_tokens))
|
||||
trajectory_data = SequenceData(
|
||||
full_text=f"{prompt_text}{final_response}",
|
||||
tokens=tokens,
|
||||
masked_tokens=masked_tokens,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
# Ensure trajectory_data carries the most recent metadata we observed (if any).
|
||||
try:
|
||||
if trajectory_data is not None and steps:
|
||||
last_step = steps[-1]
|
||||
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
|
||||
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return AgentResult(
|
||||
success=True,
|
||||
final_response=final_response,
|
||||
steps=steps,
|
||||
trajectory_data=trajectory_data,
|
||||
)
|
||||
|
||||
async def run_single_turn(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
execute_tools: bool = True,
|
||||
) -> tuple[str, List[ToolResult], Optional[SequenceData]]:
|
||||
"""
|
||||
Run a single turn of the agent (one LLM call + tool execution).
|
||||
|
||||
This is useful for integration with BaseEnv where you want more
|
||||
control over the loop.
|
||||
|
||||
Args:
|
||||
messages: The conversation history
|
||||
execute_tools: Whether to execute parsed tool calls
|
||||
|
||||
Returns:
|
||||
Tuple of (response_text, tool_results, sequence_data)
|
||||
"""
|
||||
async with self._managed() as managed:
|
||||
chat_kwargs: Dict[str, Any] = {"messages": messages, "n": 1}
|
||||
if self.config.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.config.max_tokens
|
||||
if self.config.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.config.temperature
|
||||
|
||||
self._debug_dump_request(step_num=1, chat_kwargs=chat_kwargs)
|
||||
response = await self._chat_completion_with_debug(managed=managed, step_num=1, chat_kwargs=chat_kwargs)
|
||||
self._debug_dump_response(step_num=1, response=response)
|
||||
|
||||
current_node = None
|
||||
if hasattr(managed, "get_state"):
|
||||
state = managed.get_state()
|
||||
nodes = state.get("nodes", [])
|
||||
current_node = nodes[-1] if nodes else None
|
||||
|
||||
msg = response.choices[0].message
|
||||
response_text = (msg.content or "") or (getattr(msg, "reasoning", None) or "")
|
||||
tool_results = []
|
||||
|
||||
if execute_tools:
|
||||
tool_calls = ToolCall.parse_from_text(response_text)
|
||||
for call in tool_calls:
|
||||
result = await self.execute_tool(call)
|
||||
tool_results.append(result)
|
||||
|
||||
sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
|
||||
|
||||
return response_text, tool_results, sequence_data
|
||||
|
||||
|
||||
class _DirectChatCompletionClient:
|
||||
"""
|
||||
Minimal stand-in for ManagedServer that calls the OpenAI-compatible endpoint directly.
|
||||
|
||||
This is for isolating issues where `ManagedServer.chat_completion()` hangs or misbehaves.
|
||||
It intentionally does NOT do token/logprob tracking.
|
||||
"""
|
||||
|
||||
def __init__(self, server: Any):
|
||||
self._server = server
|
||||
|
||||
def _server_config(self) -> tuple[str, str, str]:
|
||||
# ServerManager case: first configured server.
|
||||
servers = getattr(self._server, "servers", None)
|
||||
if isinstance(servers, list) and servers:
|
||||
s0 = servers[0]
|
||||
cfg = getattr(s0, "config", None)
|
||||
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
|
||||
api_key = getattr(cfg, "api_key", None) or getattr(s0, "api_key", None)
|
||||
model = getattr(cfg, "model_name", None) or getattr(s0, "model_name", None)
|
||||
if isinstance(base_url, str) and isinstance(api_key, str) and isinstance(model, str):
|
||||
return base_url.rstrip("/"), api_key, model
|
||||
|
||||
# APIServer-like fallback.
|
||||
base_url = getattr(self._server, "base_url", None)
|
||||
api_key = getattr(self._server, "api_key", None)
|
||||
model = getattr(self._server, "model_name", None) or getattr(self._server, "model", None)
|
||||
if isinstance(base_url, str) and isinstance(api_key, str) and isinstance(model, str):
|
||||
return base_url.rstrip("/"), api_key, model
|
||||
|
||||
raise RuntimeError("Unable to resolve server base_url/api_key/model for direct chat completion")
|
||||
|
||||
async def chat_completion(self, *, messages: List[Dict[str, str]], n: int = 1, **kwargs: Any) -> Any:
|
||||
base_url, api_key, model = self._server_config()
|
||||
url = f"{base_url}/chat/completions"
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"n": n,
|
||||
}
|
||||
# Pass through common generation kwargs.
|
||||
for k in ("max_tokens", "temperature", "top_p", "presence_penalty", "frequency_penalty", "stop"):
|
||||
if k in kwargs and kwargs[k] is not None:
|
||||
payload[k] = kwargs[k]
|
||||
|
||||
timeout_s = float(os.getenv("ATROPOS_DIRECT_REQUEST_TIMEOUT_S") or "120")
|
||||
print(f"[AtroposAgent] DIRECT chat_completion POST {url} (timeout={timeout_s}s)", flush=True)
|
||||
async with httpx.AsyncClient(timeout=timeout_s) as client:
|
||||
resp = await client.post(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Return a very small object compatible with the code paths that read
|
||||
# `response.choices[0].message.content`.
|
||||
class _Msg:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self.content = d.get("content")
|
||||
self.reasoning = d.get("reasoning")
|
||||
|
||||
class _Choice:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self.message = _Msg(d.get("message") or {})
|
||||
|
||||
class _Resp:
|
||||
def __init__(self, d: Dict[str, Any]):
|
||||
self._d = d
|
||||
self.choices = [_Choice(c) for c in (d.get("choices") or [])]
|
||||
|
||||
def model_dump(self) -> Dict[str, Any]:
|
||||
return self._d
|
||||
|
||||
return _Resp(data)
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
FastAPI services for atropos-agent.
|
||||
|
||||
- tool_executor_server: queued/batched sandbox tool execution (Phase 4)
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Tool Executor API (Phase 4)
|
||||
|
||||
This service provides a queued, batched execution layer on top of a ToolBackend.
|
||||
It mirrors the stateful FastAPI + app.state pattern used in:
|
||||
atropos/atroposlib/api/server.py
|
||||
|
||||
Run (dev):
|
||||
uv run uvicorn atropos_agent.api.tool_executor_server:app --host 0.0.0.0 --port 9001
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..backends.nomad_backend import NomadBackendConfig, NomadToolBackend
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.base import (
|
||||
ArtifactArchiveRequestPayload,
|
||||
ArtifactArchiveResponsePayload,
|
||||
ArtifactListRequestPayload,
|
||||
ArtifactListResponsePayload,
|
||||
ArtifactReadRequestPayload,
|
||||
ArtifactReadResponsePayload,
|
||||
ToolExecutorExecuteRequest,
|
||||
ToolExecutorReleaseRequest,
|
||||
ToolResultPayload,
|
||||
)
|
||||
from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig
|
||||
|
||||
|
||||
class ToolExecutorServerConfig(BaseModel):
|
||||
nomad_address: str = Field(default="http://localhost:4646")
|
||||
job_id: str = Field(default="atropos-sandbox-tool-executor")
|
||||
image: str = Field(default="atropos-sandbox:local")
|
||||
slots_per_container: int = Field(default=10)
|
||||
min_containers: int = Field(default=1)
|
||||
max_containers: int = Field(default=10)
|
||||
privileged: bool = Field(default=False)
|
||||
acquire_timeout_s: float = Field(default=30.0)
|
||||
|
||||
batch_window_ms: int = Field(default=20)
|
||||
max_batch_size: int = Field(default=200)
|
||||
allow_network: bool = Field(default=True)
|
||||
|
||||
tool_server_url: Optional[str] = Field(default=None)
|
||||
tool_server_token: Optional[str] = Field(default=None)
|
||||
|
||||
token: Optional[str] = Field(default=None, description="Bearer token required for requests (optional in dev).")
|
||||
|
||||
purge_job_on_shutdown: bool = Field(default=True)
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "ToolExecutorServerConfig":
|
||||
# In dev, prefer loading secrets/config from the repo-local `.env` (not committed).
|
||||
try:
|
||||
from dotenv import load_dotenv # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
load_dotenv = None # type: ignore[assignment]
|
||||
if load_dotenv is not None:
|
||||
env_path = Path(__file__).resolve().parents[2] / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
def _get_bool(name: str, default: bool) -> bool:
|
||||
raw = os.getenv(name)
|
||||
if raw is None:
|
||||
return default
|
||||
return raw.strip().lower() in {"1", "true", "yes", "y", "on"}
|
||||
|
||||
return cls(
|
||||
nomad_address=os.getenv("TOOL_EXECUTOR_NOMAD_ADDRESS", "http://localhost:4646"),
|
||||
job_id=os.getenv("TOOL_EXECUTOR_JOB_ID", "atropos-sandbox-tool-executor"),
|
||||
image=os.getenv("TOOL_EXECUTOR_IMAGE", "atropos-sandbox:local"),
|
||||
slots_per_container=int(os.getenv("TOOL_EXECUTOR_SLOTS", "10")),
|
||||
min_containers=int(os.getenv("TOOL_EXECUTOR_MIN_CONTAINERS", "1")),
|
||||
max_containers=int(os.getenv("TOOL_EXECUTOR_MAX_CONTAINERS", "10")),
|
||||
privileged=_get_bool("TOOL_EXECUTOR_PRIVILEGED", False),
|
||||
acquire_timeout_s=float(os.getenv("TOOL_EXECUTOR_ACQUIRE_TIMEOUT_S", "30.0")),
|
||||
batch_window_ms=int(os.getenv("TOOL_EXECUTOR_BATCH_WINDOW_MS", "20")),
|
||||
max_batch_size=int(os.getenv("TOOL_EXECUTOR_MAX_BATCH_SIZE", "200")),
|
||||
allow_network=_get_bool("TOOL_EXECUTOR_ALLOW_NETWORK", True),
|
||||
tool_server_url=os.getenv("TOOL_EXECUTOR_TOOL_SERVER_URL") or None,
|
||||
tool_server_token=os.getenv("TOOL_EXECUTOR_TOOL_SERVER_TOKEN") or None,
|
||||
token=os.getenv("TOOL_EXECUTOR_TOKEN") or None,
|
||||
purge_job_on_shutdown=_get_bool("TOOL_EXECUTOR_PURGE_JOB_ON_SHUTDOWN", True),
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI(title="Atropos-Agent Tool Executor")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, str]:
|
||||
return {"message": "Atropos-Agent Tool Executor"}
|
||||
|
||||
|
||||
def _check_auth(cfg: ToolExecutorServerConfig, authorization: Optional[str]) -> None:
|
||||
if not cfg.token:
|
||||
return
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header")
|
||||
if not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header")
|
||||
token = authorization.split(" ", 1)[1].strip()
|
||||
if token != cfg.token:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup() -> None:
|
||||
cfg = ToolExecutorServerConfig.from_env()
|
||||
|
||||
# Default to Atropos "full" tool surface: sandbox + external (if tool_server_url provided).
|
||||
tools: ToolRegistry = build_tool_registry(
|
||||
enabled_toolsets=["full"],
|
||||
disabled_toolsets=None,
|
||||
tool_server_url=cfg.tool_server_url,
|
||||
)
|
||||
|
||||
backend = NomadToolBackend(
|
||||
NomadBackendConfig(
|
||||
nomad_address=cfg.nomad_address,
|
||||
sandbox_job_id=cfg.job_id,
|
||||
sandbox_image=cfg.image,
|
||||
slots_per_container=cfg.slots_per_container,
|
||||
min_containers=cfg.min_containers,
|
||||
max_containers=cfg.max_containers,
|
||||
privileged=cfg.privileged,
|
||||
acquire_timeout_s=cfg.acquire_timeout_s,
|
||||
purge_job_on_start=False,
|
||||
)
|
||||
)
|
||||
await backend.start()
|
||||
|
||||
executor = ToolExecutor(
|
||||
backend=backend,
|
||||
tools=tools,
|
||||
config=ToolExecutorConfig(
|
||||
batch_window_ms=cfg.batch_window_ms,
|
||||
max_batch_size=cfg.max_batch_size,
|
||||
allow_network=cfg.allow_network,
|
||||
tool_server_url=cfg.tool_server_url,
|
||||
tool_server_token=cfg.tool_server_token,
|
||||
),
|
||||
)
|
||||
await executor.start()
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.backend = backend
|
||||
app.state.executor = executor
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def _shutdown() -> None:
|
||||
executor: Optional[ToolExecutor] = getattr(app.state, "executor", None)
|
||||
backend: Optional[NomadToolBackend] = getattr(app.state, "backend", None)
|
||||
cfg: Optional[ToolExecutorServerConfig] = getattr(app.state, "cfg", None)
|
||||
|
||||
if executor is not None:
|
||||
await executor.close()
|
||||
|
||||
if backend is not None:
|
||||
await backend.stop(purge=bool(cfg.purge_job_on_shutdown) if cfg else False)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/status")
|
||||
async def status_endpoint() -> Dict[str, Any]:
|
||||
executor: ToolExecutor = app.state.executor
|
||||
backend: NomadToolBackend = app.state.backend
|
||||
|
||||
return {
|
||||
"queue_size": executor.queue_size(),
|
||||
"total_requests": executor.total_requests,
|
||||
"total_errors": executor.total_errors,
|
||||
"pool": backend.get_stats(),
|
||||
}
|
||||
|
||||
|
||||
@app.post("/execute", response_model=ToolResultPayload)
|
||||
async def execute_tool(
|
||||
req: ToolExecutorExecuteRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
status_code: int = status.HTTP_200_OK, # noqa: B008
|
||||
) -> ToolResultPayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
result = await executor.execute(
|
||||
trajectory_id=req.trajectory_id,
|
||||
call=req.tool.to_tool_call(),
|
||||
timeout_s=req.timeout_s,
|
||||
)
|
||||
return ToolResultPayload.from_tool_result(result)
|
||||
|
||||
|
||||
@app.post("/release")
|
||||
async def release_trajectory(
|
||||
req: ToolExecutorReleaseRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> Dict[str, Any]:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
await executor.release_trajectory(req.trajectory_id, reset_workspace=req.reset_workspace)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/artifacts/read", response_model=ArtifactReadResponsePayload)
|
||||
async def artifacts_read(
|
||||
req: ArtifactReadRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactReadResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.read_artifact(req)
|
||||
|
||||
|
||||
@app.post("/artifacts/list", response_model=ArtifactListResponsePayload)
|
||||
async def artifacts_list(
|
||||
req: ArtifactListRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactListResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.list_artifacts(req)
|
||||
|
||||
|
||||
@app.post("/artifacts/archive", response_model=ArtifactArchiveResponsePayload)
|
||||
async def artifacts_archive(
|
||||
req: ArtifactArchiveRequestPayload,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ArtifactArchiveResponsePayload:
|
||||
cfg: ToolExecutorServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
executor: ToolExecutor = app.state.executor
|
||||
return await executor.archive_artifacts(req)
|
||||
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
External ToolServer (Phase 4.5+).
|
||||
|
||||
This server executes tools that must NOT run inside the sandbox, typically
|
||||
because they require credentials or access to external services.
|
||||
|
||||
Run (dev):
|
||||
uv run uvicorn atropos_agent.api.tool_server:app --host 0.0.0.0 --port 9002
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.base import ToolResultPayload, ToolServerExecuteRequest
|
||||
|
||||
|
||||
class ToolServerConfig(BaseModel):
|
||||
token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Bearer token required for requests (optional in dev).",
|
||||
)
|
||||
max_concurrency: int = Field(default=16, ge=1, description="Max concurrent tool executions.")
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "ToolServerConfig":
|
||||
# In dev, prefer loading secrets from the repo-local `.env` (not committed).
|
||||
try:
|
||||
from dotenv import load_dotenv # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
load_dotenv = None # type: ignore[assignment]
|
||||
if load_dotenv is not None:
|
||||
env_path = Path(__file__).resolve().parents[2] / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
token = os.getenv("TOOL_SERVER_TOKEN") or None
|
||||
max_concurrency = int(os.getenv("TOOL_SERVER_MAX_CONCURRENCY", "16"))
|
||||
return cls(token=token, max_concurrency=max_concurrency)
|
||||
|
||||
|
||||
app = FastAPI(title="Atropos-Agent Tool Server")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root() -> Dict[str, str]:
|
||||
return {"message": "Atropos-Agent Tool Server"}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def _startup() -> None:
|
||||
cfg = ToolServerConfig.from_env()
|
||||
|
||||
# External-only registry. It will only include tools that are enabled by toolsets and
|
||||
# whose Hermes requirements/keys are satisfied in this process.
|
||||
tools: ToolRegistry = build_tool_registry(
|
||||
enabled_toolsets=["all"],
|
||||
disabled_toolsets=["terminal", "sandbox", "filesystem", "terminal_stateful", "default"],
|
||||
tool_server_url="enabled",
|
||||
)
|
||||
|
||||
app.state.cfg = cfg
|
||||
app.state.tools = tools
|
||||
app.state.semaphore = asyncio.Semaphore(cfg.max_concurrency)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/tools")
|
||||
async def list_tools() -> Dict[str, Any]:
|
||||
tools: ToolRegistry = app.state.tools
|
||||
return {"tools": [s.to_dict() for s in tools.get_schemas()]}
|
||||
|
||||
|
||||
def _check_auth(cfg: ToolServerConfig, authorization: Optional[str]) -> None:
|
||||
if not cfg.token:
|
||||
return
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing Authorization header")
|
||||
if not authorization.lower().startswith("bearer "):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Authorization header")
|
||||
token = authorization.split(" ", 1)[1].strip()
|
||||
if token != cfg.token:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid token")
|
||||
|
||||
|
||||
@app.post("/execute", response_model=ToolResultPayload)
|
||||
async def execute_tool(
|
||||
req: ToolServerExecuteRequest,
|
||||
authorization: Optional[str] = Header(default=None),
|
||||
) -> ToolResultPayload:
|
||||
cfg: ToolServerConfig = app.state.cfg
|
||||
_check_auth(cfg, authorization)
|
||||
|
||||
tools: ToolRegistry = app.state.tools
|
||||
sem: asyncio.Semaphore = app.state.semaphore
|
||||
|
||||
tool = tools.get(req.tool.name)
|
||||
if tool is None:
|
||||
return ToolResultPayload(
|
||||
success=False,
|
||||
error=f"Unknown tool: {req.tool.name}",
|
||||
uniq_id=req.tool.uniq_id,
|
||||
)
|
||||
|
||||
async with sem:
|
||||
try:
|
||||
kwargs = dict(req.tool.arguments)
|
||||
sig = inspect.signature(tool.execute).parameters
|
||||
# Some tools can benefit from extra context.
|
||||
if req.trajectory_id and "trajectory_id" in sig:
|
||||
kwargs["trajectory_id"] = req.trajectory_id
|
||||
if req.slot_id and "slot_id" in sig:
|
||||
kwargs["slot_id"] = req.slot_id
|
||||
if req.container_addr and "container_addr" in sig:
|
||||
kwargs["container_addr"] = req.container_addr
|
||||
if "task_id" in sig:
|
||||
kwargs["task_id"] = req.trajectory_id
|
||||
result = await tool.execute(**kwargs)
|
||||
except Exception as e:
|
||||
return ToolResultPayload(
|
||||
success=False,
|
||||
error=f"Tool execution error: {e}",
|
||||
uniq_id=req.tool.uniq_id,
|
||||
)
|
||||
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = req.tool.uniq_id
|
||||
return ToolResultPayload.from_tool_result(result)
|
||||
@@ -0,0 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .base import ToolBackend
|
||||
from .modal_backend import ModalSandboxConfig, ModalToolBackend
|
||||
from .nomad_backend import NomadBackendConfig, NomadToolBackend
|
||||
|
||||
|
||||
def create_tool_backend(cfg: Any) -> ToolBackend:
|
||||
mode = str(getattr(cfg, "tool_pool_mode", "nomad")).strip().lower()
|
||||
if mode == "nomad":
|
||||
return NomadToolBackend(NomadBackendConfig.from_agent_env_config(cfg))
|
||||
if mode == "modal":
|
||||
return ModalToolBackend(ModalSandboxConfig.from_agent_env_config(cfg))
|
||||
raise ValueError(f"Unknown tool_pool_mode: {mode}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolBackend",
|
||||
"create_tool_backend",
|
||||
"NomadBackendConfig",
|
||||
"NomadToolBackend",
|
||||
"ModalSandboxConfig",
|
||||
"ModalToolBackend",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Backend interfaces for AgentEnv tool execution.
|
||||
|
||||
The goal of this module is to decouple ToolExecutor / AgentEnv from any single
|
||||
execution backend (Nomad/Docker today; Modal later).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple
|
||||
|
||||
from ..slots.executor import ExecutionResult
|
||||
from ..slots.slot import Slot
|
||||
|
||||
|
||||
class ToolBackend(Protocol):
|
||||
"""
|
||||
Minimal interface required by ToolExecutor.
|
||||
|
||||
Backends provide:
|
||||
- lifecycle (start/stop)
|
||||
- slot acquisition/release (workspace affinity)
|
||||
- batched tool execution across slots
|
||||
- optional artifact helpers (for env verification / demos)
|
||||
"""
|
||||
|
||||
@property
|
||||
def default_timeout_s(self) -> Optional[float]:
|
||||
"""Default sandbox execution timeout in seconds (if any)."""
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the backend (provision workers/containers, health checks, etc)."""
|
||||
|
||||
async def stop(self, *, purge: bool = False) -> None:
|
||||
"""Stop the backend and optionally purge remote resources."""
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
"""Acquire a slot for a trajectory (workspace affinity)."""
|
||||
|
||||
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None:
|
||||
"""Release a slot back to the pool."""
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
*,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""Execute a batch of sandbox tool calls and return results in order."""
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Optional artifact helpers (supported by the Nomad sandbox-server today)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Nomad/Docker tool backend.
|
||||
|
||||
This backend is the current default for AgentEnv: it provisions a Nomad job
|
||||
running `sandbox_server.py` and multiplexes stateless slots inside each container.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..slots import Slot, SlotPool, SlotPoolConfig
|
||||
from ..slots.executor import ExecutionResult
|
||||
from .base import ToolBackend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NomadBackendConfig:
|
||||
nomad_address: str
|
||||
sandbox_job_id: str
|
||||
sandbox_image: str
|
||||
slots_per_container: int
|
||||
min_containers: int
|
||||
max_containers: int
|
||||
privileged: bool
|
||||
acquire_timeout_s: float
|
||||
purge_job_on_start: bool
|
||||
# Driver selection: "docker" or "singularity"
|
||||
driver: str = "docker"
|
||||
# Path to .sif file for singularity driver (required if driver="singularity")
|
||||
singularity_image: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_agent_env_config(cls, cfg: Any) -> "NomadBackendConfig":
|
||||
return cls(
|
||||
nomad_address=str(getattr(cfg, "nomad_address")),
|
||||
sandbox_job_id=str(getattr(cfg, "sandbox_job_id")),
|
||||
sandbox_image=str(getattr(cfg, "sandbox_image")),
|
||||
slots_per_container=int(getattr(cfg, "slots_per_container")),
|
||||
min_containers=int(getattr(cfg, "min_containers")),
|
||||
max_containers=int(getattr(cfg, "max_containers")),
|
||||
privileged=bool(getattr(cfg, "privileged")),
|
||||
acquire_timeout_s=float(getattr(cfg, "acquire_timeout_s")),
|
||||
purge_job_on_start=bool(getattr(cfg, "purge_job_on_start", False)),
|
||||
driver=str(getattr(cfg, "driver", "docker")),
|
||||
singularity_image=getattr(cfg, "singularity_image", None),
|
||||
)
|
||||
|
||||
|
||||
class NomadToolBackend(ToolBackend):
|
||||
def __init__(self, config: NomadBackendConfig):
|
||||
self.config = config
|
||||
self.pool = SlotPool(
|
||||
SlotPoolConfig(
|
||||
nomad_address=config.nomad_address,
|
||||
job_id=config.sandbox_job_id,
|
||||
image=config.sandbox_image,
|
||||
slots_per_container=config.slots_per_container,
|
||||
min_containers=config.min_containers,
|
||||
max_containers=config.max_containers,
|
||||
privileged=config.privileged,
|
||||
acquire_timeout=config.acquire_timeout_s,
|
||||
purge_job_on_start=bool(config.purge_job_on_start),
|
||||
driver=config.driver,
|
||||
singularity_image=config.singularity_image,
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def default_timeout_s(self) -> Optional[float]:
|
||||
t = getattr(self.pool.executor, "timeout", None)
|
||||
total = getattr(t, "total", None)
|
||||
try:
|
||||
return float(total) if total is not None else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def start(self) -> None:
|
||||
await self.pool.start()
|
||||
|
||||
async def stop(self, *, purge: bool = False) -> None:
|
||||
await self.pool.stop(purge_job=purge)
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
return await self.pool.acquire(trajectory_id)
|
||||
|
||||
async def release(self, slot: Slot, *, reset_workspace: bool = False) -> None:
|
||||
await self.pool.release(slot, reset_workspace=reset_workspace)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
*,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
return await self.pool.execute_batch(requests, timeout=timeout_s)
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.read_artifact(
|
||||
slot,
|
||||
path,
|
||||
encoding=encoding,
|
||||
max_bytes=max_bytes,
|
||||
include_sha256=include_sha256,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.list_artifacts(
|
||||
slot,
|
||||
path,
|
||||
recursive=recursive,
|
||||
max_entries=max_entries,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return await self.pool.executor.archive_artifacts(
|
||||
slot,
|
||||
path,
|
||||
archive_format=archive_format,
|
||||
max_bytes=max_bytes,
|
||||
max_entries=max_entries,
|
||||
timeout=timeout_s,
|
||||
)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self.pool.get_stats()
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
Environment implementations for atropos-agent.
|
||||
"""
|
||||
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
# NOTE: Additional example envs exist as modules (e.g. `test_env`, `swe_smith_oracle_env`),
|
||||
# but are intentionally not imported here to avoid pulling heavy optional deps at import time.
|
||||
|
||||
__all__ = ["AgentEnv", "AgentEnvConfig"]
|
||||
@@ -0,0 +1,537 @@
|
||||
"""
|
||||
AgentEnv - Atropos BaseEnv extension for agent/tool-call workloads.
|
||||
|
||||
AgentEnv is responsible for starting the sandbox tool execution backend and
|
||||
providing helpers for running agent trajectories with queued/batched tool calls.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Tuple, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, Item, ScoredDataGroup, ScoredDataItem
|
||||
from atroposlib.envs.server_handling.server_baseline import AsyncSemWithAdaptiveWeight
|
||||
|
||||
from ..agent import AgentConfig, AgentResult, AtroposAgent
|
||||
from ..backends import ToolBackend, create_tool_backend
|
||||
from ..tools import ToolRegistry, build_tool_registry
|
||||
from ..tools.tool_executor import ToolExecutor, ToolExecutorConfig
|
||||
|
||||
# Main BaseEnv child classes. Child class THESE to get agent+tooling functionality easily.
|
||||
|
||||
class AgentEnvConfig(BaseEnvConfig):
|
||||
tool_pool_mode: str = Field(default="nomad", description="Tool execution backend ('nomad' or 'modal')")
|
||||
|
||||
allow_network: bool = Field(
|
||||
default=True,
|
||||
description="Whether sandbox bash commands may access the network (env policy).",
|
||||
)
|
||||
require_sandbox: bool = Field(
|
||||
default=False,
|
||||
description="Fail closed if bubblewrap sandboxing is unavailable/unusable for stateless sandbox tools.",
|
||||
)
|
||||
require_stateful_sandbox: bool = Field(
|
||||
default=False,
|
||||
description="Fail closed if bubblewrap/PID isolation is unavailable for stateful terminal tools (tmux).",
|
||||
)
|
||||
tool_batch_window_ms: int = Field(default=20, description="ToolExecutor batching window (ms)")
|
||||
tool_max_batch_size: int = Field(default=200, description="ToolExecutor maximum batch size")
|
||||
|
||||
# nomad mode settings. TODO: Add Modal support, split this into own config
|
||||
nomad_address: str = Field(default="http://localhost:4646", description="Nomad API address")
|
||||
sandbox_job_id: str = Field(default="atropos-sandbox-agent-env", description="Nomad job id for sandbox containers")
|
||||
sandbox_image: str = Field(default="atropos-sandbox:local", description="Docker image for sandbox containers")
|
||||
slots_per_container: int = Field(default=10, description="Nomad mode: slots per container")
|
||||
min_containers: int = Field(default=1, description="Nomad mode: minimum containers")
|
||||
max_containers: int = Field(default=10, description="Nomad mode: maximum containers")
|
||||
privileged: bool = Field(default=False, description="Nomad mode: run container privileged")
|
||||
acquire_timeout_s: float = Field(default=30.0, description="Slot acquisition timeout (seconds)")
|
||||
purge_job_on_start: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Nomad mode: stop/purge the sandbox job on startup. This is helpful in local dev and training runs "
|
||||
"to recover from previous crashes that leave the job in a restart backoff state."
|
||||
),
|
||||
)
|
||||
purge_job_on_shutdown: bool = Field(default=True, description="Nomad mode: stop/purge job on shutdown")
|
||||
|
||||
# Nomad driver selection (docker or singularity)
|
||||
driver: str = Field(
|
||||
default="docker",
|
||||
description="Nomad task driver: 'docker' (default) or 'singularity' (for HPC without sudo Docker)",
|
||||
)
|
||||
singularity_image: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to .sif file for Singularity driver (required if driver='singularity')",
|
||||
)
|
||||
|
||||
# Modal mode settings
|
||||
modal_app_name: str = Field(default="atropos-sandbox", description="Modal app name prefix")
|
||||
modal_image: str = Field(default="python:3.11", description="Modal: container image")
|
||||
modal_gpu: Optional[str] = Field(default=None, description="Modal: GPU type (None, 'T4', 'A10G', 'A100', 'H100')")
|
||||
modal_cpu: float = Field(default=1.0, description="Modal: CPU cores")
|
||||
modal_memory: int = Field(default=2048, description="Modal: memory in MB")
|
||||
modal_slots_per_sandbox: int = Field(default=10, description="Modal: slots per sandbox")
|
||||
modal_min_sandboxes: int = Field(default=1, description="Modal: minimum sandboxes")
|
||||
modal_max_sandboxes: int = Field(default=5, description="Modal: maximum sandboxes")
|
||||
modal_idle_timeout: int = Field(default=120, description="Modal: server-side idle timeout (seconds)")
|
||||
modal_max_lifetime: int = Field(default=3600, description="Modal: max sandbox lifetime (seconds)")
|
||||
modal_acquire_timeout: float = Field(default=60.0, description="Modal: slot acquisition timeout (seconds)")
|
||||
modal_execution_timeout: float = Field(default=30.0, description="Modal: default command execution timeout (seconds)")
|
||||
modal_secrets: str = Field(default="", description="Modal: comma-separated list of Modal Secret names")
|
||||
modal_env_vars: str = Field(default="", description="Modal: semicolon-separated KEY=VALUE pairs for env vars")
|
||||
modal_workspace_base: str = Field(default="/data", description="Modal: workspace base directory in sandbox")
|
||||
|
||||
# basic agent defaults
|
||||
agent_max_steps: int = Field(default=50, description="Max ReACT steps per trajectory")
|
||||
agent_temperature: float = Field(default=0.7, description="Sampling temperature")
|
||||
agent_max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Max tokens per model response (default: let backend decide)",
|
||||
)
|
||||
agent_tool_delay_s: float = Field(default=0.0, description="Delay between tool calls (seconds)")
|
||||
|
||||
# tool selection
|
||||
enabled_toolsets: List[str] = Field(
|
||||
default_factory=lambda: ["default"],
|
||||
description="Toolsets to enable (Hermes-style grouping).",
|
||||
)
|
||||
disabled_toolsets: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Toolsets to disable (applied after enabled_toolsets).",
|
||||
)
|
||||
|
||||
# external ToolServer routing (Phase 4.5+)
|
||||
tool_server_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Base URL for external ToolServer (enables external tools).",
|
||||
)
|
||||
tool_server_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Bearer token for ToolServer auth (optional in dev).",
|
||||
)
|
||||
|
||||
AgentEnvConfigT = TypeVar("AgentEnvConfigT", bound="AgentEnvConfig")
|
||||
|
||||
|
||||
class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||
env_config_cls = AgentEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AgentEnvConfigT,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.config: AgentEnvConfigT = config
|
||||
|
||||
self.tools: ToolRegistry = self.build_tools()
|
||||
|
||||
self._backend: Optional[ToolBackend] = None
|
||||
self._tool_executor: Optional[ToolExecutor] = None
|
||||
self._tool_server_inprocess: bool = False
|
||||
self._trajectory_workspace_meta: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def build_tools(self) -> ToolRegistry:
|
||||
"""Wraps original Hermes-Agent ToolRegistry for atropos AgentEnv use.
|
||||
See Hermes-Agent docs for toolsets and available tools etc.
|
||||
"""
|
||||
return build_tool_registry(
|
||||
enabled_toolsets=self.config.enabled_toolsets or ["default"],
|
||||
disabled_toolsets=self.config.disabled_toolsets or None,
|
||||
tool_server_url=self.config.tool_server_url,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def build_task(self, item: Item) -> str:
|
||||
"""Return the user-facing task string for the agent."""
|
||||
|
||||
@abstractmethod
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
"""Return a scalar score for this trajectory."""
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self,
|
||||
item: Item,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Optional hook: prepare the sandbox workspace before the agent starts.
|
||||
|
||||
Examples:
|
||||
- clone a repo and checkout a commit
|
||||
- write fixture files (e.g. images) for external-tool demos
|
||||
- pre-install dependencies
|
||||
|
||||
Default: no-op.
|
||||
"""
|
||||
_ = (item, trajectory_id, exec_tool)
|
||||
return {}
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]],
|
||||
agent_result: Optional[AgentResult] = None,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
Optional hook: run in-sandbox verification before scoring.
|
||||
|
||||
Many agent envs need to execute verification inside the same trajectory
|
||||
workspace (e.g. pytest) before releasing/resetting the slot.
|
||||
|
||||
Default: calls `score_trajectory()` and returns empty metadata.
|
||||
"""
|
||||
_ = (trajectory_id, exec_tool, agent_result, workspace_meta) # default ignores in-workspace verification
|
||||
score = await self.score_trajectory(item, final_response)
|
||||
return score, {}
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
tool_delay_s=self.config.agent_tool_delay_s,
|
||||
)
|
||||
|
||||
async def setup(self) -> None:
|
||||
print(f"[AgentEnv] setup(): starting tool backend ({self.config.tool_pool_mode})", flush=True)
|
||||
await self._start_tool_backend()
|
||||
print("[AgentEnv] setup(): configuring server concurrency", flush=True)
|
||||
self._configure_server_concurrency()
|
||||
print("[AgentEnv] setup(): running env-specific setup_agent_env()", flush=True)
|
||||
await self.setup_agent_env()
|
||||
print("[AgentEnv] setup(): done", flush=True)
|
||||
|
||||
def _configure_server_concurrency(self) -> None:
|
||||
"""
|
||||
Ensure the LLM server concurrency isn't accidentally capped below `group_size`.
|
||||
|
||||
In `BaseEnv process` mode, groups are collected concurrently and if the underlying
|
||||
ServerManager/OpenAIServer semaphore is left at 1, we serialize inference even
|
||||
when `--env.group_size` is > 1.
|
||||
"""
|
||||
desired = int(getattr(self.config, "group_size", 1) or 1)
|
||||
if desired <= 1:
|
||||
return
|
||||
|
||||
servers = getattr(self.server, "servers", None)
|
||||
if not isinstance(servers, list) or not servers:
|
||||
return
|
||||
|
||||
for s in servers:
|
||||
sem = getattr(s, "sem", None)
|
||||
eval_sem = getattr(s, "eval_sem", None)
|
||||
# Only increase; never shrink.
|
||||
if sem is not None and getattr(sem, "max_val", 0) < desired:
|
||||
s.sem = AsyncSemWithAdaptiveWeight(desired)
|
||||
if hasattr(s, "config") and hasattr(s.config, "num_max_requests_at_once"):
|
||||
s.config.num_max_requests_at_once = desired
|
||||
if eval_sem is not None and getattr(eval_sem, "max_val", 0) < desired:
|
||||
s.eval_sem = AsyncSemWithAdaptiveWeight(desired)
|
||||
if hasattr(s, "config") and hasattr(s.config, "num_requests_for_eval"):
|
||||
s.config.num_requests_for_eval = desired
|
||||
|
||||
@abstractmethod
|
||||
async def setup_agent_env(self) -> None:
|
||||
"""Subclass hook for env-specific setup."""
|
||||
|
||||
async def evaluate(self, *args, **kwargs): # noqa: ARG002
|
||||
"""
|
||||
Default eval hook (no-op).
|
||||
|
||||
Atropos BaseEnv requires an `evaluate()` implementation. Many agent envs
|
||||
won't have a meaningful evaluation path during early PoC work; they can
|
||||
override this when needed.
|
||||
"""
|
||||
return {}
|
||||
|
||||
async def env_manager(self):
|
||||
try:
|
||||
return await super().env_manager()
|
||||
finally:
|
||||
await self.shutdown_tool_backend()
|
||||
|
||||
async def process_manager(self):
|
||||
try:
|
||||
return await super().process_manager()
|
||||
finally:
|
||||
await self.shutdown_tool_backend()
|
||||
|
||||
async def _start_tool_backend(self) -> None:
|
||||
if self._tool_executor is not None:
|
||||
return
|
||||
|
||||
tool_server_url = self.config.tool_server_url
|
||||
tool_server_client = None
|
||||
if tool_server_url == "inprocess":
|
||||
import httpx
|
||||
from ..api.tool_server import app as tool_server_app
|
||||
|
||||
await tool_server_app.router.startup()
|
||||
tool_server_client = httpx.AsyncClient(
|
||||
transport=httpx.ASGITransport(app=tool_server_app),
|
||||
base_url="http://toolserver",
|
||||
)
|
||||
tool_server_url = "http://toolserver"
|
||||
self._tool_server_inprocess = True
|
||||
|
||||
backend = create_tool_backend(self.config)
|
||||
await backend.start()
|
||||
|
||||
executor = ToolExecutor(
|
||||
backend=backend,
|
||||
tools=self.tools,
|
||||
config=ToolExecutorConfig(
|
||||
batch_window_ms=self.config.tool_batch_window_ms,
|
||||
max_batch_size=self.config.tool_max_batch_size,
|
||||
allow_network=self.config.allow_network,
|
||||
require_sandbox=self.config.require_sandbox,
|
||||
require_stateful_sandbox=self.config.require_stateful_sandbox,
|
||||
tool_server_url=tool_server_url,
|
||||
tool_server_token=self.config.tool_server_token,
|
||||
),
|
||||
)
|
||||
await executor.start()
|
||||
if tool_server_client is not None:
|
||||
executor._tool_server_client = tool_server_client # type: ignore[attr-defined]
|
||||
|
||||
self._backend = backend
|
||||
self._tool_executor = executor
|
||||
|
||||
async def shutdown_tool_backend(self) -> None:
|
||||
executor = self._tool_executor
|
||||
backend = self._backend
|
||||
inprocess_tool_server = self._tool_server_inprocess
|
||||
self._tool_executor = None
|
||||
self._backend = None
|
||||
self._tool_server_inprocess = False
|
||||
|
||||
if executor is not None:
|
||||
await executor.close()
|
||||
if backend is not None:
|
||||
await backend.stop(purge=bool(self.config.purge_job_on_shutdown))
|
||||
if inprocess_tool_server:
|
||||
from ..api.tool_server import app as tool_server_app
|
||||
|
||||
await tool_server_app.router.shutdown()
|
||||
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
|
||||
if self._tool_executor is None:
|
||||
raise RuntimeError("Tool backend not started")
|
||||
|
||||
trajectory_id = str(uuid.uuid4())
|
||||
t0 = time.perf_counter()
|
||||
print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} start", flush=True)
|
||||
task = self.build_task(item)
|
||||
agent_config = self.build_agent_config(item)
|
||||
if os.getenv("ATROPOS_DEBUG_PRINT_TASK") == "1":
|
||||
print(f"Starting trajectory {trajectory_id} with task: {task}", flush=True)
|
||||
else:
|
||||
# Avoid printing the full task prompt by default (can be huge/noisy).
|
||||
one_line = " ".join(str(task).splitlines()).strip()
|
||||
preview = one_line[:240] + ("…" if len(one_line) > 240 else "")
|
||||
print(f"Starting trajectory {trajectory_id} (task preview): {preview}", flush=True)
|
||||
|
||||
async def _exec(call):
|
||||
return await self._tool_executor.execute(trajectory_id, call)
|
||||
|
||||
agent = AtroposAgent(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
tools=self.tools,
|
||||
config=agent_config,
|
||||
execute_tool=_exec,
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"[AgentEnv] tid={trajectory_id} setup_trajectory_workspace() start", flush=True)
|
||||
workspace_meta = await self.setup_trajectory_workspace(item, trajectory_id=trajectory_id, exec_tool=_exec)
|
||||
if not isinstance(workspace_meta, dict):
|
||||
workspace_meta = {}
|
||||
self._trajectory_workspace_meta[trajectory_id] = workspace_meta
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} setup_trajectory_workspace() done in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
print(f"[AgentEnv] tid={trajectory_id} agent.run() start", flush=True)
|
||||
result = await agent.run(task)
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} agent.run() done in {time.perf_counter() - t0:.2f}s "
|
||||
f"success={result.success} tool_calls={result.total_tool_calls}",
|
||||
flush=True,
|
||||
)
|
||||
if not result.success or result.trajectory_data is None:
|
||||
# Do not trigger BaseEnv retries for agent failures.
|
||||
# Record the trajectory with score 0.0 so training/eval can see the failure mode.
|
||||
messages = [{"role": "system", "content": agent._build_system_prompt()}] # noqa: SLF001
|
||||
messages.append({"role": "user", "content": task})
|
||||
for step in result.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
tool_text = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": tool_text})
|
||||
|
||||
scored: ScoredDataItem = {
|
||||
"tokens": (result.trajectory_data.tokens if result.trajectory_data else []),
|
||||
"masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []),
|
||||
"scores": 0.0,
|
||||
}
|
||||
if result.trajectory_data is not None:
|
||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||
if getattr(result.trajectory_data, "metadata", None):
|
||||
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
|
||||
if self.config.include_messages:
|
||||
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
|
||||
import json
|
||||
|
||||
err = result.error or "agent_failed"
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>{json.dumps({'success': False, 'error': err})}</tool_response>",
|
||||
}
|
||||
)
|
||||
scored["messages"] = messages
|
||||
return scored, []
|
||||
|
||||
print(f"[AgentEnv] tid={trajectory_id} verify_and_score_trajectory() start", flush=True)
|
||||
score, score_metadata = await self.verify_and_score_trajectory(
|
||||
item,
|
||||
result.final_response,
|
||||
trajectory_id=trajectory_id,
|
||||
exec_tool=_exec,
|
||||
agent_result=result,
|
||||
workspace_meta=workspace_meta,
|
||||
)
|
||||
print(
|
||||
f"[AgentEnv] tid={trajectory_id} verify_and_score_trajectory() done in {time.perf_counter() - t0:.2f}s "
|
||||
f"score={score}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": agent._build_system_prompt()}] # noqa: SLF001
|
||||
messages.append({"role": "user", "content": task})
|
||||
for step in result.steps:
|
||||
messages.append({"role": "assistant", "content": step.assistant_message})
|
||||
if step.tool_results:
|
||||
tool_text = "\n".join(r.to_xml() for r in step.tool_results)
|
||||
messages.append({"role": "user", "content": tool_text})
|
||||
|
||||
# Optional: allow env verification to attach additional messages (e.g. install logs).
|
||||
if self.config.include_messages and isinstance(score_metadata, dict):
|
||||
extra = score_metadata.get("verification_messages")
|
||||
if isinstance(extra, list):
|
||||
for m in extra:
|
||||
if isinstance(m, dict) and isinstance(m.get("role"), str) and isinstance(m.get("content"), str):
|
||||
messages.append({"role": m["role"], "content": m["content"]})
|
||||
|
||||
scored: ScoredDataItem = {
|
||||
"tokens": result.trajectory_data.tokens,
|
||||
"masks": result.trajectory_data.masked_tokens,
|
||||
"scores": score,
|
||||
}
|
||||
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
|
||||
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
|
||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||
if getattr(result.trajectory_data, "metadata", None):
|
||||
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
|
||||
if self.config.include_messages:
|
||||
scored["messages"] = messages
|
||||
|
||||
return scored, []
|
||||
finally:
|
||||
self._trajectory_workspace_meta.pop(trajectory_id, None)
|
||||
print(f"[AgentEnv] tid={trajectory_id} release_trajectory(reset_workspace=True)", flush=True)
|
||||
await self._tool_executor.release_trajectory(trajectory_id, reset_workspace=True)
|
||||
print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} done in {time.perf_counter() - t0:.2f}s", flush=True)
|
||||
|
||||
async def collect_trajectories(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[ScoredDataGroup], List[Item]]:
|
||||
tasks = [self.collect_trajectory(item) for _ in range(self.config.group_size)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
backlog: List[Item] = []
|
||||
items: List[ScoredDataItem] = []
|
||||
for scored, b in results:
|
||||
backlog.extend(b)
|
||||
if scored is not None:
|
||||
items.append(scored)
|
||||
|
||||
if len(items) != self.config.group_size:
|
||||
return None, backlog
|
||||
|
||||
group: ScoredDataGroup = ScoredDataGroup(
|
||||
tokens=[],
|
||||
masks=[],
|
||||
scores=[],
|
||||
advantages=[],
|
||||
ref_logprobs=[],
|
||||
messages=[] if self.config.include_messages else None,
|
||||
inference_logprobs=[],
|
||||
group_overrides={},
|
||||
overrides=[],
|
||||
images=[],
|
||||
generation_params=None,
|
||||
)
|
||||
|
||||
for it in items:
|
||||
group["tokens"].append(it["tokens"])
|
||||
group["masks"].append(it["masks"])
|
||||
group["scores"].append(it["scores"])
|
||||
# policy logprobs (for PPO/GRPO training) if present
|
||||
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
|
||||
if lp is not None:
|
||||
group["inference_logprobs"].append(lp)
|
||||
group["overrides"].append(it.get("overrides") or {}) # type: ignore[typeddict-item]
|
||||
if group.get("messages") is not None and it.get("messages") is not None:
|
||||
group["messages"].append(it["messages"])
|
||||
|
||||
return group, backlog
|
||||
|
||||
async def run_agent(self, task: str, *, trajectory_id: Optional[str] = None) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
Run the AtroposAgent on a single task and return (final_response, debug).
|
||||
|
||||
This is a helper intended for simple environments and tests.
|
||||
"""
|
||||
if self._tool_executor is None:
|
||||
raise RuntimeError("Tool backend not started")
|
||||
|
||||
tid = trajectory_id or str(uuid.uuid4())
|
||||
|
||||
async def _exec(call):
|
||||
return await self._tool_executor.execute(tid, call)
|
||||
|
||||
agent = AtroposAgent(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
tools=self.tools,
|
||||
config=AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
),
|
||||
execute_tool=_exec,
|
||||
)
|
||||
result = await agent.run(task)
|
||||
await self._tool_executor.release_trajectory(tid, reset_workspace=True)
|
||||
return result.final_response, {"success": result.success, "error": result.error, "tool_calls": result.total_tool_calls}
|
||||
@@ -0,0 +1,873 @@
|
||||
"""
|
||||
Endless Terminals Environment for Hermes-Agent + Atropos RL.
|
||||
|
||||
Runs terminal tasks from the Endless Terminals dataset.
|
||||
Supports three modes:
|
||||
1. Local directory: tasks from a local folder of task_* dirs (default)
|
||||
2. HuggingFace dataset: tasks from a HF dataset
|
||||
3. Procedural: generate tasks on-the-fly via LLM (requires vLLM)
|
||||
|
||||
Each task provides a Dockerfile that defines the initial environment.
|
||||
The agent solves the task using terminal commands inside a Docker container.
|
||||
Scoring is done by running pytest on `test_final_state.py` in the container.
|
||||
|
||||
Run (standalone process mode):
|
||||
python -m atropos.envs.endless_terminals_env process \
|
||||
--env.use_wandb false \
|
||||
--env.total_steps 100 \
|
||||
--env.group_size 4
|
||||
|
||||
Run (Tinker serve mode):
|
||||
# Terminal 1: run-api
|
||||
# Terminal 2: python launch_training.py --config configs/endless_terminals.yaml
|
||||
# Terminal 3:
|
||||
TINKER_CONFIG=configs/endless_terminals.yaml \
|
||||
ENDLESS_TERMINALS_DIR=/path/to/endless-terminals \
|
||||
python -m atropos.envs.endless_terminals_env serve
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig
|
||||
from ..backends.docker_direct_backend import (
|
||||
DockerDirectBackend,
|
||||
build_docker_image,
|
||||
docker_image_exists,
|
||||
)
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tinker integration
|
||||
# ---------------------------------------------------------------------------
|
||||
# When TINKER_CONFIG is set, we load model/training params from the Tinker YAML.
|
||||
# Custom env fields (ENDLESS_TERMINALS_DIR, etc.) are always read from env vars.
|
||||
TINKER_CONFIG = os.getenv("TINKER_CONFIG", "")
|
||||
|
||||
|
||||
def _load_tinker_config():
|
||||
"""Load TinkerAtroposConfig if available, else return None."""
|
||||
if not TINKER_CONFIG:
|
||||
return None
|
||||
config_path = Path(TINKER_CONFIG)
|
||||
if not config_path.exists():
|
||||
print(f"[EndlessTerminalsEnv] TINKER_CONFIG={TINKER_CONFIG} not found, ignoring", flush=True)
|
||||
return None
|
||||
try:
|
||||
from tinker_atropos.config import TinkerAtroposConfig
|
||||
config = TinkerAtroposConfig.from_yaml(config_path)
|
||||
print(f"[EndlessTerminalsEnv] Loaded Tinker config from {config_path}", flush=True)
|
||||
return config
|
||||
except ImportError:
|
||||
print("[EndlessTerminalsEnv] tinker_atropos not installed, ignoring TINKER_CONFIG", flush=True)
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[EndlessTerminalsEnv] Error loading Tinker config: {e}", flush=True)
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class EndlessTerminalsEnvConfig(AgentEnvConfig):
|
||||
"""Configuration for Endless Terminals environment."""
|
||||
|
||||
# ---- Local directory mode (primary) ----
|
||||
use_local_dir: bool = Field(
|
||||
default=True,
|
||||
description="Load tasks from a local directory of task_* folders.",
|
||||
)
|
||||
local_tasks_dir: str = Field(
|
||||
default="",
|
||||
description="Path to directory containing task_* folders. Required if use_local_dir=True.",
|
||||
)
|
||||
prebuild_images: bool = Field(
|
||||
default=False,
|
||||
description="Pre-build ALL Docker images during setup (slow but avoids build-during-training).",
|
||||
)
|
||||
max_concurrent_builds: int = Field(
|
||||
default=4,
|
||||
description="Max parallel Docker image builds during pre-build.",
|
||||
)
|
||||
|
||||
# ---- HuggingFace dataset mode ----
|
||||
use_dataset: bool = Field(
|
||||
default=False,
|
||||
description="Load tasks from HuggingFace dataset.",
|
||||
)
|
||||
dataset_name: str = Field(
|
||||
default="obiwan96/endless-terminals-train",
|
||||
description="HuggingFace dataset name (if use_dataset=True)",
|
||||
)
|
||||
dataset_split: str = Field(default="train")
|
||||
dataset_cache_dir: str = Field(default="~/.cache/huggingface/datasets")
|
||||
tasks_base_dir: str = Field(
|
||||
default="",
|
||||
description="Base directory containing task_* folders (for dataset mode path resolution).",
|
||||
)
|
||||
|
||||
# ---- Procedural generation mode ----
|
||||
task_gen_model: str = Field(default="Qwen/Qwen3-32B")
|
||||
task_gen_temperature: float = Field(default=1.0)
|
||||
task_gen_max_tokens: int = Field(default=2048)
|
||||
|
||||
# ---- Container / scoring ----
|
||||
container_build_timeout_s: float = Field(default=600.0, description="Docker build timeout")
|
||||
test_timeout_s: int = Field(default=120, description="Test execution timeout (seconds)")
|
||||
keep_failed_tasks: bool = Field(default=False)
|
||||
|
||||
# ---- Agent defaults ----
|
||||
agent_max_steps: int = Field(default=32)
|
||||
agent_temperature: float = Field(default=0.7)
|
||||
|
||||
# ---- Docker image prefix ----
|
||||
docker_image_prefix: str = Field(
|
||||
default="endless-terminals",
|
||||
description="Docker image name prefix for built task images.",
|
||||
)
|
||||
|
||||
# ---- Server defaults ----
|
||||
server_base_url: str = Field(default="http://127.0.0.1:8080")
|
||||
server_model: str = Field(default="hermes-4-36b")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class EndlessTerminalsEnv(AgentEnv[EndlessTerminalsEnvConfig]):
|
||||
"""
|
||||
Endless Terminals environment.
|
||||
|
||||
Each task:
|
||||
1. Has a Dockerfile defining the initial container state
|
||||
2. Has an instruction.md describing what the agent should do
|
||||
3. Has tests/test_final_state.py to verify completion
|
||||
|
||||
Flow per trajectory:
|
||||
1. get_next_item() → picks a task
|
||||
2. setup_trajectory_workspace() → builds Docker image, registers with backend
|
||||
3. Agent solves task via terminal commands (docker exec in the container)
|
||||
4. verify_and_score_trajectory() → runs pytest in container, returns binary reward
|
||||
"""
|
||||
|
||||
name = "endless_terminals_env"
|
||||
env_config_cls = EndlessTerminalsEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: EndlessTerminalsEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iteration = 0
|
||||
|
||||
# Local dir mode
|
||||
self._local_tasks: List[Dict[str, Any]] = []
|
||||
self._local_task_indices: List[int] = []
|
||||
self._local_current_index = 0
|
||||
|
||||
# Eval split (held-out tasks)
|
||||
self._eval_tasks: List[Dict[str, Any]] = []
|
||||
|
||||
# Training metrics
|
||||
self._train_scores_buffer: List[float] = []
|
||||
self._eval_metrics: List[tuple] = []
|
||||
|
||||
# HF dataset mode
|
||||
self._dataset = None
|
||||
self._dataset_indices: List[int] = []
|
||||
self._dataset_current_index = 0
|
||||
|
||||
# Docker image cache: task_name -> image_tag
|
||||
self._image_cache: Dict[str, str] = {}
|
||||
self._build_lock = asyncio.Lock()
|
||||
|
||||
# ---- Config init (CLI) ----
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[EndlessTerminalsEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Initialize config.
|
||||
|
||||
Two modes:
|
||||
1. Tinker mode: TINKER_CONFIG env var points to a Tinker YAML.
|
||||
Model, training params, and server config come from the YAML.
|
||||
2. Standalone mode: Everything from env vars (ATROPOS_SERVER_*, etc.)
|
||||
|
||||
In both modes, Endless Terminals-specific fields (ENDLESS_TERMINALS_DIR,
|
||||
PREBUILD_IMAGES, etc.) are always read from env vars.
|
||||
"""
|
||||
tinker_cfg = _load_tinker_config()
|
||||
|
||||
# ── Endless Terminals-specific fields (always from env vars) ──
|
||||
local_tasks_dir = os.getenv("ENDLESS_TERMINALS_DIR", "")
|
||||
use_local_dir = bool(local_tasks_dir)
|
||||
|
||||
if tinker_cfg is not None:
|
||||
# ── Tinker mode ─────────────────────────────────────────
|
||||
print("[EndlessTerminalsEnv] Using Tinker config", flush=True)
|
||||
|
||||
env_config = EndlessTerminalsEnvConfig(
|
||||
# Standard Atropos fields from Tinker YAML
|
||||
tokenizer_name=tinker_cfg.base_model,
|
||||
group_size=tinker_cfg.group_size,
|
||||
use_wandb=tinker_cfg.use_wandb,
|
||||
rollout_server_url=tinker_cfg.atropos_api_url,
|
||||
total_steps=tinker_cfg.num_steps,
|
||||
batch_size=tinker_cfg.batch_size,
|
||||
steps_per_eval=tinker_cfg.steps_per_eval,
|
||||
max_token_length=tinker_cfg.max_token_env_length,
|
||||
max_num_workers=tinker_cfg.max_num_workers,
|
||||
max_batches_offpolicy=tinker_cfg.max_batches_offpolicy,
|
||||
ensure_scores_are_not_same=tinker_cfg.ensure_scores_are_not_same,
|
||||
wandb_name=f"{tinker_cfg.wandb_run_name}-env",
|
||||
include_messages=True,
|
||||
|
||||
# Tooling: terminal only
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
|
||||
# Agent config
|
||||
agent_max_steps=int(os.getenv("AGENT_MAX_STEPS", "32")),
|
||||
agent_temperature=float(os.getenv("AGENT_TEMPERATURE", "0.7")),
|
||||
|
||||
# Docker-direct backend (no Nomad needed)
|
||||
tool_pool_mode="docker_direct",
|
||||
sandbox_image="ubuntu:22.04",
|
||||
purge_job_on_start=False,
|
||||
purge_job_on_shutdown=False,
|
||||
|
||||
# Endless Terminals fields
|
||||
use_local_dir=use_local_dir,
|
||||
local_tasks_dir=local_tasks_dir,
|
||||
prebuild_images=os.getenv("PREBUILD_IMAGES", "false").lower() == "true",
|
||||
use_dataset=os.getenv("USE_DATASET", "false").lower() == "true",
|
||||
dataset_name=os.getenv("ENDLESS_DATASET", "obiwan96/endless-terminals-train"),
|
||||
container_build_timeout_s=float(os.getenv("CONTAINER_BUILD_TIMEOUT", "600")),
|
||||
test_timeout_s=int(os.getenv("TEST_TIMEOUT", "120")),
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=tinker_cfg.base_model,
|
||||
base_url=tinker_cfg.inference_api_url + "/v1",
|
||||
api_key="x",
|
||||
server_type="sglang",
|
||||
num_requests_for_eval=tinker_cfg.num_requests_for_eval,
|
||||
timeout=600, # Longer timeout for multi-step agent trajectories
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
else:
|
||||
# ── Standalone mode (env vars) ──────────────────────────
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = (
|
||||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("NOUS_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or "local"
|
||||
)
|
||||
|
||||
env_config = EndlessTerminalsEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=int(os.getenv("ATROPOS_GROUP_SIZE", "4")),
|
||||
use_wandb=os.getenv("USE_WANDB", "false").lower() == "true",
|
||||
include_messages=True,
|
||||
total_steps=int(os.getenv("ATROPOS_TOTAL_STEPS", "1000")),
|
||||
batch_size=int(os.getenv("ATROPOS_BATCH_SIZE", "32")),
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
|
||||
# Tooling
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
|
||||
# Agent
|
||||
agent_max_steps=int(os.getenv("AGENT_MAX_STEPS", "32")),
|
||||
agent_temperature=float(os.getenv("AGENT_TEMPERATURE", "0.7")),
|
||||
|
||||
# Docker-direct backend
|
||||
tool_pool_mode="docker_direct",
|
||||
sandbox_image="ubuntu:22.04",
|
||||
purge_job_on_start=False,
|
||||
purge_job_on_shutdown=False,
|
||||
|
||||
# Endless Terminals fields
|
||||
use_local_dir=use_local_dir,
|
||||
local_tasks_dir=local_tasks_dir,
|
||||
prebuild_images=os.getenv("PREBUILD_IMAGES", "false").lower() == "true",
|
||||
use_dataset=os.getenv("USE_DATASET", "false").lower() == "true",
|
||||
dataset_name=os.getenv("ENDLESS_DATASET", "obiwan96/endless-terminals-train"),
|
||||
task_gen_model=os.getenv("TASK_GEN_MODEL", "Qwen/Qwen3-32B"),
|
||||
container_build_timeout_s=float(os.getenv("CONTAINER_BUILD_TIMEOUT", "600")),
|
||||
test_timeout_s=int(os.getenv("TEST_TIMEOUT", "120")),
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=int(os.getenv("MAX_CONCURRENT_REQUESTS", "4")),
|
||||
num_requests_for_eval=int(os.getenv("MAX_EVAL_REQUESTS", "4")),
|
||||
timeout=300,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
# ---- Setup ----
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
"""Env-specific setup: scan tasks and optionally pre-build images."""
|
||||
if self.config.use_local_dir:
|
||||
await self._setup_local_dir()
|
||||
elif self.config.use_dataset:
|
||||
await self._setup_hf_dataset()
|
||||
else:
|
||||
print("[EndlessTerminalsEnv] Using procedural task generation", flush=True)
|
||||
|
||||
async def _setup_local_dir(self) -> None:
|
||||
"""Scan local directory for task_* folders."""
|
||||
tasks_dir = Path(self.config.local_tasks_dir).expanduser().resolve()
|
||||
if not tasks_dir.is_dir():
|
||||
raise RuntimeError(f"local_tasks_dir does not exist: {tasks_dir}")
|
||||
|
||||
print(f"[EndlessTerminalsEnv] Scanning {tasks_dir} for tasks...", flush=True)
|
||||
|
||||
tasks = []
|
||||
for entry in sorted(tasks_dir.iterdir()):
|
||||
if not entry.is_dir() or not entry.name.startswith("task_"):
|
||||
continue
|
||||
|
||||
# Validate required files
|
||||
dockerfile = entry / "environment" / "Dockerfile"
|
||||
instruction = entry / "instruction.md"
|
||||
test_final = entry / "tests" / "test_final_state.py"
|
||||
|
||||
if not dockerfile.exists():
|
||||
continue
|
||||
if not instruction.exists():
|
||||
continue
|
||||
if not test_final.exists():
|
||||
continue
|
||||
|
||||
# Read task metadata
|
||||
task_json_path = entry / "environment" / "task.json"
|
||||
description = instruction.read_text(encoding="utf-8").strip()
|
||||
|
||||
truth = ""
|
||||
if task_json_path.exists():
|
||||
try:
|
||||
task_json = json.loads(task_json_path.read_text(encoding="utf-8"))
|
||||
# task.json may have a richer description; prefer instruction.md
|
||||
truth = task_json.get("truth", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tasks.append({
|
||||
"task_name": entry.name,
|
||||
"task_dir": str(entry),
|
||||
"dockerfile": str(dockerfile),
|
||||
"description": description,
|
||||
"truth": truth,
|
||||
"test_final": str(test_final),
|
||||
})
|
||||
|
||||
if not tasks:
|
||||
raise RuntimeError(f"No valid task_* directories found in {tasks_dir}")
|
||||
|
||||
# Split into train and eval (hold out ~5% for eval, min 10, max 50)
|
||||
random.shuffle(tasks)
|
||||
eval_count = max(10, min(50, len(tasks) // 20))
|
||||
eval_count = min(eval_count, len(tasks) // 2) # Never more than half
|
||||
|
||||
self._eval_tasks = tasks[:eval_count]
|
||||
self._local_tasks = tasks[eval_count:]
|
||||
self._local_task_indices = list(range(len(self._local_tasks)))
|
||||
random.shuffle(self._local_task_indices)
|
||||
self._local_current_index = 0
|
||||
|
||||
print(
|
||||
f"[EndlessTerminalsEnv] Found {len(tasks)} valid tasks "
|
||||
f"({len(self._local_tasks)} train, {len(self._eval_tasks)} eval)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Optionally pre-build all Docker images
|
||||
if self.config.prebuild_images:
|
||||
await self._prebuild_images()
|
||||
|
||||
async def _prebuild_images(self) -> None:
|
||||
"""Pre-build Docker images for all tasks."""
|
||||
print(f"[EndlessTerminalsEnv] Pre-building Docker images...", flush=True)
|
||||
sem = asyncio.Semaphore(self.config.max_concurrent_builds)
|
||||
built = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
async def _build_one(task: Dict[str, Any]) -> None:
|
||||
nonlocal built, skipped, failed
|
||||
image_tag = self._image_tag_for_task(task["task_name"])
|
||||
|
||||
if docker_image_exists(image_tag):
|
||||
self._image_cache[task["task_name"]] = image_tag
|
||||
skipped += 1
|
||||
return
|
||||
|
||||
async with sem:
|
||||
ok = await build_docker_image(
|
||||
task["dockerfile"], image_tag,
|
||||
timeout_s=self.config.container_build_timeout_s,
|
||||
)
|
||||
if ok:
|
||||
self._image_cache[task["task_name"]] = image_tag
|
||||
built += 1
|
||||
else:
|
||||
failed += 1
|
||||
|
||||
await asyncio.gather(*[_build_one(t) for t in self._local_tasks])
|
||||
print(
|
||||
f"[EndlessTerminalsEnv] Pre-build: {built} built, {skipped} cached, {failed} failed",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
async def _setup_hf_dataset(self) -> None:
|
||||
"""Load HuggingFace dataset."""
|
||||
print(f"[EndlessTerminalsEnv] Loading dataset: {self.config.dataset_name}", flush=True)
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
self._dataset = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: load_dataset(
|
||||
self.config.dataset_name,
|
||||
split=self.config.dataset_split,
|
||||
cache_dir=os.path.expanduser(self.config.dataset_cache_dir),
|
||||
),
|
||||
)
|
||||
self._dataset_indices = list(range(len(self._dataset)))
|
||||
random.shuffle(self._dataset_indices)
|
||||
self._dataset_current_index = 0
|
||||
print(f"[EndlessTerminalsEnv] Loaded {len(self._dataset)} tasks from dataset", flush=True)
|
||||
except Exception as e:
|
||||
print(f"[EndlessTerminalsEnv] ERROR loading dataset: {e}", flush=True)
|
||||
raise
|
||||
|
||||
# ---- Image helpers ----
|
||||
|
||||
def _image_tag_for_task(self, task_name: str) -> str:
|
||||
return f"{self.config.docker_image_prefix}:{task_name}"
|
||||
|
||||
async def _ensure_image(self, task: Dict[str, Any]) -> str:
|
||||
"""Ensure the Docker image for a task is built. Returns image tag."""
|
||||
task_name = task["task_name"]
|
||||
image_tag = self._image_tag_for_task(task_name)
|
||||
|
||||
# Fast path: already cached
|
||||
if task_name in self._image_cache:
|
||||
return self._image_cache[task_name]
|
||||
|
||||
async with self._build_lock:
|
||||
# Double-check after acquiring lock
|
||||
if task_name in self._image_cache:
|
||||
return self._image_cache[task_name]
|
||||
|
||||
# Check if image exists in Docker
|
||||
if docker_image_exists(image_tag):
|
||||
self._image_cache[task_name] = image_tag
|
||||
return image_tag
|
||||
|
||||
# Build it
|
||||
print(f"[EndlessTerminalsEnv] Building image {image_tag}...", flush=True)
|
||||
ok = await build_docker_image(
|
||||
task["dockerfile"], image_tag,
|
||||
timeout_s=self.config.container_build_timeout_s,
|
||||
)
|
||||
if not ok:
|
||||
raise RuntimeError(f"Failed to build Docker image for {task_name}")
|
||||
|
||||
self._image_cache[task_name] = image_tag
|
||||
return image_tag
|
||||
|
||||
# ---- Item generation ----
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iteration += 1
|
||||
|
||||
if self.config.use_local_dir and self._local_tasks:
|
||||
return self._get_next_local_item()
|
||||
elif self.config.use_dataset and self._dataset is not None:
|
||||
return self._get_next_dataset_item()
|
||||
else:
|
||||
return self._get_fallback_item()
|
||||
|
||||
def _get_next_local_item(self) -> Item:
|
||||
"""Pick the next task from local directories."""
|
||||
idx = self._local_task_indices[self._local_current_index]
|
||||
task = self._local_tasks[idx]
|
||||
|
||||
self._local_current_index += 1
|
||||
if self._local_current_index >= len(self._local_task_indices):
|
||||
random.shuffle(self._local_task_indices)
|
||||
self._local_current_index = 0
|
||||
print("[EndlessTerminalsEnv] Reshuffled local tasks (epoch complete)", flush=True)
|
||||
|
||||
return {
|
||||
"task_id": f"local_{self._iteration:06d}_{task['task_name']}",
|
||||
"task_name": task["task_name"],
|
||||
"description": task["description"],
|
||||
"truth": task.get("truth", ""),
|
||||
"task_dir": task["task_dir"],
|
||||
"dockerfile": task["dockerfile"],
|
||||
"test_final": task["test_final"],
|
||||
"from_local_dir": True,
|
||||
}
|
||||
|
||||
def _get_next_dataset_item(self) -> Item:
|
||||
"""Pick the next task from HuggingFace dataset."""
|
||||
idx = self._dataset_indices[self._dataset_current_index]
|
||||
task = self._dataset[idx]
|
||||
|
||||
self._dataset_current_index += 1
|
||||
if self._dataset_current_index >= len(self._dataset_indices):
|
||||
random.shuffle(self._dataset_indices)
|
||||
self._dataset_current_index = 0
|
||||
print("[EndlessTerminalsEnv] Reshuffled dataset (epoch complete)", flush=True)
|
||||
|
||||
# Resolve task directory
|
||||
task_dir = task.get("extra_info", {}).get("task_dir") or task.get("reward_spec", {}).get("ground_truth", "")
|
||||
if self.config.tasks_base_dir:
|
||||
task_name = Path(task_dir).name
|
||||
task_dir = str(Path(self.config.tasks_base_dir) / task_name)
|
||||
|
||||
task_dir_path = Path(task_dir)
|
||||
return {
|
||||
"task_id": f"dataset_{self._iteration:06d}_{task_dir_path.name}",
|
||||
"task_name": task_dir_path.name,
|
||||
"description": task.get("description", ""),
|
||||
"task_dir": task_dir,
|
||||
"dockerfile": str(task_dir_path / "environment" / "Dockerfile"),
|
||||
"test_final": str(task_dir_path / "tests" / "test_final_state.py"),
|
||||
"from_dataset": True,
|
||||
}
|
||||
|
||||
def _get_fallback_item(self) -> Item:
|
||||
return {
|
||||
"task_id": f"fallback_{self._iteration:06d}",
|
||||
"task_name": "fallback",
|
||||
"description": (
|
||||
"Create a file named 'hello.txt' in /home/user/ containing "
|
||||
"the text 'Hello, World!' on a single line."
|
||||
),
|
||||
"task_dir": "",
|
||||
"dockerfile": "",
|
||||
"test_final": "",
|
||||
}
|
||||
|
||||
# ---- AgentEnv hooks ----
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
"""Return the task prompt for the agent."""
|
||||
return str(item.get("description", ""))
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig:
|
||||
return AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
tool_delay_s=self.config.agent_tool_delay_s,
|
||||
)
|
||||
|
||||
async def setup_trajectory_workspace(
|
||||
self,
|
||||
item: Item,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build the Docker image for this task and register it with the backend.
|
||||
|
||||
The DockerDirectBackend will start a container from this image when the
|
||||
agent makes its first tool call (lazy acquisition via ToolExecutor).
|
||||
"""
|
||||
task_name = item.get("task_name", "unknown")
|
||||
dockerfile = item.get("dockerfile", "")
|
||||
|
||||
if not dockerfile or not Path(dockerfile).exists():
|
||||
print(f"[EndlessTerminalsEnv] WARNING: No Dockerfile for {task_name}", flush=True)
|
||||
return {"image": "ubuntu:22.04"}
|
||||
|
||||
# Build/get Docker image
|
||||
image_tag = await self._ensure_image({
|
||||
"task_name": task_name,
|
||||
"dockerfile": dockerfile,
|
||||
})
|
||||
|
||||
# Register image with the DockerDirect backend
|
||||
if isinstance(self._backend, DockerDirectBackend):
|
||||
self._backend.register_image(trajectory_id, image_tag)
|
||||
|
||||
return {"image": image_tag, "task_name": task_name}
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
"""Not used — scoring happens in verify_and_score_trajectory."""
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
agent_result=None,
|
||||
workspace_meta=None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
"""
|
||||
Run test_final_state.py inside the container and return binary reward.
|
||||
"""
|
||||
task_id = item.get("task_id", "unknown")
|
||||
test_final = item.get("test_final", "")
|
||||
|
||||
if not test_final or not Path(test_final).exists():
|
||||
print(f"[EndlessTerminalsEnv] No test file for {task_id}", flush=True)
|
||||
return 0.0, {"error": "No test file"}
|
||||
|
||||
print(f"[EndlessTerminalsEnv] Scoring {task_id}...", flush=True)
|
||||
|
||||
try:
|
||||
# Read the test file and base64-encode it for safe transfer
|
||||
test_content = Path(test_final).read_text(encoding="utf-8")
|
||||
encoded = base64.b64encode(test_content.encode("utf-8")).decode("ascii")
|
||||
|
||||
# Write test file into the container and run pytest
|
||||
# We write to /tmp to avoid interfering with the agent's workspace
|
||||
# Use printf + heredoc to avoid quoting issues with single quotes in base64
|
||||
verify_cmd = (
|
||||
f"printf '%s' '{encoded}' | base64 -d > /tmp/_test_final_state.py && "
|
||||
f"cd /home/user && "
|
||||
f"python3 -m pytest /tmp/_test_final_state.py -v --tb=short 2>&1; "
|
||||
f"echo \"EXIT_CODE=$?\""
|
||||
)
|
||||
|
||||
result = await exec_tool(ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": verify_cmd},
|
||||
))
|
||||
|
||||
output = result.output if hasattr(result, "output") else str(result)
|
||||
|
||||
# Check if pytest passed
|
||||
# Look for EXIT_CODE=0 at the end (most reliable)
|
||||
success = "EXIT_CODE=0" in output
|
||||
|
||||
score = 1.0 if success else 0.0
|
||||
|
||||
metadata = {
|
||||
"task_id": task_id,
|
||||
"success": success,
|
||||
"test_output": output[-2000:] if len(output) > 2000 else output,
|
||||
"total_tool_calls": agent_result.total_tool_calls if agent_result else 0,
|
||||
}
|
||||
|
||||
self._train_scores_buffer.append(score)
|
||||
print(f"[EndlessTerminalsEnv] {task_id} → score={score}", flush=True)
|
||||
return score, metadata
|
||||
|
||||
except Exception as e:
|
||||
print(f"[EndlessTerminalsEnv] Error scoring {task_id}: {e}", flush=True)
|
||||
return 0.0, {"error": str(e)}
|
||||
|
||||
# ---- WandB logging ----
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log training metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Training pass rate since last log
|
||||
if self._train_scores_buffer:
|
||||
wandb_metrics["train/percent_correct"] = (
|
||||
sum(self._train_scores_buffer) / len(self._train_scores_buffer)
|
||||
)
|
||||
wandb_metrics["train/num_trajectories"] = len(self._train_scores_buffer)
|
||||
self._train_scores_buffer = []
|
||||
|
||||
# Eval metrics (populated by evaluate())
|
||||
for key, value in self._eval_metrics:
|
||||
wandb_metrics[key] = value
|
||||
self._eval_metrics = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
# ---- Evaluation ----
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Run the agent on held-out eval tasks and report pass rate.
|
||||
|
||||
Each eval task: build Docker container → run agent (temp=0) → pytest → score.
|
||||
This is expensive (full agent trajectories), so we only eval a subset.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
if not self._eval_tasks:
|
||||
return {}
|
||||
|
||||
start_time = _time.time()
|
||||
eval_sample_size = min(len(self._eval_tasks), 20)
|
||||
eval_subset = random.sample(self._eval_tasks, eval_sample_size)
|
||||
|
||||
print(
|
||||
f"[EndlessTerminalsEnv] Running evaluation on {eval_sample_size} tasks...",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
scores = []
|
||||
samples = []
|
||||
|
||||
for task_info in eval_subset:
|
||||
task_name = task_info["task_name"]
|
||||
description = task_info["description"]
|
||||
|
||||
try:
|
||||
# Build Docker image
|
||||
image_tag = await self._ensure_image(task_info)
|
||||
|
||||
# Run agent with temp=0 for deterministic eval
|
||||
eval_tid = f"eval_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Register image with backend
|
||||
if isinstance(self._backend, DockerDirectBackend):
|
||||
self._backend.register_image(eval_tid, image_tag)
|
||||
|
||||
async def _exec(call, _tid=eval_tid):
|
||||
return await self._tool_executor.execute(_tid, call)
|
||||
|
||||
from ..agent import AtroposAgent as _AtroposAgent
|
||||
|
||||
agent = _AtroposAgent(
|
||||
server=self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
tools=self.tools,
|
||||
config=AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=0.0, # Deterministic for eval
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
),
|
||||
execute_tool=_exec,
|
||||
)
|
||||
|
||||
result = await agent.run(description)
|
||||
|
||||
# Score: run pytest in the container
|
||||
score = 0.0
|
||||
test_final = task_info.get("test_final", "")
|
||||
if result.success and test_final and Path(test_final).exists():
|
||||
test_content = Path(test_final).read_text(encoding="utf-8")
|
||||
encoded = base64.b64encode(test_content.encode("utf-8")).decode("ascii")
|
||||
verify_cmd = (
|
||||
f"printf '%s' '{encoded}' | base64 -d > /tmp/_test_final_state.py && "
|
||||
f"cd /home/user && "
|
||||
f"python3 -m pytest /tmp/_test_final_state.py -v --tb=short 2>&1; "
|
||||
f'echo "EXIT_CODE=$?"'
|
||||
)
|
||||
test_result = await _exec(ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": verify_cmd},
|
||||
))
|
||||
test_output = test_result.output if hasattr(test_result, "output") else ""
|
||||
if "EXIT_CODE=0" in test_output:
|
||||
score = 1.0
|
||||
|
||||
scores.append(score)
|
||||
samples.append({
|
||||
"task": task_name,
|
||||
"score": score,
|
||||
"tool_calls": result.total_tool_calls,
|
||||
"success": result.success,
|
||||
})
|
||||
|
||||
# Cleanup
|
||||
await self._tool_executor.release_trajectory(eval_tid, reset_workspace=True)
|
||||
|
||||
print(f" [eval] {task_name} → {score}", flush=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f" [eval] {task_name} → ERROR: {e}", flush=True)
|
||||
scores.append(0.0)
|
||||
samples.append({"task": task_name, "score": 0.0, "error": str(e)})
|
||||
|
||||
end_time = _time.time()
|
||||
|
||||
percent_correct = sum(scores) / len(scores) if scores else 0.0
|
||||
|
||||
print(
|
||||
f"[EndlessTerminalsEnv] Eval: {percent_correct:.1%} pass rate "
|
||||
f"({sum(scores):.0f}/{len(scores)}) in {end_time - start_time:.0f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Store for wandb_log to pick up
|
||||
self._eval_metrics.append(("eval/percent_correct", percent_correct))
|
||||
self._eval_metrics.append(("eval/num_tasks", len(scores)))
|
||||
self._eval_metrics.append(("eval/duration_s", end_time - start_time))
|
||||
|
||||
# Log via atroposlib
|
||||
eval_metrics = {
|
||||
"eval/percent_correct": percent_correct,
|
||||
"eval/num_tasks": len(scores),
|
||||
}
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": 0.0,
|
||||
"max_tokens": self.config.agent_max_tokens,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
EndlessTerminalsEnv.cli()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
"""
|
||||
Hermes-Agent + Atropos (Nomad sandbox) compatibility smoke environment.
|
||||
|
||||
This environment is intended to validate, end-to-end:
|
||||
BaseEnv.process -> AgentEnv -> ToolExecutor (batched) -> Nomad SlotPool -> sandbox_server
|
||||
|
||||
It forces the model to use a sandbox tool by asking it to run a command that
|
||||
generates a high-entropy token inside the sandbox, then repeat it exactly.
|
||||
|
||||
Run (process mode):
|
||||
uv run python -m atropos.envs.hermes_compat_test_env process --env.use_wandb false --env.total_steps 2 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _forced_tool_item() -> Item:
|
||||
# Use double quotes in the shell command and show JSON escaping explicitly.
|
||||
# This avoids invalid JSON escapes like `\\'` (not valid JSON) that some models produce.
|
||||
cmd = 'python -c "import secrets; print(secrets.token_hex(16))"'
|
||||
return {
|
||||
"command": cmd,
|
||||
"prompt": (
|
||||
"You are acting as an agent inside a sandboxed environment.\n"
|
||||
"You MUST use the terminal tool to execute commands.\n"
|
||||
"Run this exact command:\n"
|
||||
f"{cmd}\n"
|
||||
"When you call the tool, use valid JSON inside <tool_call>. Example:\n"
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": '
|
||||
'"python -c \\\\"import secrets; print(secrets.token_hex(16))\\\\""}}'
|
||||
"</tool_call>\n"
|
||||
"Then respond with EXACTLY what it printed (the hex token) and nothing else.\n"
|
||||
"Do not guess. Do not explain."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class HermesCompatTestEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class HermesCompatTestEnv(AgentEnv[HermesCompatTestEnvConfig]):
|
||||
name = "hermes_compat_test_env"
|
||||
env_config_cls = HermesCompatTestEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: HermesCompatTestEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[HermesCompatTestEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = HermesCompatTestEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=2,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
# Tooling: sandbox-only terminal.
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
# Default to Nomad sandboxing; users can override via --env.* args.
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
# In local dev it's common for a previous crash to leave the job in backoff.
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return _forced_tool_item()
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# Avoid imposing max_tokens by default; tool-tag responses can be long for some models.
|
||||
return AgentConfig(
|
||||
max_steps=min(8, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Scoring happens in verify_and_score_trajectory so we can inspect tool results.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
observed: str = ""
|
||||
tool_ok = False
|
||||
for step in agent_result.steps:
|
||||
for res in step.tool_results:
|
||||
if not res.success:
|
||||
return 0.0, {"error": res.error, "output": res.output}
|
||||
out = (res.output or "").strip()
|
||||
if out:
|
||||
observed = out.splitlines()[-1].strip()
|
||||
tool_ok = True
|
||||
|
||||
final = (final_response or "").strip()
|
||||
score = 1.0 if tool_ok and agent_result.total_tool_calls > 0 and observed and final == observed else 0.0
|
||||
return score, {"observed": observed, "tool_calls": agent_result.total_tool_calls, "command": item.get("command")}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
HermesCompatTestEnv.cli()
|
||||
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Nomad sandbox terminal smoke environment (training-oriented).
|
||||
|
||||
Validates, end-to-end:
|
||||
BaseEnv.process -> AgentEnv -> ToolExecutor (batched) -> Nomad SlotPool -> sandbox_server
|
||||
|
||||
It forces the model to use a sandbox tool by asking it to run a command that
|
||||
generates a high-entropy token inside the sandbox, then repeat it exactly.
|
||||
|
||||
Run (process mode):
|
||||
uv run python -m atropos.envs.sandbox_terminal_smoke_env process --env.use_wandb false --env.total_steps 2 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
STRICT_TOOLCALL_SYSTEM_PROMPT = None
|
||||
|
||||
|
||||
def _forced_tool_item() -> Item:
|
||||
# Use double quotes in the shell command and show JSON escaping explicitly.
|
||||
# This avoids invalid JSON escapes like `\\'` (not valid JSON) that some models produce.
|
||||
cmd = 'python -c "import secrets; print(secrets.token_hex(16))"'
|
||||
return {
|
||||
"command": cmd,
|
||||
"prompt": (
|
||||
"You MUST use the terminal tool.\n"
|
||||
"Run this exact command:\n"
|
||||
f"{cmd}\n"
|
||||
"When you call the tool, use valid JSON inside <tool_call>. Example:\n"
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": '
|
||||
'"python -c \\\\"import secrets; print(secrets.token_hex(16))\\\\""}}'
|
||||
"</tool_call>\n"
|
||||
"Then respond with EXACTLY what it printed (the hex token) and nothing else.\n"
|
||||
"Do not guess. Do not explain."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class SandboxTerminalSmokeEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SandboxTerminalSmokeEnv(AgentEnv[SandboxTerminalSmokeEnvConfig]):
|
||||
name = "sandbox_terminal_smoke_env"
|
||||
env_config_cls = SandboxTerminalSmokeEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SandboxTerminalSmokeEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SandboxTerminalSmokeEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SandboxTerminalSmokeEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=2,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
# Tooling: sandbox-only terminal.
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
# Default to Nomad sandboxing; users can override via --env.* args.
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return _forced_tool_item()
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# Avoid imposing max_tokens by default; tool-tag responses can be long for some models.
|
||||
return AgentConfig(
|
||||
max_steps=min(8, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
system_prompt=STRICT_TOOLCALL_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Scoring happens in verify_and_score_trajectory so we can inspect tool results.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
observed: str = ""
|
||||
tool_ok = False
|
||||
for step in agent_result.steps:
|
||||
for res in step.tool_results:
|
||||
if not res.success:
|
||||
return 0.0, {"error": res.error, "output": res.output}
|
||||
out = (res.output or "").strip()
|
||||
if out:
|
||||
observed = out.splitlines()[-1].strip()
|
||||
tool_ok = True
|
||||
|
||||
final = (final_response or "").strip()
|
||||
score = 1.0 if tool_ok and agent_result.total_tool_calls > 0 and observed and final == observed else 0.0
|
||||
return score, {"observed": observed, "tool_calls": agent_result.total_tool_calls, "command": item.get("command")}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SandboxTerminalSmokeEnv.cli()
|
||||
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
SWE-smith-oracle environment.
|
||||
|
||||
This environment is intentionally minimal:
|
||||
- prepares a sandbox workspace by cloning a public GitHub repo at `base_commit`
|
||||
- runs an AtroposAgent tool loop to apply a fix
|
||||
- verifies by running pytest nodeids from the dataset (reward = pass/fail)
|
||||
- Python only (no multi-language support currently, need to properly bauild & add to dropbox)
|
||||
- TODO: Get the other nonpython sandboxes up and running, then add a config knob to switch between them per row
|
||||
- oh and add to dockerhub
|
||||
|
||||
Dataset: NousResearch/SWE-smith-oracle (train; does NOT use SWE-bench eval set).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig
|
||||
from ..tools import ToolCall
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
|
||||
class SweSmithOracleEnvConfig(AgentEnvConfig):
|
||||
dataset_name: str = Field(default="NousResearch/SWE-smith-oracle")
|
||||
dataset_split: str = Field(default="train")
|
||||
max_items: int = Field(default=0, description="0 = no limit")
|
||||
shuffle: bool = Field(default=True)
|
||||
seed: int = Field(default=0)
|
||||
|
||||
python_only: bool = Field(default=True, description="Filter to Python-evaluable rows")
|
||||
score_include_fail_to_pass: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"If true (default), score tests on PASS_TO_PASS ∪ FAIL_TO_PASS. "
|
||||
"Disable to only run PASS_TO_PASS (faster but weaker signal)."
|
||||
),
|
||||
)
|
||||
|
||||
prompt_mode: str = Field(
|
||||
default="problem_statement",
|
||||
description="Task prompt content: 'problem_statement' (fast) or 'problem_statement+text' (slower, includes dataset 'text').",
|
||||
)
|
||||
|
||||
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
||||
install_timeout_s: float = Field(default=600.0)
|
||||
test_timeout_s: float = Field(default=600.0)
|
||||
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||
"""
|
||||
SWE-smith-oracle AgentEnv.
|
||||
|
||||
This is designed for benchmarking multiplexed slot execution vs naive container-per-trajectory.
|
||||
"""
|
||||
|
||||
name = "swe_smith_oracle_env"
|
||||
env_config_cls = SweSmithOracleEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SweSmithOracleEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._dataset = None
|
||||
self._indices: List[int] = []
|
||||
self._cursor = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
||||
# Defaults for running the env via CLI in offline `process` mode.
|
||||
# Override via env vars or `--env.*` flags as needed.
|
||||
base_url_raw = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
base_url = base_url_raw.rstrip("/")
|
||||
if not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url}/v1"
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SweSmithOracleEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
steps_per_eval=1,
|
||||
max_token_length=8192,
|
||||
inference_weight=1.0,
|
||||
wandb_name="swe_smith_oracle",
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=[],
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"),
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
from datasets import load_dataset
|
||||
|
||||
t0 = time.perf_counter()
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})",
|
||||
flush=True,
|
||||
)
|
||||
ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||||
self._dataset = ds
|
||||
|
||||
indices: List[int] = []
|
||||
for idx in range(len(ds)):
|
||||
row = ds[idx]
|
||||
if self.config.python_only and not self._is_python_row(row):
|
||||
continue
|
||||
indices.append(idx)
|
||||
|
||||
if self.config.shuffle:
|
||||
rnd = random.Random(self.config.seed)
|
||||
rnd.shuffle(indices)
|
||||
|
||||
if self.config.max_items and self.config.max_items > 0:
|
||||
indices = indices[: self.config.max_items]
|
||||
|
||||
self._indices = indices
|
||||
self._cursor = 0
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] loaded {len(self._indices)} items from {self.config.dataset_name}:{self.config.dataset_split} "
|
||||
f"in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def _is_python_row(self, row: Dict[str, Any]) -> bool:
|
||||
nodeids = row.get("PASS_TO_PASS")
|
||||
if not isinstance(nodeids, list) or not nodeids:
|
||||
return False
|
||||
for nid in nodeids:
|
||||
if not isinstance(nid, str) or ".py::" not in nid:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
print(f"[SweSmithOracleEnv] get_next_item() cursor={self._cursor}/{len(self._indices)}", flush=True)
|
||||
if not self._dataset or not self._indices:
|
||||
raise RuntimeError("Dataset not initialized (did setup() run?)")
|
||||
if self._cursor >= len(self._indices):
|
||||
self._cursor = 0
|
||||
idx = self._indices[self._cursor]
|
||||
self._cursor += 1
|
||||
return dict(self._dataset[idx])
|
||||
|
||||
def _repo_name(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
if isinstance(repo, str) and "/" in repo:
|
||||
return repo.split("/")[-1]
|
||||
return "repo"
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
repo = item.get("repo") or ""
|
||||
base_commit = item.get("base_commit") or ""
|
||||
problem = str(item.get("problem_statement") or "")
|
||||
context = str(item.get("text") or "")
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
tests_list = "\n".join(f"- {t}" for t in nodeids)
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
tests_block = (
|
||||
"Run these tests to verify:\n"
|
||||
f"{tests_list}\n\n"
|
||||
"When done, briefly describe what you changed and confirm tests pass."
|
||||
)
|
||||
|
||||
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
|
||||
if prompt_mode not in {"problem_statement", "problem_statement+text"}:
|
||||
raise ValueError(
|
||||
f"Invalid prompt_mode={self.config.prompt_mode!r}. "
|
||||
"Expected 'problem_statement' or 'problem_statement+text'."
|
||||
)
|
||||
|
||||
context_block = ""
|
||||
if prompt_mode == "problem_statement+text" and context:
|
||||
# Note: We intentionally do NOT truncate/cap here. This mode is for debugging / richer prompts and can be slow.
|
||||
context_block = f"\nAdditional context:\n{context}\n"
|
||||
|
||||
return (
|
||||
"You are a senior software engineer. Fix the repository so the specified tests pass.\n\n"
|
||||
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
||||
f"Workspace path: ./{repo_dir}\n\n"
|
||||
"Constraints:\n"
|
||||
"- You MUST use the terminal tool to inspect, edit, and verify the repository. Do not respond with a patch file.\n"
|
||||
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
|
||||
"- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n"
|
||||
"- Use non-interactive commands only.\n\n"
|
||||
"- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n"
|
||||
" Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n"
|
||||
"Problem statement:\n"
|
||||
f"{problem}\n\n"
|
||||
f"{context_block}\n"
|
||||
f"{tests_block}"
|
||||
)
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
# SWE tasks are longer than the simple test env.
|
||||
return AgentConfig(
|
||||
max_steps=self.config.agent_max_steps,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.agent_max_tokens,
|
||||
tool_delay_s=self.config.agent_tool_delay_s,
|
||||
)
|
||||
|
||||
async def setup_trajectory_workspace(self, item: Item, *, trajectory_id: str, exec_tool) -> Dict[str, Any]:
|
||||
t0 = time.perf_counter()
|
||||
repo = item.get("repo")
|
||||
base_commit = item.get("base_commit")
|
||||
instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id")
|
||||
if not isinstance(repo, str) or not isinstance(base_commit, str):
|
||||
raise RuntimeError("Invalid dataset row: missing repo/base_commit")
|
||||
|
||||
repo_dir = self._repo_name(item)
|
||||
clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git"
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||||
f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Repo setup strategy:
|
||||
# - Maintain a shared, per-container bare repo cache under /data/repo_cache
|
||||
# - For each trajectory, create an isolated git worktree under the slot workspace
|
||||
# This avoids cloning/fetching full repos per trajectory and is crucial for multiplexing.
|
||||
|
||||
def _repo_cache_slug(repo_name: str) -> str:
|
||||
return repo_name.replace("/", "__")
|
||||
|
||||
repo_slug = _repo_cache_slug(repo)
|
||||
cache_root = "/data/repo_cache"
|
||||
bare_repo = f"{cache_root}/{repo_slug}.git"
|
||||
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
|
||||
|
||||
# Use flock to serialize operations that mutate the shared bare repo (fetch/worktree).
|
||||
# util-linux (flock) is included in the sandbox image.
|
||||
worktree_cmd = (
|
||||
"set -e; "
|
||||
f"rm -rf {repo_dir}; "
|
||||
f"mkdir -p {cache_root}/.locks; "
|
||||
f": > {lock_file}; "
|
||||
f"flock -x {lock_file} sh -lc '"
|
||||
f"set -e; "
|
||||
"export GIT_TERMINAL_PROMPT=0; "
|
||||
"export GIT_LFS_SKIP_SMUDGE=1; "
|
||||
f"if [ ! -d \"{bare_repo}\" ]; then "
|
||||
f" git init --bare \"{bare_repo}\"; "
|
||||
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
|
||||
"fi; "
|
||||
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
|
||||
f"git -C \"{bare_repo}\" worktree prune || true; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
|
||||
"fi; "
|
||||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||
f" git -C \"{bare_repo}\" fetch --prune origin; "
|
||||
"fi; "
|
||||
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
|
||||
"'"
|
||||
)
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
|
||||
res = await exec_tool(
|
||||
ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": worktree_cmd, "timeout": self.config.install_timeout_s},
|
||||
)
|
||||
)
|
||||
if not res.success:
|
||||
raise RuntimeError(
|
||||
"git worktree setup failed "
|
||||
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): worktree ready in {time.perf_counter() - t0:.2f}s",
|
||||
flush=True,
|
||||
)
|
||||
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
||||
|
||||
def _tests_for_item(self, item: Item) -> List[str]:
|
||||
tests: List[str] = []
|
||||
if self.config.score_include_fail_to_pass:
|
||||
for key in ("PASS_TO_PASS", "FAIL_TO_PASS"):
|
||||
nodeids = item.get(key)
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
else:
|
||||
nodeids = item.get("PASS_TO_PASS")
|
||||
if isinstance(nodeids, list):
|
||||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||||
# Stable order for reproducibility.
|
||||
return sorted(dict.fromkeys(tests))
|
||||
|
||||
def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]:
|
||||
chunks: List[List[str]] = []
|
||||
for i in range(0, len(nodeids), max_per_chunk):
|
||||
chunks.append(nodeids[i : i + max_per_chunk])
|
||||
return chunks
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str, # noqa: ARG002
|
||||
*,
|
||||
trajectory_id: str,
|
||||
exec_tool,
|
||||
agent_result=None,
|
||||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
_ = trajectory_id
|
||||
repo_dir = self._repo_name(item)
|
||||
|
||||
# Training correctness: do not reward trajectories that never actually used tools.
|
||||
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
|
||||
print(
|
||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): no tool calls; score=0.0",
|
||||
flush=True,
|
||||
)
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"error": "No tool calls were made by the agent",
|
||||
}
|
||||
|
||||
nodeids = self._tests_for_item(item)
|
||||
if not nodeids:
|
||||
return 0.0, {"error": "No tests provided"}
|
||||
|
||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): ensuring venv + deps", flush=True)
|
||||
setup_cmd = (
|
||||
f"cd {repo_dir} && "
|
||||
"python -m venv .venv && "
|
||||
". .venv/bin/activate && "
|
||||
"python -m pip install -U pip setuptools wheel && "
|
||||
"python -m pip install -e . && "
|
||||
"python -m pip install pytest"
|
||||
)
|
||||
setup_res = await exec_tool(
|
||||
ToolCall(name="terminal", arguments={"command": setup_cmd, "timeout": self.config.install_timeout_s})
|
||||
)
|
||||
verification_messages = [{"role": "user", "content": setup_res.to_xml()}]
|
||||
if not setup_res.success:
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"phase": "install",
|
||||
"error": setup_res.error,
|
||||
"output": setup_res.output,
|
||||
"verification_messages": verification_messages,
|
||||
}
|
||||
|
||||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||||
for chunk_idx, chunk in enumerate(chunks):
|
||||
joined = " ".join(chunk)
|
||||
cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}"
|
||||
res = await exec_tool(
|
||||
ToolCall(
|
||||
name="terminal",
|
||||
arguments={"command": cmd, "timeout": self.config.test_timeout_s},
|
||||
)
|
||||
)
|
||||
verification_messages.append({"role": "user", "content": res.to_xml()})
|
||||
if not res.success:
|
||||
return 0.0, {
|
||||
"verification_mode": "dataset_tests",
|
||||
"phase": "pytest",
|
||||
"failed_chunk": chunk_idx,
|
||||
"error": res.error,
|
||||
"output": res.output,
|
||||
"verification_messages": verification_messages,
|
||||
}
|
||||
|
||||
return 1.0, {"verification_mode": "dataset_tests", "passed": True, "verification_messages": verification_messages}
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
# Not used; scoring happens in verify_and_score_trajectory.
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
SweSmithOracleEnv.cli()
|
||||
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Simple test environment for validating the atropos-agent setup.
|
||||
|
||||
This environment uses a local OpenAI-compatible server for LLM testing to verify:
|
||||
- BaseEnv extension works correctly
|
||||
- API communication via OpenAI-compatible endpoint
|
||||
- Basic trajectory collection
|
||||
|
||||
This is a minimal environment for testing, not production use.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
APIServerConfig,
|
||||
Item,
|
||||
)
|
||||
|
||||
from ..agent import AgentConfig
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Simple test prompts for validation
|
||||
TEST_PROMPTS = [
|
||||
{
|
||||
"prompt": "What is 2 + 2? Answer with just the number.",
|
||||
"expected": "4",
|
||||
},
|
||||
{
|
||||
"prompt": "What is the capital of France? Answer with just the city name.",
|
||||
"expected": "Paris",
|
||||
},
|
||||
{
|
||||
"prompt": "What color is the sky on a clear day? Answer with just the color.",
|
||||
"expected": "Blue",
|
||||
},
|
||||
{
|
||||
"prompt": "How many days are in a week? Answer with just the number.",
|
||||
"expected": "7",
|
||||
},
|
||||
{
|
||||
"prompt": "What is 10 * 5? Answer with just the number.",
|
||||
"expected": "50",
|
||||
},
|
||||
]
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
"You are a helpful assistant. Answer questions concisely and directly. "
|
||||
"When asked for a simple answer, provide just that answer without explanation."
|
||||
)
|
||||
|
||||
|
||||
class SimpleTestEnvConfig(AgentEnvConfig):
|
||||
"""Configuration for the simple test environment."""
|
||||
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible server (without /v1)",
|
||||
)
|
||||
server_model: str = Field(
|
||||
default="hermes-4-36b",
|
||||
description="Model name",
|
||||
)
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class SimpleTestEnv(AgentEnv[SimpleTestEnvConfig]):
|
||||
"""
|
||||
A simple test environment to validate the atropos-agent setup.
|
||||
|
||||
Uses a local OpenAI-compatible LLM endpoint with basic question-answering tasks.
|
||||
Scoring is based on whether the response contains the expected answer.
|
||||
"""
|
||||
|
||||
name = "simple_test_env"
|
||||
env_config_cls = SimpleTestEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SimpleTestEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self.iter = 0
|
||||
self.test_prompts = TEST_PROMPTS
|
||||
self.percent_correct_buffer: List[float] = []
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[SimpleTestEnvConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Initialize configuration with local server settings from environment variables.
|
||||
"""
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = SimpleTestEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=4,
|
||||
use_wandb=False, # Disable wandb for simple testing
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=10,
|
||||
batch_size=16,
|
||||
steps_per_eval=5,
|
||||
max_token_length=2048,
|
||||
inference_weight=1.0,
|
||||
wandb_name="simple_test",
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
)
|
||||
|
||||
# OpenAI-compatible servers typically expose chat completions at /v1.
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=4,
|
||||
num_requests_for_eval=8,
|
||||
timeout=120, # Local models may be slower
|
||||
),
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self):
|
||||
"""Setup the environment - load test data."""
|
||||
print(f"SimpleTestEnv setup complete. {len(self.test_prompts)} test prompts loaded.")
|
||||
print(f"Using server at: {self.config.server_base_url}")
|
||||
print(f"Model: {self.config.server_model}")
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""Get the next test prompt."""
|
||||
item = self.test_prompts[self.iter % len(self.test_prompts)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return item["prompt"]
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=5,
|
||||
temperature=0.7,
|
||||
max_tokens=256,
|
||||
system_prompt=SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
expected = item["expected"].lower()
|
||||
response_lower = (final_response or "").lower()
|
||||
score = 1.0 if expected in response_lower else 0.0
|
||||
self.percent_correct_buffer.append(score)
|
||||
return score
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
"""
|
||||
Simple evaluation - run through all test prompts once.
|
||||
"""
|
||||
correct = 0
|
||||
total = len(self.test_prompts)
|
||||
|
||||
for item in self.test_prompts:
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT},
|
||||
{"role": "user", "content": item["prompt"]},
|
||||
]
|
||||
|
||||
response = await self.server.chat_completion(
|
||||
messages=messages,
|
||||
n=1,
|
||||
max_tokens=256,
|
||||
temperature=0.0, # Greedy for eval
|
||||
split="eval",
|
||||
)
|
||||
|
||||
response_text = response.choices[0].message.content or ""
|
||||
expected = item["expected"].lower()
|
||||
|
||||
if expected in response_text.lower():
|
||||
correct += 1
|
||||
|
||||
accuracy = correct / total
|
||||
print(f"Evaluation: {correct}/{total} = {accuracy:.2%} accuracy")
|
||||
return {"eval_accuracy": accuracy}
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log metrics (simplified for testing)."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
if self.percent_correct_buffer:
|
||||
avg_correct = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
|
||||
wandb_metrics["train/percent_correct"] = avg_correct
|
||||
print(f"Train accuracy: {avg_correct:.2%}")
|
||||
self.percent_correct_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Allow running as CLI
|
||||
SimpleTestEnv.cli()
|
||||
@@ -0,0 +1,165 @@
|
||||
"""
|
||||
ToolServer routing smoke environment.
|
||||
|
||||
Validates that:
|
||||
- sandbox tools run through Nomad SlotPool (terminal -> bash in sandbox)
|
||||
- external tools run through ToolServer (skills_list)
|
||||
|
||||
This env uses ToolServer in-process by default (`tool_server_url="inprocess"`),
|
||||
so it is self-contained for local testing.
|
||||
|
||||
Run:
|
||||
uv run python -m atropos.envs.toolserver_smoke_env process --env.use_wandb false --env.total_steps 1 --env.group_size 1
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import APIServerConfig, Item
|
||||
|
||||
from ..agent import AgentConfig, AgentResult
|
||||
from .agent_env import AgentEnv, AgentEnvConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class ToolServerSmokeEnvConfig(AgentEnvConfig):
|
||||
server_base_url: str = Field(
|
||||
default="http://127.0.0.1:8080",
|
||||
description="Base URL for an OpenAI-compatible chat server (without /v1).",
|
||||
)
|
||||
server_model: str = Field(default="hermes-4-36b", description="Model name")
|
||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||
|
||||
|
||||
class ToolServerSmokeEnv(AgentEnv[ToolServerSmokeEnvConfig]):
|
||||
name = "toolserver_smoke_env"
|
||||
env_config_cls = ToolServerSmokeEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ToolServerSmokeEnvConfig,
|
||||
server_configs: List[APIServerConfig],
|
||||
slurm: bool = False,
|
||||
testing: bool = False,
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
self._iter = 0
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[ToolServerSmokeEnvConfig, List[APIServerConfig]]:
|
||||
base_url = (
|
||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://127.0.0.1:8080"
|
||||
)
|
||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||
|
||||
env_config = ToolServerSmokeEnvConfig(
|
||||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||||
group_size=1,
|
||||
use_wandb=False,
|
||||
include_messages=True,
|
||||
ensure_scores_are_not_same=False,
|
||||
total_steps=1,
|
||||
batch_size=1,
|
||||
server_base_url=base_url,
|
||||
server_model=model,
|
||||
enabled_toolsets=["terminal", "skills"],
|
||||
disabled_toolsets=[],
|
||||
# Self-contained ToolServer for local smoke.
|
||||
tool_server_url="inprocess",
|
||||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||||
purge_job_on_start=True,
|
||||
purge_job_on_shutdown=True,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=f"{base_url.rstrip('/')}/v1",
|
||||
api_key=api_key,
|
||||
num_max_requests_at_once=1,
|
||||
num_requests_for_eval=1,
|
||||
timeout=120,
|
||||
)
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup_agent_env(self) -> None:
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
self._iter += 1
|
||||
return {
|
||||
"prompt": (
|
||||
"You MUST call exactly one tool per assistant message.\n"
|
||||
"\n"
|
||||
"Step 1) Call the skills_list tool (no arguments), then stop.\n"
|
||||
"Step 2) After you receive the tool response, call the terminal tool to run:\n"
|
||||
"python -c \"print('ok')\"\n"
|
||||
"Step 3) After you receive the terminal tool response, answer with just: ok\n"
|
||||
"\n"
|
||||
"Tool call format requirements:\n"
|
||||
"- Every tool call MUST be a complete XML block with a closing tag.\n"
|
||||
"- Do NOT emit a second <tool_call> in the same assistant message.\n"
|
||||
"\n"
|
||||
"Example:\n"
|
||||
"<tool_call>{\"name\": \"skills_list\", \"arguments\": {}}</tool_call>\n"
|
||||
"Do not include anything else in your final answer."
|
||||
)
|
||||
}
|
||||
|
||||
def build_task(self, item: Item) -> str:
|
||||
return str(item.get("prompt") or "")
|
||||
|
||||
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
||||
return AgentConfig(
|
||||
max_steps=min(10, int(self.config.agent_max_steps)),
|
||||
temperature=0.2,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||
_ = (item, final_response)
|
||||
return 0.0
|
||||
|
||||
async def verify_and_score_trajectory(
|
||||
self,
|
||||
item: Item,
|
||||
final_response: str,
|
||||
*,
|
||||
trajectory_id: str, # noqa: ARG002
|
||||
exec_tool, # noqa: ARG002
|
||||
agent_result: AgentResult | None = None,
|
||||
workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002
|
||||
) -> tuple[float, Dict[str, Any]]:
|
||||
if agent_result is None:
|
||||
return 0.0, {"error": "Missing agent_result"}
|
||||
|
||||
called = {c.name for s in agent_result.steps for c in s.tool_calls}
|
||||
need = {"skills_list", "terminal"}
|
||||
if not need.issubset(called):
|
||||
return 0.0, {"error": f"Missing tool calls: {sorted(need - called)}", "called": sorted(called)}
|
||||
|
||||
terminal_ok = False
|
||||
for step in agent_result.steps:
|
||||
for call, res in zip(step.tool_calls, step.tool_results):
|
||||
if call.name != "terminal":
|
||||
continue
|
||||
if res.success and (res.output or "").strip().splitlines()[-1].strip() == "ok":
|
||||
terminal_ok = True
|
||||
|
||||
score = 1.0 if terminal_ok and (final_response or "").strip() == "ok" else 0.0
|
||||
return score, {"called": sorted(called), "final": (final_response or "").strip()}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ToolServerSmokeEnv.cli()
|
||||
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Nomad integration for atropos-agent.
|
||||
|
||||
Provides:
|
||||
- NomadClient: Client for Nomad HTTP API
|
||||
- Job templates for sandbox containers
|
||||
"""
|
||||
|
||||
from .client import NomadClient
|
||||
|
||||
__all__ = ["NomadClient"]
|
||||
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
Nomad API Client for atropos-agent.
|
||||
|
||||
Provides a simple async client for interacting with the Nomad HTTP API:
|
||||
- Submit/stop jobs
|
||||
- Query allocations
|
||||
- Get allocation addresses
|
||||
- Scale jobs up/down
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
|
||||
class AllocationStatus(Enum):
|
||||
"""Nomad allocation status."""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETE = "complete"
|
||||
FAILED = "failed"
|
||||
LOST = "lost"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Allocation:
|
||||
"""Information about a Nomad allocation."""
|
||||
id: str
|
||||
job_id: str
|
||||
task_group: str
|
||||
node_id: str
|
||||
status: AllocationStatus
|
||||
# Network info for reaching the allocation
|
||||
address: Optional[str] = None
|
||||
port: Optional[int] = None
|
||||
|
||||
@property
|
||||
def http_address(self) -> Optional[str]:
|
||||
"""Get full HTTP address for the allocation."""
|
||||
if self.address and self.port:
|
||||
return f"http://{self.address}:{self.port}"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobStatus:
|
||||
"""Status of a Nomad job."""
|
||||
id: str
|
||||
name: str
|
||||
status: str
|
||||
allocations: List[Allocation] = field(default_factory=list)
|
||||
count: int = 0 # Number of task groups
|
||||
|
||||
|
||||
class NomadClient:
|
||||
"""
|
||||
Async client for Nomad HTTP API.
|
||||
|
||||
Usage:
|
||||
client = NomadClient(address="http://localhost:4646")
|
||||
|
||||
# Submit a job
|
||||
await client.submit_job(job_spec)
|
||||
|
||||
# Get allocations
|
||||
allocs = await client.get_job_allocations("sandbox-python")
|
||||
|
||||
# Scale job
|
||||
await client.scale_job("sandbox-python", count=5)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: str = "http://localhost:4646",
|
||||
token: Optional[str] = None,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
self.address = address.rstrip("/")
|
||||
self.token = token or os.environ.get("NOMAD_TOKEN")
|
||||
self.timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create HTTP session."""
|
||||
if self._session is None or self._session.closed:
|
||||
headers = {}
|
||||
if self.token:
|
||||
headers["X-Nomad-Token"] = self.token
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Make an HTTP request to Nomad API."""
|
||||
session = await self._get_session()
|
||||
url = f"{self.address}{path}"
|
||||
|
||||
try:
|
||||
async with session.request(method, url, json=data) as response:
|
||||
if response.status == 404:
|
||||
return {"error": "not_found", "status": 404}
|
||||
|
||||
text = await response.text()
|
||||
if not text:
|
||||
return {"status": response.status}
|
||||
|
||||
try:
|
||||
result = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return {"text": text, "status": response.status}
|
||||
|
||||
if response.status >= 400:
|
||||
return {"error": result, "status": response.status}
|
||||
|
||||
return result if isinstance(result, dict) else {"data": result, "status": response.status}
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
return {"error": str(e), "status": 0}
|
||||
|
||||
# Job Operations
|
||||
|
||||
async def submit_job(self, job_spec: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Submit a job to Nomad.
|
||||
|
||||
Args:
|
||||
job_spec: Job specification dict (HCL converted to JSON)
|
||||
|
||||
Returns:
|
||||
Response with EvalID if successful
|
||||
"""
|
||||
return await self._request("POST", "/v1/jobs", {"Job": job_spec})
|
||||
|
||||
async def stop_job(self, job_id: str, purge: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Stop (and optionally purge) a job.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier
|
||||
purge: If True, completely remove the job
|
||||
"""
|
||||
path = f"/v1/job/{job_id}"
|
||||
if purge:
|
||||
path += "?purge=true"
|
||||
return await self._request("DELETE", path)
|
||||
|
||||
async def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get job details."""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}")
|
||||
if "error" in result and result.get("status") == 404:
|
||||
return None
|
||||
return result
|
||||
|
||||
async def get_job_status(self, job_id: str) -> Optional[JobStatus]:
|
||||
"""Get job status with allocations."""
|
||||
job = await self.get_job(job_id)
|
||||
if not job:
|
||||
return None
|
||||
|
||||
allocs = await self.get_job_allocations(job_id)
|
||||
|
||||
# Get count from task groups
|
||||
count = 0
|
||||
task_groups = job.get("TaskGroups", [])
|
||||
for tg in task_groups:
|
||||
count += tg.get("Count", 1)
|
||||
|
||||
return JobStatus(
|
||||
id=job_id,
|
||||
name=job.get("Name", job_id),
|
||||
status=job.get("Status", "unknown"),
|
||||
allocations=allocs,
|
||||
count=count,
|
||||
)
|
||||
|
||||
# Allocation Operations
|
||||
|
||||
async def get_job_allocations(self, job_id: str) -> List[Allocation]:
|
||||
"""Get all allocations for a job."""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}/allocations")
|
||||
|
||||
if "error" in result:
|
||||
return []
|
||||
|
||||
allocs_data = result.get("data", result) if isinstance(result, dict) else result
|
||||
if not isinstance(allocs_data, list):
|
||||
return []
|
||||
|
||||
allocations = []
|
||||
for alloc_data in allocs_data:
|
||||
# Parse allocation info
|
||||
alloc_id = alloc_data.get("ID", "")
|
||||
status_str = alloc_data.get("ClientStatus", "unknown")
|
||||
|
||||
try:
|
||||
status = AllocationStatus(status_str)
|
||||
except ValueError:
|
||||
status = AllocationStatus.PENDING
|
||||
|
||||
# Get network info - need to fetch detailed allocation for this
|
||||
address = None
|
||||
port = None
|
||||
|
||||
# First try the summary data
|
||||
resources = alloc_data.get("AllocatedResources") or {}
|
||||
shared = resources.get("Shared") or {}
|
||||
networks = shared.get("Networks") or []
|
||||
|
||||
# If no networks in summary, fetch detailed allocation
|
||||
if not networks and alloc_id:
|
||||
detailed = await self.get_allocation(alloc_id)
|
||||
if detailed:
|
||||
resources = detailed.get("AllocatedResources") or {}
|
||||
shared = resources.get("Shared") or {}
|
||||
networks = shared.get("Networks") or []
|
||||
|
||||
if networks:
|
||||
network = networks[0]
|
||||
address = network.get("IP")
|
||||
# Look for dynamic ports OR reserved ports (Singularity/raw_exec uses reserved)
|
||||
dyn_ports = network.get("DynamicPorts") or []
|
||||
reserved_ports = network.get("ReservedPorts") or []
|
||||
for dp in dyn_ports + reserved_ports:
|
||||
if dp.get("Label") == "http":
|
||||
port = dp.get("Value")
|
||||
break
|
||||
|
||||
allocations.append(Allocation(
|
||||
id=alloc_id,
|
||||
job_id=job_id,
|
||||
task_group=alloc_data.get("TaskGroup", ""),
|
||||
node_id=alloc_data.get("NodeID", ""),
|
||||
status=status,
|
||||
address=address,
|
||||
port=port,
|
||||
))
|
||||
|
||||
return allocations
|
||||
|
||||
async def get_allocation(self, alloc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get detailed allocation info."""
|
||||
result = await self._request("GET", f"/v1/allocation/{alloc_id}")
|
||||
if "error" in result and result.get("status") == 404:
|
||||
return None
|
||||
return result
|
||||
|
||||
# Scaling Operations
|
||||
|
||||
async def scale_job(self, job_id: str, count: int, task_group: str = "sandbox") -> Dict[str, Any]:
|
||||
"""
|
||||
Scale a job's task group to specified count.
|
||||
|
||||
Args:
|
||||
job_id: Job identifier
|
||||
count: Desired number of allocations
|
||||
task_group: Name of task group to scale
|
||||
"""
|
||||
payload = {
|
||||
"Count": count,
|
||||
"Target": {
|
||||
"Group": task_group,
|
||||
},
|
||||
}
|
||||
return await self._request("POST", f"/v1/job/{job_id}/scale", payload)
|
||||
|
||||
async def get_job_scale_status(self, job_id: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get current scale status for a job.
|
||||
|
||||
Returns:
|
||||
Dict mapping task group name to count
|
||||
"""
|
||||
result = await self._request("GET", f"/v1/job/{job_id}/scale")
|
||||
|
||||
if "error" in result:
|
||||
return {}
|
||||
|
||||
task_groups = result.get("TaskGroups", {})
|
||||
return {
|
||||
name: info.get("Running", 0)
|
||||
for name, info in task_groups.items()
|
||||
}
|
||||
|
||||
# Health Check
|
||||
|
||||
async def is_healthy(self) -> bool:
|
||||
"""Check if Nomad is reachable and healthy."""
|
||||
try:
|
||||
result = await self._request("GET", "/v1/status/leader")
|
||||
return "error" not in result
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_leader(self) -> Optional[str]:
|
||||
"""Get current Nomad leader address."""
|
||||
result = await self._request("GET", "/v1/status/leader")
|
||||
if isinstance(result, dict) and "data" in result:
|
||||
return result["data"]
|
||||
return None
|
||||
|
||||
|
||||
def load_job_template(
|
||||
template_name: str = "sandbox",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load and configure a job template.
|
||||
|
||||
Args:
|
||||
template_name: Name of template (e.g., "sandbox")
|
||||
**kwargs: Template variables to substitute
|
||||
|
||||
Returns:
|
||||
Job specification dict ready for Nomad API
|
||||
"""
|
||||
# Default job template for sandbox container
|
||||
if template_name == "sandbox":
|
||||
return create_sandbox_job(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown template: {template_name}")
|
||||
|
||||
|
||||
def create_sandbox_job(
|
||||
job_id: str = "atropos-sandbox",
|
||||
image: str = "atropos-sandbox:local", # Use :local tag to avoid registry pull
|
||||
count: int = 1,
|
||||
slots_per_container: int = 10,
|
||||
privileged: bool = False,
|
||||
cpu: int = 500,
|
||||
memory: int = 512,
|
||||
port: int = 8080,
|
||||
datacenter: str = "dc1",
|
||||
driver: str = "docker", # "docker" or "singularity"
|
||||
singularity_image: str = None, # Path to .sif file for singularity driver
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a sandbox job specification.
|
||||
|
||||
This job runs the sandbox_server.py inside a container,
|
||||
with the specified number of slots for agent workspaces.
|
||||
|
||||
Args:
|
||||
job_id: Unique job identifier
|
||||
image: Docker image to use (for docker driver)
|
||||
count: Number of container instances
|
||||
slots_per_container: Number of slots per container
|
||||
privileged: Run container in privileged mode (recommended for bubblewrap)
|
||||
cpu: CPU allocation in MHz
|
||||
memory: Memory allocation in MB
|
||||
port: HTTP port for sandbox server
|
||||
datacenter: Nomad datacenter
|
||||
driver: Container driver - "docker" or "singularity"
|
||||
singularity_image: Path to .sif file (required if driver="singularity")
|
||||
|
||||
Returns:
|
||||
Job specification dict
|
||||
"""
|
||||
# Build task config based on driver
|
||||
if driver == "singularity":
|
||||
if not singularity_image:
|
||||
raise ValueError("singularity_image path required when driver='singularity'")
|
||||
|
||||
# Use raw_exec driver to run apptainer via shell for variable expansion
|
||||
# The container binds the allocation directory for workspace persistence
|
||||
# For raw_exec, we use static port since Nomad's dynamic port mapping doesn't
|
||||
# work the same as Docker - the process runs directly on the host.
|
||||
shell_cmd = (
|
||||
f'apptainer run '
|
||||
f'--bind "$NOMAD_ALLOC_DIR/data:/data" '
|
||||
f'--pwd /app '
|
||||
f'--env PYTHONUNBUFFERED=1 '
|
||||
f'{singularity_image} '
|
||||
f'python sandbox_server.py '
|
||||
f'--port {port} '
|
||||
f'--slots {slots_per_container} '
|
||||
f'--data-dir /data'
|
||||
)
|
||||
task_config = {
|
||||
"command": "/bin/sh",
|
||||
"args": ["-c", shell_cmd],
|
||||
}
|
||||
task_driver = "raw_exec"
|
||||
else:
|
||||
# Docker driver (default)
|
||||
task_config = {
|
||||
"image": image,
|
||||
"force_pull": False, # Use local image, don't try to pull
|
||||
"ports": ["http"],
|
||||
"privileged": privileged,
|
||||
"command": "python",
|
||||
"args": [
|
||||
"sandbox_server.py",
|
||||
"--port", str(port),
|
||||
"--slots", str(slots_per_container),
|
||||
"--data-dir", "/data",
|
||||
],
|
||||
# Note: On Linux, you can mount persistent storage:
|
||||
# "volumes": ["${NOMAD_ALLOC_DIR}/data:/data"],
|
||||
# On macOS/Docker Desktop, skip volumes for PoC
|
||||
# (container /data is ephemeral but works for testing)
|
||||
}
|
||||
task_driver = "docker"
|
||||
|
||||
# For Singularity/raw_exec, use static ports since the process runs directly on host.
|
||||
# For Docker, use dynamic ports with port mapping.
|
||||
if driver == "singularity":
|
||||
network_config = {
|
||||
"Mode": "host",
|
||||
"ReservedPorts": [
|
||||
{
|
||||
"Label": "http",
|
||||
"Value": port,
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
network_config = {
|
||||
"Mode": "host",
|
||||
"DynamicPorts": [
|
||||
{
|
||||
"Label": "http",
|
||||
"To": port,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
return {
|
||||
"ID": job_id,
|
||||
"Name": job_id,
|
||||
"Type": "service",
|
||||
"Datacenters": [datacenter],
|
||||
"TaskGroups": [
|
||||
{
|
||||
"Name": "sandbox",
|
||||
"Count": count,
|
||||
# Speed up deployments and avoid Consul checks. Without this, Nomad may
|
||||
# keep an "active deployment" around for the default MinHealthyTime,
|
||||
# which blocks immediate scaling under load.
|
||||
"Update": {
|
||||
"HealthCheck": "task_states",
|
||||
"MinHealthyTime": 0,
|
||||
},
|
||||
"Networks": [network_config],
|
||||
"Tasks": [
|
||||
{
|
||||
"Name": "sandbox-server",
|
||||
"Driver": task_driver,
|
||||
"Config": task_config,
|
||||
"Env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"NOMAD_ALLOC_DIR": "${NOMAD_ALLOC_DIR}",
|
||||
},
|
||||
"Resources": {
|
||||
"CPU": cpu,
|
||||
"MemoryMB": memory,
|
||||
},
|
||||
# Note: Services with Checks require Consul, which we skip for the PoC
|
||||
}
|
||||
],
|
||||
"RestartPolicy": {
|
||||
"Attempts": 3,
|
||||
"Interval": 300_000_000_000, # 5 minutes
|
||||
"Delay": 10_000_000_000, # 10 seconds
|
||||
"Mode": "delay",
|
||||
},
|
||||
"ReschedulePolicy": {
|
||||
"Attempts": 5,
|
||||
"Interval": 3600_000_000_000, # 1 hour
|
||||
"Delay": 30_000_000_000, # 30 seconds
|
||||
"DelayFunction": "exponential",
|
||||
"MaxDelay": 300_000_000_000, # 5 minutes
|
||||
"Unlimited": False,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Slot-based multiplexing for atropos-agent.
|
||||
|
||||
Provides:
|
||||
- Slot: Isolated workspace for a single trajectory
|
||||
- SlotPool: Manages slots across Nomad allocations
|
||||
- SandboxExecutor: Executes tools in sandbox containers
|
||||
"""
|
||||
|
||||
from .executor import SandboxExecutor
|
||||
from .pool import SlotPool, SlotPoolConfig
|
||||
from .slot import Slot, SlotState
|
||||
|
||||
__all__ = [
|
||||
"Slot",
|
||||
"SlotState",
|
||||
"SlotPool",
|
||||
"SlotPoolConfig",
|
||||
"SandboxExecutor",
|
||||
]
|
||||
@@ -0,0 +1,457 @@
|
||||
"""
|
||||
SandboxExecutor - HTTP client for sandbox container communication.
|
||||
|
||||
Sends tool execution requests to sandbox_server.py running inside Nomad containers.
|
||||
Supports single and batch execution for efficiency.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .slot import Slot, SlotState
|
||||
from ..tools.base import ToolCall, ToolResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionRequest:
|
||||
"""Request to execute a tool in a slot."""
|
||||
slot: Slot
|
||||
tool_name: str
|
||||
args: Dict[str, Any]
|
||||
execution_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
timeout: float = 30.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionResult:
|
||||
"""Result from sandbox execution."""
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
execution_id: str = ""
|
||||
slot_id: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_tool_result(self) -> ToolResult:
|
||||
"""Convert to ToolResult for agent consumption."""
|
||||
return ToolResult(
|
||||
success=self.success,
|
||||
output=self.output,
|
||||
error=self.error,
|
||||
metadata=self.metadata,
|
||||
uniq_id=self.execution_id,
|
||||
)
|
||||
|
||||
|
||||
class SandboxExecutor:
|
||||
"""
|
||||
HTTP client for executing tools in sandbox containers.
|
||||
|
||||
Communicates with sandbox_server.py running inside Nomad allocations.
|
||||
Supports both single execution and batched parallel execution.
|
||||
|
||||
Usage:
|
||||
executor = SandboxExecutor()
|
||||
|
||||
# Single execution
|
||||
result = await executor.execute(slot, "bash", {"command": "ls"})
|
||||
|
||||
# Batch execution
|
||||
results = await executor.execute_batch([
|
||||
(slot1, "bash", {"command": "ls"}),
|
||||
(slot2, "write_file", {"path": "test.txt", "content": "hello"}),
|
||||
])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: float = 30.0,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
):
|
||||
self.timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
async def _get_session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create HTTP session."""
|
||||
if self._session is None or self._session.closed:
|
||||
self._session = aiohttp.ClientSession(timeout=self.timeout)
|
||||
return self._session
|
||||
|
||||
async def close(self):
|
||||
"""Close HTTP session."""
|
||||
if self._session and not self._session.closed:
|
||||
await self._session.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.close()
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
slot: Slot,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute a tool in a slot's workspace.
|
||||
|
||||
Args:
|
||||
slot: Slot to execute in
|
||||
tool_name: Name of tool (bash, read_file, write_file)
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
ExecutionResult with output or error
|
||||
"""
|
||||
execution_id = str(uuid.uuid4())
|
||||
exec_timeout = timeout or self.timeout.total or 30.0
|
||||
|
||||
# Mark slot as executing
|
||||
original_state = slot.state
|
||||
try:
|
||||
if slot.state == SlotState.ACQUIRED:
|
||||
slot.start_execution(execution_id)
|
||||
|
||||
result = await self._send_execute_request(
|
||||
container_addr=slot.container_addr,
|
||||
slot_id=slot.slot_id,
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
execution_id=execution_id,
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
result.slot_id = slot.slot_id
|
||||
return result
|
||||
|
||||
finally:
|
||||
# Restore slot state
|
||||
if slot.state == SlotState.EXECUTING:
|
||||
slot.end_execution()
|
||||
|
||||
async def _send_execute_request(
|
||||
self,
|
||||
container_addr: str,
|
||||
slot_id: str,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
execution_id: str,
|
||||
timeout: float,
|
||||
) -> ExecutionResult:
|
||||
"""Send execution request to sandbox server with retry logic."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/execute"
|
||||
|
||||
payload = {
|
||||
"slot_id": slot_id,
|
||||
"tool": tool_name,
|
||||
"args": args,
|
||||
"execution_id": execution_id,
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
last_error = None
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
data = await response.json()
|
||||
|
||||
return ExecutionResult(
|
||||
success=data.get("success", False),
|
||||
output=data.get("output", ""),
|
||||
error=data.get("error", ""),
|
||||
execution_id=data.get("execution_id", execution_id),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
last_error = str(e)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1))
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
last_error = f"Request timed out after {timeout}s"
|
||||
break
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
break
|
||||
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=f"Failed after {self.max_retries} attempts: {last_error}",
|
||||
execution_id=execution_id,
|
||||
)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
timeout: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""
|
||||
Execute multiple tools in parallel across slots.
|
||||
|
||||
This is the key optimization - we batch tool calls to maximize
|
||||
container utilization while agents are waiting for LLM responses.
|
||||
|
||||
Args:
|
||||
requests: List of (slot, tool_name, args) tuples
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
List of ExecutionResults in same order as requests
|
||||
"""
|
||||
if not requests:
|
||||
return []
|
||||
|
||||
# Group requests by container address for batch API
|
||||
by_container: Dict[str, List[Tuple[int, Slot, str, Dict[str, Any], str]]] = {}
|
||||
|
||||
for idx, (slot, tool_name, args) in enumerate(requests):
|
||||
execution_id = str(uuid.uuid4())
|
||||
container = slot.container_addr
|
||||
|
||||
if container not in by_container:
|
||||
by_container[container] = []
|
||||
by_container[container].append((idx, slot, tool_name, args, execution_id))
|
||||
|
||||
# Mark slots as executing
|
||||
if slot.state == SlotState.ACQUIRED:
|
||||
slot.start_execution(execution_id)
|
||||
|
||||
# Execute batches in parallel
|
||||
exec_timeout = timeout or self.timeout.total or 30.0
|
||||
batch_tasks = []
|
||||
|
||||
for container_addr, batch_requests in by_container.items():
|
||||
task = self._send_batch_request(
|
||||
container_addr=container_addr,
|
||||
batch_requests=batch_requests,
|
||||
timeout=exec_timeout,
|
||||
)
|
||||
batch_tasks.append(task)
|
||||
|
||||
# Gather all batch results
|
||||
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
# Collect results in original order
|
||||
results: List[Optional[ExecutionResult]] = [None] * len(requests)
|
||||
|
||||
for batch_result in batch_results:
|
||||
if isinstance(batch_result, Exception):
|
||||
# Mark all in this batch as failed
|
||||
continue
|
||||
|
||||
for idx, result in batch_result:
|
||||
results[idx] = result
|
||||
|
||||
# Fill in any missing results
|
||||
for idx, result in enumerate(results):
|
||||
if result is None:
|
||||
slot, tool_name, args = requests[idx]
|
||||
results[idx] = ExecutionResult(
|
||||
success=False,
|
||||
error="Batch execution failed",
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
|
||||
# End execution on all slots
|
||||
for slot, _, _ in requests:
|
||||
if slot.state == SlotState.EXECUTING:
|
||||
slot.end_execution()
|
||||
|
||||
return results # type: ignore
|
||||
|
||||
async def _send_batch_request(
|
||||
self,
|
||||
container_addr: str,
|
||||
batch_requests: List[Tuple[int, Slot, str, Dict[str, Any], str]],
|
||||
timeout: float,
|
||||
) -> List[Tuple[int, ExecutionResult]]:
|
||||
"""Send batch execution request to a single container."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/batch"
|
||||
|
||||
# Build batch payload
|
||||
payload = [
|
||||
{
|
||||
"slot_id": slot.slot_id,
|
||||
"tool": tool_name,
|
||||
"args": args,
|
||||
"execution_id": execution_id,
|
||||
"timeout": timeout,
|
||||
}
|
||||
for _, slot, tool_name, args, execution_id in batch_requests
|
||||
]
|
||||
|
||||
try:
|
||||
async with session.post(url, json=payload) as response:
|
||||
data = await response.json()
|
||||
|
||||
if not isinstance(data, list):
|
||||
raise ValueError(f"Expected list response, got {type(data)}")
|
||||
|
||||
results = []
|
||||
for i, (idx, slot, _, _, execution_id) in enumerate(batch_requests):
|
||||
if i < len(data):
|
||||
item = data[i]
|
||||
result = ExecutionResult(
|
||||
success=item.get("success", False),
|
||||
output=item.get("output", ""),
|
||||
error=item.get("error", ""),
|
||||
execution_id=item.get("execution_id", execution_id),
|
||||
slot_id=slot.slot_id,
|
||||
metadata=item.get("metadata", {}),
|
||||
)
|
||||
else:
|
||||
result = ExecutionResult(
|
||||
success=False,
|
||||
error="Missing result in batch response",
|
||||
execution_id=execution_id,
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
results.append((idx, result))
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
# Return error for all requests in batch
|
||||
return [
|
||||
(idx, ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
execution_id=execution_id,
|
||||
slot_id=slot.slot_id,
|
||||
))
|
||||
for idx, slot, _, _, execution_id in batch_requests
|
||||
]
|
||||
|
||||
async def reset_slot(self, slot: Slot) -> ExecutionResult:
|
||||
"""
|
||||
Reset a slot's workspace (delete all files).
|
||||
|
||||
Useful when reusing a slot for a new trajectory.
|
||||
"""
|
||||
session = await self._get_session()
|
||||
url = f"{slot.container_addr}/reset"
|
||||
|
||||
try:
|
||||
async with session.post(url, json={"slot_id": slot.slot_id}) as response:
|
||||
data = await response.json()
|
||||
return ExecutionResult(
|
||||
success=data.get("success", False),
|
||||
output=data.get("output", ""),
|
||||
error=data.get("error", ""),
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
except Exception as e:
|
||||
return ExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
slot_id=slot.slot_id,
|
||||
)
|
||||
|
||||
async def health_check(self, container_addr: str) -> bool:
|
||||
"""Check if a sandbox container is healthy."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/health"
|
||||
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
data = await response.json()
|
||||
return data.get("status") == "ok"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_container_status(
|
||||
self,
|
||||
container_addr: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get status info from a sandbox container."""
|
||||
session = await self._get_session()
|
||||
url = f"{container_addr}/health"
|
||||
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
return await response.json()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Artifact helpers (optional)
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _post_json(
|
||||
self,
|
||||
url: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
session = await self._get_session()
|
||||
try:
|
||||
async with session.post(url, json=payload, timeout=timeout) as response:
|
||||
data = await response.json()
|
||||
if isinstance(data, dict):
|
||||
data.setdefault("http_status", response.status)
|
||||
return data
|
||||
return {"success": False, "error": f"Unexpected response type: {type(data)}", "http_status": response.status}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def read_artifact(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str,
|
||||
*,
|
||||
encoding: str = "text",
|
||||
max_bytes: Optional[int] = None,
|
||||
include_sha256: bool = False,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/read"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "encoding": encoding, "include_sha256": include_sha256}
|
||||
if max_bytes is not None:
|
||||
payload["max_bytes"] = max_bytes
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
|
||||
async def list_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
recursive: bool = False,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/list"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "recursive": recursive}
|
||||
if max_entries is not None:
|
||||
payload["max_entries"] = max_entries
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
|
||||
async def archive_artifacts(
|
||||
self,
|
||||
slot: Slot,
|
||||
path: str = ".",
|
||||
*,
|
||||
archive_format: str = "tar.gz",
|
||||
max_bytes: Optional[int] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
timeout: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{slot.container_addr}/artifacts/archive"
|
||||
payload: Dict[str, Any] = {"slot_id": slot.slot_id, "path": path, "format": archive_format}
|
||||
if max_bytes is not None:
|
||||
payload["max_bytes"] = max_bytes
|
||||
if max_entries is not None:
|
||||
payload["max_entries"] = max_entries
|
||||
return await self._post_json(url, payload, timeout=timeout)
|
||||
@@ -0,0 +1,659 @@
|
||||
"""
|
||||
SlotPool - Manages slots across Nomad allocations.
|
||||
|
||||
The SlotPool is the core abstraction for slot-based multiplexing:
|
||||
- Tracks available/acquired slots across containers
|
||||
- Handles slot acquisition and release
|
||||
- Auto-scales Nomad job count based on demand
|
||||
- Provides batched tool execution
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from ..nomad.client import (
|
||||
Allocation,
|
||||
AllocationStatus,
|
||||
NomadClient,
|
||||
create_sandbox_job,
|
||||
)
|
||||
from .executor import ExecutionResult, SandboxExecutor
|
||||
from .slot import Slot, SlotState, create_slots_for_allocation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlotPoolConfig:
|
||||
"""Configuration for SlotPool."""
|
||||
|
||||
# Nomad settings
|
||||
nomad_address: str = "http://localhost:4646"
|
||||
job_id: str = "atropos-sandbox"
|
||||
datacenter: str = "dc1"
|
||||
|
||||
# Container settings
|
||||
image: str = "atropos-sandbox:local" # Use :local tag to avoid registry pull
|
||||
slots_per_container: int = 10
|
||||
privileged: bool = False
|
||||
cpu: int = 500 # MHz
|
||||
memory: int = 512 # MB
|
||||
|
||||
# Driver selection: "docker" or "singularity"
|
||||
driver: str = "docker"
|
||||
# Path to .sif file for singularity driver (required if driver="singularity")
|
||||
singularity_image: Optional[str] = None
|
||||
|
||||
# Scaling settings
|
||||
min_containers: int = 1
|
||||
max_containers: int = 10
|
||||
|
||||
# Timeouts
|
||||
acquire_timeout: float = 30.0 # Seconds between acquire polls (also triggers scale-up attempts)
|
||||
health_check_interval: float = 30.0 # Seconds between health checks
|
||||
scale_cooldown: float = 60.0 # Seconds between scale operations
|
||||
|
||||
# Job lifecycle
|
||||
purge_job_on_start: bool = False # Purge any pre-existing job before starting (local dev/training friendly)
|
||||
|
||||
# Local Docker image convenience (macOS/Nomad dev mode)
|
||||
auto_build_local_image: bool = True # If image endswith :local and is missing, build it from the bundled Dockerfile.
|
||||
dockerfile_path: Optional[str] = None # Override Dockerfile path (default: Hermes-Agent/atropos/Dockerfile).
|
||||
docker_build_context: Optional[str] = None # Override build context (default: Hermes-Agent/atropos).
|
||||
|
||||
|
||||
class SlotPool:
|
||||
"""
|
||||
Manages a pool of slots across Nomad allocations.
|
||||
|
||||
The SlotPool:
|
||||
- Deploys sandbox containers to Nomad
|
||||
- Tracks slots across all running containers
|
||||
- Handles slot acquisition/release
|
||||
- Auto-scales based on demand
|
||||
- Provides batched execution via SandboxExecutor
|
||||
|
||||
Usage:
|
||||
config = SlotPoolConfig(
|
||||
nomad_address="http://localhost:4646",
|
||||
job_id="my-sandbox",
|
||||
slots_per_container=10,
|
||||
)
|
||||
|
||||
pool = SlotPool(config)
|
||||
await pool.start()
|
||||
|
||||
# Acquire a slot
|
||||
slot = await pool.acquire()
|
||||
|
||||
# Execute tool
|
||||
result = await pool.execute(slot, "bash", {"command": "ls"})
|
||||
|
||||
# Release slot
|
||||
await pool.release(slot)
|
||||
|
||||
# Shutdown
|
||||
await pool.stop()
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[SlotPoolConfig] = None):
|
||||
self.config = config or SlotPoolConfig()
|
||||
|
||||
# Nomad client
|
||||
self.nomad = NomadClient(address=self.config.nomad_address)
|
||||
|
||||
# Sandbox executor for tool execution
|
||||
self.executor = SandboxExecutor()
|
||||
|
||||
# Slot tracking
|
||||
self._slots: Dict[str, Slot] = {} # slot_key -> Slot
|
||||
self._available_queue: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._lock = asyncio.Lock()
|
||||
self._scale_lock = asyncio.Lock()
|
||||
|
||||
# State
|
||||
self._started = False
|
||||
self._health_task: Optional[asyncio.Task] = None
|
||||
self._scale_task: Optional[asyncio.Task] = None
|
||||
self._last_scale_time = 0.0
|
||||
|
||||
def _default_dockerfile_path(self) -> Path:
|
||||
# Hermes-Agent/atropos/Dockerfile lives next to this module in source checkouts.
|
||||
return Path(__file__).resolve().parents[1] / "Dockerfile"
|
||||
|
||||
def _default_build_context(self) -> Path:
|
||||
return Path(__file__).resolve().parents[1]
|
||||
|
||||
def _docker_image_exists(self, image: str) -> bool:
|
||||
try:
|
||||
proc = subprocess.run(
|
||||
["docker", "image", "inspect", image],
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
check=False,
|
||||
env={**os.environ, "DOCKER_CLI_HINTS": "false"},
|
||||
)
|
||||
return proc.returncode == 0
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
def _try_build_local_image(self, image: str) -> None:
|
||||
dockerfile = Path(self.config.dockerfile_path) if self.config.dockerfile_path else self._default_dockerfile_path()
|
||||
context = Path(self.config.docker_build_context) if self.config.docker_build_context else self._default_build_context()
|
||||
|
||||
if not dockerfile.exists():
|
||||
raise RuntimeError(
|
||||
f"Sandbox Dockerfile not found at {dockerfile}. "
|
||||
"Build the sandbox image manually or set --env.purge_job_on_start false and provide a non-local image."
|
||||
)
|
||||
if not context.exists():
|
||||
raise RuntimeError(f"Docker build context not found at {context}")
|
||||
|
||||
# Prefer buildx+--load to ensure the image ends up in the local daemon (required by Nomad's docker driver).
|
||||
buildx_cmd = [
|
||||
"docker",
|
||||
"buildx",
|
||||
"build",
|
||||
"--load",
|
||||
"-t",
|
||||
image,
|
||||
"-f",
|
||||
str(dockerfile),
|
||||
str(context),
|
||||
]
|
||||
proc = subprocess.run(buildx_cmd, check=False, env={**os.environ, "DOCKER_CLI_HINTS": "false"})
|
||||
if proc.returncode == 0:
|
||||
return
|
||||
|
||||
# Fallback to classic docker build if buildx isn't available.
|
||||
build_cmd = ["docker", "build", "-t", image, "-f", str(dockerfile), str(context)]
|
||||
proc2 = subprocess.run(build_cmd, check=False, env={**os.environ, "DOCKER_CLI_HINTS": "false"})
|
||||
if proc2.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"Failed to build local sandbox image {image}. "
|
||||
f"Tried: {' '.join(buildx_cmd)} and {' '.join(build_cmd)}"
|
||||
)
|
||||
|
||||
def _ensure_local_image(self) -> None:
|
||||
image = (self.config.image or "").strip()
|
||||
if not image.endswith(":local"):
|
||||
return
|
||||
if not self.config.auto_build_local_image:
|
||||
return
|
||||
|
||||
if self._docker_image_exists(image):
|
||||
return
|
||||
|
||||
logger.info(f"Local sandbox image {image} not found; building it now...")
|
||||
self._try_build_local_image(image)
|
||||
|
||||
def _slot_key(self, alloc_id: str, slot_id: str) -> str:
|
||||
"""Generate unique key for a slot."""
|
||||
return f"{alloc_id}:{slot_id}"
|
||||
|
||||
@property
|
||||
def total_slots(self) -> int:
|
||||
"""Total number of slots in pool."""
|
||||
return len(self._slots)
|
||||
|
||||
@property
|
||||
def available_slots(self) -> int:
|
||||
"""Number of available slots."""
|
||||
return sum(1 for s in self._slots.values() if s.is_available)
|
||||
|
||||
@property
|
||||
def acquired_slots(self) -> int:
|
||||
"""Number of acquired slots."""
|
||||
return sum(1 for s in self._slots.values() if s.is_acquired)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the slot pool.
|
||||
|
||||
- Checks if Nomad is healthy
|
||||
- Deploys sandbox job if not running
|
||||
- Discovers existing allocations
|
||||
- Starts health check background task
|
||||
"""
|
||||
if self._started:
|
||||
return
|
||||
|
||||
logger.info(f"Starting SlotPool (job_id={self.config.job_id})")
|
||||
|
||||
try:
|
||||
# Make sure local sandbox images exist before Nomad tries to pull them.
|
||||
# This is a common footgun in macOS dev mode with :local tags.
|
||||
self._ensure_local_image()
|
||||
|
||||
# Check Nomad health
|
||||
if not await self.nomad.is_healthy():
|
||||
raise RuntimeError(f"Nomad is not reachable at {self.config.nomad_address}")
|
||||
|
||||
if self.config.purge_job_on_start:
|
||||
logger.info(f"Purging any existing Nomad job: {self.config.job_id}")
|
||||
await self.nomad.stop_job(self.config.job_id, purge=True)
|
||||
|
||||
# Check if job exists (after optional purge)
|
||||
job = await self.nomad.get_job(self.config.job_id)
|
||||
|
||||
if job is None:
|
||||
# Deploy new job
|
||||
logger.info(f"Deploying sandbox job: {self.config.job_id} (driver={self.config.driver})")
|
||||
job_spec = create_sandbox_job(
|
||||
job_id=self.config.job_id,
|
||||
image=self.config.image,
|
||||
count=self.config.min_containers,
|
||||
slots_per_container=self.config.slots_per_container,
|
||||
privileged=self.config.privileged,
|
||||
cpu=self.config.cpu,
|
||||
memory=self.config.memory,
|
||||
datacenter=self.config.datacenter,
|
||||
driver=self.config.driver,
|
||||
singularity_image=self.config.singularity_image,
|
||||
)
|
||||
result = await self.nomad.submit_job(job_spec)
|
||||
if "error" in result:
|
||||
raise RuntimeError(f"Failed to submit job: {result}")
|
||||
|
||||
# Wait for allocations to be running (even if the job already existed).
|
||||
await self._wait_for_healthy_allocations(self.config.min_containers)
|
||||
|
||||
# Discover existing allocations and slots
|
||||
await self._refresh_slots()
|
||||
|
||||
# Start health check task
|
||||
self._health_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
self._started = True
|
||||
logger.info(f"SlotPool started: {self.total_slots} slots available")
|
||||
except Exception:
|
||||
# Ensure aiohttp sessions are not leaked if we fail to start.
|
||||
await self.stop(purge_job=False)
|
||||
raise
|
||||
|
||||
async def stop(self, purge_job: bool = False) -> None:
|
||||
"""
|
||||
Stop the slot pool.
|
||||
|
||||
Args:
|
||||
purge_job: If True, also stop the Nomad job
|
||||
"""
|
||||
logger.info("Stopping SlotPool")
|
||||
|
||||
# Cancel health check task
|
||||
if self._health_task:
|
||||
self._health_task.cancel()
|
||||
try:
|
||||
await self._health_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._health_task = None
|
||||
|
||||
if self._scale_task:
|
||||
self._scale_task.cancel()
|
||||
try:
|
||||
await self._scale_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._scale_task = None
|
||||
|
||||
# Optionally stop the job (do this even if start() never completed).
|
||||
if purge_job:
|
||||
logger.info(f"Stopping Nomad job: {self.config.job_id}")
|
||||
await self.nomad.stop_job(self.config.job_id, purge=True)
|
||||
|
||||
# Close connections
|
||||
await self.executor.close()
|
||||
await self.nomad.close()
|
||||
|
||||
self._started = False
|
||||
self._slots.clear()
|
||||
|
||||
# Clear the queue
|
||||
while not self._available_queue.empty():
|
||||
try:
|
||||
self._available_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
async def acquire(self, trajectory_id: Optional[str] = None) -> Slot:
|
||||
"""
|
||||
Acquire an available slot.
|
||||
|
||||
If no slots are available, waits up to acquire_timeout seconds.
|
||||
If still no slots, attempts to scale up.
|
||||
|
||||
Args:
|
||||
trajectory_id: Optional ID of trajectory acquiring the slot
|
||||
|
||||
Returns:
|
||||
Acquired Slot
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If no slot becomes available
|
||||
"""
|
||||
if not self._started:
|
||||
raise RuntimeError("SlotPool not started")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Try to get an available slot
|
||||
slot_key = await asyncio.wait_for(
|
||||
self._available_queue.get(),
|
||||
timeout=self.config.acquire_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Try to scale up, but keep waiting even if scaling isn't possible.
|
||||
# In practice, slots may become available shortly (e.g. contention),
|
||||
# and scaling may be temporarily blocked by Nomad deployments.
|
||||
await self._try_scale_up()
|
||||
continue
|
||||
|
||||
slot = self._slots.get(slot_key)
|
||||
if slot is None:
|
||||
# Slot was removed; discard stale queue entry and retry.
|
||||
continue
|
||||
|
||||
try:
|
||||
slot.acquire(trajectory_id)
|
||||
except RuntimeError:
|
||||
# Slot isn't actually available (e.g. duplicate queue entry); retry.
|
||||
continue
|
||||
|
||||
logger.debug(f"Acquired slot {slot.slot_id} (alloc={slot.alloc_id[:8]})")
|
||||
return slot
|
||||
|
||||
async def release(self, slot: Slot, reset_workspace: bool = False) -> None:
|
||||
"""
|
||||
Release a slot back to the pool.
|
||||
|
||||
Args:
|
||||
slot: Slot to release
|
||||
reset_workspace: If True, clear the workspace files
|
||||
"""
|
||||
slot_key = self._slot_key(slot.alloc_id, slot.slot_id)
|
||||
|
||||
if slot_key not in self._slots:
|
||||
logger.warning(f"Releasing unknown slot: {slot_key}")
|
||||
return
|
||||
|
||||
# Optionally reset workspace
|
||||
if reset_workspace:
|
||||
await self.executor.reset_slot(slot)
|
||||
|
||||
slot.release()
|
||||
await self._available_queue.put(slot_key)
|
||||
|
||||
logger.debug(f"Released slot {slot.slot_id}")
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
slot: Slot,
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
timeout: Optional[float] = None,
|
||||
) -> ExecutionResult:
|
||||
"""
|
||||
Execute a tool in a slot's workspace.
|
||||
|
||||
Args:
|
||||
slot: Slot to execute in
|
||||
tool_name: Name of tool (bash, read_file, write_file)
|
||||
args: Tool arguments
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
ExecutionResult
|
||||
"""
|
||||
return await self.executor.execute(slot, tool_name, args, timeout)
|
||||
|
||||
async def execute_batch(
|
||||
self,
|
||||
requests: List[Tuple[Slot, str, Dict[str, Any]]],
|
||||
timeout: Optional[float] = None,
|
||||
) -> List[ExecutionResult]:
|
||||
"""
|
||||
Execute multiple tools in parallel.
|
||||
|
||||
This is the key optimization - batch execution across multiple slots
|
||||
maximizes container utilization.
|
||||
|
||||
Args:
|
||||
requests: List of (slot, tool_name, args) tuples
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
List of ExecutionResults in same order
|
||||
"""
|
||||
return await self.executor.execute_batch(requests, timeout)
|
||||
|
||||
async def _refresh_slots(self) -> None:
|
||||
"""Refresh slot inventory from Nomad allocations."""
|
||||
async with self._lock:
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
|
||||
# Track which slots we've seen
|
||||
seen_keys = set()
|
||||
|
||||
for alloc in allocs:
|
||||
if alloc.status != AllocationStatus.RUNNING:
|
||||
continue
|
||||
|
||||
if not alloc.http_address:
|
||||
continue
|
||||
|
||||
# Check container health
|
||||
healthy = await self.executor.health_check(alloc.http_address)
|
||||
if not healthy:
|
||||
continue
|
||||
|
||||
# Create slots for this allocation
|
||||
for i in range(self.config.slots_per_container):
|
||||
slot_id = f"slot_{i}"
|
||||
slot_key = self._slot_key(alloc.id, slot_id)
|
||||
seen_keys.add(slot_key)
|
||||
|
||||
if slot_key not in self._slots:
|
||||
# New slot
|
||||
slot = Slot(
|
||||
slot_id=slot_id,
|
||||
alloc_id=alloc.id,
|
||||
container_addr=alloc.http_address,
|
||||
)
|
||||
self._slots[slot_key] = slot
|
||||
await self._available_queue.put(slot_key)
|
||||
logger.debug(f"Added slot: {slot_key}")
|
||||
|
||||
# Remove slots from dead allocations
|
||||
for slot_key in list(self._slots.keys()):
|
||||
if slot_key not in seen_keys:
|
||||
slot = self._slots.pop(slot_key)
|
||||
logger.debug(f"Removed slot: {slot_key}")
|
||||
|
||||
async def _wait_for_healthy_allocations(
|
||||
self,
|
||||
min_count: int,
|
||||
timeout: float = 120.0
|
||||
) -> None:
|
||||
"""Wait for allocations to become healthy."""
|
||||
import time
|
||||
start = time.time()
|
||||
|
||||
def _summarize_alloc_detail(detail: Dict[str, Any]) -> str:
|
||||
task_states = detail.get("TaskStates") or {}
|
||||
parts: List[str] = []
|
||||
if isinstance(task_states, dict):
|
||||
for task_name, st in task_states.items():
|
||||
events = (st or {}).get("Events") or []
|
||||
if isinstance(events, list) and events:
|
||||
# Include a few recent events; the latest can be a generic restart message
|
||||
# while the true root cause is slightly earlier (e.g. image pull failure).
|
||||
recent = events[-3:]
|
||||
msgs: List[str] = []
|
||||
for ev in recent:
|
||||
desc = ev.get("DisplayMessage") or ev.get("Message") or ev.get("Type") or ""
|
||||
if desc:
|
||||
msgs.append(desc)
|
||||
if msgs:
|
||||
parts.append(f"{task_name}: " + " | ".join(msgs))
|
||||
return "; ".join(parts)
|
||||
|
||||
def _alloc_events_lower(detail: Dict[str, Any]) -> str:
|
||||
task_states = detail.get("TaskStates") or {}
|
||||
texts: List[str] = []
|
||||
if isinstance(task_states, dict):
|
||||
for _task_name, st in task_states.items():
|
||||
events = (st or {}).get("Events") or []
|
||||
if isinstance(events, list):
|
||||
for ev in events[-10:]:
|
||||
desc = ev.get("DisplayMessage") or ev.get("Message") or ev.get("Type") or ""
|
||||
if desc:
|
||||
texts.append(desc)
|
||||
return " ".join(texts).lower()
|
||||
|
||||
while time.time() - start < timeout:
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
|
||||
healthy_count = 0
|
||||
for alloc in allocs:
|
||||
if alloc.status == AllocationStatus.RUNNING and alloc.http_address:
|
||||
if await self.executor.health_check(alloc.http_address):
|
||||
healthy_count += 1
|
||||
|
||||
# Fast-fail on obvious driver/image errors to avoid waiting out the full timeout.
|
||||
if alloc.id:
|
||||
detail = await self.nomad.get_allocation(alloc.id)
|
||||
if isinstance(detail, dict):
|
||||
summary = _summarize_alloc_detail(detail)
|
||||
lowered = _alloc_events_lower(detail) or summary.lower()
|
||||
if "failed to pull" in lowered or "pull access denied" in lowered:
|
||||
raise RuntimeError(
|
||||
"Nomad allocation failed to start due to a Docker image pull error. "
|
||||
f"Allocation {alloc.id[:8]}: {summary}\n"
|
||||
"If you're using a local image tag (e.g. `atropos-sandbox:local`) on macOS, "
|
||||
"make sure the image is loaded into Docker, e.g.:\n"
|
||||
" docker buildx build --load -t atropos-sandbox:local -f Hermes-Agent/atropos/Dockerfile Hermes-Agent/atropos"
|
||||
)
|
||||
if "exceeded allowed attempts" in lowered:
|
||||
raise RuntimeError(
|
||||
"Nomad allocation is crash-looping and has entered restart backoff. "
|
||||
f"Allocation {alloc.id[:8]}: {summary}\n"
|
||||
"Inspect logs with:\n"
|
||||
f" nomad alloc logs -stderr -task sandbox-server {alloc.id}\n"
|
||||
"Common causes include: missing local Docker image tag, container entrypoint error, "
|
||||
"or sandbox-server startup failure."
|
||||
)
|
||||
|
||||
if healthy_count >= min_count:
|
||||
return
|
||||
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# Timed out: include allocation status detail to help debugging.
|
||||
allocs = await self.nomad.get_job_allocations(self.config.job_id)
|
||||
alloc_lines: List[str] = []
|
||||
for alloc in allocs[:10]:
|
||||
addr = alloc.http_address or "-"
|
||||
line = f"{alloc.id[:8]} status={alloc.status.value} http={addr}"
|
||||
detail = await self.nomad.get_allocation(alloc.id)
|
||||
if isinstance(detail, dict):
|
||||
summary = _summarize_alloc_detail(detail)
|
||||
if summary:
|
||||
line += f" detail={summary}"
|
||||
alloc_lines.append(line)
|
||||
|
||||
hint = (
|
||||
"Timed out waiting for healthy sandbox allocations.\n"
|
||||
f"Job: {self.config.job_id}, desired_healthy: {min_count}\n"
|
||||
"Allocations:\n - " + "\n - ".join(alloc_lines)
|
||||
)
|
||||
raise RuntimeError(hint)
|
||||
|
||||
async def _try_scale_up(self) -> bool:
|
||||
"""Attempt to scale up the job."""
|
||||
import time
|
||||
|
||||
async with self._scale_lock:
|
||||
# Check cooldown
|
||||
if time.time() - self._last_scale_time < self.config.scale_cooldown:
|
||||
return False
|
||||
|
||||
# Check max containers
|
||||
status = await self.nomad.get_job_status(self.config.job_id)
|
||||
if status is None:
|
||||
return False
|
||||
|
||||
current_count = status.count
|
||||
if current_count >= self.config.max_containers:
|
||||
logger.warning(f"Cannot scale up: already at max ({self.config.max_containers})")
|
||||
return False
|
||||
|
||||
# Scale up
|
||||
new_count = min(current_count + 1, self.config.max_containers)
|
||||
logger.info(f"Scaling up from {current_count} to {new_count} containers")
|
||||
|
||||
scale_resp = await self.nomad.scale_job(
|
||||
self.config.job_id,
|
||||
count=new_count,
|
||||
task_group="sandbox",
|
||||
)
|
||||
|
||||
# Nomad may return non-JSON errors (e.g. plain text) with a status field.
|
||||
if isinstance(scale_resp, dict) and scale_resp.get("status", 200) >= 400:
|
||||
logger.warning(f"Scale request rejected: {scale_resp}")
|
||||
self._last_scale_time = time.time()
|
||||
return False
|
||||
|
||||
self._last_scale_time = time.time()
|
||||
|
||||
# Wait for new allocation in the background so contended acquires can still
|
||||
# make progress (e.g. by grabbing slots released by other trajectories).
|
||||
if self._scale_task is None or self._scale_task.done():
|
||||
self._scale_task = asyncio.create_task(self._wait_for_scale(new_count))
|
||||
|
||||
return True
|
||||
|
||||
async def _wait_for_scale(self, desired_count: int) -> None:
|
||||
try:
|
||||
await self._wait_for_healthy_allocations(desired_count, timeout=60.0)
|
||||
await self._refresh_slots()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to scale up: {e}")
|
||||
|
||||
async def _health_check_loop(self) -> None:
|
||||
"""Background task to monitor container health."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.config.health_check_interval)
|
||||
await self._refresh_slots()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get pool statistics."""
|
||||
slots_by_state = {}
|
||||
for slot in self._slots.values():
|
||||
state = slot.state.value
|
||||
slots_by_state[state] = slots_by_state.get(state, 0) + 1
|
||||
|
||||
container_count = len({s.alloc_id for s in self._slots.values()}) if self._slots else 0
|
||||
|
||||
return {
|
||||
"total_slots": self.total_slots,
|
||||
"available_slots": self.available_slots,
|
||||
"acquired_slots": self.acquired_slots,
|
||||
"containers": container_count,
|
||||
"slots_by_state": slots_by_state,
|
||||
"started": self._started,
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Slot abstraction for atropos-agent.
|
||||
|
||||
A Slot represents an isolated workspace for a single agent trajectory.
|
||||
Slots are hosted on Nomad allocations and provide workspace isolation
|
||||
via filesystem directories.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
class SlotState(Enum):
|
||||
"""State of a slot in the pool."""
|
||||
AVAILABLE = "available" # Ready to be acquired
|
||||
ACQUIRED = "acquired" # Assigned to a trajectory
|
||||
EXECUTING = "executing" # Currently executing a tool
|
||||
RELEASING = "releasing" # Being released back to pool
|
||||
ERROR = "error" # In error state
|
||||
|
||||
|
||||
@dataclass
|
||||
class Slot:
|
||||
"""
|
||||
An isolated workspace for a single agent trajectory.
|
||||
|
||||
Slots are the unit of scheduling - each trajectory runs in its own slot,
|
||||
with an isolated workspace directory. Multiple slots share a container.
|
||||
|
||||
Attributes:
|
||||
slot_id: Unique identifier for this slot (e.g., "slot_0")
|
||||
alloc_id: Nomad allocation ID hosting this slot
|
||||
container_addr: HTTP address of the sandbox server (e.g., "http://10.0.0.1:8080")
|
||||
workspace_dir: Path to workspace in container (e.g., "/data/slot_0")
|
||||
state: Current state of the slot
|
||||
trajectory_id: ID of trajectory currently using this slot (if acquired)
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
slot_id: str
|
||||
alloc_id: str
|
||||
container_addr: str
|
||||
workspace_dir: str = ""
|
||||
state: SlotState = SlotState.AVAILABLE
|
||||
trajectory_id: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set default workspace_dir if not provided."""
|
||||
if not self.workspace_dir:
|
||||
self.workspace_dir = f"/data/{self.slot_id}"
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
"""Check if slot is available for acquisition."""
|
||||
return self.state == SlotState.AVAILABLE
|
||||
|
||||
@property
|
||||
def is_acquired(self) -> bool:
|
||||
"""Check if slot is currently acquired."""
|
||||
return self.state in (SlotState.ACQUIRED, SlotState.EXECUTING)
|
||||
|
||||
def acquire(self, trajectory_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Mark slot as acquired by a trajectory.
|
||||
|
||||
Args:
|
||||
trajectory_id: Optional ID of acquiring trajectory
|
||||
"""
|
||||
if not self.is_available:
|
||||
raise RuntimeError(f"Cannot acquire slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.ACQUIRED
|
||||
self.trajectory_id = trajectory_id or str(uuid.uuid4())
|
||||
|
||||
def start_execution(self, execution_id: Optional[str] = None) -> None:
|
||||
"""Mark slot as executing."""
|
||||
if self.state != SlotState.ACQUIRED:
|
||||
raise RuntimeError(f"Cannot start execution on slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.EXECUTING
|
||||
if execution_id:
|
||||
self.metadata["current_execution_id"] = execution_id
|
||||
|
||||
def end_execution(self) -> None:
|
||||
"""Mark execution as complete, return to acquired state."""
|
||||
if self.state != SlotState.EXECUTING:
|
||||
raise RuntimeError(f"Cannot end execution on slot {self.slot_id}: state is {self.state}")
|
||||
|
||||
self.state = SlotState.ACQUIRED
|
||||
self.metadata.pop("current_execution_id", None)
|
||||
|
||||
def release(self) -> None:
|
||||
"""Release slot back to available state."""
|
||||
self.state = SlotState.AVAILABLE
|
||||
self.trajectory_id = None
|
||||
self.metadata.pop("current_execution_id", None)
|
||||
|
||||
def mark_error(self, error: str) -> None:
|
||||
"""Mark slot as in error state."""
|
||||
self.state = SlotState.ERROR
|
||||
self.metadata["error"] = error
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"slot_id": self.slot_id,
|
||||
"alloc_id": self.alloc_id,
|
||||
"container_addr": self.container_addr,
|
||||
"workspace_dir": self.workspace_dir,
|
||||
"state": self.state.value,
|
||||
"trajectory_id": self.trajectory_id,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Slot":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
slot_id=data["slot_id"],
|
||||
alloc_id=data["alloc_id"],
|
||||
container_addr=data["container_addr"],
|
||||
workspace_dir=data.get("workspace_dir", ""),
|
||||
state=SlotState(data.get("state", "available")),
|
||||
trajectory_id=data.get("trajectory_id"),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Slot({self.slot_id}, state={self.state.value}, alloc={self.alloc_id[:8]}...)"
|
||||
|
||||
|
||||
def create_slots_for_allocation(
|
||||
alloc_id: str,
|
||||
container_addr: str,
|
||||
num_slots: int = 10,
|
||||
) -> list["Slot"]:
|
||||
"""
|
||||
Create slots for a Nomad allocation.
|
||||
|
||||
Args:
|
||||
alloc_id: Nomad allocation ID
|
||||
container_addr: HTTP address of sandbox server
|
||||
num_slots: Number of slots to create
|
||||
|
||||
Returns:
|
||||
List of Slot objects
|
||||
"""
|
||||
slots = []
|
||||
for i in range(num_slots):
|
||||
slot_id = f"slot_{i}"
|
||||
slots.append(Slot(
|
||||
slot_id=slot_id,
|
||||
alloc_id=alloc_id,
|
||||
container_addr=container_addr,
|
||||
workspace_dir=f"/data/{slot_id}",
|
||||
))
|
||||
return slots
|
||||
@@ -0,0 +1,2 @@
|
||||
"""Terminal helpers for stateful sandbox interactions."""
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pyte
|
||||
|
||||
|
||||
class AsciinemaStreamDecoder:
|
||||
def __init__(self, *, default_width: int = 80, default_height: int = 24) -> None:
|
||||
self._default_width = max(1, int(default_width))
|
||||
self._default_height = max(1, int(default_height))
|
||||
self._buffer = ""
|
||||
self._has_header = False
|
||||
self.width = self._default_width
|
||||
self.height = self._default_height
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._buffer = ""
|
||||
self._has_header = False
|
||||
self.width = self._default_width
|
||||
self.height = self._default_height
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
|
||||
def feed(self, chunk: str | bytes) -> None:
|
||||
if not chunk:
|
||||
return
|
||||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode("utf-8", errors="replace")
|
||||
self._buffer += chunk
|
||||
while True:
|
||||
line, sep, rest = self._buffer.partition("\n")
|
||||
if not sep:
|
||||
break
|
||||
self._buffer = rest
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
parsed = self._parse_json_line(line)
|
||||
if parsed is None:
|
||||
continue
|
||||
if not self._has_header:
|
||||
if isinstance(parsed, dict):
|
||||
self._init_from_header(parsed)
|
||||
continue
|
||||
if isinstance(parsed, list):
|
||||
self._has_header = True
|
||||
self._apply_event(parsed)
|
||||
continue
|
||||
continue
|
||||
if isinstance(parsed, list):
|
||||
self._apply_event(parsed)
|
||||
|
||||
def render(self) -> str:
|
||||
return "\n".join(self._screen.display)
|
||||
|
||||
def _parse_json_line(self, line: str) -> Any | None:
|
||||
try:
|
||||
return json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _init_from_header(self, header: dict[str, Any]) -> None:
|
||||
width = _coerce_int(
|
||||
header.get("width") or header.get("columns") or header.get("cols"),
|
||||
self._default_width,
|
||||
)
|
||||
height = _coerce_int(
|
||||
header.get("height") or header.get("rows") or header.get("lines"),
|
||||
self._default_height,
|
||||
)
|
||||
self.width = max(1, width)
|
||||
self.height = max(1, height)
|
||||
self._screen = pyte.Screen(self.width, self.height)
|
||||
self._stream = pyte.Stream(self._screen)
|
||||
self._has_header = True
|
||||
|
||||
def _apply_event(self, event: list[Any]) -> None:
|
||||
if len(event) < 2:
|
||||
return
|
||||
event_type = event[1]
|
||||
payload = event[2] if len(event) > 2 else ""
|
||||
if event_type == "o":
|
||||
if isinstance(payload, str):
|
||||
self._stream.feed(payload)
|
||||
elif event_type == "r":
|
||||
width, height = _parse_resize(payload)
|
||||
if width and height:
|
||||
self.width = width
|
||||
self.height = height
|
||||
self._screen.resize(width, height)
|
||||
|
||||
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return int(default)
|
||||
|
||||
|
||||
def _parse_resize(payload: Any) -> tuple[int, int]:
|
||||
if isinstance(payload, str) and "x" in payload:
|
||||
left, right = payload.lower().split("x", 1)
|
||||
return _coerce_int(left, 0), _coerce_int(right, 0)
|
||||
if isinstance(payload, dict):
|
||||
width = _coerce_int(payload.get("width") or payload.get("columns") or payload.get("cols"), 0)
|
||||
height = _coerce_int(payload.get("height") or payload.get("rows") or payload.get("lines"), 0)
|
||||
return width, height
|
||||
if isinstance(payload, list) and len(payload) >= 2:
|
||||
return _coerce_int(payload[0], 0), _coerce_int(payload[1], 0)
|
||||
return 0, 0
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Tool abstractions for atropos-agent.
|
||||
|
||||
Provides base Tool class and common tool implementations.
|
||||
"""
|
||||
|
||||
from .base import Tool, ToolCall, ToolRegistry, ToolResult, ToolSchema
|
||||
from .build_registry import build_tool_registry
|
||||
from .sandbox_stubs import BashTool, ReadFileTool, TerminalTool, WriteFileTool
|
||||
from .terminal_stateful_tool import TerminalStatefulTool
|
||||
from .tmux_tool import TmuxTool
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"ToolCall",
|
||||
"ToolRegistry",
|
||||
"ToolResult",
|
||||
"ToolSchema",
|
||||
"BashTool",
|
||||
"ReadFileTool",
|
||||
"WriteFileTool",
|
||||
"TerminalTool",
|
||||
"TerminalStatefulTool",
|
||||
"TmuxTool",
|
||||
"build_tool_registry",
|
||||
]
|
||||
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Base Tool abstraction for atropos-agent.
|
||||
|
||||
Tools follow a simple pattern:
|
||||
1. Define schema (name, description, parameters)
|
||||
2. Implement execute() method
|
||||
3. Return ToolResult with output/error
|
||||
|
||||
Tool calls use Hermes-style XML tags:
|
||||
<tool_call>{"name": "bash", "arguments": {"command": "ls"}}</tool_call>
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolSchema:
|
||||
"""JSON Schema for a tool's parameters."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, Any] = field(default_factory=dict)
|
||||
required: List[str] = field(default_factory=list)
|
||||
external: bool = False # Whether the tool must be executed via an external ToolServer (secret proxy) and not inside the sandbox.
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to OpenAI-compatible function schema."""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": self.parameters,
|
||||
"required": self.required,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def to_prompt_description(self) -> str:
|
||||
"""Convert to human-readable description for system prompt."""
|
||||
params_desc = []
|
||||
for name, spec in self.parameters.items():
|
||||
req = "(required)" if name in self.required else "(optional)"
|
||||
desc = spec.get("description", "")
|
||||
param_type = spec.get("type", "string")
|
||||
params_desc.append(f" - {name} ({param_type}) {req}: {desc}")
|
||||
|
||||
params_str = "\n".join(params_desc) if params_desc else " (no parameters)"
|
||||
return f"**{self.name}**: {self.description}\nParameters:\n{params_str}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""A parsed tool call from model output."""
|
||||
|
||||
name: str
|
||||
arguments: Dict[str, Any]
|
||||
raw_text: str = "" # Original XML/JSON text
|
||||
uniq_id: str = field(default_factory=lambda: str(uuid.uuid4())) # Unique tool-call id for traceability/reconstruction.
|
||||
|
||||
@classmethod
|
||||
def parse_from_text(cls, text: str) -> List["ToolCall"]:
|
||||
"""
|
||||
Extract tool calls from text using Hermes-style XML tags.
|
||||
|
||||
Supported formats (STRICT: requires well-formed closing tags):
|
||||
- Hermes JSON wrapper:
|
||||
<tool_call>{"name": "...", "arguments": {...}}</tool_call>
|
||||
- GLM/llama.cpp style:
|
||||
<tool_call>terminal{"command":"ls -la"}</tool_call>
|
||||
"""
|
||||
calls: List["ToolCall"] = []
|
||||
|
||||
if not text:
|
||||
return calls
|
||||
|
||||
def _append_from_payload(*, name: str, arguments: Dict[str, Any], raw: str, uniq_id: Optional[str] = None) -> None:
|
||||
if not isinstance(name, str) or not name:
|
||||
return
|
||||
if not isinstance(arguments, dict):
|
||||
return
|
||||
calls.append(
|
||||
cls(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
raw_text=raw,
|
||||
uniq_id=uniq_id or str(uuid.uuid4()),
|
||||
)
|
||||
)
|
||||
|
||||
# STRICT parsing: only accept well-formed <tool_call>...</tool_call> blocks.
|
||||
pattern = r"<tool_call>\s*(.*?)\s*</tool_call>"
|
||||
for inner in re.findall(pattern, text, re.DOTALL):
|
||||
cleaned = (inner or "").strip()
|
||||
if not cleaned:
|
||||
continue
|
||||
|
||||
# Hermes JSON wrapper.
|
||||
if cleaned.startswith("{"):
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
uniq_id = data.get("uniq_id") or data.get("id") or None
|
||||
_append_from_payload(
|
||||
name=data.get("name", ""),
|
||||
arguments=data.get("arguments", {}),
|
||||
raw=inner,
|
||||
uniq_id=uniq_id,
|
||||
)
|
||||
continue
|
||||
|
||||
# GLM/llama.cpp style: terminal{...}
|
||||
m = re.match(r"^\s*([A-Za-z0-9_.:\\-]+)\s*(\{.*\})\s*$", cleaned, re.DOTALL)
|
||||
if not m:
|
||||
continue
|
||||
name = m.group(1)
|
||||
args_text = m.group(2)
|
||||
try:
|
||||
args = json.loads(args_text)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
_append_from_payload(name=name, arguments=args, raw=inner)
|
||||
|
||||
return calls
|
||||
|
||||
@classmethod
|
||||
def has_tool_call(cls, text: str) -> bool:
|
||||
"""Check if text contains any tool calls."""
|
||||
return bool(re.search(r"<tool_call>", text))
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Result from executing a tool."""
|
||||
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
uniq_id: Optional[str] = None # Should match ToolCall.uniq_id for async execution tracking.
|
||||
|
||||
def to_xml(self) -> str:
|
||||
"""Format as XML for including in conversation."""
|
||||
data = {
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
}
|
||||
if self.uniq_id:
|
||||
data["uniq_id"] = self.uniq_id
|
||||
if self.error:
|
||||
data["error"] = self.error
|
||||
if self.metadata:
|
||||
data["metadata"] = self.metadata
|
||||
return f"<tool_response>{json.dumps(data)}</tool_response>"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"success": self.success,
|
||||
"output": self.output,
|
||||
"error": self.error,
|
||||
"metadata": self.metadata,
|
||||
"uniq_id": self.uniq_id,
|
||||
}
|
||||
|
||||
|
||||
class Tool(ABC):
|
||||
"""
|
||||
Abstract base class for tools.
|
||||
|
||||
Subclasses must implement:
|
||||
- schema: ToolSchema describing the tool
|
||||
- execute(): async method that performs the tool action
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def schema(self) -> ToolSchema:
|
||||
"""Return the tool's schema."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Tool name (from schema)."""
|
||||
return self.schema.name
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""
|
||||
Execute the tool with given arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: Tool-specific arguments
|
||||
|
||||
Returns:
|
||||
ToolResult with success/failure and output
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Return whether this tool should be exposed/executable in the current process.
|
||||
|
||||
Tools that depend on optional binaries/services/env vars can override this
|
||||
to avoid advertising a tool that will fail at runtime.
|
||||
"""
|
||||
return True, None
|
||||
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Allow calling tool instance directly."""
|
||||
return await self.execute(**kwargs)
|
||||
|
||||
# Note: This is only wrapping declarations for the external ToolServer (for execution on external process tools), and tools preinstalled in envs
|
||||
class ToolRegistry:
|
||||
"""Registry of available tools."""
|
||||
|
||||
def __init__(self):
|
||||
self._tools: Dict[str, Tool] = {}
|
||||
|
||||
def register(self, tool: Tool) -> None:
|
||||
"""Register a tool."""
|
||||
self._tools[tool.name] = tool
|
||||
|
||||
def get(self, name: str) -> Optional[Tool]:
|
||||
"""Get a tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(self) -> List[Tool]:
|
||||
"""List all registered tools."""
|
||||
return list(self._tools.values())
|
||||
|
||||
def get_schemas(self) -> List[ToolSchema]:
|
||||
"""Get schemas for all registered tools."""
|
||||
return [tool.schema for tool in self._tools.values()]
|
||||
|
||||
def get_prompt_description(self) -> str:
|
||||
"""Generate tool descriptions for system prompt."""
|
||||
descriptions = [tool.schema.to_prompt_description() for tool in self._tools.values()]
|
||||
return "\n\n".join(descriptions)
|
||||
|
||||
def get_prompt_tool_definitions_json(self) -> str:
|
||||
"""
|
||||
Return a Hermes-style JSON list of tool definitions for use inside a `<tools>...</tools>` block.
|
||||
|
||||
Hermes trajectories historically use a simplified schema list:
|
||||
[{"name": ..., "description": ..., "parameters": {...}, "required": null}, ...]
|
||||
"""
|
||||
formatted: List[Dict[str, Any]] = []
|
||||
for tool in self._tools.values():
|
||||
fn = tool.schema.to_dict().get("function", {})
|
||||
formatted.append(
|
||||
{
|
||||
"name": fn.get("name", tool.name),
|
||||
"description": fn.get("description", ""),
|
||||
"parameters": fn.get("parameters", {}),
|
||||
# Keep parity with Hermes saved trajectories (required is typically null there).
|
||||
"required": None,
|
||||
}
|
||||
)
|
||||
return json.dumps(formatted, ensure_ascii=False)
|
||||
|
||||
async def execute(self, call: ToolCall) -> ToolResult:
|
||||
"""Execute a tool call."""
|
||||
tool = self.get(call.name)
|
||||
if tool is None:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Unknown tool: {call.name}",
|
||||
uniq_id=call.uniq_id,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await tool.execute(**call.arguments)
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = call.uniq_id
|
||||
return result
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"Tool execution error: {str(e)}",
|
||||
uniq_id=call.uniq_id,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# FastAPI / transport models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ToolCallPayload(BaseModel):
|
||||
name: str
|
||||
arguments: Dict[str, Any] = Field(default_factory=dict)
|
||||
uniq_id: str
|
||||
|
||||
@classmethod
|
||||
def from_tool_call(cls, call: ToolCall) -> "ToolCallPayload":
|
||||
return cls(name=call.name, arguments=call.arguments, uniq_id=call.uniq_id)
|
||||
|
||||
def to_tool_call(self) -> ToolCall:
|
||||
return ToolCall(name=self.name, arguments=self.arguments, uniq_id=self.uniq_id)
|
||||
|
||||
|
||||
class ToolResultPayload(BaseModel):
|
||||
success: bool
|
||||
output: str = ""
|
||||
error: str = ""
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
uniq_id: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_tool_result(cls, result: ToolResult) -> "ToolResultPayload":
|
||||
return cls(
|
||||
success=result.success,
|
||||
output=result.output,
|
||||
error=result.error,
|
||||
metadata=result.metadata,
|
||||
uniq_id=result.uniq_id,
|
||||
)
|
||||
|
||||
def to_tool_result(self) -> ToolResult:
|
||||
return ToolResult(
|
||||
success=self.success,
|
||||
output=self.output,
|
||||
error=self.error,
|
||||
metadata=self.metadata,
|
||||
uniq_id=self.uniq_id,
|
||||
)
|
||||
|
||||
|
||||
class ToolExecutorExecuteRequest(BaseModel):
|
||||
trajectory_id: str
|
||||
tool: ToolCallPayload
|
||||
timeout_s: Optional[float] = None
|
||||
|
||||
|
||||
class ToolExecutorReleaseRequest(BaseModel):
|
||||
trajectory_id: str
|
||||
reset_workspace: bool = False
|
||||
|
||||
|
||||
class ToolServerExecuteRequest(BaseModel):
|
||||
trajectory_id: Optional[str] = None
|
||||
tool: ToolCallPayload
|
||||
timeout_s: Optional[float] = None
|
||||
# Optional sandbox context for tools that need workspace artifacts.
|
||||
# This is set by ToolExecutor and is NOT model-controlled.
|
||||
slot_id: Optional[str] = None
|
||||
container_addr: Optional[str] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Artifact transport models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ArtifactReadRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str
|
||||
encoding: Literal["text", "base64"] = "text"
|
||||
max_bytes: Optional[int] = None
|
||||
include_sha256: bool = False
|
||||
|
||||
|
||||
class ArtifactReadResponsePayload(BaseModel):
|
||||
success: bool
|
||||
content: str = ""
|
||||
error: str = ""
|
||||
encoding: str = "text"
|
||||
truncated: bool = False
|
||||
bytes: int = 0
|
||||
file_size: Optional[int] = None
|
||||
path: str = ""
|
||||
mime: Optional[str] = None
|
||||
sha256: Optional[str] = None
|
||||
|
||||
|
||||
class ArtifactListRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str = "."
|
||||
recursive: bool = False
|
||||
max_entries: Optional[int] = None
|
||||
|
||||
|
||||
class ArtifactListEntryPayload(BaseModel):
|
||||
path: str
|
||||
is_dir: bool
|
||||
size: int
|
||||
mtime: float
|
||||
|
||||
|
||||
class ArtifactListResponsePayload(BaseModel):
|
||||
success: bool
|
||||
entries: List[ArtifactListEntryPayload] = Field(default_factory=list)
|
||||
truncated: bool = False
|
||||
error: str = ""
|
||||
|
||||
|
||||
class ArtifactArchiveRequestPayload(BaseModel):
|
||||
trajectory_id: str
|
||||
path: str = "."
|
||||
format: Literal["tar.gz", "tgz"] = "tar.gz"
|
||||
max_bytes: Optional[int] = None
|
||||
max_entries: Optional[int] = None
|
||||
|
||||
|
||||
class ArtifactArchiveResponsePayload(BaseModel):
|
||||
success: bool
|
||||
content: str = ""
|
||||
error: str = ""
|
||||
encoding: str = "base64"
|
||||
format: str = "tar.gz"
|
||||
bytes: int = 0
|
||||
entry_count: int = 0
|
||||
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Unified tool registry builder for Hermes-Agent Atropos integration.
|
||||
|
||||
This composes:
|
||||
- sandbox tool stubs (terminal/bash/read_file/write_file + stateful terminal/tmux)
|
||||
- Hermes external tools (web/vision/image/moa/skills/browser), executed via ToolServer
|
||||
|
||||
ToolExecutor only needs the schema + `external` routing bit; ToolServer executes
|
||||
the external tools via Hermes' existing implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .base import ToolRegistry
|
||||
from .hermes_external_tools import build_external_tools
|
||||
from .sandbox_stubs import BashTool, ReadFileTool, TerminalTool, WriteFileTool
|
||||
from .terminal_stateful_tool import TerminalStatefulTool
|
||||
from .tmux_tool import TmuxTool
|
||||
from .toolset_resolver import resolve_multiple_toolsets
|
||||
|
||||
|
||||
def build_tool_registry(
|
||||
*,
|
||||
enabled_toolsets: Optional[List[str]] = None,
|
||||
disabled_toolsets: Optional[List[str]] = None,
|
||||
tool_server_url: Optional[str] = None,
|
||||
) -> ToolRegistry:
|
||||
"""
|
||||
Build a ToolRegistry for AgentEnv / ToolExecutor / ToolServer.
|
||||
|
||||
If `tool_server_url` is not provided, external tools will be omitted so we do
|
||||
not advertise tools that cannot execute.
|
||||
"""
|
||||
enabled_toolsets = enabled_toolsets or ["default"]
|
||||
|
||||
# Resolve tool names using Hermes toolsets plus Atropos additions.
|
||||
selected = set(resolve_multiple_toolsets(enabled_toolsets))
|
||||
if disabled_toolsets:
|
||||
selected -= set(resolve_multiple_toolsets(disabled_toolsets))
|
||||
|
||||
reg = ToolRegistry()
|
||||
|
||||
# Always register sandbox tools if selected.
|
||||
sandbox_by_name = {
|
||||
"terminal": TerminalTool(),
|
||||
"bash": BashTool(),
|
||||
"read_file": ReadFileTool(),
|
||||
"write_file": WriteFileTool(),
|
||||
"terminal_stateful": TerminalStatefulTool(),
|
||||
"tmux": TmuxTool(),
|
||||
}
|
||||
for name, tool in sandbox_by_name.items():
|
||||
if name in selected:
|
||||
reg.register(tool)
|
||||
|
||||
# External tools: only include when ToolServer is configured.
|
||||
if tool_server_url:
|
||||
for tool in build_external_tools(selected_tool_names=selected):
|
||||
if tool.name in selected:
|
||||
reg.register(tool)
|
||||
|
||||
return reg
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Hermes external tool adapter for Atropos ToolServer.
|
||||
|
||||
These tools reuse Hermes-Agent's existing tool runner (`model_tools.handle_function_call`)
|
||||
so we don't duplicate external tool implementations.
|
||||
|
||||
Important:
|
||||
- These are marked `external=True` and should be executed ONLY by ToolServer.
|
||||
- We run `handle_function_call` in a worker thread because the Hermes implementation
|
||||
uses `asyncio.run()` internally for some async tools (web_extract, vision, MoA, etc).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import model_tools
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
def _schema_from_openai_tool_dict(tool: Dict[str, Any], *, external: bool) -> ToolSchema:
|
||||
fn = tool.get("function") or {}
|
||||
name = str(fn.get("name") or "")
|
||||
description = str(fn.get("description") or "")
|
||||
params = fn.get("parameters") or {}
|
||||
properties = params.get("properties") or {}
|
||||
required = params.get("required") or []
|
||||
if not isinstance(required, list):
|
||||
required = []
|
||||
return ToolSchema(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=dict(properties),
|
||||
required=[str(x) for x in required if isinstance(x, (str, int))],
|
||||
external=external,
|
||||
)
|
||||
|
||||
|
||||
class HermesExternalTool(Tool):
|
||||
def __init__(self, schema: ToolSchema):
|
||||
self._schema = schema
|
||||
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return self._schema
|
||||
|
||||
async def execute(self, task_id: Optional[str] = None, **kwargs: Any) -> ToolResult:
|
||||
# `model_tools.handle_function_call` returns a JSON string (success or error).
|
||||
# Run in a thread because some Hermes tool handlers call `asyncio.run()`.
|
||||
raw = await asyncio.to_thread(model_tools.handle_function_call, self.name, kwargs, task_id)
|
||||
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except Exception:
|
||||
# Keep as plain string.
|
||||
return ToolResult(success=True, output=str(raw))
|
||||
|
||||
if isinstance(parsed, dict) and parsed.get("error"):
|
||||
return ToolResult(success=False, error=str(parsed.get("error")), output="")
|
||||
|
||||
return ToolResult(success=True, output=json.dumps(parsed, ensure_ascii=False))
|
||||
|
||||
|
||||
def build_external_tools(
|
||||
*,
|
||||
selected_tool_names: Optional[set[str]] = None,
|
||||
) -> List[HermesExternalTool]:
|
||||
"""
|
||||
Build external tool wrappers from Hermes tool declarations.
|
||||
|
||||
Filters out sandbox-oriented tools (e.g. `terminal`) since those should run
|
||||
inside the sandbox via ToolExecutor.
|
||||
"""
|
||||
# IMPORTANT: Hermes' `model_tools.get_tool_definitions()` only understands Hermes toolsets.
|
||||
# Atropos envs add extra toolsets (filesystem/sandbox/stateful). To avoid noisy "Unknown toolset"
|
||||
# prints and accidental filtering, we fetch ALL Hermes tool definitions here and filter by name.
|
||||
tools = model_tools.get_tool_definitions(enabled_toolsets=None, disabled_toolsets=None, quiet_mode=True)
|
||||
|
||||
wrappers: List[HermesExternalTool] = []
|
||||
for t in tools:
|
||||
schema = _schema_from_openai_tool_dict(t, external=True)
|
||||
if schema.name in {"terminal"}:
|
||||
continue
|
||||
if selected_tool_names is not None and schema.name not in selected_tool_names:
|
||||
continue
|
||||
wrappers.append(HermesExternalTool(schema))
|
||||
return wrappers
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Sandbox tool stubs for Atropos ToolExecutor.
|
||||
|
||||
These tools are executed inside the sandbox containers via:
|
||||
ToolExecutor -> SlotPool -> sandbox_server.py
|
||||
|
||||
They intentionally do NOT execute anything on the host process. If they are
|
||||
called directly (outside ToolExecutor), they return a clear error.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
class TerminalTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="terminal",
|
||||
description=(
|
||||
"Execute a command inside the sandbox slot workspace and return stdout/stderr. "
|
||||
"Filesystem persists within a trajectory slot. Background processes are not supported "
|
||||
"in stateless mode. Commands run under POSIX /bin/sh and each tool call runs in a fresh "
|
||||
"shell (no persisted env vars). Avoid bash-only syntax like `source`; prefer `. .venv/bin/activate` "
|
||||
"or invoke `.venv/bin/python ...` directly."
|
||||
),
|
||||
parameters={
|
||||
"command": {"type": "string", "description": "The command to execute"},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (optional).",
|
||||
"minimum": 1,
|
||||
},
|
||||
"background": {
|
||||
"type": "boolean",
|
||||
"description": "Not supported in sandbox terminal (always false).",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
required=["command"],
|
||||
external=False,
|
||||
)
|
||||
|
||||
async def execute(self, **_kwargs) -> ToolResult:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="terminal must be executed via ToolExecutor inside the sandbox",
|
||||
)
|
||||
|
||||
|
||||
class BashTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="bash",
|
||||
description="Execute a bash command inside the sandbox slot workspace.",
|
||||
parameters={"command": {"type": "string", "description": "The bash command to execute"}},
|
||||
required=["command"],
|
||||
external=False,
|
||||
)
|
||||
|
||||
async def execute(self, **_kwargs) -> ToolResult:
|
||||
return ToolResult(success=False, error="bash must be executed via ToolExecutor inside the sandbox")
|
||||
|
||||
|
||||
class ReadFileTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="read_file",
|
||||
description="Read a file from the sandbox slot workspace.",
|
||||
parameters={"path": {"type": "string", "description": "Path to the file"}},
|
||||
required=["path"],
|
||||
external=False,
|
||||
)
|
||||
|
||||
async def execute(self, **_kwargs) -> ToolResult:
|
||||
return ToolResult(success=False, error="read_file must be executed via ToolExecutor inside the sandbox")
|
||||
|
||||
|
||||
class WriteFileTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="write_file",
|
||||
description="Write a file into the sandbox slot workspace.",
|
||||
parameters={
|
||||
"path": {"type": "string", "description": "Path to the file"},
|
||||
"content": {"type": "string", "description": "File content"},
|
||||
},
|
||||
required=["path", "content"],
|
||||
external=False,
|
||||
)
|
||||
|
||||
async def execute(self, **_kwargs) -> ToolResult:
|
||||
return ToolResult(success=False, error="write_file must be executed via ToolExecutor inside the sandbox")
|
||||
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Stateful terminal tool schema.
|
||||
|
||||
This is a sandbox tool that routes to the sandbox server as `bash_stateful`
|
||||
via ToolExecutor mapping. It exists to expose an explicit, opt-in terminal
|
||||
primitive suitable for stateful workflows (e.g. tmux sessions / TUIs).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
class TerminalStatefulTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="terminal_stateful",
|
||||
description=(
|
||||
"Execute a command in the sandbox, allowing stateful/background processes to persist "
|
||||
"across tool calls within the same trajectory slot (e.g. tmux sessions). "
|
||||
"Use sparingly; output is still non-interactive."
|
||||
),
|
||||
parameters={
|
||||
"command": {"type": "string", "description": "The command to execute"},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (optional).",
|
||||
"minimum": 1,
|
||||
},
|
||||
},
|
||||
required=["command"],
|
||||
)
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
async def execute(self, command: str, timeout: Optional[int] = None) -> ToolResult:
|
||||
_ = (command, timeout)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="terminal_stateful must be executed via ToolExecutor inside the sandbox",
|
||||
)
|
||||
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
tmux tool schema (sandbox).
|
||||
|
||||
This is a sandbox tool that provides basic tmux session control suitable for
|
||||
TUI-style terminal interactions:
|
||||
- send keys (arrow keys, enter, etc.)
|
||||
- capture the current screen buffer
|
||||
|
||||
Execution is routed by ToolExecutor to the sandbox server's `tmux` backend.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import Tool, ToolResult, ToolSchema
|
||||
|
||||
|
||||
class TmuxTool(Tool):
|
||||
@property
|
||||
def schema(self) -> ToolSchema:
|
||||
return ToolSchema(
|
||||
name="tmux",
|
||||
description=(
|
||||
"Control a per-trajectory tmux session inside the sandbox (stateful terminal). "
|
||||
"Use this for TUI-style interactions: send keys and capture the current screen."
|
||||
),
|
||||
parameters={
|
||||
"action": {
|
||||
"type": "string",
|
||||
"description": "Action to perform: start | send_keys | stream | stop.",
|
||||
"enum": ["start", "send_keys", "stream", "stop", "capture"],
|
||||
},
|
||||
"keys": {
|
||||
"description": "Keys to send (string or list of strings) when action=send_keys.",
|
||||
},
|
||||
"block": {
|
||||
"type": "boolean",
|
||||
"description": "If true, wait for shell command completion (only valid at a shell prompt).",
|
||||
"default": False,
|
||||
},
|
||||
"min_wait_s": {
|
||||
"type": "number",
|
||||
"description": "For non-blocking send_keys, sleep this long after sending keys (seconds).",
|
||||
"default": 0.0,
|
||||
},
|
||||
"max_wait_s": {
|
||||
"type": "number",
|
||||
"description": "For blocking send_keys, max time to wait for completion (seconds).",
|
||||
},
|
||||
"capture_entire": {
|
||||
"type": "boolean",
|
||||
"description": "Deprecated. Streaming is preferred.",
|
||||
"default": False,
|
||||
},
|
||||
"max_bytes": {
|
||||
"type": "integer",
|
||||
"description": "Max bytes to return per stream call.",
|
||||
},
|
||||
"reset": {
|
||||
"type": "boolean",
|
||||
"description": "If true, reset stream offset to the beginning of the asciinema recording.",
|
||||
"default": False,
|
||||
},
|
||||
"pane_width": {
|
||||
"type": "integer",
|
||||
"description": "Pane width for action=start (columns).",
|
||||
"minimum": 20,
|
||||
},
|
||||
"pane_height": {
|
||||
"type": "integer",
|
||||
"description": "Pane height for action=start (rows).",
|
||||
"minimum": 10,
|
||||
},
|
||||
},
|
||||
required=["action"],
|
||||
)
|
||||
|
||||
def is_available(self) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
async def execute(self, **kwargs: Dict[str, Any]) -> ToolResult:
|
||||
# This tool is intended to be executed via ToolExecutor -> sandbox server.
|
||||
# We keep a safe fallback for non-sandbox contexts.
|
||||
action = str(kwargs.get("action") or "").strip()
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"tmux tool must be executed in the sandbox (got action={action!r})",
|
||||
)
|
||||
@@ -0,0 +1,500 @@
|
||||
"""
|
||||
ToolExecutor - queued, batched tool dispatch for multiplexed agent trajectories.
|
||||
|
||||
This component is responsible for:
|
||||
- Maintaining trajectory -> Slot affinity (workspace continuity)
|
||||
- Batching sandbox tool calls across trajectories to maximize container utilization
|
||||
- Routing external tools (ToolSchema.external=True) to a ToolServer (Phase 4.5)
|
||||
|
||||
For now, only sandbox tools are executed:
|
||||
- bash
|
||||
- read_file
|
||||
- write_file
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .base import (
|
||||
ArtifactArchiveRequestPayload,
|
||||
ArtifactArchiveResponsePayload,
|
||||
ArtifactListRequestPayload,
|
||||
ArtifactListResponsePayload,
|
||||
ArtifactReadRequestPayload,
|
||||
ArtifactReadResponsePayload,
|
||||
ToolCall,
|
||||
ToolCallPayload,
|
||||
ToolRegistry,
|
||||
ToolResult,
|
||||
ToolResultPayload,
|
||||
ToolServerExecuteRequest,
|
||||
)
|
||||
from ..backends.base import ToolBackend
|
||||
from ..slots import Slot
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExecutorConfig:
|
||||
batch_window_ms: int = 20
|
||||
max_batch_size: int = 200
|
||||
allow_network: bool = True
|
||||
require_sandbox: bool = False
|
||||
require_stateful_sandbox: bool = False
|
||||
tool_server_url: Optional[str] = None
|
||||
tool_server_token: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _QueuedToolRequest:
|
||||
trajectory_id: str
|
||||
call: ToolCall
|
||||
timeout_s: Optional[float]
|
||||
future: asyncio.Future
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
backend: ToolBackend,
|
||||
tools: ToolRegistry,
|
||||
config: Optional[ToolExecutorConfig] = None,
|
||||
) -> None:
|
||||
self.backend = backend
|
||||
self.tools = tools
|
||||
self.config = config or ToolExecutorConfig()
|
||||
|
||||
self._queue: asyncio.Queue[Optional[_QueuedToolRequest]] = asyncio.Queue()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._stopping = asyncio.Event()
|
||||
|
||||
self._slots_lock = asyncio.Lock()
|
||||
self._slot_by_trajectory: Dict[str, Slot] = {}
|
||||
|
||||
self._tool_server_client: Optional[httpx.AsyncClient] = None
|
||||
self._tool_server_lock = asyncio.Lock()
|
||||
|
||||
# lightweight stats for status endpoints
|
||||
self.total_requests: int = 0
|
||||
self.total_errors: int = 0
|
||||
self.latencies_s: List[float] = []
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
|
||||
def queue_size(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def close(self) -> None:
|
||||
self._stopping.set()
|
||||
await self._queue.put(None)
|
||||
if self._task:
|
||||
await self._task
|
||||
self._task = None
|
||||
|
||||
client = self._tool_server_client
|
||||
self._tool_server_client = None
|
||||
if client is not None:
|
||||
await client.aclose()
|
||||
|
||||
# Best-effort release any remaining slots.
|
||||
async with self._slots_lock:
|
||||
slots = list(self._slot_by_trajectory.items())
|
||||
self._slot_by_trajectory.clear()
|
||||
|
||||
for _, slot in slots:
|
||||
try:
|
||||
await self.backend.release(slot, reset_workspace=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
trajectory_id: str,
|
||||
call: ToolCall,
|
||||
timeout_s: Optional[float] = None,
|
||||
) -> ToolResult:
|
||||
if self._task is None:
|
||||
raise RuntimeError("ToolExecutor not started (call start() first)")
|
||||
|
||||
# Allow tool args to suggest a timeout (Hermes-compatible terminal tool),
|
||||
# but never let the model choose "infinite" timeouts.
|
||||
if timeout_s is None:
|
||||
raw_timeout = call.arguments.get("timeout")
|
||||
if isinstance(raw_timeout, (int, float)):
|
||||
timeout_s = float(raw_timeout)
|
||||
if timeout_s is not None:
|
||||
timeout_s = max(1.0, min(float(timeout_s), 600.0))
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
fut: asyncio.Future = loop.create_future()
|
||||
started = time.perf_counter()
|
||||
await self._queue.put(_QueuedToolRequest(trajectory_id=trajectory_id, call=call, timeout_s=timeout_s, future=fut))
|
||||
try:
|
||||
result: ToolResult = await fut
|
||||
return result
|
||||
finally:
|
||||
self.latencies_s.append(time.perf_counter() - started)
|
||||
|
||||
async def release_trajectory(self, trajectory_id: str, reset_workspace: bool = False) -> None:
|
||||
async with self._slots_lock:
|
||||
slot = self._slot_by_trajectory.pop(trajectory_id, None)
|
||||
|
||||
if slot is not None:
|
||||
await self.backend.release(slot, reset_workspace=reset_workspace)
|
||||
|
||||
async def _get_slot_if_present(self, trajectory_id: str) -> Optional[Slot]:
|
||||
async with self._slots_lock:
|
||||
return self._slot_by_trajectory.get(trajectory_id)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Artifact helpers (optional)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
async def read_artifact(self, req: ArtifactReadRequestPayload) -> ArtifactReadResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactReadResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.read_artifact(
|
||||
slot,
|
||||
req.path,
|
||||
encoding=req.encoding,
|
||||
max_bytes=req.max_bytes,
|
||||
include_sha256=req.include_sha256,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactReadResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactReadResponsePayload(success=False, error=f"Invalid artifact read response: {e}")
|
||||
|
||||
async def list_artifacts(self, req: ArtifactListRequestPayload) -> ArtifactListResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactListResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.list_artifacts(
|
||||
slot,
|
||||
req.path,
|
||||
recursive=req.recursive,
|
||||
max_entries=req.max_entries,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactListResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactListResponsePayload(success=False, error=f"Invalid artifact list response: {e}")
|
||||
|
||||
async def archive_artifacts(self, req: ArtifactArchiveRequestPayload) -> ArtifactArchiveResponsePayload:
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is None:
|
||||
return ArtifactArchiveResponsePayload(success=False, error="No active slot for trajectory (run a sandbox tool first)")
|
||||
data = await self.backend.archive_artifacts(
|
||||
slot,
|
||||
req.path,
|
||||
archive_format=req.format,
|
||||
max_bytes=req.max_bytes,
|
||||
max_entries=req.max_entries,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
data = dict(data)
|
||||
data.pop("http_status", None)
|
||||
try:
|
||||
return ArtifactArchiveResponsePayload(**(data or {}))
|
||||
except Exception as e:
|
||||
return ArtifactArchiveResponsePayload(success=False, error=f"Invalid artifact archive response: {e}")
|
||||
|
||||
async def _get_or_acquire_slot(self, trajectory_id: str) -> Slot:
|
||||
async with self._slots_lock:
|
||||
existing = self._slot_by_trajectory.get(trajectory_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
slot = await self.backend.acquire(trajectory_id)
|
||||
|
||||
async with self._slots_lock:
|
||||
existing = self._slot_by_trajectory.get(trajectory_id)
|
||||
if existing is not None:
|
||||
# Another coroutine won the race; return its slot.
|
||||
await self.backend.release(slot, reset_workspace=False)
|
||||
return existing
|
||||
self._slot_by_trajectory[trajectory_id] = slot
|
||||
return slot
|
||||
|
||||
async def _run_loop(self) -> None:
|
||||
pending: List[_QueuedToolRequest] = []
|
||||
deadline: Optional[float] = None
|
||||
|
||||
batch_window_s = max(0.0, self.config.batch_window_ms / 1000.0)
|
||||
max_batch = max(1, self.config.max_batch_size)
|
||||
|
||||
while True:
|
||||
if self._stopping.is_set() and self._queue.empty() and not pending:
|
||||
break
|
||||
|
||||
timeout = None
|
||||
if pending and deadline is not None:
|
||||
timeout = max(0.0, deadline - time.perf_counter())
|
||||
|
||||
try:
|
||||
item = await asyncio.wait_for(self._queue.get(), timeout=timeout)
|
||||
if item is None:
|
||||
continue
|
||||
pending.append(item)
|
||||
if len(pending) == 1:
|
||||
deadline = time.perf_counter() + batch_window_s
|
||||
if len(pending) < max_batch:
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
# batch window elapsed
|
||||
pass
|
||||
|
||||
if not pending:
|
||||
deadline = None
|
||||
continue
|
||||
|
||||
batch = pending
|
||||
pending = []
|
||||
deadline = None
|
||||
|
||||
await self._execute_batch(batch)
|
||||
|
||||
async def _get_tool_server_client(self) -> httpx.AsyncClient:
|
||||
url = self.config.tool_server_url
|
||||
if not url:
|
||||
raise RuntimeError("ToolServer not configured")
|
||||
|
||||
if self._tool_server_client is not None:
|
||||
return self._tool_server_client
|
||||
|
||||
async with self._tool_server_lock:
|
||||
if self._tool_server_client is None:
|
||||
self._tool_server_client = httpx.AsyncClient(base_url=url.rstrip("/"))
|
||||
return self._tool_server_client
|
||||
|
||||
def _tool_server_headers(self) -> Dict[str, str]:
|
||||
token = self.config.tool_server_token
|
||||
if not token:
|
||||
return {}
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
async def _execute_external(self, req: _QueuedToolRequest) -> ToolResult:
|
||||
client = await self._get_tool_server_client()
|
||||
slot_id: Optional[str] = None
|
||||
container_addr: Optional[str] = None
|
||||
slot = await self._get_slot_if_present(req.trajectory_id)
|
||||
if slot is not None:
|
||||
slot_id = slot.slot_id
|
||||
container_addr = slot.container_addr
|
||||
|
||||
payload = ToolServerExecuteRequest(
|
||||
trajectory_id=req.trajectory_id,
|
||||
tool=ToolCallPayload.from_tool_call(req.call),
|
||||
timeout_s=req.timeout_s,
|
||||
slot_id=slot_id,
|
||||
container_addr=container_addr,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await client.post(
|
||||
"/execute",
|
||||
json=payload.model_dump(),
|
||||
headers=self._tool_server_headers(),
|
||||
timeout=req.timeout_s,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
parsed = ToolResultPayload(**data)
|
||||
result = parsed.to_tool_result()
|
||||
if result.uniq_id is None:
|
||||
result.uniq_id = req.call.uniq_id
|
||||
return result
|
||||
except Exception as e:
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=f"External tool failed: {e}",
|
||||
uniq_id=req.call.uniq_id,
|
||||
)
|
||||
|
||||
async def _execute_batch(self, batch: List[_QueuedToolRequest]) -> None:
|
||||
# Resolve tool schemas once per request and separate sandbox/external/unknown.
|
||||
sandbox_items: List[_QueuedToolRequest] = []
|
||||
external_items: List[_QueuedToolRequest] = []
|
||||
unknown_items: List[_QueuedToolRequest] = []
|
||||
|
||||
for it in batch:
|
||||
tool = self.tools.get(it.call.name)
|
||||
if tool is None:
|
||||
unknown_items.append(it)
|
||||
continue
|
||||
|
||||
schema = tool.schema
|
||||
if not schema.external:
|
||||
sandbox_items.append(it)
|
||||
else:
|
||||
external_items.append(it)
|
||||
|
||||
for it in unknown_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Unknown tool: {it.call.name}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
|
||||
if external_items:
|
||||
if not self.config.tool_server_url:
|
||||
for it in external_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"External tool not available (ToolServer not configured): {it.call.name}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results = await asyncio.gather(*[self._execute_external(it) for it in external_items])
|
||||
for it, res in zip(external_items, results):
|
||||
self.total_requests += 1
|
||||
if not getattr(res, "success", False):
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(res)
|
||||
|
||||
if not sandbox_items:
|
||||
return
|
||||
|
||||
# Acquire slots for the distinct trajectories in this batch.
|
||||
try:
|
||||
traj_ids = list({it.trajectory_id for it in sandbox_items})
|
||||
slots = await asyncio.gather(*[self._get_or_acquire_slot(tid) for tid in traj_ids])
|
||||
slot_by_traj = dict(zip(traj_ids, slots))
|
||||
except Exception as e:
|
||||
for it in sandbox_items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Failed to acquire slot: {e}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Group by timeout so we don't accidentally make short timeouts wait on long ones.
|
||||
by_timeout: Dict[float, List[_QueuedToolRequest]] = {}
|
||||
default_timeout = self.backend.default_timeout_s
|
||||
|
||||
for it in sandbox_items:
|
||||
t = it.timeout_s
|
||||
if t is None:
|
||||
t = default_timeout
|
||||
if t is None:
|
||||
t = 30.0
|
||||
by_timeout.setdefault(float(t), []).append(it)
|
||||
|
||||
for timeout_s, items in by_timeout.items():
|
||||
requests = []
|
||||
dispatched: List[_QueuedToolRequest] = []
|
||||
for it in items:
|
||||
slot = slot_by_traj[it.trajectory_id]
|
||||
tool_name = it.call.name
|
||||
args = dict(it.call.arguments)
|
||||
|
||||
# Hermes compatibility: treat `terminal` as an alias of sandbox `bash`.
|
||||
if tool_name == "terminal":
|
||||
if args.get("background"):
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error="terminal background execution is not supported in sandbox",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
tool_name = "bash"
|
||||
# `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args.
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "terminal_stateful":
|
||||
tool_name = "bash_stateful"
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "tmux":
|
||||
# `tmux` is a sandbox tool backed by the stateful session manager.
|
||||
# Network policy is env-controlled.
|
||||
args.pop("allow_network", None)
|
||||
|
||||
if tool_name == "bash":
|
||||
# Network policy is set by the environment/executor, not by the model.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_sandbox"] = bool(self.config.require_sandbox)
|
||||
# `timeout` is handled at the ToolExecutor level, not passed to the sandbox tool args.
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "bash_stateful":
|
||||
# Network policy is set by the environment/executor, not by the model.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args.pop("require_stateful_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox)
|
||||
args.pop("timeout", None)
|
||||
elif tool_name == "tmux":
|
||||
# Network policy applies to the underlying stateful session.
|
||||
args.pop("allow_network", None)
|
||||
args.pop("require_sandbox", None)
|
||||
args.pop("require_stateful_sandbox", None)
|
||||
args["allow_network"] = bool(self.config.allow_network)
|
||||
args["require_stateful_sandbox"] = bool(self.config.require_stateful_sandbox)
|
||||
|
||||
requests.append((slot, tool_name, args))
|
||||
dispatched.append(it)
|
||||
|
||||
results = None
|
||||
try:
|
||||
if not dispatched:
|
||||
continue
|
||||
results = await self.backend.execute_batch(requests, timeout_s=timeout_s)
|
||||
except Exception as e:
|
||||
for it in items:
|
||||
self.total_requests += 1
|
||||
self.total_errors += 1
|
||||
if not it.future.done():
|
||||
it.future.set_result(
|
||||
ToolResult(
|
||||
success=False,
|
||||
error=f"Batch execution failed: {e}",
|
||||
uniq_id=it.call.uniq_id,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
for it, res in zip(dispatched, results):
|
||||
self.total_requests += 1
|
||||
if not getattr(res, "success", False):
|
||||
self.total_errors += 1
|
||||
tool_result = res.to_tool_result()
|
||||
tool_result.uniq_id = it.call.uniq_id
|
||||
if not it.future.done():
|
||||
it.future.set_result(tool_result)
|
||||
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Toolset resolution for Hermes-Agent Atropos integration.
|
||||
|
||||
We primarily reuse Hermes-Agent toolsets (`toolsets.py`), but Atropos training/envs
|
||||
need a few extra sandbox-oriented toolsets that Hermes doesn't expose by default
|
||||
(e.g. filesystem + stateful terminal).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import toolsets as hermes_toolsets
|
||||
|
||||
|
||||
ATROPOS_TOOLSETS: Dict[str, Dict[str, Any]] = {
|
||||
"filesystem": {
|
||||
"description": "Read/write files in the sandbox workspace.",
|
||||
"tools": ["read_file", "write_file"],
|
||||
"includes": [],
|
||||
},
|
||||
"terminal_stateful": {
|
||||
"description": "Stateful terminal execution (tmux/TUI support) inside the sandbox.",
|
||||
"tools": ["terminal_stateful", "tmux"],
|
||||
"includes": [],
|
||||
},
|
||||
"sandbox": {
|
||||
"description": "Sandbox tools (terminal + filesystem).",
|
||||
"tools": [],
|
||||
"includes": ["terminal", "filesystem"],
|
||||
},
|
||||
"default": {
|
||||
"description": "Default toolset for Atropos AgentEnv tasks.",
|
||||
"tools": [],
|
||||
"includes": ["sandbox"],
|
||||
},
|
||||
"full": {
|
||||
"description": "All Hermes tools plus Atropos sandbox additions.",
|
||||
"tools": [],
|
||||
"includes": ["all", "filesystem", "sandbox", "terminal_stateful"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def validate_toolset(name: str) -> bool:
|
||||
if name in {"all", "*"}:
|
||||
return True
|
||||
return hermes_toolsets.validate_toolset(name) or name in ATROPOS_TOOLSETS
|
||||
|
||||
|
||||
def resolve_toolset(name: str, visited: Optional[Set[str]] = None) -> List[str]:
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if name in {"all", "*"}:
|
||||
# Union Hermes + Atropos toolsets.
|
||||
all_tools: Set[str] = set()
|
||||
for tname in hermes_toolsets.get_toolset_names():
|
||||
all_tools.update(resolve_toolset(tname, visited=set()))
|
||||
for tname, spec in ATROPOS_TOOLSETS.items():
|
||||
# Avoid recursion: some Atropos toolsets (e.g. "full") include "all".
|
||||
if tname == "full" or "all" in (spec.get("includes") or []):
|
||||
continue
|
||||
all_tools.update(resolve_toolset(tname, visited=set()))
|
||||
return sorted(all_tools)
|
||||
|
||||
if name in ATROPOS_TOOLSETS:
|
||||
if name in visited:
|
||||
return []
|
||||
visited.add(name)
|
||||
spec = ATROPOS_TOOLSETS[name]
|
||||
tools: Set[str] = set(spec.get("tools", []))
|
||||
for inc in spec.get("includes", []):
|
||||
tools.update(resolve_toolset(inc, visited=set(visited)))
|
||||
return sorted(tools)
|
||||
|
||||
# Fall back to Hermes toolsets.
|
||||
# IMPORTANT: do not pre-add `name` to `visited` here; Hermes' resolver uses
|
||||
# `visited` for its own cycle detection and will treat the presence of `name`
|
||||
# as a circular dependency.
|
||||
return sorted(hermes_toolsets.resolve_toolset(name, visited=set(visited)))
|
||||
|
||||
|
||||
def resolve_multiple_toolsets(names: List[str]) -> List[str]:
|
||||
tools: Set[str] = set()
|
||||
for name in names:
|
||||
tools.update(resolve_toolset(name, visited=set()))
|
||||
return sorted(tools)
|
||||
@@ -0,0 +1,415 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Atropos-compatible Hermes agent runner.
|
||||
|
||||
This is a minimal subclass of Hermes-Agent's `AIAgent` that swaps the OpenAI
|
||||
function-calling backend for Atroposlib's `ManagedServer`/`ServerManager` backend
|
||||
and uses Hermes-style XML tool tags:
|
||||
|
||||
- <tool_call>{"name": "...", "arguments": {...}}</tool_call>
|
||||
- <tool_response>{...}</tool_response>
|
||||
|
||||
Tool observations are appended as `role="user"` messages containing one or more
|
||||
`<tool_response>` blocks so they survive common chat templates during tokenization.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from model_tools import cleanup_vm, handle_function_call
|
||||
from run_agent import AIAgent
|
||||
|
||||
_TOOL_CALL_RE = re.compile(r"<tool_call>\\s*(.*?)\\s*</tool_call>", re.DOTALL)
|
||||
|
||||
|
||||
ATROPOS_TOOL_SYSTEM_PROMPT = """You are a helpful AI assistant with access to tools.
|
||||
|
||||
## Available Tools
|
||||
<tools>
|
||||
{tool_descriptions}
|
||||
</tools>
|
||||
|
||||
## How to Use Tools
|
||||
To call a tool, output:
|
||||
<tool_call>{{"name": "tool_name", "arguments": {{"arg1": "value1"}}}}</tool_call>
|
||||
|
||||
You may include optional reasoning in <think>...</think> before tool calls.
|
||||
|
||||
After each tool call, you will receive tool results as:
|
||||
<tool_response>{{...}}</tool_response>
|
||||
|
||||
Continue until finished, then provide a final response with no <tool_call> blocks.
|
||||
"""
|
||||
|
||||
|
||||
class AtroposAIAgent(AIAgent):
|
||||
"""
|
||||
Hermes `AIAgent` variant that uses Atroposlib ServerManager/ManagedServer.
|
||||
|
||||
Notes:
|
||||
- The default Hermes `AIAgent` remains unchanged; this class is opt-in.
|
||||
- The underlying server must expose `managed_server(tokenizer=...)` OR be a single
|
||||
APIServer-compatible object usable by Atroposlib's `ManagedServer`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server: Any,
|
||||
tokenizer: Any = None,
|
||||
model: str = "local",
|
||||
max_iterations: int = 10,
|
||||
tool_delay: float = 0.0,
|
||||
enabled_toolsets: Optional[List[str]] = None,
|
||||
disabled_toolsets: Optional[List[str]] = None,
|
||||
save_trajectories: bool = False,
|
||||
verbose_logging: bool = False,
|
||||
quiet_mode: bool = False,
|
||||
ephemeral_system_prompt: Optional[str] = None,
|
||||
log_prefix_chars: int = 100,
|
||||
log_prefix: str = "",
|
||||
session_id: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
):
|
||||
# Call parent init mainly to reuse tool selection + trajectory saving utilities.
|
||||
super().__init__(
|
||||
base_url="http://unused",
|
||||
api_key="dummy-key",
|
||||
model=model,
|
||||
max_iterations=max_iterations,
|
||||
tool_delay=tool_delay,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
disabled_toolsets=disabled_toolsets,
|
||||
save_trajectories=save_trajectories,
|
||||
verbose_logging=verbose_logging,
|
||||
quiet_mode=quiet_mode,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||
log_prefix_chars=log_prefix_chars,
|
||||
log_prefix=log_prefix,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@asynccontextmanager
|
||||
async def _managed(self) -> AsyncGenerator[Any, None]:
|
||||
if hasattr(self.server, "managed_server"):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"Using OpenAIServer with managed_server does not allow for state tracking",
|
||||
category=UserWarning,
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
yield managed
|
||||
return
|
||||
|
||||
# Fall back to directly wrapping a single server object.
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
managed = ManagedServer(server=self.server, tokenizer=self.tokenizer)
|
||||
try:
|
||||
yield managed
|
||||
finally:
|
||||
managed.reset()
|
||||
|
||||
def _tool_descriptions_text(self) -> str:
|
||||
if not self.tools:
|
||||
return "(no tools available)"
|
||||
|
||||
parts: List[str] = []
|
||||
for tool in self.tools:
|
||||
fn = (tool or {}).get("function", {})
|
||||
name = fn.get("name", "")
|
||||
desc = (fn.get("description") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
if desc:
|
||||
parts.append(f"- {name}: {desc}")
|
||||
else:
|
||||
parts.append(f"- {name}")
|
||||
return "\n".join(parts) if parts else "(no tools available)"
|
||||
|
||||
def _build_system_prompt(self, system_message: Optional[str]) -> Optional[str]:
|
||||
tool_prompt = ATROPOS_TOOL_SYSTEM_PROMPT.format(
|
||||
tool_descriptions=self._tool_descriptions_text()
|
||||
)
|
||||
|
||||
parts: List[str] = []
|
||||
if system_message:
|
||||
parts.append(system_message)
|
||||
if self.ephemeral_system_prompt:
|
||||
parts.append(self.ephemeral_system_prompt)
|
||||
parts.append(tool_prompt)
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def _parse_tool_calls(self, content: str) -> Tuple[List[Tuple[str, Dict[str, Any]]], List[str]]:
|
||||
"""
|
||||
Returns:
|
||||
(calls, errors)
|
||||
"""
|
||||
calls: List[Tuple[str, Dict[str, Any]]] = []
|
||||
errors: List[str] = []
|
||||
|
||||
for raw in _TOOL_CALL_RE.findall(content or ""):
|
||||
try:
|
||||
payload = json.loads(raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
errors.append(f"Invalid JSON inside <tool_call>: {exc}")
|
||||
continue
|
||||
|
||||
name = payload.get("name")
|
||||
args = payload.get("arguments", {})
|
||||
if not isinstance(name, str) or not name:
|
||||
errors.append("Tool call missing 'name' string")
|
||||
continue
|
||||
if not isinstance(args, dict):
|
||||
errors.append("Tool call 'arguments' must be an object")
|
||||
continue
|
||||
|
||||
calls.append((name, args))
|
||||
|
||||
return calls, errors
|
||||
|
||||
async def run_conversation_async(
|
||||
self,
|
||||
user_message: str,
|
||||
system_message: Optional[str] = None,
|
||||
conversation_history: Optional[List[Dict[str, Any]]] = None,
|
||||
task_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
import uuid
|
||||
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
messages: List[Dict[str, Any]] = conversation_history.copy() if conversation_history else []
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
active_system_prompt = self._build_system_prompt(system_message)
|
||||
|
||||
api_call_count = 0
|
||||
final_response: Optional[str] = None
|
||||
managed_state: Optional[Dict[str, Any]] = None
|
||||
completed = False
|
||||
|
||||
try:
|
||||
async with self._managed() as managed:
|
||||
while api_call_count < self.max_iterations:
|
||||
api_call_count += 1
|
||||
|
||||
api_messages = messages.copy()
|
||||
if active_system_prompt:
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
chat_kwargs: Dict[str, Any] = {"messages": api_messages, "n": 1}
|
||||
if self.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.max_tokens
|
||||
if self.temperature is not None:
|
||||
chat_kwargs["temperature"] = self.temperature
|
||||
|
||||
# Prefer OpenAI tool calling when supported by the backend:
|
||||
# - Many providers normalize Hermes-style <tool_call> tags into tool_calls when `tools` is provided.
|
||||
# - ManagedServer (atroposlib) does prompt->completion conversion and does not support `tools`.
|
||||
# Only pass `tools` when we're calling an OpenAI-compatible chat endpoint directly.
|
||||
tool_schemas = self.tools if self.tools else None
|
||||
managed_cls = type(managed).__name__
|
||||
if tool_schemas and managed_cls != "ManagedServer":
|
||||
chat_kwargs["tools"] = tool_schemas
|
||||
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_REQUEST") == "1":
|
||||
meta = {
|
||||
"managed_type": managed_cls,
|
||||
"model": getattr(getattr(managed, "config", None), "model_name", self.model),
|
||||
"base_url": getattr(getattr(managed, "config", None), "base_url", None),
|
||||
"kwargs": chat_kwargs,
|
||||
}
|
||||
# Avoid dumping megabytes of data accidentally.
|
||||
# (Messages can be large; this is still "full" but bounded.)
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_REQUEST ===", flush=True)
|
||||
print(json.dumps(meta, ensure_ascii=False, indent=2)[:200_000], flush=True)
|
||||
|
||||
response = await managed.chat_completion(**chat_kwargs)
|
||||
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_RESPONSE") == "1":
|
||||
try:
|
||||
dumped = response.model_dump() # openai pydantic model
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_RESPONSE: ChatCompletion (raw) ===", flush=True)
|
||||
print(json.dumps(dumped, ensure_ascii=False, indent=2), flush=True)
|
||||
|
||||
if hasattr(managed, "get_state"):
|
||||
managed_state = managed.get_state()
|
||||
|
||||
msg = response.choices[0].message
|
||||
assistant_content = (msg.content or "")
|
||||
msg_reasoning = getattr(msg, "reasoning", None)
|
||||
|
||||
# Use tool_calls if the backend provides them (preferred).
|
||||
structured_tool_calls = getattr(msg, "tool_calls", None)
|
||||
|
||||
# If the backend emits content="" but includes useful text in reasoning,
|
||||
# use it for parsing *only if needed* (e.g. tool tags).
|
||||
if assistant_content == "" and isinstance(msg_reasoning, str) and msg_reasoning:
|
||||
if os.getenv("HERMES_DEBUG_ATROPOS_RESPONSE") == "1":
|
||||
print("\n=== HERMES_DEBUG_ATROPOS_RESPONSE: message.reasoning present (content empty) ===", flush=True)
|
||||
print(msg_reasoning, flush=True)
|
||||
|
||||
assistant_msg: Dict[str, Any] = {"role": "assistant", "content": assistant_content}
|
||||
if structured_tool_calls:
|
||||
# Preserve tool_calls so the next request is consistent with OpenAI protocol.
|
||||
try:
|
||||
assistant_msg["tool_calls"] = [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
|
||||
}
|
||||
for tc in structured_tool_calls
|
||||
]
|
||||
except Exception:
|
||||
# Best-effort; keep conversation moving.
|
||||
pass
|
||||
messages.append(assistant_msg)
|
||||
|
||||
# Mode A: OpenAI tool calling (preferred when supported)
|
||||
if structured_tool_calls:
|
||||
for tc in structured_tool_calls:
|
||||
tool_start = time.time()
|
||||
try:
|
||||
tool_args = json.loads(tc.function.arguments or "{}")
|
||||
except Exception:
|
||||
tool_args = {}
|
||||
tool_result = handle_function_call(tc.function.name, tool_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start
|
||||
|
||||
# Keep the raw tool result as tool content (OpenAI protocol expects role=tool).
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tc.id,
|
||||
"content": tool_result,
|
||||
}
|
||||
)
|
||||
|
||||
if self.tool_delay and self.tool_delay > 0:
|
||||
await asyncio.sleep(self.tool_delay)
|
||||
|
||||
# Continue loop after tool execution.
|
||||
continue
|
||||
|
||||
# Mode B: Hermes XML tool tags in assistant text (fallback).
|
||||
parse_source = assistant_content or (msg_reasoning or "")
|
||||
tool_calls, parse_errors = self._parse_tool_calls(parse_source)
|
||||
|
||||
if parse_errors and not tool_calls:
|
||||
# Ask the model to retry with valid tool JSON.
|
||||
err_text = "; ".join(parse_errors[:3])
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"<tool_response>{json.dumps({'error': err_text}, ensure_ascii=False)}</tool_response>\n"
|
||||
"The previous <tool_call> blocks were invalid. Please output valid JSON inside <tool_call>."
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if not tool_calls:
|
||||
# No tool calls: treat as final answer.
|
||||
final_response = (assistant_content or "").strip()
|
||||
completed = True
|
||||
break
|
||||
|
||||
tool_responses: List[str] = []
|
||||
for tool_name, tool_args in tool_calls:
|
||||
tool_start = time.time()
|
||||
tool_result = handle_function_call(tool_name, tool_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start
|
||||
|
||||
try:
|
||||
parsed = json.loads(tool_result)
|
||||
payload: Any = parsed
|
||||
except Exception:
|
||||
payload = tool_result
|
||||
|
||||
tool_payload = {
|
||||
"name": tool_name,
|
||||
"duration_s": round(tool_duration, 3),
|
||||
"result": payload,
|
||||
}
|
||||
tool_responses.append(
|
||||
f"<tool_response>{json.dumps(tool_payload, ensure_ascii=False)}</tool_response>"
|
||||
)
|
||||
|
||||
if self.tool_delay and self.tool_delay > 0:
|
||||
await asyncio.sleep(self.tool_delay)
|
||||
|
||||
messages.append({"role": "user", "content": "\n".join(tool_responses)})
|
||||
|
||||
if final_response is None:
|
||||
final_response = "I've reached the maximum number of iterations."
|
||||
|
||||
finally:
|
||||
try:
|
||||
cleanup_vm(effective_task_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Save trajectory using Hermes formatting (optional).
|
||||
self._save_trajectory(messages, user_message, completed=completed)
|
||||
|
||||
return {
|
||||
"final_response": final_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": completed,
|
||||
"managed_state": managed_state,
|
||||
"system_prompt": active_system_prompt,
|
||||
"task_id": effective_task_id,
|
||||
}
|
||||
|
||||
def run_conversation(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Sync wrapper for convenience.
|
||||
|
||||
If called from within a running event loop (e.g. prompt_toolkit), this
|
||||
runs the async conversation in a dedicated thread to avoid nested loops.
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(self.run_conversation_async(*args, **kwargs))
|
||||
|
||||
import queue
|
||||
import threading
|
||||
|
||||
out: "queue.Queue[object]" = queue.Queue(maxsize=1)
|
||||
|
||||
def runner() -> None:
|
||||
try:
|
||||
out.put(asyncio.run(self.run_conversation_async(*args, **kwargs)))
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
out.put(exc)
|
||||
|
||||
thread = threading.Thread(target=runner, daemon=True)
|
||||
thread.start()
|
||||
|
||||
result = out.get()
|
||||
if isinstance(result, BaseException):
|
||||
raise result
|
||||
return result # type: ignore[return-value]
|
||||
@@ -28,13 +28,18 @@ os.environ["HERMES_QUIET"] = "1" # Our own modules
|
||||
import yaml
|
||||
|
||||
# prompt_toolkit for fixed input area TUI
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.history import FileHistory
|
||||
from prompt_toolkit.styles import Style as PTStyle
|
||||
from prompt_toolkit.formatted_text import HTML
|
||||
from prompt_toolkit.patch_stdout import patch_stdout
|
||||
from prompt_toolkit.application import Application
|
||||
from prompt_toolkit.application import Application, get_app
|
||||
from prompt_toolkit.buffer import Buffer
|
||||
from prompt_toolkit.layout import Layout, HSplit, Window, FormattedTextControl
|
||||
from prompt_toolkit.layout.processors import BeforeInput
|
||||
from prompt_toolkit.widgets import TextArea
|
||||
from prompt_toolkit.key_binding import KeyBindings
|
||||
import asyncio
|
||||
import threading
|
||||
import queue
|
||||
|
||||
@@ -493,8 +498,6 @@ COMMANDS = {
|
||||
"/clear": "Clear screen and reset conversation (fresh start)",
|
||||
"/history": "Show conversation history",
|
||||
"/reset": "Reset conversation only (keep screen)",
|
||||
"/retry": "Retry the last message (resend to agent)",
|
||||
"/undo": "Remove the last user/assistant exchange",
|
||||
"/save": "Save the current conversation",
|
||||
"/config": "Show current configuration",
|
||||
"/cron": "Manage scheduled tasks (list, add, remove)",
|
||||
@@ -505,11 +508,7 @@ COMMANDS = {
|
||||
|
||||
def save_config_value(key_path: str, value: any) -> bool:
|
||||
"""
|
||||
Save a value to the active config file at the specified key path.
|
||||
|
||||
Respects the same lookup order as load_cli_config():
|
||||
1. ~/.hermes/config.yaml (user config - preferred, used if it exists)
|
||||
2. ./cli-config.yaml (project config - fallback)
|
||||
Save a value to cli-config.yaml at the specified key path.
|
||||
|
||||
Args:
|
||||
key_path: Dot-separated path like "agent.system_prompt"
|
||||
@@ -518,15 +517,9 @@ def save_config_value(key_path: str, value: any) -> bool:
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
# Use the same precedence as load_cli_config: user config first, then project config
|
||||
user_config_path = Path.home() / '.hermes' / 'config.yaml'
|
||||
project_config_path = Path(__file__).parent / 'cli-config.yaml'
|
||||
config_path = user_config_path if user_config_path.exists() else project_config_path
|
||||
config_path = Path(__file__).parent / 'cli-config.yaml'
|
||||
|
||||
try:
|
||||
# Ensure parent directory exists (for ~/.hermes/config.yaml on first use)
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load existing config
|
||||
if config_path.exists():
|
||||
with open(config_path, 'r') as f:
|
||||
@@ -638,8 +631,26 @@ class HermesCLI:
|
||||
short_uuid = uuid.uuid4().hex[:6]
|
||||
self.session_id = f"{timestamp_str}_{short_uuid}"
|
||||
|
||||
# History file for persistent input recall across sessions
|
||||
self._history_file = Path.home() / ".hermes_history"
|
||||
# Setup prompt_toolkit session with history
|
||||
self._setup_prompt_session()
|
||||
|
||||
def _setup_prompt_session(self):
|
||||
"""Setup prompt_toolkit session with history and styling."""
|
||||
history_file = Path.home() / ".hermes_history"
|
||||
|
||||
# Custom style for the prompt
|
||||
self.prompt_style = PTStyle.from_dict({
|
||||
'prompt': '#FFD700 bold',
|
||||
'input': '#FFF8DC',
|
||||
})
|
||||
|
||||
# Create prompt session with file history
|
||||
# Note: multiline disabled - Enter submits, use \ at end of line for continuation
|
||||
self.prompt_session = PromptSession(
|
||||
history=FileHistory(str(history_file)),
|
||||
style=self.prompt_style,
|
||||
enable_history_search=True,
|
||||
)
|
||||
|
||||
def _init_agent(self) -> bool:
|
||||
"""
|
||||
@@ -662,7 +673,6 @@ class HermesCLI:
|
||||
quiet_mode=True, # Suppress verbose output for clean CLI
|
||||
ephemeral_system_prompt=self.system_prompt if self.system_prompt else None,
|
||||
session_id=self.session_id, # Pass CLI's session ID to agent
|
||||
platform="cli", # CLI interface — agent uses terminal-friendly formatting
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -921,67 +931,6 @@ class HermesCLI:
|
||||
except Exception as e:
|
||||
print(f"(x_x) Failed to save: {e}")
|
||||
|
||||
def retry_last(self):
|
||||
"""Retry the last user message by removing the last exchange and re-sending.
|
||||
|
||||
Removes the last assistant response (and any tool-call messages) and
|
||||
the last user message, then re-sends that user message to the agent.
|
||||
Returns the message to re-send, or None if there's nothing to retry.
|
||||
"""
|
||||
if not self.conversation_history:
|
||||
print("(._.) No messages to retry.")
|
||||
return None
|
||||
|
||||
# Walk backwards to find the last user message
|
||||
last_user_idx = None
|
||||
for i in range(len(self.conversation_history) - 1, -1, -1):
|
||||
if self.conversation_history[i].get("role") == "user":
|
||||
last_user_idx = i
|
||||
break
|
||||
|
||||
if last_user_idx is None:
|
||||
print("(._.) No user message found to retry.")
|
||||
return None
|
||||
|
||||
# Extract the message text and remove everything from that point forward
|
||||
last_message = self.conversation_history[last_user_idx].get("content", "")
|
||||
self.conversation_history = self.conversation_history[:last_user_idx]
|
||||
|
||||
print(f"(^_^)b Retrying: \"{last_message[:60]}{'...' if len(last_message) > 60 else ''}\"")
|
||||
return last_message
|
||||
|
||||
def undo_last(self):
|
||||
"""Remove the last user/assistant exchange from conversation history.
|
||||
|
||||
Walks backwards and removes all messages from the last user message
|
||||
onward (including assistant responses, tool calls, etc.).
|
||||
"""
|
||||
if not self.conversation_history:
|
||||
print("(._.) No messages to undo.")
|
||||
return
|
||||
|
||||
# Walk backwards to find the last user message
|
||||
last_user_idx = None
|
||||
for i in range(len(self.conversation_history) - 1, -1, -1):
|
||||
if self.conversation_history[i].get("role") == "user":
|
||||
last_user_idx = i
|
||||
break
|
||||
|
||||
if last_user_idx is None:
|
||||
print("(._.) No user message found to undo.")
|
||||
return
|
||||
|
||||
# Count how many messages we're removing
|
||||
removed_count = len(self.conversation_history) - last_user_idx
|
||||
removed_msg = self.conversation_history[last_user_idx].get("content", "")
|
||||
|
||||
# Truncate history to before the last user message
|
||||
self.conversation_history = self.conversation_history[:last_user_idx]
|
||||
|
||||
print(f"(^_^)b Undid {removed_count} message(s). Removed: \"{removed_msg[:60]}{'...' if len(removed_msg) > 60 else ''}\"")
|
||||
remaining = len(self.conversation_history)
|
||||
print(f" {remaining} message(s) remaining in history.")
|
||||
|
||||
def _handle_prompt_command(self, cmd: str):
|
||||
"""Handle the /prompt command to view or set system prompt."""
|
||||
parts = cmd.split(maxsplit=1)
|
||||
@@ -1319,13 +1268,6 @@ class HermesCLI:
|
||||
elif cmd_lower.startswith("/personality"):
|
||||
# Use original case (handler lowercases the personality name itself)
|
||||
self._handle_personality_command(cmd_original)
|
||||
elif cmd_lower == "/retry":
|
||||
retry_msg = self.retry_last()
|
||||
if retry_msg and hasattr(self, '_pending_input'):
|
||||
# Re-queue the message so process_loop sends it to the agent
|
||||
self._pending_input.put(retry_msg)
|
||||
elif cmd_lower == "/undo":
|
||||
self.undo_last()
|
||||
elif cmd_lower == "/save":
|
||||
self.save_conversation()
|
||||
elif cmd_lower.startswith("/cron"):
|
||||
@@ -1360,9 +1302,8 @@ class HermesCLI:
|
||||
# Add user message to history
|
||||
self.conversation_history.append({"role": "user", "content": message})
|
||||
|
||||
# Visual separator after user input (adapt to terminal width, capped for readability)
|
||||
term_width = min(self.console.width, 120)
|
||||
print("─" * term_width, flush=True)
|
||||
# Visual separator after user input
|
||||
print("─" * 60, flush=True)
|
||||
|
||||
try:
|
||||
# Run the conversation with interrupt monitoring
|
||||
@@ -1420,20 +1361,14 @@ class HermesCLI:
|
||||
|
||||
if response:
|
||||
# Use simple print for compatibility with prompt_toolkit's patch_stdout
|
||||
# Adapt box width to terminal (cap at 120 for readability)
|
||||
box_width = min(self.console.width, 120)
|
||||
inner = box_width - 2 # account for border chars ╭/╰ and ╮/╯
|
||||
label = "⚕ Hermes"
|
||||
padding = inner - len(label) - 1 # -1 for the leading space
|
||||
|
||||
print()
|
||||
print("╭" + "─" * inner + "╮")
|
||||
print("│ " + label + " " * max(padding, 0) + "│")
|
||||
print("╰" + "─" * inner + "╯")
|
||||
print("╭" + "─" * 58 + "╮")
|
||||
print("│ ⚕ Hermes" + " " * 49 + "│")
|
||||
print("╰" + "─" * 58 + "╯")
|
||||
print()
|
||||
print(response)
|
||||
print()
|
||||
print("─" * box_width)
|
||||
print("─" * 60)
|
||||
|
||||
# If we have a pending message from interrupt, re-queue it for process_loop
|
||||
# instead of recursing (avoids unbounded recursion from rapid interrupts)
|
||||
@@ -1447,6 +1382,37 @@ class HermesCLI:
|
||||
print(f"Error: {e}")
|
||||
return None
|
||||
|
||||
def get_input(self) -> Optional[str]:
|
||||
"""
|
||||
Get user input using prompt_toolkit.
|
||||
|
||||
Enter submits. For multiline, end line with \\ to continue.
|
||||
|
||||
Returns:
|
||||
The user's input, or None if EOF/interrupt
|
||||
"""
|
||||
try:
|
||||
# Get first line
|
||||
line = self.prompt_session.prompt(
|
||||
HTML('<prompt>❯ </prompt>'),
|
||||
style=self.prompt_style,
|
||||
)
|
||||
|
||||
# Handle multi-line input (lines ending with \)
|
||||
lines = [line]
|
||||
while line.endswith("\\"):
|
||||
lines[-1] = line[:-1] # Remove trailing backslash
|
||||
line = self.prompt_session.prompt(
|
||||
HTML('<prompt> </prompt>'), # Continuation prompt
|
||||
style=self.prompt_style,
|
||||
)
|
||||
lines.append(line)
|
||||
|
||||
return "\n".join(lines).strip()
|
||||
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
"""Run the interactive CLI loop with persistent input at bottom."""
|
||||
self.show_banner()
|
||||
@@ -1460,6 +1426,9 @@ class HermesCLI:
|
||||
self._should_exit = False
|
||||
self._last_ctrl_c_time = 0 # Track double Ctrl+C for force exit
|
||||
|
||||
# Create a persistent input area using prompt_toolkit Application
|
||||
input_buffer = Buffer()
|
||||
|
||||
# Key bindings for the input area
|
||||
kb = KeyBindings()
|
||||
|
||||
@@ -1517,14 +1486,13 @@ class HermesCLI:
|
||||
self._should_exit = True
|
||||
event.app.exit()
|
||||
|
||||
# Create the input area widget with persistent history across sessions
|
||||
# Create the input area widget
|
||||
input_area = TextArea(
|
||||
height=1,
|
||||
prompt='❯ ',
|
||||
style='class:input-area',
|
||||
multiline=False,
|
||||
wrap_lines=False,
|
||||
history=FileHistory(str(self._history_file)),
|
||||
)
|
||||
|
||||
# Create a status line that shows when agent is working
|
||||
@@ -1577,7 +1545,6 @@ class HermesCLI:
|
||||
|
||||
# Check for commands
|
||||
if user_input.startswith("/"):
|
||||
print(f"\n⚙️ {user_input}")
|
||||
if not self.process_command(user_input):
|
||||
self._should_exit = True
|
||||
# Schedule app exit
|
||||
@@ -1589,9 +1556,6 @@ class HermesCLI:
|
||||
self._agent_running = True
|
||||
app.invalidate() # Refresh status line
|
||||
|
||||
# Echo the user's input so it stays visible in scrollback
|
||||
print(f"\n💬 You: {user_input}")
|
||||
|
||||
try:
|
||||
self.chat(user_input)
|
||||
finally:
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
# Endless Terminals Environment Configuration
|
||||
#
|
||||
# Two modes:
|
||||
# 1. Dataset mode (default): Load pre-generated tasks from HuggingFace
|
||||
# 2. Procedural mode: Generate tasks on-demand via LLM
|
||||
#
|
||||
# Usage:
|
||||
# python -m atropos.envs.endless_terminals_env process \
|
||||
# --config configs/endless_terminals.yaml
|
||||
|
||||
# Environment settings
|
||||
env:
|
||||
# Dataset mode (primary - recommended)
|
||||
use_dataset: true # Load from HuggingFace (fast, no vLLM needed)
|
||||
dataset_name: "obiwan96/endless-terminals-train"
|
||||
dataset_split: "train"
|
||||
dataset_cache_dir: "~/.cache/huggingface/datasets"
|
||||
tasks_base_dir: "" # Set to dir containing task_* folders if not using default paths
|
||||
# Example: "/path/to/endless-terminals-train"
|
||||
|
||||
# Task generation (fallback if use_dataset=false)
|
||||
task_gen_model: "Qwen/Qwen3-32B" # Only needed if use_dataset=false
|
||||
task_gen_temperature: 1.0
|
||||
task_gen_max_tokens: 2048
|
||||
|
||||
# Container settings
|
||||
base_container_image: "ubuntu:22.04"
|
||||
container_timeout_s: 180
|
||||
test_timeout_s: 60
|
||||
|
||||
# Workspace
|
||||
workspace_dir: "/tmp/endless_terminals_workspace"
|
||||
keep_failed_tasks: false # Set true to debug failed tasks
|
||||
|
||||
# Agent config (increased for long traces)
|
||||
agent_max_steps: 32
|
||||
agent_temperature: 0.7
|
||||
agent_max_tokens: null # Let backend decide
|
||||
|
||||
# Tooling: terminal only
|
||||
enabled_toolsets: ["terminal"]
|
||||
disabled_toolsets: []
|
||||
|
||||
# Training settings
|
||||
group_size: 4 # Parallel trajectory collection
|
||||
batch_size: 32
|
||||
total_steps: 1000 # Total training episodes
|
||||
use_wandb: false # Enable for experiment tracking
|
||||
include_messages: true
|
||||
|
||||
# Tool execution backend (nomad or modal)
|
||||
tool_pool_mode: "nomad"
|
||||
|
||||
# Nomad settings (if using nomad)
|
||||
nomad_address: "http://localhost:4646"
|
||||
sandbox_job_id: "atropos-sandbox-endless"
|
||||
sandbox_image: "atropos-sandbox:local"
|
||||
slots_per_container: 10
|
||||
min_containers: 1
|
||||
max_containers: 10
|
||||
privileged: false
|
||||
acquire_timeout_s: 30.0
|
||||
purge_job_on_start: true
|
||||
purge_job_on_shutdown: true
|
||||
|
||||
# Modal settings (if using modal instead)
|
||||
# modal_app_name: "atropos-endless"
|
||||
# modal_image: "python:3.11"
|
||||
# modal_slots_per_sandbox: 10
|
||||
# modal_min_sandboxes: 1
|
||||
# modal_max_sandboxes: 5
|
||||
|
||||
# Server config
|
||||
server_base_url: "http://127.0.0.1:8080"
|
||||
server_model: "hermes-4-36b"
|
||||
tokenizer_name: "NousResearch/Hermes-4.3-36B"
|
||||
|
||||
# Server configs are auto-generated from env vars and env.server_* settings
|
||||
# Override via environment variables:
|
||||
# ATROPOS_SERVER_BASE_URL
|
||||
# ATROPOS_SERVER_MODEL
|
||||
# ATROPOS_SERVER_API_KEY
|
||||
# ATROPOS_TOKENIZER_NAME
|
||||
@@ -0,0 +1,224 @@
|
||||
# Modal Backend
|
||||
|
||||
Hermes Agent uses [Modal](https://modal.com) for scalable, isolated cloud execution environments. There are two Modal integrations:
|
||||
|
||||
1. **Terminal Tool** (`tools/terminal_tool.py`) - For CLI/agent command execution
|
||||
2. **Atropos Backend** (`atropos/backends/modal_backend.py`) - For batch RL training workloads
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Terminal Tool (CLI/Agent)
|
||||
|
||||
The terminal tool provides a simple interface for executing commands in Modal sandboxes.
|
||||
|
||||
### Configuration
|
||||
|
||||
Set environment variables:
|
||||
|
||||
```bash
|
||||
export TERMINAL_ENV=modal
|
||||
export TERMINAL_MODAL_IMAGE=python:3.11
|
||||
export TERMINAL_MODAL_APP_NAME=hermes-sandbox
|
||||
```
|
||||
|
||||
Or use a YAML config file (`modal_profiles.yaml`):
|
||||
|
||||
```yaml
|
||||
profiles:
|
||||
default:
|
||||
image: python:3.11
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 1
|
||||
max_pool: 5
|
||||
idle_timeout: 120
|
||||
|
||||
gpu:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: T4
|
||||
memory: 16384
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
```
|
||||
|
||||
### Features
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| **Sandbox Pool** | Pre-warmed sandboxes for low latency |
|
||||
| **Auto-scaling** | Grows/shrinks pool based on demand |
|
||||
| **Idle Timeout** | Sandboxes auto-terminate when unused |
|
||||
| **Profile Selection** | Different configs for different workloads |
|
||||
| **Credential Injection** | `modal.Secret` integration |
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from tools.terminal_tool import terminal_tool
|
||||
|
||||
# Simple command
|
||||
output = terminal_tool("echo hello", task_id="my-task")
|
||||
|
||||
# With profile selection
|
||||
output = terminal_tool("python train.py", task_id="training", profile="gpu")
|
||||
|
||||
# Cleanup when done
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
cleanup_vm("my-task")
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
_ModalPoolManager (singleton)
|
||||
├── "default" pool → [sandbox-0, sandbox-1, ...]
|
||||
└── "gpu" pool → [sandbox-0, ...]
|
||||
|
||||
Each pool:
|
||||
- Maintains min_pool warm sandboxes
|
||||
- Scales up to max_pool on demand
|
||||
- Background thread scales down idle sandboxes
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Atropos Backend (RL Training)
|
||||
|
||||
The Atropos backend is designed for high-throughput batch execution during reinforcement learning training.
|
||||
|
||||
### Key Concept: Slot-based Multiplexing
|
||||
|
||||
Instead of one sandbox per trajectory, multiple trajectories share sandboxes via **slots**:
|
||||
|
||||
```
|
||||
Sandbox (1 container)
|
||||
├── Slot 0 → Trajectory A (workspace: /data/slot_0)
|
||||
├── Slot 1 → Trajectory B (workspace: /data/slot_1)
|
||||
└── Slot 2 → Trajectory C (workspace: /data/slot_2)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Fewer containers = lower cost
|
||||
- Shared warm-up time
|
||||
- Better GPU utilization
|
||||
|
||||
### Configuration
|
||||
|
||||
```python
|
||||
from atropos.backends.modal_backend import ModalSandboxConfig, ModalToolBackend
|
||||
|
||||
config = ModalSandboxConfig(
|
||||
name="default",
|
||||
image="python:3.11",
|
||||
cpu=1.0,
|
||||
memory=2048,
|
||||
slots_per_sandbox=10, # 10 trajectories per container
|
||||
min_sandboxes=1,
|
||||
max_sandboxes=5,
|
||||
)
|
||||
|
||||
backend = ModalToolBackend(config.with_app_name("my-training"))
|
||||
```
|
||||
|
||||
### Multi-Profile Support
|
||||
|
||||
Different trajectory types can request different resources:
|
||||
|
||||
```python
|
||||
backend = ModalToolBackend.with_profiles(
|
||||
app_name="rl-training",
|
||||
profiles={
|
||||
"default": ModalSandboxConfig(
|
||||
name="default",
|
||||
cpu=1.0,
|
||||
memory=2048,
|
||||
),
|
||||
"pytorch-gpu": ModalSandboxConfig(
|
||||
name="pytorch-gpu",
|
||||
image="pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime",
|
||||
gpu="T4",
|
||||
memory=16384,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# CPU task
|
||||
slot1 = await backend.acquire("traj-1", profile="default")
|
||||
|
||||
# GPU task
|
||||
slot2 = await backend.acquire("traj-2", profile="pytorch-gpu")
|
||||
```
|
||||
|
||||
### Batched Execution
|
||||
|
||||
The key optimization - execute many commands in parallel:
|
||||
|
||||
```python
|
||||
# Acquire slots for multiple trajectories
|
||||
slots = [await backend.acquire(f"traj-{i}") for i in range(50)]
|
||||
|
||||
# Execute batch across all slots in parallel
|
||||
results = await backend.execute_batch([
|
||||
(slot, "bash", {"command": "python step.py"})
|
||||
for slot in slots
|
||||
])
|
||||
|
||||
# Release slots
|
||||
for slot in slots:
|
||||
await backend.release(slot)
|
||||
```
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
ModalToolBackend
|
||||
└── _ModalMultiProfileManager
|
||||
├── "default" → _ModalSandboxPool
|
||||
│ ├── Sandbox 0 (slots 0-9)
|
||||
│ └── Sandbox 1 (slots 0-9)
|
||||
│
|
||||
└── "pytorch-gpu" → _ModalSandboxPool
|
||||
└── Sandbox 0 (slots 0-9)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Credentials
|
||||
|
||||
Inject secrets securely using Modal's secret management:
|
||||
|
||||
```bash
|
||||
# Create secret in Modal dashboard or CLI
|
||||
modal secret create my-api-key API_KEY=sk-xxx
|
||||
```
|
||||
|
||||
```python
|
||||
# Reference in config
|
||||
config = ModalSandboxConfig(
|
||||
secrets=["my-api-key"], # Modal secret names
|
||||
env_vars={"DEBUG": "1"}, # Additional env vars
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Modal package not installed"
|
||||
```bash
|
||||
pip install modal
|
||||
modal token new # Authenticate
|
||||
```
|
||||
|
||||
### "Sandbox creation failed"
|
||||
- Check Modal dashboard for quota limits
|
||||
- Verify image exists and is accessible
|
||||
- Check secret names are correct
|
||||
|
||||
### Shutdown errors
|
||||
These are harmless warnings during Python interpreter shutdown:
|
||||
```
|
||||
[Modal] Error terminating ...: cannot schedule new futures after interpreter shutdown
|
||||
```
|
||||
|
||||
The sandboxes will auto-terminate via Modal's idle_timeout anyway.
|
||||
@@ -1,330 +0,0 @@
|
||||
# Hermes-Agent Atropos Environments
|
||||
|
||||
This directory contains the integration layer between **hermes-agent's** tool-calling capabilities and the **Atropos** RL training framework. It provides everything needed to run agentic LLMs through multi-turn tool-calling loops, score their output with arbitrary reward functions, and feed results into Atropos for training or evaluation.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Atropos Framework
|
||||
┌───────────────────────┐
|
||||
│ BaseEnv │ (atroposlib)
|
||||
│ - Server management │
|
||||
│ - Worker scheduling │
|
||||
│ - Wandb logging │
|
||||
│ - CLI (serve/process/ │
|
||||
│ evaluate) │
|
||||
└───────────┬───────────┘
|
||||
│ inherits
|
||||
┌───────────┴───────────┐
|
||||
│ HermesAgentBaseEnv │ hermes_base_env.py
|
||||
│ - Terminal backend │
|
||||
│ - Tool resolution │
|
||||
│ - Agent loop │
|
||||
│ - ToolContext │
|
||||
│ - Async patches │
|
||||
└───────────┬───────────┘
|
||||
│ inherits
|
||||
┌─────────────────┼─────────────────┐
|
||||
│ │ │
|
||||
TerminalTestEnv HermesSweEnv TerminalBench2EvalEnv
|
||||
(stack testing) (SWE training) (TB2 benchmark eval)
|
||||
```
|
||||
|
||||
### Inheritance Chain
|
||||
|
||||
**BaseEnv** (from `atroposlib`) is the Atropos base class. It provides:
|
||||
- Server management (OpenAI-compatible API servers, VLLM, SGLang)
|
||||
- Worker scheduling for parallel rollouts
|
||||
- Wandb integration for metrics and rollout logging
|
||||
- CLI interface with three subcommands: `serve`, `process`, `evaluate`
|
||||
- `evaluate_log()` for saving eval results to JSON + samples.jsonl
|
||||
|
||||
**HermesAgentBaseEnv** (`hermes_base_env.py`) extends BaseEnv with hermes-agent specifics:
|
||||
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, modal, ssh, singularity)
|
||||
- Resolves hermes-agent toolsets via `_resolve_tools_for_group()` (calls `get_tool_definitions()` from `model_tools.py`)
|
||||
- Implements `collect_trajectory()` which runs the full agent loop and computes rewards
|
||||
- Supports two-phase operation (Phase 1: OpenAI server, Phase 2: VLLM ManagedServer)
|
||||
- Applies monkey patches for async-safe tool operation at import time
|
||||
|
||||
Concrete environments inherit from `HermesAgentBaseEnv` and implement:
|
||||
- `setup()` -- Load dataset, initialize state
|
||||
- `get_next_item()` -- Return the next item for rollout
|
||||
- `format_prompt()` -- Convert a dataset item into the user message
|
||||
- `compute_reward()` -- Score the rollout using ToolContext
|
||||
- `evaluate()` -- Periodic evaluation logic
|
||||
|
||||
## Core Components
|
||||
|
||||
### Agent Loop (`agent_loop.py`)
|
||||
|
||||
`HermesAgentLoop` is the reusable multi-turn agent engine. It runs the same pattern as hermes-agent's `run_agent.py`:
|
||||
|
||||
1. Send messages + tools to the API via `server.chat_completion()`
|
||||
2. If the response contains `tool_calls`, execute each one via `handle_function_call()` from `model_tools.py`
|
||||
3. Append tool results to the conversation and go back to step 1
|
||||
4. If the response has no tool_calls, the agent is done
|
||||
|
||||
Tool calls are executed in a thread pool (`run_in_executor`) so backends that use `asyncio.run()` internally (Modal, Docker) don't deadlock inside Atropos's event loop.
|
||||
|
||||
Returns an `AgentResult` containing the full conversation history, turn count, reasoning content per turn, tool errors, and optional ManagedServer state (for Phase 2).
|
||||
|
||||
### Tool Context (`tool_context.py`)
|
||||
|
||||
`ToolContext` is a per-rollout handle that gives reward/verification functions direct access to **all** hermes-agent tools, scoped to the rollout's `task_id`. The same `task_id` means the terminal/browser session is the SAME one the model used during its rollout -- all state (files, processes, browser tabs) is preserved.
|
||||
|
||||
```python
|
||||
async def compute_reward(self, item, result, ctx: ToolContext):
|
||||
# Run tests in the model's terminal sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
if test["exit_code"] == 0:
|
||||
return 1.0
|
||||
|
||||
# Check if a file was created
|
||||
content = ctx.read_file("/workspace/solution.py")
|
||||
if content.get("content"):
|
||||
return 0.5
|
||||
|
||||
# Download files locally for verification (binary-safe)
|
||||
ctx.download_file("/remote/output.bin", "/local/output.bin")
|
||||
|
||||
return 0.0
|
||||
```
|
||||
|
||||
Available methods:
|
||||
- **Terminal**: `terminal(command, timeout)` -- run shell commands
|
||||
- **Files**: `read_file(path)`, `write_file(path, content)`, `search(query, path)`
|
||||
- **Transfers**: `upload_file()`, `upload_dir()`, `download_file()`, `download_dir()` -- binary-safe file transfers between host and sandbox
|
||||
- **Web**: `web_search(query)`, `web_extract(urls)`
|
||||
- **Browser**: `browser_navigate(url)`, `browser_snapshot()`
|
||||
- **Generic**: `call_tool(name, args)` -- call any hermes-agent tool by name
|
||||
- **Cleanup**: `cleanup()` -- release all resources (called automatically after `compute_reward`)
|
||||
|
||||
### Patches (`patches.py`)
|
||||
|
||||
**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., mini-swe-agent's Modal backend via SWE-ReX). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested.
|
||||
|
||||
**Solution**: `patches.py` monkey-patches `SwerexModalEnvironment` to use a dedicated background thread (`_AsyncWorker`) with its own event loop. The calling code sees the same sync interface, but internally the async work happens on a separate thread that doesn't conflict with Atropos's loop.
|
||||
|
||||
What gets patched:
|
||||
- `SwerexModalEnvironment.__init__` -- creates Modal deployment on a background thread
|
||||
- `SwerexModalEnvironment.execute` -- runs commands on the same background thread
|
||||
- `SwerexModalEnvironment.stop` -- stops deployment on the background thread
|
||||
|
||||
The patches are:
|
||||
- **Idempotent** -- calling `apply_patches()` multiple times is safe
|
||||
- **Transparent** -- same interface and behavior, only the internal async execution changes
|
||||
- **Universal** -- works identically in normal CLI use (no running event loop)
|
||||
|
||||
Applied automatically at import time by `hermes_base_env.py`.
|
||||
|
||||
### Tool Call Parsers (`tool_call_parsers/`)
|
||||
|
||||
Client-side parsers that extract structured `tool_calls` from raw model output text. Used in **Phase 2** (VLLM server type) where ManagedServer's `/generate` endpoint returns raw text without tool call parsing.
|
||||
|
||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's `extract_tool_calls()` logic. No VLLM dependency -- only standard library (`re`, `json`, `uuid`) and `openai` types.
|
||||
|
||||
Available parsers:
|
||||
- `hermes` -- Hermes/ChatML `<tool_call>` XML format
|
||||
- `mistral` -- Mistral `[TOOL_CALLS]` format
|
||||
- `llama3_json` -- Llama 3 JSON tool calling
|
||||
- `qwen` -- Qwen tool calling format
|
||||
- `qwen3_coder` -- Qwen3 Coder format
|
||||
- `deepseek_v3` -- DeepSeek V3 format
|
||||
- `deepseek_v3_1` -- DeepSeek V3.1 format
|
||||
- `kimi_k2` -- Kimi K2 format
|
||||
- `longcat` -- Longcat format
|
||||
- `glm45` / `glm47` -- GLM model formats
|
||||
|
||||
Usage:
|
||||
```python
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse(raw_model_output)
|
||||
```
|
||||
|
||||
In Phase 1 (OpenAI server type), these parsers are not needed -- the server handles tool call parsing natively.
|
||||
|
||||
## Two-Phase Operation
|
||||
|
||||
### Phase 1: OpenAI Server (Evaluation / SFT Data Generation)
|
||||
|
||||
Uses `server.chat_completion()` with `tools=` parameter. The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing natively. Returns `ChatCompletion` objects with structured `tool_calls`.
|
||||
|
||||
- Good for: evaluation, SFT data generation, testing
|
||||
- Run with: `serve` (with `run-api`), `process`, or `evaluate` subcommands
|
||||
- Placeholder tokens are created for the Atropos pipeline
|
||||
|
||||
### Phase 2: VLLM ManagedServer (Full RL Training)
|
||||
|
||||
Uses ManagedServer for exact token IDs + logprobs via `/generate`. Client-side tool call parser (from `tool_call_parsers/`) reconstructs structured `tool_calls` from raw output.
|
||||
|
||||
- Good for: full RL training with GRPO/PPO
|
||||
- Run with: `serve` subcommand
|
||||
- Real tokens, masks, and logprobs flow through the pipeline
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
environments/
|
||||
├── README.md # This file
|
||||
├── __init__.py # Package exports
|
||||
├── hermes_base_env.py # Abstract base (HermesAgentBaseEnv)
|
||||
├── agent_loop.py # Multi-turn agent engine (HermesAgentLoop)
|
||||
├── tool_context.py # Per-rollout tool access for reward functions
|
||||
├── patches.py # Async-safety patches for Modal backend
|
||||
│
|
||||
├── tool_call_parsers/ # Phase 2 client-side parsers
|
||||
│ ├── __init__.py # Registry + base class
|
||||
│ ├── hermes_parser.py
|
||||
│ ├── mistral_parser.py
|
||||
│ ├── llama_parser.py
|
||||
│ ├── qwen_parser.py
|
||||
│ ├── qwen3_coder_parser.py
|
||||
│ ├── deepseek_v3_parser.py
|
||||
│ ├── deepseek_v3_1_parser.py
|
||||
│ ├── kimi_k2_parser.py
|
||||
│ ├── longcat_parser.py
|
||||
│ ├── glm45_parser.py
|
||||
│ └── glm47_parser.py
|
||||
│
|
||||
├── terminal_test_env/ # Stack validation environment
|
||||
│ └── terminal_test_env.py
|
||||
│
|
||||
├── hermes_swe_env/ # SWE-bench style training environment
|
||||
│ └── hermes_swe_env.py
|
||||
│
|
||||
└── benchmarks/ # Evaluation benchmarks
|
||||
└── terminalbench_2/
|
||||
└── terminalbench2_env.py
|
||||
```
|
||||
|
||||
## Concrete Environments
|
||||
|
||||
### TerminalTestEnv (`terminal_test_env/`)
|
||||
|
||||
A self-contained environment with inline tasks (no external dataset needed) for validating the full stack end-to-end. Each task asks the model to create a file at a known path, and the verifier checks the content matches.
|
||||
|
||||
```bash
|
||||
# Serve mode (needs run-api)
|
||||
run-api
|
||||
python environments/terminal_test_env/terminal_test_env.py serve
|
||||
|
||||
# Process mode (no run-api, saves to JSONL)
|
||||
python environments/terminal_test_env/terminal_test_env.py process \
|
||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
||||
```
|
||||
|
||||
### HermesSweEnv (`hermes_swe_env/`)
|
||||
|
||||
SWE-bench style training environment. The model gets a coding task, uses terminal + file + web tools to solve it, and the reward function runs tests in the same Modal sandbox.
|
||||
|
||||
```bash
|
||||
python environments/hermes_swe_env/hermes_swe_env.py serve \
|
||||
--openai.model_name YourModel \
|
||||
--env.dataset_name bigcode/humanevalpack \
|
||||
--env.terminal_backend modal
|
||||
```
|
||||
|
||||
### TerminalBench2EvalEnv (`benchmarks/terminalbench_2/`)
|
||||
|
||||
**Eval-only** environment for the Terminal-Bench 2.0 benchmark (89 tasks). Each task gets a pre-built Docker Hub image, a natural language instruction, and a test suite. The agent uses terminal + file tools to solve the task, then the test suite verifies correctness.
|
||||
|
||||
Follows the standard Atropos eval pattern (like GPQA, MMLU, etc.):
|
||||
- Run via `evaluate` subcommand (no `run-api` needed)
|
||||
- `setup()` loads the dataset, `evaluate()` runs all tasks
|
||||
- `rollout_and_score_eval()` handles per-task agent loop + test verification
|
||||
- Downloads verifier output locally for reliable reward checking (Harbor pattern)
|
||||
|
||||
```bash
|
||||
# Run full benchmark
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6
|
||||
|
||||
# Run subset of tasks
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6 \
|
||||
--env.task_filter fix-git,git-multibranch
|
||||
|
||||
# Skip specific tasks
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--openai.model_name anthropic/claude-opus-4.6 \
|
||||
--env.skip_tasks heavy-task,slow-task
|
||||
```
|
||||
|
||||
## Creating a New Environment
|
||||
|
||||
### Training Environment
|
||||
|
||||
1. Create a new directory under `environments/`
|
||||
2. Create your env file inheriting from `HermesAgentBaseEnv`
|
||||
3. Implement the four abstract methods + `evaluate()`
|
||||
|
||||
```python
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
class MyEnvConfig(HermesAgentEnvConfig):
|
||||
pass # Add custom fields as needed
|
||||
|
||||
class MyEnv(HermesAgentBaseEnv):
|
||||
name = "my-env"
|
||||
env_config_cls = MyEnvConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls):
|
||||
env_config = MyEnvConfig(
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
terminal_backend="modal",
|
||||
# ... other config
|
||||
)
|
||||
server_configs = [APIServerConfig(...)]
|
||||
return env_config, server_configs
|
||||
|
||||
async def setup(self):
|
||||
self.dataset = load_dataset(...)
|
||||
self.iter = 0
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.dataset[self.iter % len(self.dataset)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item):
|
||||
return item["instruction"]
|
||||
|
||||
async def compute_reward(self, item, result, ctx):
|
||||
# ctx gives you full tool access to the rollout's sandbox
|
||||
test = ctx.terminal("pytest -v")
|
||||
return 1.0 if test["exit_code"] == 0 else 0.0
|
||||
|
||||
async def evaluate(self, *args, **kwargs):
|
||||
# Periodic evaluation logic
|
||||
...
|
||||
|
||||
if __name__ == "__main__":
|
||||
MyEnv.cli()
|
||||
```
|
||||
|
||||
### Eval-Only Environment (Benchmark)
|
||||
|
||||
For eval benchmarks, follow the pattern in `terminalbench2_env.py`:
|
||||
1. Create under `environments/benchmarks/your-benchmark/`
|
||||
2. Inherit from `HermesAgentBaseEnv`
|
||||
3. Set eval-only config: `eval_handling=STOP_TRAIN`, `steps_per_eval=1`, `total_steps=1`
|
||||
4. Stub the training methods (`collect_trajectories`, `score`)
|
||||
5. Implement `rollout_and_score_eval()` and `evaluate()`
|
||||
6. Run with `evaluate` subcommand
|
||||
|
||||
## Key Config Fields
|
||||
|
||||
| Field | Description | Default |
|
||||
|-------|-------------|---------|
|
||||
| `enabled_toolsets` | Which hermes toolsets to enable | `None` (all) |
|
||||
| `disabled_toolsets` | Toolsets to disable | `None` |
|
||||
| `distribution` | Probabilistic toolset distribution name | `None` |
|
||||
| `max_agent_turns` | Max LLM calls per rollout | `30` |
|
||||
| `agent_temperature` | Sampling temperature | `1.0` |
|
||||
| `terminal_backend` | `local`, `docker`, `modal`, `ssh`, `singularity` | `local` |
|
||||
| `system_prompt` | System message for the agent | `None` |
|
||||
| `tool_call_parser` | Parser name for Phase 2 | `hermes` |
|
||||
| `eval_handling` | `STOP_TRAIN`, `LIMIT_TRAIN`, `NONE` | `STOP_TRAIN` |
|
||||
@@ -4,18 +4,15 @@ Hermes-Agent Atropos Environments
|
||||
Provides a layered integration between hermes-agent's tool-calling capabilities
|
||||
and the Atropos RL training framework.
|
||||
|
||||
Core layers:
|
||||
Layers:
|
||||
- agent_loop: Reusable multi-turn agent loop with standard OpenAI-spec tool calling
|
||||
- tool_context: Per-rollout tool access handle for reward/verification functions
|
||||
- hermes_base_env: Abstract base environment (BaseEnv subclass) for Atropos
|
||||
- tool_call_parsers: Client-side tool call parser registry for Phase 2 (VLLM /generate)
|
||||
|
||||
Concrete environments:
|
||||
- terminal_test_env/: Simple file-creation tasks for testing the stack
|
||||
- hermes_swe_env/: SWE-bench style tasks with Modal sandboxes
|
||||
|
||||
Benchmarks (eval-only):
|
||||
- benchmarks/terminalbench_2/: Terminal-Bench 2.0 evaluation
|
||||
- terminal_test_env: Simple file-creation tasks for testing the stack
|
||||
- hermes_swe_env: SWE-bench style tasks with Modal sandboxes
|
||||
"""
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
|
||||
+47
-263
@@ -15,7 +15,6 @@ import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
@@ -25,22 +24,7 @@ from model_tools import handle_function_call
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
# (e.g., mini-swe-agent's modal/docker backends). Running them in a separate
|
||||
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
|
||||
# Size must be large enough for concurrent eval tasks (e.g., 89 TB2 tasks all
|
||||
# making tool calls). Too small = thread pool starvation, tasks queue for minutes.
|
||||
# Resized at runtime by HermesAgentBaseEnv.__init__ via resize_tool_pool().
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=128)
|
||||
|
||||
|
||||
def resize_tool_pool(max_workers: int):
|
||||
"""
|
||||
Replace the global tool executor with a new one of the given size.
|
||||
|
||||
Called by HermesAgentBaseEnv.__init__ based on config.tool_pool_size.
|
||||
Safe to call before any tasks are submitted.
|
||||
"""
|
||||
global _tool_executor
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
logger.info("Tool thread pool resized to %d workers", max_workers)
|
||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=8)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -73,12 +57,6 @@ class AgentResult:
|
||||
# Tool errors encountered during the loop
|
||||
tool_errors: List[ToolError] = field(default_factory=list)
|
||||
|
||||
# Tool-call metrics (debugging / optional reward shaping)
|
||||
tool_calls_attempted: int = 0
|
||||
tool_calls_schema_valid: int = 0
|
||||
tool_calls_executed_ok: int = 0
|
||||
tool_calls_exec_error: int = 0
|
||||
|
||||
|
||||
def _extract_reasoning_from_message(message) -> Optional[str]:
|
||||
"""
|
||||
@@ -141,9 +119,6 @@ class HermesAgentLoop:
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
tool_handler=None,
|
||||
max_context_tokens: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
@@ -157,16 +132,6 @@ class HermesAgentLoop:
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
||||
Used for OpenRouter provider preferences, transforms, etc.
|
||||
e.g. {"provider": {"ignore": ["DeepInfra"]}}
|
||||
tool_handler: Optional async callable(tool_name, args, task_id) -> str.
|
||||
When provided, used INSTEAD of handle_function_call() for
|
||||
tool dispatch. This allows sandbox backends (Modal, Nomad)
|
||||
to route tool calls through their slot-based execution.
|
||||
max_context_tokens: Maximum prompt tokens before truncation.
|
||||
If None, no truncation is applied.
|
||||
Recommended: set to max_model_len - max_tokens - 512 (safety margin).
|
||||
"""
|
||||
self.server = server
|
||||
self.tool_schemas = tool_schemas
|
||||
@@ -175,124 +140,6 @@ class HermesAgentLoop:
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.extra_body = extra_body
|
||||
self.tool_handler = tool_handler
|
||||
self.max_context_tokens = max_context_tokens
|
||||
|
||||
def _truncate_context(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Truncate conversation history to fit within max_context_tokens.
|
||||
|
||||
Strategy:
|
||||
- Keep system message (index 0) and initial user message (index 1) always
|
||||
- Keep last 6 messages (recent context) always
|
||||
- For everything in between, progressively truncate tool result content
|
||||
- If still too long, drop oldest middle messages entirely
|
||||
|
||||
Uses rough char/4 token estimate (fast, no tokenizer needed).
|
||||
|
||||
NOTE: This function mutates the provided list (it may pop/replace entries).
|
||||
Call it on a copy when you want to preserve the full trajectory.
|
||||
"""
|
||||
if self.max_context_tokens is None:
|
||||
return messages
|
||||
|
||||
def estimate_tokens(msgs):
|
||||
total = 0
|
||||
for m in msgs:
|
||||
content = m.get("content", "") or ""
|
||||
total += len(content) // 4 + 10 # ~4 chars per token + overhead
|
||||
if "tool_calls" in m:
|
||||
total += 50 * len(m["tool_calls"]) # tool call overhead
|
||||
return total
|
||||
|
||||
if estimate_tokens(messages) <= self.max_context_tokens:
|
||||
return messages
|
||||
|
||||
protect_head = 2
|
||||
protect_tail = max(0, min(6, len(messages) - protect_head))
|
||||
middle_start = protect_head
|
||||
middle_end = len(messages) - protect_tail
|
||||
|
||||
# Phase 1: truncate tool outputs in the middle
|
||||
if middle_start < middle_end:
|
||||
for i in range(middle_start, middle_end):
|
||||
if messages[i].get("role") == "tool":
|
||||
content = messages[i].get("content", "") or ""
|
||||
if len(content) > 200:
|
||||
messages[i] = dict(messages[i])
|
||||
messages[i]["content"] = content[:100] + "\n...[truncated]...\n" + content[-50:]
|
||||
|
||||
if estimate_tokens(messages) <= self.max_context_tokens:
|
||||
return messages
|
||||
|
||||
# Phase 2: drop oldest middle messages (try to keep assistant+tool pairs)
|
||||
while middle_start < middle_end and estimate_tokens(messages) > self.max_context_tokens:
|
||||
msg = messages[middle_start]
|
||||
messages.pop(middle_start)
|
||||
middle_end -= 1
|
||||
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_ids = {
|
||||
tc.get("id") or tc.get("tool_call_id", "")
|
||||
for tc in msg.get("tool_calls", [])
|
||||
if isinstance(tc, dict)
|
||||
}
|
||||
i = middle_start
|
||||
while i < middle_end:
|
||||
if messages[i].get("role") == "tool" and messages[i].get("tool_call_id", "") in tool_ids:
|
||||
messages.pop(i)
|
||||
middle_end -= 1
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return messages
|
||||
|
||||
def _normalize_tool_args(self, tool_name: str, tool_args_raw: str) -> (Dict[str, Any], bool):
|
||||
"""Normalize tool arguments into a dict.
|
||||
|
||||
Returns: (args_dict, schema_valid)
|
||||
|
||||
schema_valid is True only when arguments decode directly into a dict
|
||||
(no double-decoding and no coercion/wrapping required).
|
||||
|
||||
Goal: keep environments robust (never crash on args format drift) while
|
||||
still allowing reward functions to penalize malformed formats if desired.
|
||||
"""
|
||||
try:
|
||||
decoded = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
# Not JSON at all — treat as a plain string
|
||||
if tool_name == "terminal":
|
||||
return {"command": tool_args_raw}, False
|
||||
return {"input": tool_args_raw}, False
|
||||
|
||||
if isinstance(decoded, dict):
|
||||
if tool_name == "terminal":
|
||||
cmd = decoded.get("command")
|
||||
if isinstance(cmd, str) and cmd.strip():
|
||||
return decoded, True
|
||||
if isinstance(decoded.get("input"), str):
|
||||
return {"command": decoded.get("input")}, False
|
||||
return decoded, False
|
||||
return decoded, True
|
||||
|
||||
if isinstance(decoded, str):
|
||||
s = decoded.strip()
|
||||
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
|
||||
try:
|
||||
decoded2 = json.loads(s)
|
||||
except json.JSONDecodeError:
|
||||
decoded2 = None
|
||||
if isinstance(decoded2, dict):
|
||||
return decoded2, False
|
||||
|
||||
if tool_name == "terminal":
|
||||
return {"command": decoded}, False
|
||||
return {"input": decoded}, False
|
||||
|
||||
if tool_name == "terminal":
|
||||
return {"command": str(decoded)}, False
|
||||
return {"input": decoded}, False
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
||||
"""
|
||||
@@ -308,22 +155,10 @@ class HermesAgentLoop:
|
||||
reasoning_per_turn = []
|
||||
tool_errors: List[ToolError] = []
|
||||
|
||||
tool_calls_attempted = 0
|
||||
tool_calls_schema_valid = 0
|
||||
tool_calls_executed_ok = 0
|
||||
tool_calls_exec_error = 0
|
||||
|
||||
import time as _time
|
||||
|
||||
for turn in range(self.max_turns):
|
||||
turn_start = _time.monotonic()
|
||||
|
||||
# Truncate prompt view on a copy (preserve full trajectory in `messages`)
|
||||
prompt_messages = self._truncate_context(list(messages))
|
||||
|
||||
# Build the chat_completion kwargs
|
||||
chat_kwargs = {
|
||||
"messages": prompt_messages,
|
||||
"messages": messages,
|
||||
"n": 1,
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
@@ -336,18 +171,11 @@ class HermesAgentLoop:
|
||||
if self.max_tokens is not None:
|
||||
chat_kwargs["max_tokens"] = self.max_tokens
|
||||
|
||||
# Inject extra_body for provider-specific params (e.g., OpenRouter
|
||||
# provider preferences like banned/preferred providers, transforms)
|
||||
if self.extra_body:
|
||||
chat_kwargs["extra_body"] = self.extra_body
|
||||
|
||||
# Make the API call -- standard OpenAI spec
|
||||
api_start = _time.monotonic()
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
logger.error("API call failed on turn %d: %s", turn + 1, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
@@ -355,16 +183,10 @@ class HermesAgentLoop:
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
|
||||
if not response or not response.choices:
|
||||
logger.warning("Empty response on turn %d (api=%.1fs)", turn + 1, api_elapsed)
|
||||
logger.warning("Empty response on turn %d", turn + 1)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
@@ -372,10 +194,6 @@ class HermesAgentLoop:
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
assistant_msg = response.choices[0].message
|
||||
@@ -418,7 +236,6 @@ class HermesAgentLoop:
|
||||
|
||||
# Validate tool name
|
||||
if tool_name not in self.valid_tool_names:
|
||||
tool_calls_exec_error += 1
|
||||
tool_result = json.dumps(
|
||||
{
|
||||
"error": f"Unknown tool '{tool_name}'. "
|
||||
@@ -436,47 +253,34 @@ class HermesAgentLoop:
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
else:
|
||||
tool_calls_attempted += 1
|
||||
args, schema_valid = self._normalize_tool_args(tool_name, tool_args_raw)
|
||||
if schema_valid:
|
||||
tool_calls_schema_valid += 1
|
||||
# Parse arguments and dispatch
|
||||
try:
|
||||
args = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
logger.warning(
|
||||
"Invalid JSON in tool call arguments for '%s': %s",
|
||||
tool_name, tool_args_raw[:200],
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
import os
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = str(args.get("command", ""))[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
print(f" 🖥️ [{backend}] $ {cmd_preview}")
|
||||
|
||||
tool_submit_time = _time.monotonic()
|
||||
|
||||
if self.tool_handler:
|
||||
tool_result = await self.tool_handler(tool_name, args, self.task_id)
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that use
|
||||
# asyncio.run() internally (modal, docker) get a clean
|
||||
# event loop instead of deadlocking inside Atropos's loop.
|
||||
loop = asyncio.get_event_loop()
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
tool_name, args, task_id=self.task_id
|
||||
),
|
||||
)
|
||||
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
)
|
||||
# Run tool calls in a thread pool so backends that use
|
||||
# asyncio.run() internally (modal, docker) get a clean
|
||||
# event loop instead of deadlocking inside Atropos's loop.
|
||||
loop = asyncio.get_event_loop()
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
tool_name, args, task_id=self.task_id
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
tool_calls_exec_error += 1
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
@@ -490,31 +294,22 @@ class HermesAgentLoop:
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
else:
|
||||
tool_err = False
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
if isinstance(result_data, dict):
|
||||
err = result_data.get("error")
|
||||
if err:
|
||||
tool_err = True
|
||||
|
||||
exit_code = result_data.get("exit_code")
|
||||
if exit_code is not None and isinstance(exit_code, int) and exit_code < 0:
|
||||
tool_err = True
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=str(err) if err else "nonzero exit_code",
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if tool_err:
|
||||
tool_calls_exec_error += 1
|
||||
else:
|
||||
tool_calls_executed_ok += 1
|
||||
# Also check if the tool returned an error in its JSON result
|
||||
try:
|
||||
result_data = json.loads(tool_result)
|
||||
if isinstance(result_data, dict):
|
||||
err = result_data.get("error")
|
||||
exit_code = result_data.get("exit_code")
|
||||
if err and exit_code and exit_code < 0:
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=str(err),
|
||||
tool_result=tool_result[:500],
|
||||
))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Add tool response to conversation
|
||||
messages.append(
|
||||
@@ -525,11 +320,10 @@ class HermesAgentLoop:
|
||||
}
|
||||
)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs",
|
||||
self.task_id[:8], turn + 1, api_elapsed,
|
||||
len(assistant_msg.tool_calls), turn_elapsed,
|
||||
logger.debug(
|
||||
"Turn %d: %d tool calls executed",
|
||||
turn + 1,
|
||||
len(assistant_msg.tool_calls),
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -542,10 +336,8 @@ class HermesAgentLoop:
|
||||
msg_dict["reasoning_content"] = reasoning
|
||||
messages.append(msg_dict)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, no tools (finished), turn_total=%.1fs",
|
||||
self.task_id[:8], turn + 1, api_elapsed, turn_elapsed,
|
||||
logger.debug(
|
||||
"Turn %d: model finished naturally (no tool calls)", turn + 1
|
||||
)
|
||||
|
||||
return AgentResult(
|
||||
@@ -555,10 +347,6 @@ class HermesAgentLoop:
|
||||
finished_naturally=True,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
# Hit max turns without the model stopping
|
||||
@@ -570,10 +358,6 @@ class HermesAgentLoop:
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
tool_calls_attempted=tool_calls_attempted,
|
||||
tool_calls_schema_valid=tool_calls_schema_valid,
|
||||
tool_calls_executed_ok=tool_calls_executed_ok,
|
||||
tool_calls_exec_error=tool_calls_exec_error,
|
||||
)
|
||||
|
||||
def _get_managed_state(self) -> Optional[Dict[str, Any]]:
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
# Terminal-Bench 2.0 Evaluation -- Default Configuration
|
||||
#
|
||||
# Eval-only environment for the TB2 benchmark (89 terminal tasks).
|
||||
# Uses Modal terminal backend for per-task cloud-isolated sandboxes
|
||||
# and OpenRouter for inference.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
# --config environments/benchmarks/terminalbench_2/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
# --config environments/benchmarks/terminalbench_2/default.yaml \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2"
|
||||
test_timeout: 600
|
||||
task_timeout: 1800 # 30 min wall-clock per task, auto-FAIL if exceeded
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "terminal-bench-2"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/terminal-bench-2"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
@@ -1,32 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Terminal-Bench 2.0 Evaluation
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-sonnet-4
|
||||
#
|
||||
# Run a subset:
|
||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
||||
# --env.task_filter fix-git,git-multibranch
|
||||
|
||||
mkdir -p logs evals/terminal-bench-2
|
||||
LOG_FILE="logs/terminalbench2_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "Terminal-Bench 2.0 Evaluation"
|
||||
echo "Log: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
export TERMINAL_ENV=modal
|
||||
export TERMINAL_TIMEOUT=300
|
||||
|
||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--config environments/benchmarks/terminalbench_2/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
@@ -1,904 +0,0 @@
|
||||
"""
|
||||
TerminalBench2Env -- Terminal-Bench 2.0 Evaluation Environment
|
||||
|
||||
Evaluates agentic LLMs on challenging terminal tasks from Terminal-Bench 2.0.
|
||||
Each task provides a unique Docker environment (pre-built on Docker Hub), a natural
|
||||
language instruction, and a test suite for verification. The agent uses terminal +
|
||||
file tools to complete the task, then the test suite runs inside the same sandbox.
|
||||
|
||||
This is an eval-only environment (not a training environment). It is designed to
|
||||
be run via the `evaluate` subcommand:
|
||||
|
||||
python environments/terminalbench2_env.py evaluate \\
|
||||
--env.dataset_name NousResearch/terminal-bench-2
|
||||
|
||||
The evaluate flow:
|
||||
1. setup() -- Loads the TB2 dataset from HuggingFace
|
||||
2. evaluate() -- Iterates over all tasks, running each through:
|
||||
a. rollout_and_score_eval() -- Per-task agent loop + test verification
|
||||
- Resolves Docker image (pre-built Hub image or Dockerfile fallback)
|
||||
- Registers per-task Modal sandbox via register_task_env_overrides()
|
||||
- Runs the HermesAgentLoop (terminal + file tools)
|
||||
- Uploads test suite and runs test.sh in the same sandbox
|
||||
- Returns binary pass/fail result
|
||||
b. Aggregates per-task, per-category, and overall pass rates
|
||||
c. Logs results via evaluate_log() and wandb
|
||||
|
||||
Key features:
|
||||
- Per-task Modal sandboxes using pre-built Docker Hub images
|
||||
- Binary reward: 1.0 if all tests pass, 0.0 otherwise
|
||||
- Concurrency-controlled parallel evaluation via asyncio.Semaphore
|
||||
- Per-task, per-category, and aggregate pass rate tracking
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.terminal_tool import (
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
cleanup_vm,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
class TerminalBench2EvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the Terminal-Bench 2.0 evaluation environment.
|
||||
|
||||
Extends HermesAgentEnvConfig with TB2-specific settings for dataset loading,
|
||||
test execution, task filtering, and eval concurrency.
|
||||
"""
|
||||
|
||||
# --- Dataset ---
|
||||
dataset_name: str = Field(
|
||||
default="NousResearch/terminal-bench-2",
|
||||
description="HuggingFace dataset containing TB2 tasks.",
|
||||
)
|
||||
|
||||
# --- Test execution ---
|
||||
test_timeout: int = Field(
|
||||
default=180,
|
||||
description="Timeout in seconds for running the test suite after agent completes.",
|
||||
)
|
||||
|
||||
# --- Image strategy ---
|
||||
force_build: bool = Field(
|
||||
default=False,
|
||||
description="If True, always build from Dockerfile (ignore docker_image). "
|
||||
"Useful for testing custom Dockerfiles.",
|
||||
)
|
||||
|
||||
# --- Task filtering (comma-separated from CLI) ---
|
||||
task_filter: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Comma-separated task names to run (e.g., 'fix-git,git-multibranch'). "
|
||||
"If not set, all tasks are run.",
|
||||
)
|
||||
skip_tasks: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Comma-separated task names to skip on top of the default skip list.",
|
||||
)
|
||||
|
||||
# --- Per-task wall-clock timeout ---
|
||||
task_timeout: int = Field(
|
||||
default=1800,
|
||||
description="Maximum wall-clock seconds per task (agent loop + verification). "
|
||||
"Tasks exceeding this are scored as FAIL. Default 30 minutes.",
|
||||
)
|
||||
|
||||
|
||||
# Tasks that cannot run properly on Modal and are excluded from scoring.
|
||||
MODAL_INCOMPATIBLE_TASKS = {
|
||||
"qemu-startup", # Needs KVM/hardware virtualization
|
||||
"qemu-alpine-ssh", # Needs KVM/hardware virtualization
|
||||
"crack-7z-hash", # Password brute-force -- too slow for cloud sandbox timeouts
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tar extraction helper
|
||||
# =============================================================================
|
||||
|
||||
def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
"""Extract a base64-encoded tar.gz archive into target_dir."""
|
||||
if not b64_data:
|
||||
return
|
||||
raw = base64.b64decode(b64_data)
|
||||
buf = io.BytesIO(raw)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
tar.extractall(path=str(target_dir))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
Terminal-Bench 2.0 evaluation environment (eval-only, no training).
|
||||
|
||||
Inherits from HermesAgentBaseEnv for:
|
||||
- Terminal backend setup (os.environ["TERMINAL_ENV"])
|
||||
- Tool resolution via _resolve_tools_for_group()
|
||||
- Monkey patches for async-safe tool operation
|
||||
- Wandb trajectory formatting
|
||||
|
||||
The evaluate flow (triggered by `environment.py evaluate`):
|
||||
1. setup() -- Load dataset from HuggingFace
|
||||
2. evaluate() -- Run all tasks through rollout_and_score_eval()
|
||||
|
||||
Each task in rollout_and_score_eval():
|
||||
1. Resolve Docker image (pre-built Hub image or Dockerfile fallback)
|
||||
2. Register per-task Modal sandbox override
|
||||
3. Run HermesAgentLoop with terminal + file tools
|
||||
4. Upload test suite and execute test.sh in the same sandbox
|
||||
5. Check /logs/verifier/reward.txt for pass/fail
|
||||
6. Clean up sandbox, overrides, and temp files
|
||||
"""
|
||||
|
||||
name = "terminal-bench-2"
|
||||
env_config_cls = TerminalBench2EvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[TerminalBench2EvalConfig, List[APIServerConfig]]:
|
||||
"""
|
||||
Default configuration for Terminal-Bench 2.0 evaluation.
|
||||
|
||||
Uses eval-only settings:
|
||||
- eval_handling=STOP_TRAIN so the eval flow runs cleanly
|
||||
- steps_per_eval=1, total_steps=1 so eval triggers immediately
|
||||
- group_size=1 (one rollout per group, each task is expensive)
|
||||
|
||||
Uses Modal terminal backend (cloud-isolated sandbox per task) and
|
||||
OpenRouter with Claude for inference.
|
||||
"""
|
||||
env_config = TerminalBench2EvalConfig(
|
||||
# Terminal + file tools only (the agent interacts via shell commands)
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
|
||||
# Agent settings -- TB2 tasks are complex, need many turns
|
||||
max_agent_turns=60,
|
||||
max_token_length=16000,
|
||||
agent_temperature=0.6,
|
||||
system_prompt=None,
|
||||
|
||||
# Modal backend for per-task cloud-isolated sandboxes
|
||||
terminal_backend="modal",
|
||||
terminal_timeout=300, # 5 min per command (builds, pip install, etc.)
|
||||
|
||||
# Test execution timeout (TB2 test scripts can install deps like pytest)
|
||||
test_timeout=180,
|
||||
|
||||
# 89 tasks run in parallel, each needs a thread for tool calls
|
||||
tool_pool_size=128,
|
||||
|
||||
# --- Eval-only Atropos settings ---
|
||||
# These settings make the env work as an eval-only environment:
|
||||
# - STOP_TRAIN: pauses training during eval (standard for eval envs)
|
||||
# - steps_per_eval=1, total_steps=1: eval triggers immediately
|
||||
# - group_size=1: one rollout per group (each task is expensive)
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-bench-2",
|
||||
ensure_scores_are_not_same=False, # Binary rewards may all be 0 or 1
|
||||
)
|
||||
|
||||
# OpenRouter with Claude -- API key loaded from .env
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Setup -- load dataset
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Load the Terminal-Bench 2.0 dataset from HuggingFace."""
|
||||
from datasets import load_dataset
|
||||
|
||||
# Auto-set terminal_lifetime to task_timeout + 120s so sandboxes
|
||||
# never get killed during an active task, but still get cleaned up
|
||||
# promptly after the task times out.
|
||||
lifetime = self.config.task_timeout + 120
|
||||
self.config.terminal_lifetime = lifetime
|
||||
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(lifetime)
|
||||
print(f" Terminal lifetime auto-set to {lifetime}s (task_timeout + 120s)")
|
||||
|
||||
print(f"Loading TB2 dataset from: {self.config.dataset_name}")
|
||||
ds = load_dataset(self.config.dataset_name, split="train")
|
||||
|
||||
# Apply task filters (comma-separated strings from CLI)
|
||||
tasks = list(ds)
|
||||
if self.config.task_filter:
|
||||
allowed = {name.strip() for name in self.config.task_filter.split(",")}
|
||||
tasks = [t for t in tasks if t["task_name"] in allowed]
|
||||
print(f" Filtered to {len(tasks)} tasks: {sorted(allowed)}")
|
||||
|
||||
# Skip tasks incompatible with the current backend (e.g., QEMU on Modal)
|
||||
# plus any user-specified skip_tasks
|
||||
skip = set(MODAL_INCOMPATIBLE_TASKS) if self.config.terminal_backend == "modal" else set()
|
||||
if self.config.skip_tasks:
|
||||
skip |= {name.strip() for name in self.config.skip_tasks.split(",")}
|
||||
if skip:
|
||||
before = len(tasks)
|
||||
tasks = [t for t in tasks if t["task_name"] not in skip]
|
||||
skipped = before - len(tasks)
|
||||
if skipped > 0:
|
||||
print(f" Skipped {skipped} incompatible tasks: {sorted(skip & {t['task_name'] for t in ds})}")
|
||||
|
||||
self.all_eval_items = tasks
|
||||
self.iter = 0
|
||||
|
||||
# Build category index for per-category metrics
|
||||
self.category_index: Dict[str, List[int]] = defaultdict(list)
|
||||
for i, task in enumerate(self.all_eval_items):
|
||||
self.category_index[task.get("category", "unknown")].append(i)
|
||||
|
||||
# Reward tracking for wandb logging
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
# Streaming JSONL writer -- saves each task's full conversation
|
||||
# immediately on completion so data is preserved even on Ctrl+C.
|
||||
# Timestamped filename so each run produces a unique file.
|
||||
import datetime
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
self._streaming_file = open(self._streaming_path, "w")
|
||||
self._streaming_lock = __import__("threading").Lock()
|
||||
print(f" Streaming results to: {self._streaming_path}")
|
||||
|
||||
print(f"TB2 ready: {len(self.all_eval_items)} tasks across {len(self.category_index)} categories")
|
||||
for cat, indices in sorted(self.category_index.items()):
|
||||
print(f" {cat}: {len(indices)} tasks")
|
||||
|
||||
def _save_result(self, result: Dict[str, Any]):
|
||||
"""Write a single task result to the streaming JSONL file immediately."""
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
# Training pipeline stubs -- NOT used in eval-only mode
|
||||
# =========================================================================
|
||||
# These satisfy the abstract method requirements from HermesAgentBaseEnv.
|
||||
# The evaluate subcommand calls setup() -> evaluate() directly, bypassing
|
||||
# the training pipeline entirely.
|
||||
|
||||
async def get_next_item(self):
|
||||
"""Return next item (stub -- not used in eval-only mode)."""
|
||||
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
"""Return the task's instruction as the user prompt."""
|
||||
return item["instruction"]
|
||||
|
||||
async def compute_reward(self, item, result, ctx) -> float:
|
||||
"""Compute reward (stub -- actual verification is in rollout_and_score_eval)."""
|
||||
return 0.0
|
||||
|
||||
async def collect_trajectories(self, item):
|
||||
"""Collect trajectories (stub -- not used in eval-only mode)."""
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data):
|
||||
"""Score rollouts (stub -- not used in eval-only mode)."""
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Docker image resolution
|
||||
# =========================================================================
|
||||
|
||||
def _resolve_task_image(
|
||||
self, item: Dict[str, Any], task_name: str
|
||||
) -> Tuple[str, Optional[Path]]:
|
||||
"""
|
||||
Resolve the Docker image for a task, with fallback to Dockerfile.
|
||||
|
||||
Strategy (mirrors Harbor's approach):
|
||||
1. If force_build=True, always build from Dockerfile in environment_tar
|
||||
2. If docker_image is available, use the pre-built Docker Hub image (fast)
|
||||
3. Otherwise, extract Dockerfile from environment_tar and build (slow)
|
||||
|
||||
Returns:
|
||||
(modal_image, temp_dir) -- modal_image is a Docker Hub name or a
|
||||
Dockerfile path. temp_dir is set if we extracted files that need
|
||||
cleanup later.
|
||||
"""
|
||||
docker_image = item.get("docker_image", "")
|
||||
environment_tar = item.get("environment_tar", "")
|
||||
|
||||
# Fast path: use pre-built Docker Hub image
|
||||
if docker_image and not self.config.force_build:
|
||||
logger.info("Task %s: using pre-built image %s", task_name, docker_image)
|
||||
return docker_image, None
|
||||
|
||||
# Slow path: extract Dockerfile from environment_tar and build
|
||||
if environment_tar:
|
||||
task_dir = Path(tempfile.mkdtemp(prefix=f"tb2-{task_name}-"))
|
||||
_extract_base64_tar(environment_tar, task_dir)
|
||||
dockerfile_path = task_dir / "Dockerfile"
|
||||
if dockerfile_path.exists():
|
||||
logger.info(
|
||||
"Task %s: building from Dockerfile (force_build=%s, docker_image=%s)",
|
||||
task_name, self.config.force_build, bool(docker_image),
|
||||
)
|
||||
return str(dockerfile_path), task_dir
|
||||
|
||||
# Neither available -- fall back to Hub image if force_build was True
|
||||
if docker_image:
|
||||
logger.warning(
|
||||
"Task %s: force_build=True but no environment_tar, "
|
||||
"falling back to docker_image %s", task_name, docker_image,
|
||||
)
|
||||
return docker_image, None
|
||||
|
||||
return "", None
|
||||
|
||||
# =========================================================================
|
||||
# Per-task evaluation -- agent loop + test verification
|
||||
# =========================================================================
|
||||
|
||||
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Evaluate a single TB2 task: run the agent loop, then verify with tests.
|
||||
|
||||
This is the core evaluation method. For each task it:
|
||||
1. Resolves the Docker image and registers the Modal sandbox override
|
||||
2. Runs HermesAgentLoop with terminal + file tools
|
||||
3. Uploads the test suite into the sandbox
|
||||
4. Executes test.sh and checks the result
|
||||
5. Cleans up the sandbox and temp files
|
||||
|
||||
Args:
|
||||
eval_item: A single TB2 task dict from the dataset
|
||||
|
||||
Returns:
|
||||
Dict with 'passed' (bool), 'reward' (float), 'task_name' (str),
|
||||
'category' (str), and optional debug info
|
||||
"""
|
||||
task_name = eval_item.get("task_name", "unknown")
|
||||
category = eval_item.get("category", "unknown")
|
||||
task_id = str(uuid.uuid4())
|
||||
task_dir = None # Set if we extract a Dockerfile (needs cleanup)
|
||||
|
||||
from tqdm import tqdm
|
||||
tqdm.write(f" [START] {task_name} (task_id={task_id[:8]})")
|
||||
task_start = time.time()
|
||||
|
||||
try:
|
||||
# --- 1. Resolve Docker image ---
|
||||
modal_image, task_dir = self._resolve_task_image(eval_item, task_name)
|
||||
if not modal_image:
|
||||
logger.error("Task %s: no docker_image or environment_tar, skipping", task_name)
|
||||
return {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": "no_image",
|
||||
}
|
||||
|
||||
# --- 2. Register per-task Modal image override ---
|
||||
register_task_env_overrides(task_id, {"modal_image": modal_image})
|
||||
logger.info(
|
||||
"Task %s: registered image override for task_id %s",
|
||||
task_name, task_id[:8],
|
||||
)
|
||||
|
||||
# --- 3. Resolve tools and build messages ---
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(eval_item)})
|
||||
|
||||
# --- 4. Run agent loop ---
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# --- 5. Verify -- run test suite in the agent's sandbox ---
|
||||
# Skip verification if the agent produced no meaningful output
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Task %s: agent produced no output (turns=%d). Reward=0.",
|
||||
task_name, result.turns_used,
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
# Run tests in a thread so the blocking ctx.terminal() calls
|
||||
# don't freeze the entire event loop (which would stall all
|
||||
# other tasks, tqdm updates, and timeout timers).
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
reward = await loop.run_in_executor(
|
||||
None, # default thread pool
|
||||
self._run_tests, eval_item, ctx, task_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Task %s: test verification failed: %s", task_name, e)
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
passed = reward == 1.0
|
||||
status = "PASS" if passed else "FAIL"
|
||||
elapsed = time.time() - task_start
|
||||
tqdm.write(f" [{status}] {task_name} (turns={result.turns_used}, {elapsed:.0f}s)")
|
||||
logger.info(
|
||||
"Task %s: reward=%.1f, turns=%d, finished=%s",
|
||||
task_name, reward, result.turns_used, result.finished_naturally,
|
||||
)
|
||||
|
||||
out = {
|
||||
"passed": passed,
|
||||
"reward": reward,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"messages": result.messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - task_start
|
||||
logger.error("Task %s: rollout failed: %s", task_name, e, exc_info=True)
|
||||
tqdm.write(f" [ERROR] {task_name}: {e} ({elapsed:.0f}s)")
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": str(e),
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
finally:
|
||||
# --- Cleanup: clear overrides, sandbox, and temp files ---
|
||||
clear_task_env_overrides(task_id)
|
||||
try:
|
||||
cleanup_vm(task_id)
|
||||
except Exception as e:
|
||||
logger.debug("VM cleanup for %s: %s", task_id[:8], e)
|
||||
if task_dir and task_dir.exists():
|
||||
shutil.rmtree(task_dir, ignore_errors=True)
|
||||
|
||||
def _run_tests(
|
||||
self, item: Dict[str, Any], ctx: ToolContext, task_name: str
|
||||
) -> float:
|
||||
"""
|
||||
Upload and execute the test suite in the agent's sandbox, then
|
||||
download the verifier output locally to read the reward.
|
||||
|
||||
Follows Harbor's verification pattern:
|
||||
1. Upload tests/ directory into the sandbox
|
||||
2. Execute test.sh inside the sandbox
|
||||
3. Download /logs/verifier/ directory to a local temp dir
|
||||
4. Read reward.txt locally with native Python I/O
|
||||
|
||||
Downloading locally avoids issues with the file_read tool on
|
||||
the Modal VM and matches how Harbor handles verification.
|
||||
|
||||
TB2 test scripts (test.sh) typically:
|
||||
1. Install pytest via uv/pip
|
||||
2. Run pytest against the test files in /tests/
|
||||
3. Write results to /logs/verifier/reward.txt
|
||||
|
||||
Args:
|
||||
item: The TB2 task dict (contains tests_tar, test_sh)
|
||||
ctx: ToolContext scoped to this task's sandbox
|
||||
task_name: For logging
|
||||
|
||||
Returns:
|
||||
1.0 if tests pass, 0.0 otherwise
|
||||
"""
|
||||
tests_tar = item.get("tests_tar", "")
|
||||
test_sh = item.get("test_sh", "")
|
||||
|
||||
if not test_sh:
|
||||
logger.warning("Task %s: no test_sh content, reward=0", task_name)
|
||||
return 0.0
|
||||
|
||||
# Create required directories in the sandbox
|
||||
ctx.terminal("mkdir -p /tests /logs/verifier")
|
||||
|
||||
# Upload test files into the sandbox (binary-safe via base64)
|
||||
if tests_tar:
|
||||
tests_temp = Path(tempfile.mkdtemp(prefix=f"tb2-tests-{task_name}-"))
|
||||
try:
|
||||
_extract_base64_tar(tests_tar, tests_temp)
|
||||
ctx.upload_dir(str(tests_temp), "/tests")
|
||||
except Exception as e:
|
||||
logger.warning("Task %s: failed to upload test files: %s", task_name, e)
|
||||
finally:
|
||||
shutil.rmtree(tests_temp, ignore_errors=True)
|
||||
|
||||
# Write the test runner script (test.sh)
|
||||
ctx.write_file("/tests/test.sh", test_sh)
|
||||
ctx.terminal("chmod +x /tests/test.sh")
|
||||
|
||||
# Execute the test suite
|
||||
logger.info(
|
||||
"Task %s: running test suite (timeout=%ds)",
|
||||
task_name, self.config.test_timeout,
|
||||
)
|
||||
test_result = ctx.terminal(
|
||||
"bash /tests/test.sh",
|
||||
timeout=self.config.test_timeout,
|
||||
)
|
||||
|
||||
exit_code = test_result.get("exit_code", -1)
|
||||
output = test_result.get("output", "")
|
||||
|
||||
# Download the verifier output directory locally, then read reward.txt
|
||||
# with native Python I/O. This avoids issues with file_read on the
|
||||
# Modal VM and matches Harbor's verification pattern.
|
||||
reward = 0.0
|
||||
local_verifier_dir = Path(tempfile.mkdtemp(prefix=f"tb2-verifier-{task_name}-"))
|
||||
try:
|
||||
ctx.download_dir("/logs/verifier", str(local_verifier_dir))
|
||||
|
||||
reward_file = local_verifier_dir / "reward.txt"
|
||||
if reward_file.exists() and reward_file.stat().st_size > 0:
|
||||
content = reward_file.read_text().strip()
|
||||
if content == "1":
|
||||
reward = 1.0
|
||||
elif content == "0":
|
||||
reward = 0.0
|
||||
else:
|
||||
# Unexpected content -- try parsing as float
|
||||
try:
|
||||
reward = float(content)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Task %s: reward.txt content unexpected (%r), "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, content, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
else:
|
||||
# reward.txt not written -- fall back to exit code
|
||||
logger.warning(
|
||||
"Task %s: reward.txt not found after download, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Task %s: failed to download verifier dir: %s, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, e, exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
finally:
|
||||
shutil.rmtree(local_verifier_dir, ignore_errors=True)
|
||||
|
||||
# Log test output for debugging failures
|
||||
if reward == 0.0:
|
||||
output_preview = output[-500:] if output else "(no output)"
|
||||
logger.info(
|
||||
"Task %s: FAIL (exit_code=%d)\n%s",
|
||||
task_name, exit_code, output_preview,
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
# =========================================================================
|
||||
# Evaluate -- main entry point for the eval subcommand
|
||||
# =========================================================================
|
||||
|
||||
async def _eval_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Wrap rollout_and_score_eval with a per-task wall-clock timeout.
|
||||
|
||||
If the task exceeds task_timeout seconds, it's automatically scored
|
||||
as FAIL. This prevents any single task from hanging indefinitely.
|
||||
"""
|
||||
task_name = item.get("task_name", "unknown")
|
||||
category = item.get("category", "unknown")
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
timeout=self.config.task_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
elapsed = self.config.task_timeout
|
||||
tqdm.write(f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)")
|
||||
logger.error("Task %s: wall-clock timeout after %ds", task_name, elapsed)
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"error": f"timeout ({elapsed}s)",
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run Terminal-Bench 2.0 evaluation over all tasks.
|
||||
|
||||
This is the main entry point when invoked via:
|
||||
python environments/terminalbench2_env.py evaluate
|
||||
|
||||
Runs all tasks through rollout_and_score_eval() via asyncio.gather()
|
||||
(same pattern as GPQA and other Atropos eval envs). Each task is
|
||||
wrapped with a wall-clock timeout so hung tasks auto-fail.
|
||||
|
||||
Suppresses noisy Modal/terminal output (HERMES_QUIET) so the tqdm
|
||||
bar stays visible.
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Route all logging through tqdm.write() so the progress bar stays
|
||||
# pinned at the bottom while log lines scroll above it.
|
||||
from tqdm import tqdm
|
||||
|
||||
class _TqdmHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
tqdm.write(self.format(record))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
))
|
||||
root = logging.getLogger()
|
||||
root.handlers = [handler] # Replace any existing handlers
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
# Silence noisy third-party loggers that flood the output
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request
|
||||
logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries
|
||||
logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment
|
||||
logging.getLogger("rex_image_builder").setLevel(logging.WARNING) # Image builds
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting Terminal-Bench 2.0 Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Total tasks: {len(self.all_eval_items)}")
|
||||
print(f" Max agent turns: {self.config.max_agent_turns}")
|
||||
print(f" Task timeout: {self.config.task_timeout}s")
|
||||
print(f" Terminal backend: {self.config.terminal_backend}")
|
||||
print(f" Tool thread pool: {self.config.tool_pool_size}")
|
||||
print(f" Terminal timeout: {self.config.terminal_timeout}s/cmd")
|
||||
print(f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Fire all tasks with wall-clock timeout, track live accuracy on the bar
|
||||
total_tasks = len(self.all_eval_items)
|
||||
eval_tasks = [
|
||||
asyncio.ensure_future(self._eval_with_timeout(item))
|
||||
for item in self.all_eval_items
|
||||
]
|
||||
|
||||
results = []
|
||||
passed_count = 0
|
||||
pbar = tqdm(total=total_tasks, desc="Evaluating TB2", dynamic_ncols=True)
|
||||
try:
|
||||
for coro in asyncio.as_completed(eval_tasks):
|
||||
result = await coro
|
||||
results.append(result)
|
||||
if result and result.get("passed"):
|
||||
passed_count += 1
|
||||
done = len(results)
|
||||
pct = (passed_count / done * 100) if done else 0
|
||||
pbar.set_postfix_str(f"pass={passed_count}/{done} ({pct:.1f}%)")
|
||||
pbar.update(1)
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
pbar.close()
|
||||
print(f"\n\nInterrupted! Cleaning up {len(eval_tasks)} tasks...")
|
||||
# Cancel all pending tasks
|
||||
for task in eval_tasks:
|
||||
task.cancel()
|
||||
# Let cancellations propagate (finally blocks run cleanup_vm)
|
||||
await asyncio.gather(*eval_tasks, return_exceptions=True)
|
||||
# Belt-and-suspenders: clean up any remaining sandboxes
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
print("All sandboxes cleaned up.")
|
||||
return
|
||||
finally:
|
||||
pbar.close()
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# Filter out None results (shouldn't happen, but be safe)
|
||||
valid_results = [r for r in results if r is not None]
|
||||
|
||||
if not valid_results:
|
||||
print("Warning: No valid evaluation results obtained")
|
||||
return
|
||||
|
||||
# ---- Compute metrics ----
|
||||
total = len(valid_results)
|
||||
passed = sum(1 for r in valid_results if r.get("passed"))
|
||||
overall_pass_rate = passed / total if total > 0 else 0.0
|
||||
|
||||
# Per-category breakdown
|
||||
cat_results: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in valid_results:
|
||||
cat_results[r.get("category", "unknown")].append(r)
|
||||
|
||||
# Build metrics dict
|
||||
eval_metrics = {
|
||||
"eval/pass_rate": overall_pass_rate,
|
||||
"eval/total_tasks": total,
|
||||
"eval/passed_tasks": passed,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
}
|
||||
|
||||
# Per-category metrics
|
||||
for category, cat_items in sorted(cat_results.items()):
|
||||
cat_passed = sum(1 for r in cat_items if r.get("passed"))
|
||||
cat_total = len(cat_items)
|
||||
cat_pass_rate = cat_passed / cat_total if cat_total > 0 else 0.0
|
||||
cat_key = category.replace(" ", "_").replace("-", "_").lower()
|
||||
eval_metrics[f"eval/pass_rate_{cat_key}"] = cat_pass_rate
|
||||
|
||||
# Store metrics for wandb_log
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# ---- Print summary ----
|
||||
print(f"\n{'='*60}")
|
||||
print("Terminal-Bench 2.0 Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(f"Overall Pass Rate: {overall_pass_rate:.4f} ({passed}/{total})")
|
||||
print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
|
||||
|
||||
print("\nCategory Breakdown:")
|
||||
for category, cat_items in sorted(cat_results.items()):
|
||||
cat_passed = sum(1 for r in cat_items if r.get("passed"))
|
||||
cat_total = len(cat_items)
|
||||
cat_rate = cat_passed / cat_total if cat_total > 0 else 0.0
|
||||
print(f" {category}: {cat_rate:.1%} ({cat_passed}/{cat_total})")
|
||||
|
||||
# Print individual task results
|
||||
print("\nTask Results:")
|
||||
for r in sorted(valid_results, key=lambda x: x.get("task_name", "")):
|
||||
status = "PASS" if r.get("passed") else "FAIL"
|
||||
turns = r.get("turns_used", "?")
|
||||
error = r.get("error", "")
|
||||
extra = f" (error: {error})" if error else ""
|
||||
print(f" [{status}] {r['task_name']} (turns={turns}){extra}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Build sample records for evaluate_log (includes full conversations)
|
||||
samples = [
|
||||
{
|
||||
"task_name": r.get("task_name"),
|
||||
"category": r.get("category"),
|
||||
"passed": r.get("passed"),
|
||||
"reward": r.get("reward"),
|
||||
"turns_used": r.get("turns_used"),
|
||||
"error": r.get("error"),
|
||||
"messages": r.get("messages"),
|
||||
}
|
||||
for r in valid_results
|
||||
]
|
||||
|
||||
# Log evaluation results
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
"terminal_backend": self.config.terminal_backend,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error logging evaluation results: {e}")
|
||||
|
||||
# Close streaming file
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
print(f" Live results saved to: {self._streaming_path}")
|
||||
|
||||
# Kill all remaining sandboxes. Timed-out tasks leave orphaned thread
|
||||
# pool workers still executing commands -- cleanup_all stops them.
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
print("\nCleaning up all sandboxes...")
|
||||
cleanup_all_environments()
|
||||
|
||||
# Shut down the tool thread pool so orphaned workers from timed-out
|
||||
# tasks are killed immediately instead of retrying against dead
|
||||
# sandboxes and spamming the console with TimeoutError warnings.
|
||||
from environments.agent_loop import _tool_executor
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
print("Done.")
|
||||
|
||||
# =========================================================================
|
||||
# Wandb logging
|
||||
# =========================================================================
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log TB2-specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Add stored eval metrics
|
||||
for metric_name, metric_value in self.eval_metrics:
|
||||
wandb_metrics[metric_name] = metric_value
|
||||
self.eval_metrics = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TerminalBench2EvalEnv.cli()
|
||||
@@ -4,8 +4,7 @@
|
||||
# Uses terminal + file + web toolsets.
|
||||
#
|
||||
# Usage:
|
||||
# python environments/hermes_swe_env/hermes_swe_env.py serve \
|
||||
# --config environments/hermes_swe_env/default.yaml
|
||||
# python environments/hermes_swe_env.py serve --config environments/configs/swe_default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file", "web"]
|
||||
+3
-2
@@ -6,8 +6,9 @@
|
||||
#
|
||||
# Usage:
|
||||
# run-api
|
||||
# python environments/terminal_test_env/terminal_test_env.py serve \
|
||||
# --config environments/terminal_test_env/default.yaml
|
||||
# python environments/terminal_test_env.py serve
|
||||
# # Or with config file:
|
||||
# python environments/terminal_test_env.py serve --config environments/configs/terminal_test_default.yaml
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
@@ -117,18 +117,6 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
description="Terminal backend: 'local', 'docker', 'modal', 'ssh', 'singularity'. "
|
||||
"Modal recommended for production RL (cloud isolation per rollout).",
|
||||
)
|
||||
terminal_timeout: int = Field(
|
||||
default=120,
|
||||
description="Per-command timeout in seconds for terminal tool calls. "
|
||||
"Commands exceeding this are killed. Increase for tasks with long-running "
|
||||
"commands (compilation, pip install, etc.).",
|
||||
)
|
||||
terminal_lifetime: int = Field(
|
||||
default=3600,
|
||||
description="Sandbox inactivity lifetime in seconds. The cleanup thread kills "
|
||||
"sandboxes that have been idle longer than this. Must be longer than "
|
||||
"the longest gap between tool calls (e.g., waiting for LLM response).",
|
||||
)
|
||||
|
||||
# --- Dataset ---
|
||||
dataset_name: Optional[str] = Field(
|
||||
@@ -144,14 +132,6 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
description="Which field in the dataset contains the prompt.",
|
||||
)
|
||||
|
||||
# --- Thread pool ---
|
||||
tool_pool_size: int = Field(
|
||||
default=128,
|
||||
description="Thread pool size for tool execution. Each concurrent task needs a "
|
||||
"thread for tool calls. Must be large enough for parallel evaluation. "
|
||||
"Too small = thread pool starvation.",
|
||||
)
|
||||
|
||||
# --- Phase 2: Tool call parsing ---
|
||||
tool_call_parser: str = Field(
|
||||
default="hermes",
|
||||
@@ -160,22 +140,6 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
||||
)
|
||||
|
||||
# --- Provider-specific parameters ---
|
||||
# Passed as extra_body to the OpenAI client's chat.completions.create() call.
|
||||
# Useful for OpenRouter provider preferences, transforms, route settings, etc.
|
||||
# Example YAML:
|
||||
# extra_body:
|
||||
# provider:
|
||||
# ignore: ["DeepInfra", "Fireworks"]
|
||||
# order: ["Together"]
|
||||
# transforms: ["middle-out"]
|
||||
extra_body: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Extra body parameters passed to the OpenAI client's "
|
||||
"chat.completions.create(). Used for OpenRouter provider preferences, "
|
||||
"transforms, and other provider-specific settings.",
|
||||
)
|
||||
|
||||
|
||||
class HermesAgentBaseEnv(BaseEnv):
|
||||
"""
|
||||
@@ -211,23 +175,10 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
):
|
||||
super().__init__(config, server_configs, slurm, testing)
|
||||
|
||||
# Set terminal environment variables so hermes tools pick them up.
|
||||
# These can all be overridden per-environment via config fields instead
|
||||
# of requiring users to set shell env vars.
|
||||
# Set terminal backend environment variable so hermes tools pick it up
|
||||
if config.terminal_backend:
|
||||
os.environ["TERMINAL_ENV"] = config.terminal_backend
|
||||
os.environ["TERMINAL_TIMEOUT"] = str(config.terminal_timeout)
|
||||
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(config.terminal_lifetime)
|
||||
print(
|
||||
f"🖥️ Terminal: backend={config.terminal_backend}, "
|
||||
f"timeout={config.terminal_timeout}s, lifetime={config.terminal_lifetime}s"
|
||||
)
|
||||
|
||||
# Resize the agent loop's thread pool for tool execution.
|
||||
# This must be large enough for the number of concurrent tasks
|
||||
# (e.g., 89 parallel TB2 eval tasks each need a thread for tool calls).
|
||||
from environments.agent_loop import resize_tool_pool
|
||||
resize_tool_pool(config.tool_pool_size)
|
||||
print(f"🖥️ Terminal backend: {config.terminal_backend}")
|
||||
|
||||
# Current group's resolved tools (set in collect_trajectories)
|
||||
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
|
||||
@@ -478,7 +429,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
tokenizer=self.tokenizer,
|
||||
tool_call_parser=tc_parser,
|
||||
) as managed:
|
||||
_max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None
|
||||
agent = HermesAgentLoop(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
@@ -487,8 +437,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
except NotImplementedError:
|
||||
@@ -497,7 +445,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
"ManagedServer not available (OpenAI server?). "
|
||||
"Falling back to direct server mode."
|
||||
)
|
||||
_max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
@@ -506,13 +453,10 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
|
||||
_max_ctx = self.config.max_token_length if (self.config.max_token_length and self.config.max_token_length > 0) else None
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
@@ -521,8 +465,6 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
max_context_tokens=_max_ctx,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
+1
-1
@@ -36,7 +36,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
# Ensure repo root is on sys.path for imports
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
||||
_repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
@@ -49,22 +49,15 @@ class HermesToolCallParser(ToolCallParser):
|
||||
continue
|
||||
|
||||
tc_data = json.loads(raw_json)
|
||||
# Handle arguments: could be dict or already a JSON string
|
||||
raw_args = tc_data.get("arguments", {})
|
||||
if isinstance(raw_args, str):
|
||||
# Already a string — pass through as-is.
|
||||
# It may be a JSON string ("{...}") or a plain string ("ls").
|
||||
args_str = raw_args
|
||||
else:
|
||||
# Dict — serialize to JSON
|
||||
args_str = json.dumps(raw_args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc_data["name"],
|
||||
arguments=args_str,
|
||||
arguments=json.dumps(
|
||||
tc_data.get("arguments", {}), ensure_ascii=False
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -129,14 +129,11 @@ class ToolContext:
|
||||
|
||||
def write_file(self, path: str, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Write a TEXT file in the rollout's filesystem.
|
||||
|
||||
Uses a shell heredoc under the hood, so this is only safe for text content.
|
||||
For binary files (images, compiled artifacts, etc.), use upload_file() instead.
|
||||
Write a file in the rollout's filesystem.
|
||||
|
||||
Args:
|
||||
path: File path to write
|
||||
content: Text content to write
|
||||
content: Content to write
|
||||
|
||||
Returns:
|
||||
Dict with success status or error
|
||||
@@ -149,177 +146,6 @@ class ToolContext:
|
||||
except json.JSONDecodeError:
|
||||
return {"error": result}
|
||||
|
||||
def upload_file(self, local_path: str, remote_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Upload a local file to the rollout's sandbox (binary-safe).
|
||||
|
||||
Unlike write_file() which passes content through a shell heredoc (text-only),
|
||||
this method base64-encodes the file and decodes it inside the sandbox.
|
||||
Safe for any file type: binaries, images, archives, etc.
|
||||
|
||||
For large files (>1MB), the content is split into chunks to avoid
|
||||
hitting shell command-length limits.
|
||||
|
||||
Args:
|
||||
local_path: Path to a local file on the host
|
||||
remote_path: Destination path inside the sandbox
|
||||
|
||||
Returns:
|
||||
Dict with 'exit_code' and 'output'
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path as _Path
|
||||
|
||||
local = _Path(local_path)
|
||||
if not local.exists():
|
||||
return {"exit_code": -1, "output": f"Local file not found: {local_path}"}
|
||||
|
||||
raw = local.read_bytes()
|
||||
b64 = base64.b64encode(raw).decode("ascii")
|
||||
|
||||
# Ensure parent directory exists in the sandbox
|
||||
parent = str(_Path(remote_path).parent)
|
||||
if parent not in (".", "/"):
|
||||
self.terminal(f"mkdir -p {parent}", timeout=10)
|
||||
|
||||
# For small files, single command is fine
|
||||
chunk_size = 60_000 # ~60KB per chunk (well within shell limits)
|
||||
if len(b64) <= chunk_size:
|
||||
result = self.terminal(
|
||||
f"printf '%s' '{b64}' | base64 -d > {remote_path}",
|
||||
timeout=30,
|
||||
)
|
||||
else:
|
||||
# For larger files, write base64 in chunks then decode
|
||||
tmp_b64 = "/tmp/_hermes_upload.b64"
|
||||
self.terminal(f": > {tmp_b64}", timeout=5) # truncate
|
||||
for i in range(0, len(b64), chunk_size):
|
||||
chunk = b64[i : i + chunk_size]
|
||||
self.terminal(f"printf '%s' '{chunk}' >> {tmp_b64}", timeout=15)
|
||||
result = self.terminal(
|
||||
f"base64 -d {tmp_b64} > {remote_path} && rm -f {tmp_b64}",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def upload_dir(self, local_dir: str, remote_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Upload an entire local directory to the rollout's sandbox (binary-safe).
|
||||
|
||||
Recursively uploads all files, preserving directory structure.
|
||||
|
||||
Args:
|
||||
local_dir: Path to a local directory on the host
|
||||
remote_dir: Destination directory inside the sandbox
|
||||
|
||||
Returns:
|
||||
List of results, one per file uploaded
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
local = _Path(local_dir)
|
||||
if not local.exists() or not local.is_dir():
|
||||
return [{"exit_code": -1, "output": f"Local directory not found: {local_dir}"}]
|
||||
|
||||
results = []
|
||||
for file_path in sorted(local.rglob("*")):
|
||||
if file_path.is_file():
|
||||
relative = file_path.relative_to(local)
|
||||
target = f"{remote_dir}/{relative}"
|
||||
results.append(self.upload_file(str(file_path), target))
|
||||
return results
|
||||
|
||||
def download_file(self, remote_path: str, local_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Download a file from the rollout's sandbox to the host (binary-safe).
|
||||
|
||||
The inverse of upload_file(). Base64-encodes the file inside the sandbox,
|
||||
reads the encoded data through the terminal, and decodes it locally.
|
||||
Safe for any file type.
|
||||
|
||||
Args:
|
||||
remote_path: Path to the file inside the sandbox
|
||||
local_path: Destination path on the host
|
||||
|
||||
Returns:
|
||||
Dict with 'success' (bool) and 'bytes' (int) or 'error' (str)
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# Base64-encode the file inside the sandbox and capture output
|
||||
result = self.terminal(
|
||||
f"base64 {remote_path} 2>/dev/null",
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
if result.get("exit_code", -1) != 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to read remote file: {result.get('output', '')}",
|
||||
}
|
||||
|
||||
b64_data = result.get("output", "").strip()
|
||||
if not b64_data:
|
||||
return {"success": False, "error": f"Remote file is empty or missing: {remote_path}"}
|
||||
|
||||
try:
|
||||
raw = base64.b64decode(b64_data)
|
||||
except Exception as e:
|
||||
return {"success": False, "error": f"Base64 decode failed: {e}"}
|
||||
|
||||
# Write to local host filesystem
|
||||
local = _Path(local_path)
|
||||
local.parent.mkdir(parents=True, exist_ok=True)
|
||||
local.write_bytes(raw)
|
||||
|
||||
return {"success": True, "bytes": len(raw)}
|
||||
|
||||
def download_dir(self, remote_dir: str, local_dir: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Download a directory from the rollout's sandbox to the host (binary-safe).
|
||||
|
||||
Lists all files in the remote directory, then downloads each one.
|
||||
Preserves directory structure.
|
||||
|
||||
Args:
|
||||
remote_dir: Path to the directory inside the sandbox
|
||||
local_dir: Destination directory on the host
|
||||
|
||||
Returns:
|
||||
List of results, one per file downloaded
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
# List files in the remote directory
|
||||
ls_result = self.terminal(
|
||||
f"find {remote_dir} -type f 2>/dev/null",
|
||||
timeout=15,
|
||||
)
|
||||
|
||||
if ls_result.get("exit_code", -1) != 0:
|
||||
return [{"success": False, "error": f"Failed to list remote dir: {remote_dir}"}]
|
||||
|
||||
file_list = ls_result.get("output", "").strip()
|
||||
if not file_list:
|
||||
return [{"success": False, "error": f"Remote directory is empty or missing: {remote_dir}"}]
|
||||
|
||||
results = []
|
||||
for remote_file in file_list.splitlines():
|
||||
remote_file = remote_file.strip()
|
||||
if not remote_file:
|
||||
continue
|
||||
# Compute the relative path to preserve directory structure
|
||||
if remote_file.startswith(remote_dir):
|
||||
relative = remote_file[len(remote_dir):].lstrip("/")
|
||||
else:
|
||||
relative = _Path(remote_file).name
|
||||
local_file = str(_Path(local_dir) / relative)
|
||||
results.append(self.download_file(remote_file, local_file))
|
||||
|
||||
return results
|
||||
|
||||
def search(self, query: str, path: str = ".") -> Dict[str, Any]:
|
||||
"""
|
||||
Search for text in the rollout's filesystem.
|
||||
|
||||
+15
-166
@@ -6,11 +6,10 @@ and implement the required methods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable, Tuple
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
@@ -178,123 +177,6 @@ class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an image natively via the platform API.
|
||||
|
||||
Override in subclasses to send images as proper attachments
|
||||
instead of plain-text URLs. Default falls back to sending the
|
||||
URL as a text message.
|
||||
"""
|
||||
# Fallback: send URL as text (subclasses override for native images)
|
||||
text = f"{caption}\n{image_url}" if caption else image_url
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
@staticmethod
|
||||
def extract_images(content: str) -> Tuple[List[Tuple[str, str]], str]:
|
||||
"""
|
||||
Extract image URLs from markdown and HTML image tags in a response.
|
||||
|
||||
Finds patterns like:
|
||||
- 
|
||||
- <img src="https://example.com/image.png">
|
||||
- <img src="https://example.com/image.png"></img>
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (url, alt_text) pairs, cleaned content with image tags removed).
|
||||
"""
|
||||
images = []
|
||||
cleaned = content
|
||||
|
||||
# Match markdown images: 
|
||||
md_pattern = r'!\[([^\]]*)\]\((https?://[^\s\)]+)\)'
|
||||
for match in re.finditer(md_pattern, content):
|
||||
alt_text = match.group(1)
|
||||
url = match.group(2)
|
||||
# Only extract URLs that look like actual images
|
||||
if any(url.lower().endswith(ext) or ext in url.lower() for ext in
|
||||
['.png', '.jpg', '.jpeg', '.gif', '.webp', 'fal.media', 'fal-cdn', 'replicate.delivery']):
|
||||
images.append((url, alt_text))
|
||||
|
||||
# Match HTML img tags: <img src="url"> or <img src="url"></img> or <img src="url"/>
|
||||
html_pattern = r'<img\s+src=["\']?(https?://[^\s"\'<>]+)["\']?\s*/?>\s*(?:</img>)?'
|
||||
for match in re.finditer(html_pattern, content):
|
||||
url = match.group(1)
|
||||
images.append((url, ""))
|
||||
|
||||
# Remove matched image tags from content if we found images
|
||||
if images:
|
||||
cleaned = re.sub(md_pattern, '', cleaned)
|
||||
cleaned = re.sub(html_pattern, '', cleaned)
|
||||
# Clean up leftover blank lines
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
return images, cleaned
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send an audio file as a native voice message via the platform API.
|
||||
|
||||
Override in subclasses to send audio as voice bubbles (Telegram)
|
||||
or file attachments (Discord). Default falls back to sending the
|
||||
file path as text.
|
||||
"""
|
||||
text = f"🔊 Audio: {audio_path}"
|
||||
if caption:
|
||||
text = f"{caption}\n{text}"
|
||||
return await self.send(chat_id=chat_id, content=text, reply_to=reply_to)
|
||||
|
||||
@staticmethod
|
||||
def extract_media(content: str) -> Tuple[List[Tuple[str, bool]], str]:
|
||||
"""
|
||||
Extract MEDIA:<path> tags and [[audio_as_voice]] directives from response text.
|
||||
|
||||
The TTS tool returns responses like:
|
||||
[[audio_as_voice]]
|
||||
MEDIA:/path/to/audio.ogg
|
||||
|
||||
Args:
|
||||
content: The response text to scan.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of (path, is_voice) pairs, cleaned content with tags removed).
|
||||
"""
|
||||
media = []
|
||||
cleaned = content
|
||||
|
||||
# Check for [[audio_as_voice]] directive
|
||||
has_voice_tag = "[[audio_as_voice]]" in content
|
||||
cleaned = cleaned.replace("[[audio_as_voice]]", "")
|
||||
|
||||
# Extract MEDIA:<path> tags (path may contain spaces)
|
||||
media_pattern = r'MEDIA:(\S+)'
|
||||
for match in re.finditer(media_pattern, content):
|
||||
path = match.group(1).strip()
|
||||
if path:
|
||||
media.append((path, has_voice_tag))
|
||||
|
||||
# Remove MEDIA tags from content
|
||||
if media:
|
||||
cleaned = re.sub(media_pattern, '', cleaned)
|
||||
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned).strip()
|
||||
|
||||
return media, cleaned
|
||||
|
||||
async def _keep_typing(self, chat_id: str, interval: float = 2.0) -> None:
|
||||
"""
|
||||
Continuously send typing indicator until cancelled.
|
||||
@@ -349,56 +231,23 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
# Send response if any
|
||||
if response:
|
||||
# Extract MEDIA:<path> tags (from TTS tool) before other processing
|
||||
media_files, response = self.extract_media(response)
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=response,
|
||||
reply_to=event.message_id
|
||||
)
|
||||
|
||||
# Extract image URLs and send them as native platform attachments
|
||||
images, text_content = self.extract_images(response)
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
result = await self.send(
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=text_content,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{response[:3500]}",
|
||||
reply_to=event.message_id
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Send extracted images as native attachments
|
||||
for image_url, alt_text in images:
|
||||
try:
|
||||
img_result = await self.send_image(
|
||||
chat_id=event.source.chat_id,
|
||||
image_url=image_url,
|
||||
caption=alt_text if alt_text else None,
|
||||
)
|
||||
if not img_result.success:
|
||||
print(f"[{self.name}] Failed to send image: {img_result.error}")
|
||||
except Exception as img_err:
|
||||
print(f"[{self.name}] Error sending image: {img_err}")
|
||||
|
||||
# Send extracted audio/voice files as native attachments
|
||||
for audio_path, is_voice in media_files:
|
||||
try:
|
||||
voice_result = await self.send_voice(
|
||||
chat_id=event.source.chat_id,
|
||||
audio_path=audio_path,
|
||||
)
|
||||
if not voice_result.success:
|
||||
print(f"[{self.name}] Failed to send voice: {voice_result.error}")
|
||||
except Exception as voice_err:
|
||||
print(f"[{self.name}] Error sending voice: {voice_err}")
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
@@ -437,7 +286,7 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
|
||||
"""Get and clear any pending message for a session."""
|
||||
return self._pending_messages.pop(session_key, None)
|
||||
return self._pending_messages.get(session_key)
|
||||
|
||||
def build_source(
|
||||
self,
|
||||
|
||||
@@ -8,7 +8,6 @@ Uses discord.py library for:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
try:
|
||||
@@ -174,99 +173,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
# Determine filename from path
|
||||
filename = os.path.basename(audio_path)
|
||||
|
||||
with open(audio_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Download the image and send as a Discord file attachment
|
||||
# (Discord renders attachments inline, unlike plain URLs)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to download image: HTTP {resp.status}")
|
||||
|
||||
image_data = await resp.read()
|
||||
|
||||
# Determine filename from URL or content type
|
||||
content_type = resp.headers.get("content-type", "image/png")
|
||||
ext = "png"
|
||||
if "jpeg" in content_type or "jpg" in content_type:
|
||||
ext = "jpg"
|
||||
elif "gif" in content_type:
|
||||
ext = "gif"
|
||||
elif "webp" in content_type:
|
||||
ext = "webp"
|
||||
|
||||
import io
|
||||
file = discord.File(io.BytesIO(image_data), filename=f"image.{ext}")
|
||||
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, falling back to URL. Run: pip install aiohttp")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send image attachment, falling back to URL: {e}")
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._client:
|
||||
@@ -326,36 +232,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
# UNLESS the channel is in the free-response list.
|
||||
#
|
||||
# Config:
|
||||
# DISCORD_FREE_RESPONSE_CHANNELS: Comma-separated channel IDs where the
|
||||
# bot responds to every message without needing a mention.
|
||||
# DISCORD_REQUIRE_MENTION: Set to "false" to disable mention requirement
|
||||
# globally (all channels become free-response). Default: "true".
|
||||
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
# Check if this channel is in the free-response list
|
||||
free_channels_raw = os.getenv("DISCORD_FREE_RESPONSE_CHANNELS", "")
|
||||
free_channels = {ch.strip() for ch in free_channels_raw.split(",") if ch.strip()}
|
||||
channel_id = str(message.channel.id)
|
||||
|
||||
# Global override: if DISCORD_REQUIRE_MENTION=false, all channels are free
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
|
||||
is_free_channel = channel_id in free_channels
|
||||
|
||||
if require_mention and not is_free_channel:
|
||||
# Must be @mentioned to respond
|
||||
if self._client.user not in message.mentions:
|
||||
return # Silently ignore messages that don't mention the bot
|
||||
|
||||
# Strip the bot mention from the message text so the agent sees clean input
|
||||
if self._client.user and self._client.user in message.mentions:
|
||||
message.content = message.content.replace(f"<@{self._client.user.id}>", "").strip()
|
||||
message.content = message.content.replace(f"<@!{self._client.user.id}>", "").strip()
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if message.content.startswith("/"):
|
||||
|
||||
@@ -174,69 +174,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
audio_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send audio as a native Telegram voice message or audio file."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(audio_path):
|
||||
return SendResult(success=False, error=f"Audio file not found: {audio_path}")
|
||||
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
# .ogg files -> send as voice (round playable bubble)
|
||||
if audio_path.endswith(".ogg") or audio_path.endswith(".opus"):
|
||||
msg = await self._bot.send_voice(
|
||||
chat_id=int(chat_id),
|
||||
voice=audio_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
else:
|
||||
# .mp3 and others -> send as audio file
|
||||
msg = await self._bot.send_audio(
|
||||
chat_id=int(chat_id),
|
||||
audio=audio_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send voice/audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_url,
|
||||
caption=caption[:1024] if caption else None, # Telegram caption limit
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send photo, falling back to URL: {e}")
|
||||
# Fallback: send as text link
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._bot:
|
||||
|
||||
+18
-103
@@ -35,9 +35,6 @@ load_dotenv()
|
||||
# Gateway runs in quiet mode - suppress debug output and use cwd directly (no temp dirs)
|
||||
os.environ["HERMES_QUIET"] = "1"
|
||||
|
||||
# Enable interactive exec approval for dangerous commands on messaging platforms
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
|
||||
# Set terminal working directory for messaging platforms
|
||||
# Uses MESSAGING_CWD if set, otherwise defaults to home directory
|
||||
# This is separate from CLI which uses the directory where `hermes` is run
|
||||
@@ -80,10 +77,6 @@ class GatewayRunner:
|
||||
# Key: session_key, Value: AIAgent instance
|
||||
self._running_agents: Dict[str, Any] = {}
|
||||
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||
|
||||
# Track pending exec approvals per session
|
||||
# Key: session_key, Value: {"command": str, "pattern_key": str}
|
||||
self._pending_approvals: Dict[str, Dict[str, str]] = {}
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""
|
||||
@@ -253,25 +246,6 @@ class GatewayRunner:
|
||||
if command == "stop":
|
||||
return await self._handle_stop_command(event)
|
||||
|
||||
# Check for pending exec approval responses
|
||||
session_key_preview = f"agent:main:{source.platform.value}:{source.chat_type}:{source.chat_id}" if source.chat_type != "dm" else f"agent:main:{source.platform.value}:dm"
|
||||
if session_key_preview in self._pending_approvals:
|
||||
user_text = event.text.strip().lower()
|
||||
if user_text in ("yes", "y", "approve", "ok", "go", "do it"):
|
||||
approval = self._pending_approvals.pop(session_key_preview)
|
||||
cmd = approval["command"]
|
||||
pattern_key = approval.get("pattern_key", "")
|
||||
print(f"[gateway] ✅ User approved dangerous command: {cmd[:60]}...")
|
||||
# Approve for session and re-run via terminal_tool with force=True
|
||||
from tools.terminal_tool import terminal_tool, _session_approved_patterns
|
||||
_session_approved_patterns.add(pattern_key)
|
||||
result = terminal_tool(command=cmd, force=True)
|
||||
return f"✅ Command approved and executed.\n\n```\n{result[:3500]}\n```"
|
||||
elif user_text in ("no", "n", "deny", "cancel", "nope"):
|
||||
self._pending_approvals.pop(session_key_preview)
|
||||
return "❌ Command denied."
|
||||
# If it's not clearly an approval/denial, fall through to normal processing
|
||||
|
||||
# Get or create session
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
@@ -308,17 +282,6 @@ class GatewayRunner:
|
||||
session_key=session_key
|
||||
)
|
||||
|
||||
# Check if the agent encountered a dangerous command needing approval
|
||||
# The terminal tool stores the last pending approval globally
|
||||
try:
|
||||
from tools.terminal_tool import _last_pending_approval
|
||||
if _last_pending_approval:
|
||||
self._pending_approvals[session_key] = _last_pending_approval.copy()
|
||||
# Clear the global so it doesn't leak to other sessions
|
||||
_last_pending_approval.clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Append to transcript
|
||||
self.session_store.append_to_transcript(
|
||||
session_entry.session_id,
|
||||
@@ -455,35 +418,23 @@ class GatewayRunner:
|
||||
return
|
||||
last_tool[0] = tool_name
|
||||
|
||||
# Build progress message with primary argument preview
|
||||
# Build progress message
|
||||
tool_emojis = {
|
||||
"terminal": "💻",
|
||||
"web_search": "🔍",
|
||||
"web_extract": "📄",
|
||||
"read_file": "📖",
|
||||
"write_file": "✍️",
|
||||
"patch": "🔧",
|
||||
"search": "🔎",
|
||||
"list_directory": "📂",
|
||||
"image_generate": "🎨",
|
||||
"text_to_speech": "🔊",
|
||||
"browser_navigate": "🌐",
|
||||
"browser_click": "👆",
|
||||
"browser_type": "⌨️",
|
||||
"browser_snapshot": "📸",
|
||||
"moa_query": "🧠",
|
||||
"mixture_of_agents": "🧠",
|
||||
"vision_analyze": "👁️",
|
||||
"skill_view": "📚",
|
||||
"skills_list": "📋",
|
||||
}
|
||||
emoji = tool_emojis.get(tool_name, "⚙️")
|
||||
|
||||
if preview:
|
||||
# Truncate preview to keep messages clean
|
||||
if len(preview) > 40:
|
||||
preview = preview[:37] + "..."
|
||||
msg = f"{emoji} {tool_name}... \"{preview}\""
|
||||
if tool_name == "terminal" and preview:
|
||||
msg = f"{emoji} `{preview}`..."
|
||||
else:
|
||||
msg = f"{emoji} {tool_name}..."
|
||||
|
||||
@@ -529,10 +480,6 @@ class GatewayRunner:
|
||||
# Read from env var or use default (same as CLI)
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
||||
|
||||
# Map platform enum to the platform hint key the agent understands.
|
||||
# Platform.LOCAL ("local") maps to "cli"; others pass through as-is.
|
||||
platform_key = "cli" if source.platform == Platform.LOCAL else source.platform.value
|
||||
|
||||
agent = AIAgent(
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-opus-4.6"),
|
||||
max_iterations=max_iterations,
|
||||
@@ -541,42 +488,19 @@ class GatewayRunner:
|
||||
ephemeral_system_prompt=context_prompt,
|
||||
session_id=session_id,
|
||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||
platform=platform_key, # Tells the agent which interface to format for
|
||||
)
|
||||
|
||||
# Store agent reference for interrupt support
|
||||
agent_holder[0] = agent
|
||||
|
||||
# Convert history to agent format.
|
||||
# Two cases:
|
||||
# 1. Normal path (from transcript): simple {role, content, timestamp} dicts
|
||||
# - Strip timestamps, keep role+content
|
||||
# 2. Interrupt path (from agent result["messages"]): full agent messages
|
||||
# that may include tool_calls, tool_call_id, reasoning, etc.
|
||||
# - These must be passed through intact so the API sees valid
|
||||
# assistant→tool sequences (dropping tool_calls causes 500 errors)
|
||||
# Convert transcript history to agent format
|
||||
# Transcript has timestamps; agent expects {"role": ..., "content": ...}
|
||||
agent_history = []
|
||||
for msg in history:
|
||||
role = msg.get("role")
|
||||
if not role:
|
||||
continue
|
||||
|
||||
# Check if this is a rich agent message (has tool_calls or tool_call_id)
|
||||
# If so, pass it through with full structure intact
|
||||
has_tool_calls = "tool_calls" in msg
|
||||
has_tool_call_id = "tool_call_id" in msg
|
||||
is_tool_message = role == "tool"
|
||||
|
||||
if has_tool_calls or has_tool_call_id or is_tool_message:
|
||||
# Preserve full message structure (tool_calls, tool_call_id, etc.)
|
||||
# Only strip fields that are purely internal (e.g. timestamp)
|
||||
clean_msg = {k: v for k, v in msg.items() if k != "timestamp"}
|
||||
agent_history.append(clean_msg)
|
||||
else:
|
||||
# Simple text message - just need role and content
|
||||
content = msg.get("content")
|
||||
if content:
|
||||
agent_history.append({"role": role, "content": content})
|
||||
content = msg.get("content")
|
||||
if role and content:
|
||||
agent_history.append({"role": role, "content": content})
|
||||
|
||||
result = agent.run_conversation(message, conversation_history=agent_history)
|
||||
result_holder[0] = result
|
||||
@@ -648,16 +572,13 @@ class GatewayRunner:
|
||||
|
||||
if pending:
|
||||
print(f"[gateway] 📨 Processing interrupted message: '{pending[:40]}...'")
|
||||
# Add an indicator to the response
|
||||
if response:
|
||||
response = response + "\n\n---\n_[Interrupted - processing your new message]_"
|
||||
|
||||
# Clear the adapter's interrupt event so the next _run_agent call
|
||||
# doesn't immediately re-trigger the interrupt before the new agent
|
||||
# even makes its first API call (this was causing an infinite loop).
|
||||
if adapter and hasattr(adapter, '_active_sessions') and source.chat_id in adapter._active_sessions:
|
||||
adapter._active_sessions[source.chat_id].clear()
|
||||
|
||||
# Don't send the interrupted response to the user — it's just noise
|
||||
# like "Operation interrupted." They already know they sent a new
|
||||
# message, so go straight to processing it.
|
||||
# Send the interrupted response first
|
||||
if adapter and response:
|
||||
await adapter.send(chat_id=source.chat_id, content=response)
|
||||
|
||||
# Now process the pending message with updated history
|
||||
updated_history = result.get("messages", history)
|
||||
@@ -691,13 +612,11 @@ class GatewayRunner:
|
||||
return response
|
||||
|
||||
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None) -> bool:
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None) -> None:
|
||||
"""
|
||||
Start the gateway and run until interrupted.
|
||||
|
||||
This is the main entry point for running the gateway.
|
||||
Returns True if the gateway ran successfully, False if it failed to start.
|
||||
A False return causes a non-zero exit code so systemd can auto-restart.
|
||||
"""
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
@@ -716,11 +635,10 @@ async def start_gateway(config: Optional[GatewayConfig] = None) -> bool:
|
||||
# Start the gateway
|
||||
success = await runner.start()
|
||||
if not success:
|
||||
return False
|
||||
return
|
||||
|
||||
# Wait for shutdown
|
||||
await runner.wait_for_shutdown()
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
@@ -740,11 +658,8 @@ def main():
|
||||
data = json.load(f)
|
||||
config = GatewayConfig.from_dict(data)
|
||||
|
||||
# Run the gateway - exit with code 1 if no platforms connected,
|
||||
# so systemd Restart=on-failure will retry on transient errors (e.g. DNS)
|
||||
success = asyncio.run(start_gateway(config))
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
# Run the gateway
|
||||
asyncio.run(start_gateway(config))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -7,6 +7,40 @@ Usage: ./hermes [options]
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
Fire (google/python-fire) does not support POSIX-style short flags like `-p`.
|
||||
We translate the most common shorthands to their long equivalents so wrapper
|
||||
scripts can reliably use:
|
||||
- `-p "..."` -> `--prompt "..."` (no TUI/banner; print result and exit)
|
||||
- `-q "..."` -> `--query "..."` (single-shot with banner UX)
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
def _rewrite_short_flags(argv: list[str]) -> list[str]:
|
||||
rewritten: list[str] = []
|
||||
i = 0
|
||||
while i < len(argv):
|
||||
arg = argv[i]
|
||||
if arg == "-p":
|
||||
rewritten.append("--prompt")
|
||||
if i + 1 < len(argv):
|
||||
rewritten.append(argv[i + 1])
|
||||
i += 2
|
||||
continue
|
||||
if arg == "-q":
|
||||
rewritten.append("--query")
|
||||
if i + 1 < len(argv):
|
||||
rewritten.append(argv[i + 1])
|
||||
i += 2
|
||||
continue
|
||||
rewritten.append(arg)
|
||||
i += 1
|
||||
return rewritten
|
||||
|
||||
sys.argv = [sys.argv[0]] + _rewrite_short_flags(sys.argv[1:])
|
||||
|
||||
from cli import main
|
||||
import fire
|
||||
|
||||
fire.Fire(main)
|
||||
|
||||
@@ -0,0 +1,659 @@
|
||||
Metadata-Version: 2.4
|
||||
Name: hermes-agent
|
||||
Version: 0.1.0
|
||||
Summary: AI agent with advanced tool-calling and toolsets
|
||||
Author: Nous Research
|
||||
License: MIT
|
||||
Requires-Python: >=3.10
|
||||
Description-Content-Type: text/markdown
|
||||
Requires-Dist: openai
|
||||
Requires-Dist: python-dotenv
|
||||
Requires-Dist: fire
|
||||
Requires-Dist: httpx
|
||||
Requires-Dist: rich
|
||||
Requires-Dist: tenacity
|
||||
Requires-Dist: pyyaml
|
||||
Requires-Dist: prompt_toolkit
|
||||
Requires-Dist: requests
|
||||
Requires-Dist: jinja2
|
||||
Requires-Dist: pydantic>=2.0
|
||||
Requires-Dist: firecrawl-py
|
||||
Requires-Dist: fal-client
|
||||
Requires-Dist: litellm>=1.75.5
|
||||
Requires-Dist: typer
|
||||
Requires-Dist: platformdirs
|
||||
Provides-Extra: modal
|
||||
Requires-Dist: modal; extra == "modal"
|
||||
Requires-Dist: boto3; extra == "modal"
|
||||
Provides-Extra: dev
|
||||
Requires-Dist: pytest; extra == "dev"
|
||||
Requires-Dist: pytest-asyncio; extra == "dev"
|
||||
Provides-Extra: atropos
|
||||
Requires-Dist: atroposlib @ git+https://github.com/NousResearch/atropos.git ; extra == "atropos"
|
||||
Requires-Dist: aiohttp; extra == "atropos"
|
||||
Requires-Dist: fastapi; extra == "atropos"
|
||||
Requires-Dist: uvicorn; extra == "atropos"
|
||||
Requires-Dist: pyte; extra == "atropos"
|
||||
|
||||
# Hermes Agent
|
||||
|
||||
An AI agent with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools.
|
||||
|
||||
## Features
|
||||
|
||||
- **Interactive CLI**: Beautiful terminal interface with animated feedback, personalities, and session management
|
||||
- **Web Tools**: Search, extract content, and crawl websites
|
||||
- **Terminal Tools**: Execute commands via local, Docker, Singularity, Modal, or SSH backends
|
||||
- **Browser Tools**: Automate web browsers to navigate, click, type, and extract content
|
||||
- **Vision Tools**: Analyze images from URLs
|
||||
- **Reasoning Tools**: Advanced multi-model reasoning (Mixture of Agents)
|
||||
- **Creative Tools**: Generate images from text prompts
|
||||
- **Skills Tools**: On-demand knowledge documents with progressive disclosure
|
||||
- **Toolsets System**: Organize tools into logical groups for different scenarios
|
||||
- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking
|
||||
- **Ephemeral System Prompts**: Guide model behavior without polluting training datasets
|
||||
|
||||
## Quick Start (CLI)
|
||||
|
||||
```bash
|
||||
# After setup (see below), just run:
|
||||
./hermes
|
||||
|
||||
# Or with options:
|
||||
./hermes --model "anthropic/claude-sonnet-4" --toolsets "web,terminal"
|
||||
```
|
||||
|
||||
The CLI provides:
|
||||
- Animated spinners during thinking and tool execution
|
||||
- Kawaii-style feedback messages
|
||||
- `/commands` for configuration, history, and session management
|
||||
- Customizable personalities (`/personality kawaii`, `/personality pirate`, etc.)
|
||||
- Persistent configuration via `cli-config.yaml`
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Clone the Repository
|
||||
```bash
|
||||
# Clone with submodules (recommended)
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Or if already cloned without submodules:
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
### 2. Install Dependencies
|
||||
```bash
|
||||
# Create and activate virtual environment (recommended)
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
|
||||
# Install Python packages
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install mini-swe-agent for terminal tools
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Install Node.js dependencies for browser tools (requires Node.js)
|
||||
npm install
|
||||
```
|
||||
|
||||
### 3. Configure Environment Variables
|
||||
```bash
|
||||
# Copy the example environment file
|
||||
cp .env.example .env
|
||||
|
||||
# Edit .env and add your API keys
|
||||
nano .env # or use your preferred editor
|
||||
```
|
||||
|
||||
**Required API Keys:**
|
||||
- `OPENROUTER_API_KEY` - LLM access via OpenRouter (get at: https://openrouter.ai/keys)
|
||||
- `FIRECRAWL_API_KEY` - Web tools (get at: https://firecrawl.dev/)
|
||||
- `NOUS_API_KEY` - Vision & reasoning tools (get at: https://inference-api.nousresearch.com/)
|
||||
- `FAL_KEY` - Image generation (get at: https://fal.ai/)
|
||||
|
||||
**Optional API Keys (for specific features):**
|
||||
- `BROWSERBASE_API_KEY` - Browser automation (get at: https://browserbase.com/)
|
||||
- `BROWSERBASE_PROJECT_ID` - From Browserbase dashboard
|
||||
- `MORPH_API_KEY` - For legacy Hecate terminal backend (get at: https://morph.so/)
|
||||
|
||||
### 4. Configure Terminal Backend
|
||||
|
||||
The terminal tool uses **mini-swe-agent** environments. Configure in `.env` or `cli-config.yaml`:
|
||||
|
||||
```bash
|
||||
# Backend: "local", "docker", "singularity", "modal", or "ssh"
|
||||
TERMINAL_ENV=local # Default: runs on host machine (no isolation)
|
||||
TERMINAL_ENV=ssh # Remote execution via SSH (agent code stays local)
|
||||
TERMINAL_ENV=singularity # Recommended for HPC: Apptainer/Singularity containers
|
||||
TERMINAL_ENV=docker # Isolated Docker containers
|
||||
TERMINAL_ENV=modal # Cloud execution via Modal
|
||||
|
||||
# Container image (for docker/singularity/modal backends)
|
||||
TERMINAL_DOCKER_IMAGE=python:3.11-slim
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://python:3.11-slim
|
||||
TERMINAL_TIMEOUT=60
|
||||
|
||||
# SSH backend (for ssh)
|
||||
TERMINAL_SSH_HOST=my-server.example.com
|
||||
TERMINAL_SSH_USER=myuser
|
||||
TERMINAL_SSH_KEY=~/.ssh/id_rsa # Optional, uses ssh-agent if not set
|
||||
```
|
||||
|
||||
**Backend Requirements:**
|
||||
- **local**: No extra setup (runs directly on your machine, no isolation)
|
||||
- **ssh**: SSH access to remote machine (great for sandboxing - agent can't touch its own code)
|
||||
- **singularity**: Requires Apptainer or Singularity installed (common on HPC clusters, no root needed)
|
||||
- **docker**: Requires Docker installed and user in `docker` group
|
||||
- **modal**: Requires Modal account (see setup below)
|
||||
|
||||
### Singularity/Apptainer Setup (Recommended for HPC)
|
||||
|
||||
Singularity/Apptainer provides rootless container execution, ideal for HPC clusters:
|
||||
|
||||
```bash
|
||||
# 1. Verify Apptainer is installed
|
||||
apptainer --version # or: singularity --version
|
||||
|
||||
# 2. Set up cache directories (important for parallel workers)
|
||||
# Use /scratch if available (HPC), otherwise /tmp
|
||||
export APPTAINER_CACHEDIR=/scratch/$USER/.apptainer
|
||||
export APPTAINER_TMPDIR=/scratch/$USER/.apptainer/tmp
|
||||
mkdir -p "$APPTAINER_CACHEDIR" "$APPTAINER_TMPDIR"
|
||||
|
||||
# 3. Pre-build SIF image (recommended for parallel batch processing)
|
||||
# This avoids race conditions when multiple workers start simultaneously
|
||||
apptainer build $APPTAINER_CACHEDIR/python-nodejs.sif docker://nikolaik/python-nodejs:python3.11-nodejs20
|
||||
|
||||
# 4. Configure .env to use the local SIF
|
||||
TERMINAL_ENV=singularity
|
||||
TERMINAL_SINGULARITY_IMAGE=/scratch/$USER/.apptainer/python-nodejs.sif
|
||||
```
|
||||
|
||||
**Tip:** The batch scripts in `configs/` automatically handle SIF pre-building if `/scratch` is available.
|
||||
|
||||
### Modal Cloud Backend Setup
|
||||
|
||||
[Modal](https://modal.com) provides serverless cloud compute for running sandboxed environments at scale.
|
||||
|
||||
```bash
|
||||
# 1. Install Modal and dependencies
|
||||
pip install modal boto3
|
||||
|
||||
# 2. Authenticate with Modal (opens browser)
|
||||
modal setup
|
||||
|
||||
# 3. Set terminal backend to modal in .env
|
||||
TERMINAL_ENV=modal
|
||||
```
|
||||
|
||||
Modal uses CLI-based authentication (stored in `~/.modal/`), so no API key is needed in `.env`. After running `modal setup`, commands will automatically execute in Modal's cloud sandboxes.
|
||||
|
||||
### Browser Tools Setup
|
||||
|
||||
Browser tools enable the agent to navigate websites, fill forms, click buttons, and extract content. They use [agent-browser](https://github.com/vercel-labs/agent-browser) CLI with [Browserbase](https://browserbase.com) cloud execution.
|
||||
|
||||
```bash
|
||||
# 1. Install Node.js (if not already installed)
|
||||
# Use nvm (recommended) or your package manager
|
||||
|
||||
# 2. Install agent-browser CLI (choose one option):
|
||||
npm install -g agent-browser # Option A: Global install (recommended)
|
||||
npm install # Option B: Local install (uses npx fallback)
|
||||
|
||||
# 3. Get Browserbase credentials
|
||||
# Sign up at https://browserbase.com/ and get your:
|
||||
# - API Key (from Settings → API Keys)
|
||||
# - Project ID (from your project dashboard)
|
||||
|
||||
# 4. Add to your .env file:
|
||||
BROWSERBASE_API_KEY=your_api_key_here
|
||||
BROWSERBASE_PROJECT_ID=your_project_id_here
|
||||
```
|
||||
|
||||
**Available Browser Tools:**
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `browser_navigate` | Navigate to a URL |
|
||||
| `browser_snapshot` | Get text-based page snapshot with element refs |
|
||||
| `browser_click` | Click an element by ref (e.g., `@e5`) |
|
||||
| `browser_type` | Type text into an input field |
|
||||
| `browser_scroll` | Scroll up or down |
|
||||
| `browser_back` | Go back in browser history |
|
||||
| `browser_press` | Press a keyboard key (Enter, Tab, etc.) |
|
||||
| `browser_close` | Close the browser session |
|
||||
| `browser_get_images` | Get list of images on the page |
|
||||
|
||||
**Example Usage:**
|
||||
```bash
|
||||
# Use browser tools with web search and vision
|
||||
python run_agent.py \
|
||||
--query "Go to amazon.com and find the price of the latest Kindle" \
|
||||
--enabled_toolsets=browser,web,vision
|
||||
|
||||
# Use browser-focused distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=browser_tasks.jsonl \
|
||||
--distribution=browser_use \
|
||||
--run_name=browser_run
|
||||
```
|
||||
|
||||
See `.env.example` for all available configuration options including debug settings.
|
||||
|
||||
### Skills Tools
|
||||
|
||||
Skills are on-demand knowledge documents the agent can load when needed. They follow a **progressive disclosure** pattern to minimize token usage:
|
||||
|
||||
```
|
||||
skills/
|
||||
├── mlops/ # Category folder
|
||||
│ ├── axolotl/ # Skill folder
|
||||
│ │ ├── SKILL.md # Main instructions (required)
|
||||
│ │ ├── references/ # Additional docs, API specs
|
||||
│ │ └── templates/ # Output formats, configs
|
||||
│ └── vllm/
|
||||
│ └── SKILL.md
|
||||
```
|
||||
|
||||
**Available Skills Tools:**
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `skills_categories` | List available skill categories (~50 tokens) |
|
||||
| `skills_list` | List skills with name + description (~3k tokens for 40 skills) |
|
||||
| `skill_view` | Load full skill content, tags, and linked files |
|
||||
|
||||
**Example Usage:**
|
||||
```bash
|
||||
# Use skills tools
|
||||
python run_agent.py \
|
||||
--query "What skills do you have for fine-tuning? Show me the axolotl skill." \
|
||||
--enabled_toolsets=skills
|
||||
```
|
||||
|
||||
**Creating Skills:**
|
||||
|
||||
Skills use YAML frontmatter for metadata:
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
description: Brief description shown in skills_list
|
||||
tags: [tag1, tag2]
|
||||
related_skills: [other-skill]
|
||||
version: 1.0.0
|
||||
---
|
||||
# Skill Content
|
||||
|
||||
Instructions, examples, and guidelines here...
|
||||
```
|
||||
|
||||
Skills can include:
|
||||
- `references/` - Additional documentation, API specs, examples
|
||||
- `templates/` - Output formats, config files, boilerplate code
|
||||
- `scripts/` - Executable helpers (Python, shell scripts)
|
||||
|
||||
## Session Logging
|
||||
|
||||
Every conversation is automatically logged to `logs/` for debugging and inspection:
|
||||
|
||||
```
|
||||
logs/
|
||||
├── session_20260201_143052_a1b2c3.json
|
||||
├── session_20260201_150217_d4e5f6.json
|
||||
└── ...
|
||||
```
|
||||
|
||||
**Log Format:**
|
||||
```json
|
||||
{
|
||||
"session_id": "20260201_143052_a1b2c3",
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"session_start": "2026-02-01T14:30:52.123456",
|
||||
"last_updated": "2026-02-01T14:35:12.789012",
|
||||
"message_count": 8,
|
||||
"conversations": [
|
||||
{"from": "system", "value": "..."},
|
||||
{"from": "human", "value": "..."},
|
||||
{"from": "gpt", "value": "..."},
|
||||
{"from": "tool", "value": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
- **Automatic**: Logs are created and updated automatically after each conversation turn
|
||||
- **Session ID in Banner**: The CLI displays the session ID in the welcome banner
|
||||
- **Trajectory Format**: Uses the same format as batch processing for consistency
|
||||
- **Git Ignored**: `logs/` is in `.gitignore` so logs aren't committed
|
||||
|
||||
## Interactive CLI
|
||||
|
||||
The CLI provides a rich interactive experience for working with the agent.
|
||||
|
||||
### Running the CLI
|
||||
|
||||
```bash
|
||||
# Basic usage
|
||||
./hermes
|
||||
|
||||
# With specific model
|
||||
./hermes --model "anthropic/claude-sonnet-4"
|
||||
|
||||
# With specific toolsets
|
||||
./hermes --toolsets "web,terminal,skills"
|
||||
```
|
||||
|
||||
### CLI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show available commands |
|
||||
| `/tools` | List available tools by toolset |
|
||||
| `/toolsets` | List available toolsets |
|
||||
| `/model [name]` | Show or change the current model |
|
||||
| `/prompt [text]` | View/set custom system prompt |
|
||||
| `/personality [name]` | Set a predefined personality |
|
||||
| `/clear` | Clear screen and reset conversation |
|
||||
| `/reset` | Reset conversation only |
|
||||
| `/history` | Show conversation history |
|
||||
| `/save` | Save current conversation to file |
|
||||
| `/config` | Show current configuration |
|
||||
| `/quit` | Exit the CLI |
|
||||
|
||||
### Configuration
|
||||
|
||||
Copy `cli-config.yaml.example` to `cli-config.yaml` and customize:
|
||||
|
||||
```yaml
|
||||
# Model settings
|
||||
model:
|
||||
default: "anthropic/claude-sonnet-4"
|
||||
|
||||
# Terminal backend (local, docker, singularity, modal, or ssh)
|
||||
terminal:
|
||||
env_type: "local"
|
||||
cwd: "." # Use current directory
|
||||
|
||||
# Or use SSH for remote execution (keeps agent code isolated)
|
||||
# terminal:
|
||||
# env_type: "ssh"
|
||||
# ssh_host: "my-server.example.com"
|
||||
# ssh_user: "myuser"
|
||||
# ssh_key: "~/.ssh/id_rsa"
|
||||
# cwd: "/home/myuser/project"
|
||||
|
||||
# Enable specific toolsets
|
||||
toolsets:
|
||||
- all # or: web, terminal, browser, vision, etc.
|
||||
|
||||
# Custom personalities (use with /personality command)
|
||||
agent:
|
||||
personalities:
|
||||
helpful: "You are a helpful assistant."
|
||||
kawaii: "You are a kawaii assistant! Use cute expressions..."
|
||||
```
|
||||
|
||||
### Personalities
|
||||
|
||||
Built-in personalities available via `/personality`:
|
||||
- `helpful`, `concise`, `technical`, `creative`, `teacher`
|
||||
- `kawaii`, `catgirl`, `pirate`, `shakespeare`, `surfer`
|
||||
- `noir`, `uwu`, `philosopher`, `hype`
|
||||
|
||||
## Toolsets System
|
||||
|
||||
The agent uses a toolsets system for organizing and managing tools. All tools must be part of a toolset to be accessible - individual tool selection is not supported. This ensures consistent and logical grouping of capabilities.
|
||||
|
||||
### Key Concepts
|
||||
|
||||
- **Toolsets**: Logical groups of tools for specific use cases (e.g., "research", "development", "debugging")
|
||||
- **Composition**: Toolsets can include other toolsets for powerful combinations
|
||||
- **Custom Toolsets**: Create your own toolsets at runtime or by editing `toolsets.py`
|
||||
- **Toolset-Only Access**: Tools are only accessible through toolsets, not individually
|
||||
|
||||
### Available Toolsets
|
||||
|
||||
See `toolsets.py` for the complete list of predefined toolsets including:
|
||||
- Basic toolsets (web, terminal, vision, creative, reasoning)
|
||||
- Composite toolsets (research, development, analysis, etc.)
|
||||
- Scenario-specific toolsets (debugging, documentation, API testing, etc.)
|
||||
- Special toolsets (safe mode without terminal, minimal, offline)
|
||||
|
||||
### Using Toolsets
|
||||
|
||||
```bash
|
||||
# Use a predefined toolset
|
||||
python run_agent.py --enabled_toolsets=research --query "Find latest AI papers"
|
||||
|
||||
# Combine multiple toolsets
|
||||
python run_agent.py --enabled_toolsets=web,vision --query "Analyze this website"
|
||||
|
||||
# Enable all toolsets explicitly (same as omitting the flag)
|
||||
python run_agent.py --enabled_toolsets=all --query "Do web research and run commands if helpful"
|
||||
|
||||
# Safe mode (no terminal access)
|
||||
python run_agent.py --enabled_toolsets=safe --query "Help without running commands"
|
||||
|
||||
# List all available toolsets and tools
|
||||
python run_agent.py --list_tools
|
||||
```
|
||||
|
||||
See `toolsets.py` for the complete list of available toolsets and how to create custom ones.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Default (all tools enabled)
|
||||
```bash
|
||||
# Uses OpenRouter by default - just set OPENROUTER_API_KEY in .env
|
||||
python run_agent.py \
|
||||
--query "search up the latest docs on jit in python 3.13 and write me basic example that's not in their docs. profile its perf" \
|
||||
--max_turns 20 \
|
||||
--model anthropic/claude-sonnet-4-20250514
|
||||
```
|
||||
|
||||
### With specific toolset
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--query "Debug this Python error" \
|
||||
--enabled_toolsets=debugging \
|
||||
--model anthropic/claude-sonnet-4-20250514
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Uses OpenRouter by default (reads OPENROUTER_API_KEY from .env)
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4-20250514",
|
||||
enabled_toolsets=["research"]
|
||||
)
|
||||
response = agent.chat("Find information about quantum computing")
|
||||
|
||||
# Create custom toolset at runtime
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
create_custom_toolset(
|
||||
name="my_tools",
|
||||
description="My custom toolkit",
|
||||
tools=["web_search"],
|
||||
includes=["terminal", "vision"]
|
||||
)
|
||||
|
||||
agent = AIAgent(enabled_toolsets=["my_tools"])
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
Process multiple prompts from a dataset in parallel with automatic checkpointing and statistics tracking:
|
||||
|
||||
```bash
|
||||
# Basic batch processing
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=my_run
|
||||
|
||||
# With specific distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=image_run \
|
||||
--distribution=image_gen \
|
||||
--num_workers=4
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Parallel processing with configurable workers
|
||||
- Toolset distributions for varied data generation
|
||||
- Automatic checkpointing and resume capability
|
||||
- Combined output in `data/<run_name>/trajectories.jsonl`
|
||||
- Tool usage statistics and success rates
|
||||
|
||||
Use `--list_distributions` to see available toolset distributions for varied data generation.
|
||||
|
||||
### Trajectory Compression
|
||||
|
||||
Post-process trajectories to fit within token budgets for training:
|
||||
|
||||
```bash
|
||||
# Compress a directory of JSONL files
|
||||
python trajectory_compressor.py --input=data/my_run
|
||||
|
||||
# Compress a single JSONL file
|
||||
python trajectory_compressor.py --input=data/trajectories.jsonl
|
||||
|
||||
# Compress a 15% sample (useful for creating smaller training sets)
|
||||
python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15
|
||||
|
||||
# Custom output and token target
|
||||
python trajectory_compressor.py \
|
||||
--input=data/trajectories.jsonl \
|
||||
--output=data/compressed.jsonl \
|
||||
--target_max_tokens=16000
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Protects first turns (system, human, first GPT response, first tool call)
|
||||
- Protects last N turns (configurable)
|
||||
- Summarizes middle turns using LLM to fit target token budget
|
||||
- Supports both directory and single file input
|
||||
- Optional random sampling with `--sample_percent`
|
||||
- Configurable via `configs/trajectory_compression.yaml`
|
||||
|
||||
### Ephemeral System Prompts
|
||||
|
||||
The ephemeral system prompt feature allows you to guide the model's behavior during batch processing **without** saving that prompt to the training dataset trajectories. This is useful for:
|
||||
|
||||
- Guiding model behavior during data collection
|
||||
- Adding task-specific instructions
|
||||
- Keeping saved trajectories clean and focused on tool-calling format
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=10 \
|
||||
--run_name=my_run \
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
```
|
||||
|
||||
The ephemeral prompt will influence the model's behavior during execution, but **only the standard tool-calling system prompt** will be saved in the trajectory files.
|
||||
|
||||
The ephemeral prompt influences model behavior during execution, but **only the standard tool-calling system prompt** is saved in trajectory files.
|
||||
|
||||
## Command Line Arguments
|
||||
|
||||
**Single Agent (`run_agent.py`):**
|
||||
- `--query`: The question or task for the agent
|
||||
- `--model`: Model to use (default: claude-opus-4-20250514)
|
||||
- `--api_key`: API key for authentication
|
||||
- `--base_url`: API endpoint URL
|
||||
- `--max_turns`: Maximum number of tool-calling iterations
|
||||
- `--enabled_toolsets`: Comma-separated list of toolsets to enable. Use `all` (or `*`) to enable everything. If omitted, all toolsets are enabled by default.
|
||||
- `--disabled_toolsets`: Comma-separated list of toolsets to disable
|
||||
- `--list_tools`: List all available toolsets and tools
|
||||
- `--save_trajectories`: Save conversation trajectories to JSONL files
|
||||
|
||||
**Batch Processing (`batch_runner.py`):**
|
||||
- `--dataset_file`: Path to JSONL file with prompts
|
||||
- `--batch_size`: Number of prompts per batch
|
||||
- `--run_name`: Name for this run (for output/checkpointing)
|
||||
- `--distribution`: Toolset distribution to use (default: "default")
|
||||
- `--num_workers`: Number of parallel workers (default: 4)
|
||||
- `--resume`: Resume from checkpoint if interrupted
|
||||
- `--ephemeral_system_prompt`: System prompt used during execution but NOT saved to trajectories
|
||||
- `--list_distributions`: List available toolset distributions
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All environment variables can be configured in the `.env` file (copy from `.env.example`).
|
||||
|
||||
**LLM Provider (OpenRouter):**
|
||||
- `OPENROUTER_API_KEY`: Primary LLM access via OpenRouter (supports Claude, GPT-4, Gemini, etc.)
|
||||
- `LLM_MODEL`: Default model (e.g., `anthropic/claude-sonnet-4`, `openai/gpt-4o`)
|
||||
|
||||
**Tool API Keys:**
|
||||
- `FIRECRAWL_API_KEY`: Web tools (search, extract, crawl)
|
||||
- `NOUS_API_KEY`: Vision and reasoning tools
|
||||
- `FAL_KEY`: Image generation tools
|
||||
|
||||
**Terminal Tool Configuration (mini-swe-agent backend):**
|
||||
- `TERMINAL_ENV`: Backend type - `local`, `docker`, `singularity`, `modal`, or `ssh` (default: `local`)
|
||||
- `TERMINAL_DOCKER_IMAGE`: Docker image for docker backend (default: `python:3.11-slim`)
|
||||
- `TERMINAL_SINGULARITY_IMAGE`: Singularity/Apptainer image (can be `docker://...` URL or local `.sif` path)
|
||||
- `TERMINAL_TIMEOUT`: Command timeout in seconds (default: `60`)
|
||||
- `TERMINAL_LIFETIME_SECONDS`: Cleanup inactive environments after this time (default: `300`)
|
||||
- `TERMINAL_CWD`: Working directory inside containers (default: `/tmp`)
|
||||
- `TERMINAL_SCRATCH_DIR`: Custom scratch directory for sandbox storage (optional, auto-detects `/scratch`)
|
||||
- `SUDO_PASSWORD`: Enable sudo commands by piping password via `sudo -S` (works with all backends)
|
||||
- If unset in CLI mode, you'll be prompted interactively when sudo is needed (45s timeout)
|
||||
|
||||
**SSH Backend Configuration (for remote execution):**
|
||||
- `TERMINAL_SSH_HOST`: Remote server hostname or IP
|
||||
- `TERMINAL_SSH_USER`: SSH username
|
||||
- `TERMINAL_SSH_PORT`: SSH port (default: `22`)
|
||||
- `TERMINAL_SSH_KEY`: Path to SSH private key (optional, uses ssh-agent if not set)
|
||||
|
||||
**Browser Tool Configuration (agent-browser + Browserbase):**
|
||||
- `BROWSERBASE_API_KEY`: Browserbase API key for cloud browser execution
|
||||
- `BROWSERBASE_PROJECT_ID`: Browserbase project ID
|
||||
- `BROWSER_SESSION_TIMEOUT`: Session timeout in seconds (default: `300`)
|
||||
|
||||
**Legacy Hecate Terminal Backend (optional):**
|
||||
- `MORPH_API_KEY`: For Hecate/MorphCloud terminal backend
|
||||
- `HECATE_VM_LIFETIME_SECONDS`: VM lifetime (default: 300)
|
||||
- `HECATE_DEFAULT_SNAPSHOT_ID`: Default snapshot (default: snapshot_p5294qxt)
|
||||
|
||||
**Debug Options:**
|
||||
- `WEB_TOOLS_DEBUG`, `VISION_TOOLS_DEBUG`, `MOA_TOOLS_DEBUG`, `IMAGE_TOOLS_DEBUG`: Enable debug logging
|
||||
|
||||
## Key Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `hermes` | CLI launcher script (run with `./hermes`) |
|
||||
| `cli.py` | Interactive CLI implementation |
|
||||
| `cli-config.yaml` | CLI configuration (copy from `.example`) |
|
||||
| `run_agent.py` | Main agent runner - single query execution |
|
||||
| `batch_runner.py` | Parallel batch processing with checkpointing |
|
||||
| `model_tools.py` | Core tool definitions and handlers |
|
||||
| `toolsets.py` | Toolset definitions and composition |
|
||||
| `toolset_distributions.py` | Probability distributions for data generation |
|
||||
| `trajectory_compressor.py` | Post-process trajectories for training |
|
||||
| `tools/` | Individual tool implementations |
|
||||
| `tools/skills_tool.py` | Skills system with progressive disclosure |
|
||||
| `skills/` | On-demand knowledge documents |
|
||||
| `docs/` | Documentation |
|
||||
| `configs/` | Example batch run scripts |
|
||||
|
||||
# Atropos Integrations & RL Training
|
||||
|
||||
## Nomad Setup
|
||||
Follow this: https://developer.hashicorp.com/nomad/docs/deploy
|
||||
|
||||
## Atropos dependencies
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -e '.[atropos]'
|
||||
@@ -0,0 +1,70 @@
|
||||
README.md
|
||||
atropos_compatible_agent.py
|
||||
batch_runner.py
|
||||
local_server.py
|
||||
model_tools.py
|
||||
pyproject.toml
|
||||
run_agent.py
|
||||
toolset_distributions.py
|
||||
toolsets.py
|
||||
trajectory_compressor.py
|
||||
atropos/__init__.py
|
||||
atropos/sandbox_server.py
|
||||
atropos/agent/__init__.py
|
||||
atropos/agent/atropos_agent.py
|
||||
atropos/api/__init__.py
|
||||
atropos/api/tool_executor_server.py
|
||||
atropos/api/tool_server.py
|
||||
atropos/backends/__init__.py
|
||||
atropos/backends/base.py
|
||||
atropos/backends/modal_backend.py
|
||||
atropos/backends/nomad_backend.py
|
||||
atropos/envs/__init__.py
|
||||
atropos/envs/agent_env.py
|
||||
atropos/envs/hermes_compat_test_env.py
|
||||
atropos/envs/sandbox_terminal_smoke_env.py
|
||||
atropos/envs/swe_smith_oracle_env.py
|
||||
atropos/envs/test_env.py
|
||||
atropos/envs/toolserver_smoke_env.py
|
||||
atropos/nomad/__init__.py
|
||||
atropos/nomad/client.py
|
||||
atropos/slots/__init__.py
|
||||
atropos/slots/executor.py
|
||||
atropos/slots/pool.py
|
||||
atropos/slots/slot.py
|
||||
atropos/terminal/__init__.py
|
||||
atropos/terminal/asciinema_stream.py
|
||||
atropos/tools/__init__.py
|
||||
atropos/tools/base.py
|
||||
atropos/tools/build_registry.py
|
||||
atropos/tools/hermes_external_tools.py
|
||||
atropos/tools/sandbox_stubs.py
|
||||
atropos/tools/terminal_stateful_tool.py
|
||||
atropos/tools/tmux_tool.py
|
||||
atropos/tools/tool_executor.py
|
||||
atropos/tools/toolset_resolver.py
|
||||
hermes_agent.egg-info/PKG-INFO
|
||||
hermes_agent.egg-info/SOURCES.txt
|
||||
hermes_agent.egg-info/dependency_links.txt
|
||||
hermes_agent.egg-info/entry_points.txt
|
||||
hermes_agent.egg-info/requires.txt
|
||||
hermes_agent.egg-info/top_level.txt
|
||||
tests/test_batch_runner.py
|
||||
tests/test_checkpoint_resumption.py
|
||||
tests/test_modal_integration.py
|
||||
tests/test_modal_stress.py
|
||||
tests/test_modal_terminal.py
|
||||
tests/test_nous_api_limits.py
|
||||
tests/test_nous_api_pattern.py
|
||||
tests/test_temperature_fix.py
|
||||
tests/test_tool_call_parsing.py
|
||||
tests/test_web_tools.py
|
||||
tools/__init__.py
|
||||
tools/browser_tool.py
|
||||
tools/image_generation_tool.py
|
||||
tools/mixture_of_agents_tool.py
|
||||
tools/skills_tool.py
|
||||
tools/terminal_hecate.py
|
||||
tools/terminal_tool.py
|
||||
tools/vision_tools.py
|
||||
tools/web_tools.py
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
[console_scripts]
|
||||
hermes-agent = run_agent:main
|
||||
hermes-atropos-sandbox-smoke = atropos.envs.sandbox_terminal_smoke_env:SandboxTerminalSmokeEnv.cli
|
||||
hermes-atropos-toolserver-smoke = atropos.envs.toolserver_smoke_env:ToolServerSmokeEnv.cli
|
||||
@@ -0,0 +1,31 @@
|
||||
openai
|
||||
python-dotenv
|
||||
fire
|
||||
httpx
|
||||
rich
|
||||
tenacity
|
||||
pyyaml
|
||||
prompt_toolkit
|
||||
requests
|
||||
jinja2
|
||||
pydantic>=2.0
|
||||
firecrawl-py
|
||||
fal-client
|
||||
litellm>=1.75.5
|
||||
typer
|
||||
platformdirs
|
||||
|
||||
[atropos]
|
||||
atroposlib @ git+https://github.com/NousResearch/atropos.git
|
||||
aiohttp
|
||||
fastapi
|
||||
uvicorn
|
||||
pyte
|
||||
|
||||
[dev]
|
||||
pytest
|
||||
pytest-asyncio
|
||||
|
||||
[modal]
|
||||
modal
|
||||
boto3
|
||||
@@ -0,0 +1,10 @@
|
||||
atropos
|
||||
atropos_compatible_agent
|
||||
batch_runner
|
||||
local_server
|
||||
model_tools
|
||||
run_agent
|
||||
tools
|
||||
toolset_distributions
|
||||
toolsets
|
||||
trajectory_compressor
|
||||
@@ -99,24 +99,6 @@ DEFAULT_CONFIG = {
|
||||
"personality": "kawaii",
|
||||
},
|
||||
|
||||
# Text-to-speech configuration
|
||||
"tts": {
|
||||
"provider": "edge", # "edge" (free) | "elevenlabs" (premium) | "openai"
|
||||
"edge": {
|
||||
"voice": "en-US-AriaNeural",
|
||||
# Popular: AriaNeural, JennyNeural, AndrewNeural, BrianNeural, SoniaNeural
|
||||
},
|
||||
"elevenlabs": {
|
||||
"voice_id": "pNInz6obpgDQGcFmaJgB", # Adam
|
||||
"model_id": "eleven_multilingual_v2",
|
||||
},
|
||||
"openai": {
|
||||
"model": "gpt-4o-mini-tts",
|
||||
"voice": "alloy",
|
||||
# Voices: alloy, echo, fable, onyx, nova, shimmer
|
||||
},
|
||||
},
|
||||
|
||||
# Permanently allowed dangerous command patterns (added via "always" approval)
|
||||
"command_allowlist": [],
|
||||
|
||||
@@ -220,13 +202,6 @@ OPTIONAL_ENV_VARS = {
|
||||
"url": None,
|
||||
"password": False,
|
||||
},
|
||||
# Text-to-speech (premium providers)
|
||||
"ELEVENLABS_API_KEY": {
|
||||
"description": "ElevenLabs API key for premium text-to-speech voices",
|
||||
"prompt": "ElevenLabs API key",
|
||||
"url": "https://elevenlabs.io/",
|
||||
"password": True,
|
||||
},
|
||||
# Terminal configuration
|
||||
"MESSAGING_CWD": {
|
||||
"description": "Working directory for terminal commands via messaging (Telegram/Discord/etc). CLI always uses current directory.",
|
||||
|
||||
@@ -360,11 +360,7 @@ def run_gateway(verbose: bool = False):
|
||||
print("└─────────────────────────────────────────────────────────┘")
|
||||
print()
|
||||
|
||||
# Exit with code 1 if gateway fails to connect any platform,
|
||||
# so systemd Restart=on-failure will retry on transient errors
|
||||
success = asyncio.run(start_gateway())
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
asyncio.run(start_gateway())
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -186,11 +186,6 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||
else:
|
||||
tool_status.append(("Image Generation", False, "FAL_KEY"))
|
||||
|
||||
# TTS (always available via Edge TTS; ElevenLabs/OpenAI are optional)
|
||||
tool_status.append(("Text-to-Speech (Edge TTS)", True, None))
|
||||
if get_env_value('ELEVENLABS_API_KEY'):
|
||||
tool_status.append(("Text-to-Speech (ElevenLabs)", True, None))
|
||||
|
||||
# Tinker + WandB (RL training)
|
||||
if get_env_value('TINKER_API_KEY') and get_env_value('WANDB_API_KEY'):
|
||||
tool_status.append(("RL Training (Tinker)", True, None))
|
||||
@@ -996,28 +991,6 @@ def run_setup_wizard(args):
|
||||
print_success(" Configured ✓")
|
||||
print()
|
||||
|
||||
# ElevenLabs - Premium TTS
|
||||
print_info("─" * 50)
|
||||
print(color(" Text-to-Speech - ElevenLabs (Premium)", Colors.CYAN))
|
||||
print_info(" Enables: Premium TTS voices (Edge TTS is free and works without a key)")
|
||||
print_info(" Use case: High-quality, customizable voice synthesis")
|
||||
if get_env_value('ELEVENLABS_API_KEY'):
|
||||
print_success(" Status: Configured ✓")
|
||||
if prompt_yes_no(" Update ElevenLabs API key?", False):
|
||||
api_key = prompt(" API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("ELEVENLABS_API_KEY", api_key)
|
||||
print_success(" Updated")
|
||||
else:
|
||||
print_warning(" Status: Not configured (free Edge TTS will be used by default)")
|
||||
if prompt_yes_no(" Set up ElevenLabs?", False):
|
||||
print_info(" Get your API key at: https://elevenlabs.io/")
|
||||
api_key = prompt(" API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("ELEVENLABS_API_KEY", api_key)
|
||||
print_success(" Configured ✓")
|
||||
print()
|
||||
|
||||
# Tinker + WandB - RL Training
|
||||
print_info("─" * 50)
|
||||
print(color(" RL Training (Tinker + WandB)", Colors.CYAN))
|
||||
|
||||
@@ -76,7 +76,6 @@ def show_status(args):
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
"WandB": "WANDB_API_KEY",
|
||||
"ElevenLabs": "ELEVENLABS_API_KEY",
|
||||
}
|
||||
|
||||
for name, env_var in keys.items():
|
||||
|
||||
+353
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Local OpenAI-compatible server implementation for Hermes-Agent (Atropos integration).
|
||||
|
||||
Extends the Atropos APIServer to work with local OpenAI-compatible APIs (e.g. vLLM, SGLang),
|
||||
providing tokens_and_logprobs_completion support via client-side tokenization.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import openai
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.completion import Completion
|
||||
|
||||
from atroposlib.envs.server_handling.server_baseline import (
|
||||
APIServer,
|
||||
APIServerConfig,
|
||||
ReasoningConfig,
|
||||
)
|
||||
|
||||
|
||||
class LocalServer(APIServer):
|
||||
"""
|
||||
OpenAI-compatible local server with tokens_and_logprobs support.
|
||||
|
||||
Uses an OpenAI-compatible API (typically at a /v1 endpoint) and handles
|
||||
token extraction via client-side tokenization.
|
||||
|
||||
Note: Many local servers don't return per-token logprobs in the standard API,
|
||||
so this implementation uses placeholder logprobs (0.0) for PoC purposes.
|
||||
For production training, use vLLM/SGLang servers that return real logprobs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: APIServerConfig,
|
||||
tokenizer: Optional[Any] = None,
|
||||
tokenizer_name: str = "gpt2",
|
||||
reasoning_config: Optional[ReasoningConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the local server.
|
||||
|
||||
Args:
|
||||
config: Server configuration
|
||||
tokenizer: Pre-initialized tokenizer (optional)
|
||||
tokenizer_name: Name of tokenizer to load if tokenizer not provided
|
||||
reasoning_config: Optional reasoning configuration
|
||||
"""
|
||||
# Build the OpenAI client pointing to the server's /v1 endpoint
|
||||
base_url = config.base_url
|
||||
if base_url and not base_url.endswith("/v1"):
|
||||
base_url = f"{base_url.rstrip('/')}/v1"
|
||||
|
||||
self.openai = openai.AsyncClient(
|
||||
api_key=config.api_key or "local", # Local servers often ignore auth
|
||||
base_url=base_url,
|
||||
timeout=config.timeout,
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
if tokenizer is not None:
|
||||
self.tokenizer = tokenizer
|
||||
else:
|
||||
try:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ModuleNotFoundError(
|
||||
"Missing optional dependency 'transformers'. Pass a tokenizer instance to LocalServer, "
|
||||
"or install transformers to enable `tokenizer_name` auto-loading."
|
||||
) from exc
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
|
||||
# Add a simple chat template if the tokenizer doesn't have one
|
||||
# This is needed for ManagedServer's chat_completion to work
|
||||
if not hasattr(self.tokenizer, 'chat_template') or self.tokenizer.chat_template is None:
|
||||
# Simple ChatML-style template
|
||||
self.tokenizer.chat_template = (
|
||||
"{% for message in messages %}"
|
||||
"{% if message['role'] == 'system' %}<|im_start|>system\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'user' %}<|im_start|>user\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% elif message['role'] == 'assistant' %}<|im_start|>assistant\n{{ message['content'] }}<|im_end|>\n"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
"{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
|
||||
)
|
||||
|
||||
super().__init__(config, reasoning_config=reasoning_config)
|
||||
# Local servers are treated as always-healthy unless a status task is enabled.
|
||||
self.server_healthy = True
|
||||
|
||||
@classmethod
|
||||
def from_env(
|
||||
cls,
|
||||
base_url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
tokenizer_name: str = "gpt2",
|
||||
**kwargs,
|
||||
) -> "LocalServer":
|
||||
"""
|
||||
Create a LocalServer from environment variables (or explicit overrides).
|
||||
|
||||
Env vars (checked in order):
|
||||
- base URL: ATROPOS_SERVER_BASE_URL, OPENAI_BASE_URL, LOCAL_LLM_BASE_URL, LLM_BASE_URL
|
||||
- model: ATROPOS_SERVER_MODEL, LLM_MODEL, LOCAL_LLM_MODEL
|
||||
- api key: ATROPOS_SERVER_API_KEY, OPENAI_API_KEY, LOCAL_LLM_API_KEY, LLM_API_KEY
|
||||
"""
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
base_url = (
|
||||
base_url
|
||||
or os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("LOCAL_LLM_BASE_URL")
|
||||
or os.getenv("LLM_BASE_URL")
|
||||
or "http://localhost:11434"
|
||||
)
|
||||
model = (
|
||||
model
|
||||
or os.getenv("ATROPOS_SERVER_MODEL")
|
||||
or os.getenv("LLM_MODEL")
|
||||
or os.getenv("LOCAL_LLM_MODEL")
|
||||
or "hermes3:8b"
|
||||
)
|
||||
api_key = (
|
||||
api_key
|
||||
or os.getenv("ATROPOS_SERVER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("LOCAL_LLM_API_KEY")
|
||||
or os.getenv("LLM_API_KEY")
|
||||
)
|
||||
|
||||
config = APIServerConfig(
|
||||
model_name=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key or "local",
|
||||
timeout=kwargs.get("timeout", 120),
|
||||
num_max_requests_at_once=kwargs.get("num_max_requests_at_once", 4),
|
||||
num_requests_for_eval=kwargs.get("num_requests_for_eval", 4),
|
||||
health_check=False, # Local dev servers often lack /health
|
||||
)
|
||||
|
||||
return cls(config, tokenizer_name=tokenizer_name)
|
||||
|
||||
async def check_server_status_task(self, chat_completion: bool = True):
|
||||
"""
|
||||
Check if the server is healthy.
|
||||
|
||||
For local development, we generally assume the server is healthy.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
# Simple health check via a minimal completion
|
||||
if chat_completion:
|
||||
await self.openai.chat.completions.create(
|
||||
model=self.config.model_name,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
else:
|
||||
await self.openai.completions.create(
|
||||
model=self.config.model_name,
|
||||
prompt="hi",
|
||||
max_tokens=1,
|
||||
)
|
||||
self.server_healthy = True
|
||||
except Exception:
|
||||
self.server_healthy = False
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
|
||||
"""
|
||||
Wrapper for chat completion using an OpenAI-compatible API.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("messages") is not None, "Messages are required!"
|
||||
|
||||
n = kwargs.get("n", 1)
|
||||
|
||||
# Some OpenAI-compatible servers don't support n > 1, so we make multiple requests.
|
||||
if n > 1:
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.chat.completions.create(**{**kwargs, "n": 1}) for _ in range(n)]
|
||||
)
|
||||
# Merge completions
|
||||
completions = completion_list[0]
|
||||
for c in completion_list[1:]:
|
||||
for choice in c.choices:
|
||||
choice.index = len(completions.choices)
|
||||
completions.choices.append(choice)
|
||||
return completions
|
||||
else:
|
||||
return await self.openai.chat.completions.create(**kwargs)
|
||||
|
||||
async def _completion_wrapper(self, **kwargs) -> Completion:
|
||||
"""
|
||||
Wrapper for completion using an OpenAI-compatible API.
|
||||
"""
|
||||
assert kwargs.get("model") is not None, "Model is required!"
|
||||
assert kwargs.get("prompt") is not None, "Prompt is required!"
|
||||
|
||||
n = kwargs.get("n", 1)
|
||||
|
||||
# Some OpenAI-compatible servers don't support n > 1.
|
||||
if n > 1:
|
||||
completion_list = await asyncio.gather(
|
||||
*[self.openai.completions.create(**{**kwargs, "n": 1}) for _ in range(n)]
|
||||
)
|
||||
completions = completion_list[0]
|
||||
for c in completion_list[1:]:
|
||||
for choice in c.choices:
|
||||
choice.index = len(completions.choices)
|
||||
completions.choices.append(choice)
|
||||
return completions
|
||||
else:
|
||||
return await self.openai.completions.create(**kwargs)
|
||||
|
||||
async def _tokens_and_logprobs_completion_wrapper(
|
||||
self, **kwargs
|
||||
) -> tuple[List[int], List[List[int]], List[List[float]], List[str]]:
|
||||
"""
|
||||
Wrapper for tokens and logprobs completion.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons)
|
||||
|
||||
Note: Many OpenAI-compatible local servers don't return per-token logprobs,
|
||||
so we use placeholder logprobs (0.0). For real training, use vLLM/SGLang.
|
||||
"""
|
||||
model = kwargs.get("model")
|
||||
assert model is not None, "Model is required!"
|
||||
|
||||
# Handle input_ids (from ManagedServer) or prompt
|
||||
if "input_ids" in kwargs:
|
||||
prompt_tokens = kwargs.pop("input_ids")
|
||||
prompt = self.tokenizer.decode(prompt_tokens)
|
||||
kwargs.pop("prompt", None)
|
||||
else:
|
||||
prompt = kwargs.pop("prompt", "")
|
||||
prompt_tokens = self.tokenizer.encode(prompt, add_special_tokens=True)
|
||||
|
||||
n = kwargs.pop("n", 1)
|
||||
max_tokens = kwargs.pop("max_tokens", 256)
|
||||
temperature = kwargs.pop("temperature", 0.7)
|
||||
stop = kwargs.pop("stop", None)
|
||||
|
||||
# Make completion requests
|
||||
completions = []
|
||||
for _ in range(n):
|
||||
try:
|
||||
response = await self.openai.completions.create(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
)
|
||||
completions.append(response)
|
||||
except Exception as e:
|
||||
# Fallback to chat completion if completion endpoint not supported
|
||||
warnings.warn(f"Completion API failed, trying chat: {e}")
|
||||
response = await self.openai.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
stop=stop,
|
||||
)
|
||||
# Convert to completion-like response
|
||||
completions.append(response)
|
||||
|
||||
output_tokens_list = []
|
||||
output_logprobs_list = []
|
||||
finish_reasons = []
|
||||
|
||||
for completion in completions:
|
||||
# Extract text from response
|
||||
if hasattr(completion.choices[0], "text"):
|
||||
# Completion API response
|
||||
text = completion.choices[0].text
|
||||
finish_reason = completion.choices[0].finish_reason or "stop"
|
||||
else:
|
||||
# Chat completion API response
|
||||
text = completion.choices[0].message.content or ""
|
||||
finish_reason = completion.choices[0].finish_reason or "stop"
|
||||
|
||||
# Tokenize output
|
||||
output_tokens = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Placeholder logprobs (many local servers don't provide per-token logprobs).
|
||||
# In production, use vLLM/SGLang which return real logprobs
|
||||
output_logprobs = [0.0] * len(output_tokens)
|
||||
|
||||
output_tokens_list.append(output_tokens)
|
||||
output_logprobs_list.append(output_logprobs)
|
||||
finish_reasons.append(finish_reason)
|
||||
|
||||
return prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons
|
||||
|
||||
def managed_server(self, tokenizer=None, track_tree: bool = False):
|
||||
"""
|
||||
Create a ManagedServer context manager for this server.
|
||||
|
||||
Args:
|
||||
tokenizer: Optional tokenizer override
|
||||
track_tree: Whether to maintain tree structure for multi-turn
|
||||
|
||||
Returns:
|
||||
ManagedServer context manager
|
||||
"""
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
return ManagedServerContext(
|
||||
self,
|
||||
tokenizer=tokenizer or self.tokenizer,
|
||||
track_tree=track_tree,
|
||||
)
|
||||
|
||||
|
||||
class ManagedServerContext:
|
||||
"""
|
||||
Context manager wrapper for ManagedServer.
|
||||
|
||||
Usage:
|
||||
async with server.managed_server(tokenizer=tokenizer) as managed:
|
||||
response = await managed.chat_completion(...)
|
||||
state = managed.get_state()
|
||||
"""
|
||||
|
||||
def __init__(self, server: LocalServer, tokenizer, track_tree: bool = False):
|
||||
self.server = server
|
||||
self.tokenizer = tokenizer
|
||||
self.track_tree = track_tree
|
||||
self.managed = None
|
||||
|
||||
async def __aenter__(self):
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
self.managed = ManagedServer(
|
||||
self.server,
|
||||
tokenizer=self.tokenizer,
|
||||
track_tree=self.track_tree,
|
||||
)
|
||||
return self.managed
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.managed:
|
||||
self.managed.reset()
|
||||
return False
|
||||
@@ -0,0 +1,61 @@
|
||||
# Active Context
|
||||
|
||||
## Current Focus
|
||||
Tinker RL training integration - pipeline fully wired up, waiting on Tinker billing to test.
|
||||
|
||||
## Recently Completed (Feb 9, 2026)
|
||||
|
||||
### Tinker RL Training Integration
|
||||
Created a complete agent training pipeline using Tinker (Thinking Machines) + Atropos:
|
||||
|
||||
**New Files Created:**
|
||||
1. `tinker-atropos/tinker_atropos/environments/gsm8k_agent.py` - Agent GSM8k environment with:
|
||||
- Python REPL tool calling (Hermes-style `<tool_call>` format)
|
||||
- Multi-step agent loop within `collect_trajectories()`
|
||||
- Math answer verification via `math_verify`
|
||||
- Subprocess-based Python execution
|
||||
- WandB metrics (percent_correct, tool_use_rate)
|
||||
2. `tinker-atropos/configs/gsm8k_agent.yaml` - Config for Qwen3-4B-Instruct training
|
||||
|
||||
**Dependencies Updated:**
|
||||
- `pyproject.toml` `[atropos]` extra now includes: tinker SDK, torch, wandb, math-verify
|
||||
- Installed: tinker 0.12.0, tinker-atropos 0.1.0, torch (CPU)
|
||||
|
||||
**README Updated:**
|
||||
- Added comprehensive "RL Training with Tinker" section with architecture diagram, quick start, config docs
|
||||
- Added TINKER_API_KEY and WANDB_API_KEY to optional keys table
|
||||
|
||||
**Verified Working:**
|
||||
- Tinker SDK connection ✅
|
||||
- All imports (tinker, tinker_atropos, trainer, environment) ✅
|
||||
- Python REPL execution + tool call parsing ✅
|
||||
- Math verification ✅
|
||||
- Atropos run-api (port 8000) ✅
|
||||
- Tinker trainer starts, loads config, creates inference server (port 8001) ✅
|
||||
|
||||
**Blocked:** Tinker billing (402 error) - user's payment didn't process (possibly regional card issue)
|
||||
|
||||
### Main Branch Merge (Feb 9, 2026)
|
||||
Merged `origin/main` into `atropos-integrations` - 22,560 lines, 79 files, 5 conflicts resolved.
|
||||
|
||||
### Modal Backend (Feb 8, 2026)
|
||||
Merged modal-integration branch, working with Modal Sandboxes.
|
||||
|
||||
### Singularity/Apptainer (Feb 6, 2026)
|
||||
Completed and tested.
|
||||
|
||||
## Architecture: Training Pipeline
|
||||
|
||||
```
|
||||
Terminal 1: run-api (port 8000) - Atropos Rollout API
|
||||
Terminal 2: launch_training.py (port 8001) - Tinker Trainer + FastAPI inference
|
||||
Terminal 3: gsm8k_agent.py serve - Environment (generates trajectories)
|
||||
```
|
||||
|
||||
The agent env gets math problems → model calls Python REPL tool → scores answer → sends to Atropos → Tinker does LoRA training → updates sampling weights → repeat.
|
||||
|
||||
## Next Steps
|
||||
- [ ] Resolve Tinker billing to test full training loop
|
||||
- [ ] Run GSM8k agent training for ~20 steps (proof of concept)
|
||||
- [ ] Monitor WandB for reward improvement
|
||||
- [ ] Graduate to more complex agent envs (SWE tasks with Modal backend)
|
||||
@@ -0,0 +1,55 @@
|
||||
# Product Context: Hermes-Agent
|
||||
|
||||
## Why This Project Exists
|
||||
|
||||
Hermes-Agent addresses several key challenges in the AI agent space:
|
||||
|
||||
1. **Unified Tool Interface** - Provides a clean, consistent interface for LLMs to use various tools (web, terminal, browser, vision, etc.) without requiring custom integration for each model provider.
|
||||
|
||||
2. **Training Data Generation** - Enables efficient generation of high-quality tool-calling trajectories for fine-tuning LLMs, with features like batch processing, checkpointing, and trajectory compression.
|
||||
|
||||
3. **Flexible Deployment** - Supports multiple execution environments (local, Docker, Singularity, Modal, SSH) to accommodate different security and isolation requirements.
|
||||
|
||||
4. **Developer Experience** - Offers a beautiful, interactive CLI with kawaii-style feedback that makes working with AI agents enjoyable.
|
||||
|
||||
## Problems It Solves
|
||||
|
||||
### For AI Researchers
|
||||
- **Data Generation at Scale**: Parallel batch processing with content-based checkpointing for fault tolerance
|
||||
- **Clean Trajectories**: Trajectory compression to fit token budgets while preserving important information
|
||||
- **Toolset Distributions**: Probability-based tool selection for varied training data
|
||||
|
||||
### For Developers
|
||||
- **Tool Orchestration**: Logical grouping of tools into toolsets (research, development, debugging, etc.)
|
||||
- **Session Persistence**: Conversation history and session logging for debugging
|
||||
- **Multi-Model Support**: Works with any OpenAI-compatible API (OpenRouter, local models, etc.)
|
||||
|
||||
### For MLOps
|
||||
- **Skills System**: On-demand knowledge documents for specific tools/frameworks (Axolotl, vLLM, TRL, etc.)
|
||||
- **Sandboxed Execution**: Terminal commands can run in isolated environments (Docker, Singularity, Modal)
|
||||
- **Configurable Backends**: Easy switching between local and cloud execution
|
||||
|
||||
## How It Should Work
|
||||
|
||||
### User Flow (CLI)
|
||||
1. User launches `./hermes`
|
||||
2. Beautiful welcome banner displays with caduceus logo, model info, and available tools
|
||||
3. User types a natural language request
|
||||
4. Agent processes request, potentially calling tools with animated feedback
|
||||
5. Agent responds with results, conversation continues
|
||||
6. Session is automatically logged for debugging
|
||||
|
||||
### User Flow (Batch Processing)
|
||||
1. User prepares JSONL file with prompts
|
||||
2. Runs `batch_runner.py` with distribution and worker count
|
||||
3. System processes prompts in parallel, saves checkpoints
|
||||
4. Completed trajectories saved to `data/<run_name>/trajectories.jsonl`
|
||||
5. Optional: compress trajectories with `trajectory_compressor.py`
|
||||
|
||||
## User Experience Goals
|
||||
|
||||
- **Delightful Interaction**: Kawaii ASCII faces, animated spinners, cute messages
|
||||
- **Informative Feedback**: Clear progress indication during tool execution
|
||||
- **Configurable Personalities**: From "helpful" to "pirate" to "Shakespeare"
|
||||
- **Easy Configuration**: YAML config file + environment variables + CLI flags
|
||||
- **Graceful Degradation**: Missing tools/APIs don't break the system, just disable features
|
||||
@@ -0,0 +1,96 @@
|
||||
# Progress
|
||||
|
||||
## Completed Features
|
||||
|
||||
### ✅ Modal Backend Integration (Feb 8, 2026 - MERGED & TESTED)
|
||||
Merged the `modal-integration` branch and fixed integration issues.
|
||||
|
||||
**What Works:**
|
||||
- `ModalToolBackend` implements full `ToolBackend` interface (start, stop, acquire, release, execute_batch)
|
||||
- Modal Sandboxes used for long-lived containers (not Functions)
|
||||
- `sandbox.exec()` for direct command execution (no HTTP server needed)
|
||||
- Slot-based multiplexing matching Nomad pattern
|
||||
- Multi-profile support (`ModalSandboxConfig`, `_ModalMultiProfileManager`)
|
||||
- YAML profile loading (`modal_profiles.yaml`)
|
||||
- `AgentEnvConfig` fields for all Modal settings (`--env.modal_*`)
|
||||
- `create_tool_backend()` supports `tool_pool_mode="modal"`
|
||||
- Terminal tool (`tools/terminal_tool.py`) native Modal integration with pool management
|
||||
- Named sandbox recovery via `Sandbox.from_name()`
|
||||
- Auto-scaling sandbox pool per profile
|
||||
- Artifact helpers (read, list, archive)
|
||||
|
||||
**CLI Usage:**
|
||||
```bash
|
||||
# Atropos backend
|
||||
python -m atropos.envs.swe_smith_oracle_env process \
|
||||
--env.tool_pool_mode modal \
|
||||
--env.modal_image python:3.11
|
||||
|
||||
# Terminal tool
|
||||
TERMINAL_ENV=modal ./hermes
|
||||
```
|
||||
|
||||
**Files Modified/Created:**
|
||||
- `atropos/backends/modal_backend.py` - Full implementation (~1200 lines)
|
||||
- `atropos/backends/__init__.py` - `create_tool_backend()` updated
|
||||
- `atropos/envs/agent_env.py` - 15 Modal config fields added
|
||||
- `tools/terminal_tool.py` - Native Modal sandbox pool
|
||||
- `docs/MODAL_BACKEND.md` - Documentation
|
||||
- `modal_profiles.yaml.example` - Example profiles
|
||||
- `tests/test_modal_integration.py` - Integration tests
|
||||
- `tests/test_modal_stress.py` - Stress tests
|
||||
- `tests/test_modal_terminal.py` - Terminal tool tests
|
||||
|
||||
### ✅ Singularity/Apptainer Sandbox Integration (Feb 6, 2026 - FULLY TESTED)
|
||||
Adapted the Atropos sandbox environment from Docker to Singularity/Apptainer for HPC clusters.
|
||||
|
||||
**What Works:**
|
||||
- `create_sandbox_job()` supports both `driver="docker"` and `driver="singularity"`
|
||||
- SlotPoolConfig and NomadBackendConfig propagate driver settings
|
||||
- Singularity container runs sandbox_server.py via Nomad's raw_exec driver
|
||||
- All sandbox operations work: bash execution, file read/write
|
||||
- **CLI arguments** `--env.driver` and `--env.singularity_image` for AgentEnvConfig
|
||||
- **Static port binding** for Singularity (ReservedPorts vs DynamicPorts)
|
||||
|
||||
### ✅ Memory Bank Initialized (Feb 5, 2026)
|
||||
Set up project documentation structure for context persistence.
|
||||
|
||||
## In Progress
|
||||
None currently.
|
||||
|
||||
## Known Issues
|
||||
- Modal backend not yet live-tested with actual Modal cloud credentials
|
||||
- `bwrap_available: false` in Singularity containers
|
||||
- Health check timing - may need longer wait for container startup on slower systems
|
||||
|
||||
## What's Left to Build
|
||||
|
||||
### Modal Backend
|
||||
- [ ] Live test with Modal credentials on actual cloud
|
||||
- [ ] Test multi-profile GPU workflows
|
||||
- [ ] Test sandbox recovery after restart
|
||||
- [ ] Integrate with SWE-smith-oracle env for GRPO training loop
|
||||
- [ ] Performance benchmarking vs Nomad backend
|
||||
|
||||
### HPC Deployment
|
||||
- [ ] Test on actual HPC cluster with Slurm/PBS integration
|
||||
- [ ] Document cluster-specific deployment procedures
|
||||
|
||||
### Documentation
|
||||
- [ ] Add Singularity deployment to README
|
||||
- [ ] Create HPC deployment skill in skills/mlops/
|
||||
|
||||
## Evolution of Decisions
|
||||
|
||||
### Container Runtime Selection
|
||||
- **Initial**: Docker-only via Nomad docker driver
|
||||
- **Problem**: HPC clusters don't allow Docker without sudo
|
||||
- **Solution**: Added Singularity/Apptainer support via raw_exec driver
|
||||
- **Result**: Both runtimes now supported with same API
|
||||
|
||||
### Modal Backend Architecture
|
||||
- **Initial**: Stub placeholder raising RuntimeError
|
||||
- **Investigation**: Modal Sandboxes vs Functions - chose Sandboxes for long-lived containers
|
||||
- **Design**: Direct `sandbox.exec()` instead of HTTP/sandbox_server.py (simpler, no networking needed)
|
||||
- **Implementation**: Merged from `modal-integration` branch, fixed agent_env.py config fields
|
||||
- **Result**: Three backends now supported: Nomad/Docker, Nomad/Singularity, Modal
|
||||
@@ -0,0 +1,44 @@
|
||||
# Project Brief: Hermes-Agent
|
||||
|
||||
## Overview
|
||||
Hermes-Agent is an AI agent harness for LLMs with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools. Named after Hermes, the Greek messenger god, it serves as a bridge between human intent and AI-powered task execution.
|
||||
|
||||
## Core Requirements
|
||||
|
||||
### Primary Goals
|
||||
1. **Interactive CLI Experience** - Beautiful terminal interface with animated feedback, personalities, and session management
|
||||
2. **Flexible Tool System** - Modular tools organized into logical toolsets for different use cases
|
||||
3. **Batch Processing** - Process multiple prompts in parallel with checkpointing and statistics
|
||||
4. **Multi-Backend Support** - Support for local, Docker, Singularity, Modal, and SSH terminal backends
|
||||
5. **Training Data Generation** - Save conversation trajectories in formats suitable for LLM fine-tuning
|
||||
|
||||
### Target Users
|
||||
- AI researchers generating training data
|
||||
- Developers needing an AI assistant with tool access
|
||||
- MLOps practitioners automating workflows
|
||||
- Anyone needing a powerful CLI-based AI agent
|
||||
|
||||
## Scope
|
||||
|
||||
### In Scope
|
||||
- Interactive CLI with rich formatting and kawaii-style feedback
|
||||
- Web tools (search, extract, crawl via Firecrawl)
|
||||
- Terminal tools (command execution across multiple backends)
|
||||
- Browser automation (via agent-browser + Browserbase)
|
||||
- Vision tools (image analysis)
|
||||
- Image generation (FLUX via FAL.ai)
|
||||
- Mixture-of-Agents reasoning
|
||||
- Skills system for on-demand knowledge
|
||||
- Batch processing with parallel workers
|
||||
- Trajectory compression for training
|
||||
|
||||
### Out of Scope (Current)
|
||||
- Proactive suggestions (agent only runs on request)
|
||||
- Clipboard integration (no local system access)
|
||||
- Real-time streaming of thinking/reasoning (deferred)
|
||||
|
||||
## Success Metrics
|
||||
- Clean, maintainable tool architecture
|
||||
- Reliable tool execution with proper error handling
|
||||
- Efficient context management for long conversations
|
||||
- High-quality trajectory data for training
|
||||
@@ -0,0 +1,191 @@
|
||||
# System Patterns: Hermes-Agent
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ CLI (cli.py) │
|
||||
│ - Rich welcome banner with caduceus │
|
||||
│ - prompt_toolkit for input with history │
|
||||
│ - Kawaii-style feedback and personalities │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ AIAgent (run_agent.py) │
|
||||
│ - Conversation loop with tool calling │
|
||||
│ - KawaiiSpinner for animated feedback │
|
||||
│ - Retry logic with exponential backoff │
|
||||
│ - Session logging to logs/ directory │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Tool Routing (model_tools.py) │
|
||||
│ - get_tool_definitions() - returns tools for API calls │
|
||||
│ - handle_function_call() - dispatches to tool handlers │
|
||||
│ - Toolset filtering (enabled/disabled) │
|
||||
└────────────────────────────┬────────────────────────────────────┘
|
||||
│
|
||||
┌─────────────────┼─────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌───────────┐ ┌───────────┐ ┌───────────┐
|
||||
│ Web Tools │ │ Terminal │ │ Browser │
|
||||
│ (Firecrawl)│ │ (mini-swe)│ │(agent-brw)│
|
||||
└───────────┘ └───────────┘ └───────────┘
|
||||
│ │ │
|
||||
└─────────────────┼─────────────────┘
|
||||
▼
|
||||
┌───────────────┐
|
||||
│ Toolsets │
|
||||
│ (toolsets.py)│
|
||||
│ Composition │
|
||||
└───────────────┘
|
||||
```
|
||||
|
||||
## Key Design Patterns
|
||||
|
||||
### 1. Toolset Composition Pattern
|
||||
Toolsets can include other toolsets, allowing flexible composition:
|
||||
|
||||
```python
|
||||
TOOLSETS = {
|
||||
"web": {"tools": ["web_search", "web_extract"], "includes": []},
|
||||
"debugging": {"tools": ["terminal"], "includes": ["web"]},
|
||||
"full_stack": {"tools": [], "includes": ["web", "terminal", "vision", "browser"]}
|
||||
}
|
||||
```
|
||||
|
||||
Resolution is recursive with cycle detection.
|
||||
|
||||
### 2. Graceful Degradation Pattern
|
||||
Each tool module has a `check_*_requirements()` function:
|
||||
- Tools are only loaded if requirements are met
|
||||
- Missing API keys disable tools, not crash the system
|
||||
- Import errors are caught and tools marked unavailable
|
||||
|
||||
```python
|
||||
try:
|
||||
from tools.web_tools import web_search_tool, check_firecrawl_api_key
|
||||
except ModuleNotFoundError:
|
||||
web_search_tool = None
|
||||
def check_firecrawl_api_key(): return False
|
||||
```
|
||||
|
||||
### 3. Session Isolation Pattern (task_id)
|
||||
Stateful tools (terminal, browser) use `task_id` to isolate concurrent sessions:
|
||||
- Each batch worker gets unique task_id
|
||||
- VMs and browser sessions are tracked per task_id
|
||||
- Cleanup functions release resources: `cleanup_vm(task_id)`, `cleanup_browser(task_id)`
|
||||
|
||||
### 4. Trajectory Format Pattern
|
||||
Conversations are saved in ShareGPT format for training:
|
||||
|
||||
```json
|
||||
{"from": "system", "value": "System prompt with <tools>...</tools>"}
|
||||
{"from": "human", "value": "User message"}
|
||||
{"from": "gpt", "value": "<think>reasoning</think>\n<tool_call>{...}</tool_call>"}
|
||||
{"from": "tool", "value": "<tool_response>{...}</tool_response>"}
|
||||
{"from": "gpt", "value": "Final response"}
|
||||
```
|
||||
|
||||
### 5. Ephemeral System Prompt Pattern
|
||||
Guide model behavior during data collection without saving to trajectories:
|
||||
- `ephemeral_system_prompt` influences execution
|
||||
- Only standard tool-calling system prompt saved to trajectories
|
||||
- Keeps training data clean
|
||||
|
||||
### 6. Retry with Validation Pattern
|
||||
The agent validates responses before accepting:
|
||||
- Check tool names against `valid_tool_names` set
|
||||
- Validate JSON arguments can be parsed
|
||||
- Check for content after `<think>` blocks
|
||||
- Roll back to last valid state on persistent failures
|
||||
|
||||
## Component Relationships
|
||||
|
||||
### AIAgent Class
|
||||
- Central orchestrator for conversations
|
||||
- Manages conversation history
|
||||
- Calls OpenAI-compatible API
|
||||
- Routes tool calls to handlers
|
||||
- Provides animated feedback (KawaiiSpinner)
|
||||
|
||||
### Tool Modules (tools/*.py)
|
||||
- Self-contained tool implementations
|
||||
- Export: handler function + check function + schema
|
||||
- Return JSON strings (never raw dicts)
|
||||
- Accept optional `task_id` for stateful tools
|
||||
|
||||
### Toolsets System (toolsets.py)
|
||||
- Defines logical groupings of tools
|
||||
- Supports composition via `includes`
|
||||
- `resolve_toolset()` recursively resolves all tools
|
||||
- `validate_toolset()` checks if name is valid
|
||||
|
||||
### Model Tools (model_tools.py)
|
||||
- Aggregates all tool definitions
|
||||
- Routes function calls to correct handlers
|
||||
- Filters tools based on enabled/disabled toolsets
|
||||
- Bridge between agent and tool implementations
|
||||
|
||||
## Critical Implementation Paths
|
||||
|
||||
### Tool Execution Flow
|
||||
1. AIAgent receives tool_calls from API response
|
||||
2. Validates tool names against `valid_tool_names`
|
||||
3. Validates JSON arguments can be parsed
|
||||
4. Calls `handle_function_call()` with tool name, args, task_id
|
||||
5. `handle_function_call()` routes to appropriate handler
|
||||
6. Tool executes, returns JSON string
|
||||
7. Result added to conversation as tool message
|
||||
8. Loop continues until natural language response
|
||||
|
||||
### Configuration Loading Flow
|
||||
1. `cli.py` calls `load_cli_config()`
|
||||
2. Loads `cli-config.yaml`, merges with defaults
|
||||
3. Sets environment variables for terminal config
|
||||
4. `AIAgent` reads env vars when initializing terminal tool
|
||||
5. Terminal tool creates appropriate backend based on `TERMINAL_ENV`
|
||||
|
||||
## Atropos Backend Architecture
|
||||
|
||||
### Backend Hierarchy
|
||||
```
|
||||
ToolBackend (Protocol - base.py)
|
||||
├── NomadToolBackend → SlotPool → NomadClient + SandboxExecutor (HTTP)
|
||||
│ ├── Docker driver (default)
|
||||
│ └── Singularity driver (HPC)
|
||||
└── ModalToolBackend → _ModalSandboxPool → modal.Sandbox.exec() (direct)
|
||||
└── _ModalMultiProfileManager (multi-profile support)
|
||||
```
|
||||
|
||||
### Slot-Based Multiplexing Pattern
|
||||
All backends share the same slot multiplexing concept:
|
||||
- **Sandbox/Container**: Long-lived compute unit
|
||||
- **Slot**: Isolated workspace directory within a sandbox (e.g., `/data/slot_0`)
|
||||
- **Trajectory**: One agent task using one slot
|
||||
- Multiple trajectories share a sandbox via different slots
|
||||
|
||||
### Nomad Backend (HTTP-based)
|
||||
- Deploys `sandbox_server.py` inside containers (Docker or Singularity)
|
||||
- Uses `SandboxExecutor` for HTTP communication (POST /execute, POST /batch)
|
||||
- Nomad manages container lifecycle (scaling, health checks)
|
||||
- Tools: bash, bash_stateful, read_file, write_file, tmux
|
||||
|
||||
### Modal Backend (exec-based)
|
||||
- Creates `modal.Sandbox` instances (long-lived containers)
|
||||
- Uses `sandbox.exec("bash", "-c", command)` directly (no HTTP server)
|
||||
- Modal manages container lifecycle (idle_timeout, max_lifetime)
|
||||
- Multi-profile support: different resource configs (CPU, GPU, memory)
|
||||
- Named sandboxes for recovery: `Sandbox.from_name(app_name, sandbox_name)`
|
||||
- YAML config via `modal_profiles.yaml`
|
||||
|
||||
### Backend Selection
|
||||
```python
|
||||
# In agent_env.py / create_tool_backend()
|
||||
if mode == "nomad":
|
||||
return NomadToolBackend(NomadBackendConfig.from_agent_env_config(cfg))
|
||||
if mode == "modal":
|
||||
return ModalToolBackend(ModalSandboxConfig.from_agent_env_config(cfg))
|
||||
```
|
||||
@@ -0,0 +1,113 @@
|
||||
# Technical Context: Hermes-Agent
|
||||
|
||||
## Technologies Used
|
||||
|
||||
### Core Stack
|
||||
- **Python 3.11+** - Primary language
|
||||
- **OpenAI SDK** - For LLM API interactions (OpenAI-compatible)
|
||||
- **OpenRouter** - Default LLM provider (supports multiple models)
|
||||
- **Rich** - Terminal formatting and panels
|
||||
- **prompt_toolkit** - Interactive input with history
|
||||
- **Fire** - CLI argument parsing
|
||||
- **PyYAML** - Configuration files
|
||||
- **python-dotenv** - Environment variable management
|
||||
|
||||
### Tool Dependencies
|
||||
- **Firecrawl** - Web search and extraction (`FIRECRAWL_API_KEY`)
|
||||
- **mini-swe-agent** - Terminal tool backend (local/docker/singularity/modal/ssh)
|
||||
- **agent-browser** - Browser automation (npm package)
|
||||
- **Browserbase** - Cloud browser execution (`BROWSERBASE_API_KEY`)
|
||||
- **FAL.ai** - Image generation with FLUX (`FAL_KEY`)
|
||||
- **Nous API** - Vision and MoA tools (`NOUS_API_KEY`)
|
||||
|
||||
### Optional Dependencies
|
||||
- **Modal** - Cloud compute for sandboxed environments
|
||||
- **Singularity/Apptainer** - Rootless containers (HPC environments)
|
||||
- **Docker** - Container isolation
|
||||
|
||||
## Development Setup
|
||||
|
||||
### Quick Start
|
||||
```bash
|
||||
# Clone with submodules
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Create virtual environment
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Install browser tools (optional)
|
||||
npm install
|
||||
|
||||
# Configure environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your API keys
|
||||
```
|
||||
|
||||
### Key Configuration Files
|
||||
- `.env` - API keys and secrets
|
||||
- `cli-config.yaml` - CLI configuration (model, terminal, toolsets, personalities)
|
||||
- `configs/` - Batch run scripts and configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
**Required for Full Functionality:**
|
||||
- `OPENROUTER_API_KEY` - Primary LLM access
|
||||
- `FIRECRAWL_API_KEY` - Web tools
|
||||
- `NOUS_API_KEY` - Vision and reasoning tools
|
||||
- `FAL_KEY` - Image generation
|
||||
|
||||
**Terminal Backend:**
|
||||
- `TERMINAL_ENV` - Backend type: `local`, `docker`, `singularity`, `modal`, `ssh`
|
||||
- `TERMINAL_CWD` - Working directory
|
||||
- `TERMINAL_DOCKER_IMAGE` / `TERMINAL_SINGULARITY_IMAGE` - Container images
|
||||
- `TERMINAL_SSH_HOST/USER/KEY` - SSH backend config
|
||||
- `SUDO_PASSWORD` - Optional sudo support
|
||||
|
||||
**Browser:**
|
||||
- `BROWSERBASE_API_KEY` - Browser automation
|
||||
- `BROWSERBASE_PROJECT_ID` - Browserbase project
|
||||
|
||||
## Technical Constraints
|
||||
|
||||
1. **Context Window Limits** - Long tool outputs can exhaust context; trajectory compression helps
|
||||
2. **API Rate Limits** - OpenRouter and tool APIs have rate limits; exponential backoff implemented
|
||||
3. **Tool Availability** - Tools gracefully degrade if dependencies/keys missing
|
||||
4. **Async Compatibility** - Some tools are async, handled via `asyncio.run()` in sync context
|
||||
|
||||
## Dependency Graph
|
||||
|
||||
```
|
||||
tools/*.py → tools/__init__.py → model_tools.py → toolsets.py → toolset_distributions.py
|
||||
↑
|
||||
run_agent.py ──────────────────────────┘
|
||||
cli.py → run_agent.py (uses AIAgent with quiet_mode=True)
|
||||
batch_runner.py → run_agent.py + toolset_distributions.py
|
||||
```
|
||||
|
||||
## Tool Usage Patterns
|
||||
|
||||
### Adding a New Tool
|
||||
1. Create `tools/your_tool.py` with handler + requirements check
|
||||
2. Export in `tools/__init__.py`
|
||||
3. Register in `model_tools.py` (definitions + handler routing)
|
||||
4. Add to toolset in `toolsets.py`
|
||||
5. Optionally add to `toolset_distributions.py` for batch processing
|
||||
|
||||
### Tool Handler Pattern
|
||||
```python
|
||||
def your_tool(param: str, task_id: str = None) -> str:
|
||||
"""Execute tool and return JSON string result."""
|
||||
try:
|
||||
result = {"success": True, "data": "..."}
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
```
|
||||
|
||||
All tool handlers MUST return a JSON string, never raw dicts.
|
||||
+1
-1
Submodule mini-swe-agent updated: 07aa6a7385...9ddd61b62d
@@ -0,0 +1,134 @@
|
||||
# Modal Sandbox Profiles Configuration
|
||||
# =====================================
|
||||
# This file defines different sandbox profiles for heterogeneous workloads.
|
||||
# Copy to modal_profiles.yaml and customize as needed.
|
||||
#
|
||||
# Usage:
|
||||
# terminal_tool("python train.py", profile="pytorch-gpu")
|
||||
# terminal_tool("npm test", profile="node")
|
||||
#
|
||||
# Each profile can specify:
|
||||
# - image: Docker image to use
|
||||
# - gpu: GPU type (null, "T4", "A10G", "A100", "H100")
|
||||
# - cpu: CPU cores (float)
|
||||
# - memory: Memory in MB
|
||||
# - min_pool: Minimum warm sandboxes (cost vs latency tradeoff)
|
||||
# - max_pool: Maximum sandboxes (hard cost cap)
|
||||
# - idle_timeout: Server-side auto-cleanup in seconds
|
||||
# - max_lifetime: Maximum sandbox lifetime in seconds
|
||||
# - scale_down_idle: Client-side scale-down threshold in seconds
|
||||
# - workdir: Working directory inside container
|
||||
# - secrets: List of Modal Secret names to inject (created via dashboard/CLI)
|
||||
# - env_vars: Dict of environment variables to pass directly
|
||||
# - use_dotenv: If true, loads local .env file into sandbox
|
||||
#
|
||||
# SECRETS SETUP:
|
||||
# Create secrets via Modal dashboard or CLI:
|
||||
# modal secret create huggingface-token HF_TOKEN=hf_xxx
|
||||
# modal secret create openai-key OPENAI_API_KEY=sk-xxx
|
||||
# Then reference by name in profile's secrets list.
|
||||
|
||||
# Default profile used when no profile specified
|
||||
default_profile: default
|
||||
|
||||
profiles:
|
||||
# Default Python environment - good for most tasks
|
||||
default:
|
||||
image: python:3.11
|
||||
gpu: null
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 1 # Keep 1 warm for fast response
|
||||
max_pool: 5
|
||||
idle_timeout: 120 # Modal terminates if idle 2 min
|
||||
max_lifetime: 3600 # Max 1 hour
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
secrets: [] # Add secret names here: ["my-api-keys"]
|
||||
env_vars: {} # Add env vars here: {DEBUG: "1"}
|
||||
use_dotenv: false # Set to true to load local .env
|
||||
|
||||
# PyTorch with GPU for ML training/inference
|
||||
pytorch-gpu:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: T4 # Options: T4, A10G, A100, H100
|
||||
cpu: 4.0
|
||||
memory: 16384 # 16GB
|
||||
min_pool: 0 # Don't keep GPU sandboxes warm (expensive!)
|
||||
max_pool: 2
|
||||
idle_timeout: 60 # Shorter idle timeout for GPU (cost)
|
||||
max_lifetime: 1800 # 30 min max for GPU tasks
|
||||
scale_down_idle: 60
|
||||
workdir: /workspace
|
||||
# ML-specific secrets
|
||||
secrets:
|
||||
- huggingface-token # HF_TOKEN env var
|
||||
- wandb-key # WANDB_API_KEY env var
|
||||
env_vars:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
|
||||
|
||||
# High-end GPU for large models
|
||||
pytorch-a100:
|
||||
image: pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime
|
||||
gpu: A100
|
||||
cpu: 8.0
|
||||
memory: 65536 # 64GB
|
||||
min_pool: 0
|
||||
max_pool: 1 # Only 1 at a time (very expensive)
|
||||
idle_timeout: 30
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 30
|
||||
workdir: /workspace
|
||||
|
||||
# Node.js for JavaScript/TypeScript tasks
|
||||
node:
|
||||
image: node:18
|
||||
gpu: null
|
||||
cpu: 1.0
|
||||
memory: 2048
|
||||
min_pool: 0 # Create on-demand
|
||||
max_pool: 3
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# High memory for data processing
|
||||
high-memory:
|
||||
image: python:3.11
|
||||
gpu: null
|
||||
cpu: 4.0
|
||||
memory: 32768 # 32GB
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# Rust development environment
|
||||
rust:
|
||||
image: rust:1.75
|
||||
gpu: null
|
||||
cpu: 2.0
|
||||
memory: 4096
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
|
||||
# Go development environment
|
||||
golang:
|
||||
image: golang:1.21
|
||||
gpu: null
|
||||
cpu: 2.0
|
||||
memory: 4096
|
||||
min_pool: 0
|
||||
max_pool: 2
|
||||
idle_timeout: 120
|
||||
max_lifetime: 3600
|
||||
scale_down_idle: 180
|
||||
workdir: /workspace
|
||||
+35
-100
@@ -41,7 +41,7 @@ from tools.terminal_hecate import terminal_hecate_tool, check_hecate_requirement
|
||||
from tools.vision_tools import vision_analyze_tool, check_vision_requirements
|
||||
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
|
||||
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
|
||||
from tools.skills_tool import skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION
|
||||
from tools.skills_tool import skills_categories, skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from tools.rl_training_tool import (
|
||||
rl_list_environments,
|
||||
@@ -83,8 +83,6 @@ from tools.browser_tool import (
|
||||
check_browser_requirements,
|
||||
BROWSER_TOOL_SCHEMAS
|
||||
)
|
||||
# Text-to-speech tool (Edge TTS / ElevenLabs / OpenAI)
|
||||
from tools.tts_tool import text_to_speech_tool, check_tts_requirements
|
||||
from toolsets import (
|
||||
get_toolset, resolve_toolset, resolve_multiple_toolsets,
|
||||
get_all_toolsets, get_toolset_names, validate_toolset,
|
||||
@@ -145,7 +143,7 @@ TOOLSET_REQUIREMENTS = {
|
||||
"env_vars": [], # Just needs skills directory
|
||||
"check_fn": check_skills_requirements,
|
||||
"setup_url": None,
|
||||
"tools": ["skills_list", "skill_view"],
|
||||
"tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
},
|
||||
"rl": {
|
||||
"name": "RL Training (Tinker-Atropos)",
|
||||
@@ -167,13 +165,6 @@ TOOLSET_REQUIREMENTS = {
|
||||
"setup_url": None,
|
||||
"tools": ["read_file", "write_file", "patch", "search"],
|
||||
},
|
||||
"tts": {
|
||||
"name": "Text-to-Speech",
|
||||
"env_vars": [], # Edge TTS needs no key; premium providers checked at runtime
|
||||
"check_fn": check_tts_requirements,
|
||||
"setup_url": None,
|
||||
"tools": ["text_to_speech"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -401,7 +392,7 @@ def get_image_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "image_generate",
|
||||
"description": "Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL. Display it using markdown: ",
|
||||
"description": "Generate high-quality images from text prompts using FLUX 2 Pro model with automatic 2x upscaling. Creates detailed, artistic images that are automatically upscaled for hi-rez results. Returns a single upscaled image URL that can be displayed using <img src=\"{URL}\"></img> tags.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -441,7 +432,24 @@ def get_skills_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"properties": {
|
||||
"category": {
|
||||
"type": "string",
|
||||
"description": "Optional category filter to narrow results"
|
||||
"description": "Optional category filter (from skills_categories)"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "skills_categories",
|
||||
"description": "List available skill categories. Call this first to discover what skill categories exist, then use skills_list(category) to see skills in a category.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"verbose": {
|
||||
"type": "boolean",
|
||||
"description": "If true, include skill counts per category. Default: false."
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
@@ -871,38 +879,6 @@ def get_file_tool_definitions() -> List[Dict[str, Any]]:
|
||||
]
|
||||
|
||||
|
||||
def get_tts_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for text-to-speech tools in OpenAI's expected format.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of TTS tool definitions compatible with OpenAI API
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "text_to_speech",
|
||||
"description": "Convert text to speech audio. Returns a MEDIA: path that the platform delivers as a voice message. On Telegram it plays as a voice bubble, on Discord/WhatsApp as an audio attachment. In CLI mode, saves to ~/voice-memos/. Voice and provider are user-configured, not model-selected.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to convert to speech. Keep under 4000 characters."
|
||||
},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "Optional custom file path to save the audio. Defaults to ~/voice-memos/<timestamp>.mp3"
|
||||
}
|
||||
},
|
||||
"required": ["text"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def get_all_tool_names() -> List[str]:
|
||||
"""
|
||||
Get the names of all available tools across all toolsets.
|
||||
@@ -934,7 +910,7 @@ def get_all_tool_names() -> List[str]:
|
||||
|
||||
# Skills tools
|
||||
if check_skills_requirements():
|
||||
tool_names.extend(["skills_list", "skill_view"])
|
||||
tool_names.extend(["skills_categories", "skills_list", "skill_view"])
|
||||
|
||||
# Browser automation tools
|
||||
if check_browser_requirements():
|
||||
@@ -967,10 +943,6 @@ def get_all_tool_names() -> List[str]:
|
||||
"read_file", "write_file", "patch", "search"
|
||||
])
|
||||
|
||||
# Text-to-speech tools
|
||||
if check_tts_requirements():
|
||||
tool_names.extend(["text_to_speech"])
|
||||
|
||||
return tool_names
|
||||
|
||||
|
||||
@@ -985,6 +957,7 @@ TOOL_TO_TOOLSET_MAP = {
|
||||
"mixture_of_agents": "moa_tools",
|
||||
"image_generate": "image_tools",
|
||||
# Skills tools
|
||||
"skills_categories": "skills_tools",
|
||||
"skills_list": "skills_tools",
|
||||
"skill_view": "skills_tools",
|
||||
# Browser automation tools
|
||||
@@ -1012,8 +985,6 @@ TOOL_TO_TOOLSET_MAP = {
|
||||
"rl_stop_training": "rl_tools",
|
||||
"rl_get_results": "rl_tools",
|
||||
"rl_list_runs": "rl_tools",
|
||||
# Text-to-speech tools
|
||||
"text_to_speech": "tts_tools",
|
||||
# File manipulation tools
|
||||
"read_file": "file_tools",
|
||||
"write_file": "file_tools",
|
||||
@@ -1117,11 +1088,6 @@ def get_tool_definitions(
|
||||
for tool in get_file_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# Text-to-speech tools
|
||||
if check_tts_requirements():
|
||||
for tool in get_tts_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# Determine which tools to include based on toolsets
|
||||
tools_to_include = set()
|
||||
|
||||
@@ -1143,7 +1109,7 @@ def get_tool_definitions(
|
||||
"vision_tools": ["vision_analyze"],
|
||||
"moa_tools": ["mixture_of_agents"],
|
||||
"image_tools": ["image_generate"],
|
||||
"skills_tools": ["skills_list", "skill_view"],
|
||||
"skills_tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
"browser_tools": [
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
@@ -1158,8 +1124,7 @@ def get_tool_definitions(
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
],
|
||||
"file_tools": ["read_file", "write_file", "patch", "search"],
|
||||
"tts_tools": ["text_to_speech"]
|
||||
"file_tools": ["read_file", "write_file", "patch", "search"]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.update(legacy_tools)
|
||||
@@ -1197,7 +1162,7 @@ def get_tool_definitions(
|
||||
"vision_tools": ["vision_analyze"],
|
||||
"moa_tools": ["mixture_of_agents"],
|
||||
"image_tools": ["image_generate"],
|
||||
"skills_tools": ["skills_list", "skill_view"],
|
||||
"skills_tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
"browser_tools": [
|
||||
"browser_navigate", "browser_snapshot", "browser_click",
|
||||
"browser_type", "browser_scroll", "browser_back",
|
||||
@@ -1212,8 +1177,7 @@ def get_tool_definitions(
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
],
|
||||
"file_tools": ["read_file", "write_file", "patch", "search"],
|
||||
"tts_tools": ["text_to_speech"]
|
||||
"file_tools": ["read_file", "write_file", "patch", "search"]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.difference_update(legacy_tools)
|
||||
@@ -1427,7 +1391,11 @@ def handle_skills_function_call(function_name: str, function_args: Dict[str, Any
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
if function_name == "skills_list":
|
||||
if function_name == "skills_categories":
|
||||
verbose = function_args.get("verbose", False)
|
||||
return skills_categories(verbose=verbose)
|
||||
|
||||
elif function_name == "skills_list":
|
||||
category = function_args.get("category")
|
||||
return skills_list(category=category)
|
||||
|
||||
@@ -1671,28 +1639,6 @@ def handle_file_function_call(
|
||||
return json.dumps({"error": f"Unknown file function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_tts_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
Handle function calls for text-to-speech tools.
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the TTS function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
if function_name == "text_to_speech":
|
||||
text = function_args.get("text", "")
|
||||
output_path = function_args.get("output_path")
|
||||
return text_to_speech_tool(text=text, output_path=output_path)
|
||||
|
||||
return json.dumps({"error": f"Unknown TTS function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
@@ -1740,7 +1686,7 @@ def handle_function_call(
|
||||
return handle_image_function_call(function_name, function_args)
|
||||
|
||||
# Route skills tools
|
||||
elif function_name in ["skills_list", "skill_view"]:
|
||||
elif function_name in ["skills_categories", "skills_list", "skill_view"]:
|
||||
return handle_skills_function_call(function_name, function_args)
|
||||
|
||||
# Route browser automation tools
|
||||
@@ -1770,10 +1716,6 @@ def handle_function_call(
|
||||
elif function_name in ["read_file", "write_file", "patch", "search"]:
|
||||
return handle_file_function_call(function_name, function_args, task_id)
|
||||
|
||||
# Route text-to-speech tools
|
||||
elif function_name in ["text_to_speech"]:
|
||||
return handle_tts_function_call(function_name, function_args)
|
||||
|
||||
else:
|
||||
error_msg = f"Unknown function: {function_name}"
|
||||
print(f"❌ {error_msg}")
|
||||
@@ -1825,7 +1767,7 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
},
|
||||
"skills_tools": {
|
||||
"available": check_skills_requirements(),
|
||||
"tools": ["skills_list", "skill_view"],
|
||||
"tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
"description": "Access skill documents that provide specialized instructions, guidelines, or knowledge the agent can load on demand",
|
||||
"requirements": ["skills/ directory in repo root"]
|
||||
},
|
||||
@@ -1851,12 +1793,6 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
||||
"tools": ["read_file", "write_file", "patch", "search"],
|
||||
"description": "File manipulation tools: read/write files, search content/files, patch with fuzzy matching",
|
||||
"requirements": ["Terminal backend available (local/docker/ssh/singularity/modal)"]
|
||||
},
|
||||
"tts_tools": {
|
||||
"available": check_tts_requirements(),
|
||||
"tools": ["text_to_speech"],
|
||||
"description": "Text-to-speech: convert text to audio (Edge TTS free, ElevenLabs, OpenAI)",
|
||||
"requirements": ["edge-tts package (free) or ELEVENLABS_API_KEY or OPENAI_API_KEY"]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1878,8 +1814,7 @@ def check_toolset_requirements() -> Dict[str, bool]:
|
||||
"skills_tools": check_skills_requirements(),
|
||||
"browser_tools": check_browser_requirements(),
|
||||
"cronjob_tools": check_cronjob_requirements(),
|
||||
"file_tools": check_file_requirements(),
|
||||
"tts_tools": check_tts_requirements()
|
||||
"file_tools": check_file_requirements()
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
# Nomad Development Configuration (Hermes-Agent)
|
||||
# Run with: nomad agent -dev -config=nomad-dev.hcl
|
||||
#
|
||||
# This is intended for local development only.
|
||||
|
||||
client {
|
||||
enabled = true
|
||||
|
||||
options {
|
||||
# Enable Docker volume mounts for persistent slot workspaces
|
||||
"docker.volumes.enabled" = "true"
|
||||
}
|
||||
}
|
||||
|
||||
# Docker driver plugin configuration
|
||||
plugin "docker" {
|
||||
config {
|
||||
# CRITICAL: Enable volume mounts
|
||||
volumes {
|
||||
enabled = true
|
||||
}
|
||||
|
||||
# Allow privileged containers if needed
|
||||
allow_privileged = false
|
||||
|
||||
# Garbage collection settings
|
||||
gc {
|
||||
image = true
|
||||
# NOTE: For local dev we often rely on locally built images like `atropos-sandbox:local`.
|
||||
# A short image GC delay can delete these between runs, causing confusing "Failed to pull"
|
||||
# crash loops. Keep this comfortably long; tighten it for CI/production if needed.
|
||||
image_delay = "24h"
|
||||
container = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
# Nomad Configuration for Singularity/Apptainer Sandbox
|
||||
# Run with: nomad agent -dev -config=nomad-singularity.hcl
|
||||
#
|
||||
# This uses the raw_exec driver to run Apptainer containers.
|
||||
# Suitable for HPC environments where Docker cannot run without sudo.
|
||||
|
||||
client {
|
||||
enabled = true
|
||||
|
||||
options {
|
||||
# Enable raw_exec driver for Singularity/Apptainer
|
||||
"driver.raw_exec.enable" = "1"
|
||||
}
|
||||
}
|
||||
|
||||
# raw_exec driver plugin configuration
|
||||
plugin "raw_exec" {
|
||||
config {
|
||||
enabled = true
|
||||
}
|
||||
}
|
||||
|
||||
# Optional: If you have the nomad-driver-singularity plugin installed,
|
||||
# uncomment the following instead of using raw_exec:
|
||||
# plugin "singularity" {
|
||||
# config {
|
||||
# enabled = true
|
||||
# # Allow bind mounts
|
||||
# bind_paths = ["/tmp", "/var/tmp"]
|
||||
# }
|
||||
# }
|
||||
+28
-2
@@ -19,6 +19,7 @@ dependencies = [
|
||||
"rich",
|
||||
"tenacity",
|
||||
"pyyaml",
|
||||
"prompt_toolkit",
|
||||
"requests",
|
||||
"jinja2",
|
||||
"pydantic>=2.0",
|
||||
@@ -39,6 +40,19 @@ dev = ["pytest", "pytest-asyncio"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0"]
|
||||
cron = ["croniter"]
|
||||
cli = ["simple-term-menu"]
|
||||
# Install Atropos + Tinker training integration from source.
|
||||
atropos = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git",
|
||||
"tinker @ git+https://github.com/thinking-machines-lab/tinker.git",
|
||||
# Atropos integration runtime deps (kept optional for Hermes-only users)
|
||||
"aiohttp",
|
||||
"fastapi",
|
||||
"uvicorn",
|
||||
"pyte",
|
||||
"torch",
|
||||
"wandb",
|
||||
"math-verify",
|
||||
]
|
||||
all = [
|
||||
"hermes-agent[modal]",
|
||||
"hermes-agent[messaging]",
|
||||
@@ -50,9 +64,21 @@ all = [
|
||||
[project.scripts]
|
||||
hermes = "hermes_cli.main:main"
|
||||
hermes-agent = "run_agent:main"
|
||||
hermes-atropos-sandbox-smoke = "atropos.envs.sandbox_terminal_smoke_env:SandboxTerminalSmokeEnv.cli"
|
||||
hermes-atropos-toolserver-smoke = "atropos.envs.toolserver_smoke_env:ToolServerSmokeEnv.cli"
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions", "cli"]
|
||||
py-modules = [
|
||||
"run_agent",
|
||||
"model_tools",
|
||||
"toolsets",
|
||||
"batch_runner",
|
||||
"trajectory_compressor",
|
||||
"toolset_distributions",
|
||||
"atropos_compatible_agent",
|
||||
"local_server",
|
||||
"cli",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron"]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron", "atropos", "atropos.*"]
|
||||
|
||||
@@ -29,12 +29,6 @@ platformdirs
|
||||
# Optional: For Modal backend (cloud execution)
|
||||
# swe-rex[modal]>=1.4.0 # Includes modal + boto3 + swe-rex runtime
|
||||
|
||||
# Text-to-speech (Edge TTS is free, no API key needed)
|
||||
edge-tts
|
||||
|
||||
# Optional: Premium TTS providers
|
||||
# elevenlabs # Uncomment if using ElevenLabs TTS (needs ELEVENLABS_API_KEY)
|
||||
|
||||
# Optional: For cron expression parsing (cronjob scheduling)
|
||||
croniter
|
||||
|
||||
|
||||
+73
-571
@@ -20,7 +20,6 @@ Usage:
|
||||
response = agent.run_conversation("Tell me about the latest Python updates")
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -31,7 +30,6 @@ import threading
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI
|
||||
import fire
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -49,46 +47,11 @@ elif not os.getenv("HERMES_QUIET"):
|
||||
|
||||
# Import our tool system
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
from tools.terminal_tool import cleanup_vm, set_interrupt_event as _set_terminal_interrupt
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.browser_tool import cleanup_browser
|
||||
|
||||
import requests
|
||||
|
||||
# =============================================================================
|
||||
# Default Agent Identity & Platform Hints
|
||||
# =============================================================================
|
||||
|
||||
# The default identity prompt is prepended to every conversation so the agent
|
||||
# knows who it is and behaves consistently across platforms.
|
||||
DEFAULT_AGENT_IDENTITY = (
|
||||
"You are Hermes Agent, an intelligent AI assistant created by Nous Research. "
|
||||
"You are helpful, knowledgeable, and direct. You assist users with a wide "
|
||||
"range of tasks including answering questions, writing and editing code, "
|
||||
"analyzing information, creative work, and executing actions via your tools. "
|
||||
"You communicate clearly, admit uncertainty when appropriate, and prioritize "
|
||||
"being genuinely useful over being verbose unless otherwise directed below."
|
||||
)
|
||||
|
||||
# Platform-specific formatting hints appended to the system prompt.
|
||||
# These tell the agent how to format its output for the current interface.
|
||||
PLATFORM_HINTS = {
|
||||
"whatsapp": (
|
||||
"You are on a text messaging communication platform, WhatsApp. "
|
||||
"Please do not use markdown as it does not render."
|
||||
),
|
||||
"telegram": (
|
||||
"You are on a text messaging communication platform, Telegram. "
|
||||
"Please do not use markdown as it does not render."
|
||||
),
|
||||
"discord": (
|
||||
"You are in a Discord server or group chat communicating with your user."
|
||||
),
|
||||
"cli": (
|
||||
"You are a CLI AI Agent. Try not to use markdown but simple text "
|
||||
"renderable inside a terminal."
|
||||
),
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Model Context Management
|
||||
# =============================================================================
|
||||
@@ -493,389 +456,18 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
return compressed
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Anthropic Prompt Caching (system_and_3 strategy)
|
||||
# =============================================================================
|
||||
# Reduces input token costs by ~75% on multi-turn conversations by caching
|
||||
# the conversation prefix. Uses 4 cache_control breakpoints (Anthropic max):
|
||||
# 1. System prompt (stable across all turns)
|
||||
# 2-4. Last 3 non-system messages (rolling window)
|
||||
#
|
||||
# Cached tokens are read at 0.1x input price. Cache writes cost 1.25x (5m TTL)
|
||||
# or 2x (1h TTL). Only applied to Claude models via OpenRouter.
|
||||
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
"""
|
||||
Add cache_control to a single message, handling all format variations.
|
||||
|
||||
- tool messages: cache_control at message level (Anthropic API quirk)
|
||||
- string content: converted to multipart content array
|
||||
- list content: marker added to last item
|
||||
- None content (assistant with tool_calls): message level
|
||||
"""
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "tool":
|
||||
msg["cache_control"] = cache_marker
|
||||
return
|
||||
|
||||
if content is None:
|
||||
msg["cache_control"] = cache_marker
|
||||
return
|
||||
|
||||
if isinstance(content, str):
|
||||
msg["content"] = [{"type": "text", "text": content, "cache_control": cache_marker}]
|
||||
return
|
||||
|
||||
if isinstance(content, list) and content:
|
||||
last = content[-1]
|
||||
if isinstance(last, dict):
|
||||
last["cache_control"] = cache_marker
|
||||
|
||||
|
||||
def apply_anthropic_cache_control(
|
||||
api_messages: List[Dict[str, Any]],
|
||||
cache_ttl: str = "5m",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply system_and_3 caching strategy to messages for Anthropic models.
|
||||
|
||||
Places up to 4 cache_control breakpoints:
|
||||
1. System prompt (index 0, stable across all turns)
|
||||
2-4. Last 3 non-system messages (rolling cache frontier)
|
||||
|
||||
Each breakpoint tells Anthropic "cache everything from the start up to here."
|
||||
Multiple breakpoints create a ladder of cached prefixes at different depths,
|
||||
which provides robust cache hits even when the most recent cache entry hasn't
|
||||
propagated yet.
|
||||
|
||||
Args:
|
||||
api_messages: Fully assembled message list (system prompt first).
|
||||
cache_ttl: "5m" (default, 1.25x write cost) or "1h" (2x write cost).
|
||||
|
||||
Returns:
|
||||
Deep copy of messages with cache_control breakpoints injected.
|
||||
"""
|
||||
messages = copy.deepcopy(api_messages)
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
marker = {"type": "ephemeral"}
|
||||
if cache_ttl == "1h":
|
||||
marker["ttl"] = "1h"
|
||||
|
||||
breakpoints_used = 0
|
||||
|
||||
# Breakpoint 1: System prompt (always stable, gives a guaranteed minimum hit)
|
||||
if messages[0].get("role") == "system":
|
||||
_apply_cache_marker(messages[0], marker)
|
||||
breakpoints_used += 1
|
||||
|
||||
# Breakpoints 2-4: Last 3 non-system messages (rolling window)
|
||||
remaining = 4 - breakpoints_used
|
||||
non_sys = [i for i in range(len(messages)) if messages[i].get("role") != "system"]
|
||||
for idx in non_sys[-remaining:]:
|
||||
_apply_cache_marker(messages[idx], marker)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Default System Prompt Components
|
||||
# =============================================================================
|
||||
|
||||
# Skills guidance - embeds a compact skill index in the system prompt so
|
||||
# the model can match skills at a glance without extra tool calls.
|
||||
def build_skills_system_prompt() -> str:
|
||||
"""
|
||||
Build a dynamic skills system prompt by scanning the skills/ directory.
|
||||
|
||||
Returns a prompt section that lists all skill categories (with descriptions
|
||||
from DESCRIPTION.md) and their skill names inline, so the model can
|
||||
immediately see if a relevant skill exists and load it with a single
|
||||
skill_view(name) call -- no discovery tool calls needed.
|
||||
|
||||
Returns:
|
||||
str: The skills system prompt section, or empty string if no skills found.
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
skills_dir = Path(__file__).parent / "skills"
|
||||
if not skills_dir.exists():
|
||||
return ""
|
||||
|
||||
# Scan for SKILL.md files grouped by category
|
||||
skills_by_category = {}
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
category = parts[0]
|
||||
skill_name = parts[-2] # Folder containing SKILL.md
|
||||
else:
|
||||
category = "general"
|
||||
skill_name = skill_file.parent.name
|
||||
skills_by_category.setdefault(category, []).append(skill_name)
|
||||
|
||||
if not skills_by_category:
|
||||
return ""
|
||||
|
||||
# Load category descriptions from DESCRIPTION.md files (YAML frontmatter)
|
||||
category_descriptions = {}
|
||||
for category in skills_by_category:
|
||||
desc_file = skills_dir / category / "DESCRIPTION.md"
|
||||
if desc_file.exists():
|
||||
try:
|
||||
content = desc_file.read_text(encoding="utf-8")
|
||||
# Parse description from YAML frontmatter: ---\ndescription: ...\n---
|
||||
match = re.search(r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---", content, re.MULTILINE | re.DOTALL)
|
||||
if match:
|
||||
category_descriptions[category] = match.group(1).strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build compact index: category with description + skill names
|
||||
index_lines = []
|
||||
for category in sorted(skills_by_category.keys()):
|
||||
desc = category_descriptions.get(category, "")
|
||||
names = ", ".join(sorted(skills_by_category[category]))
|
||||
if desc:
|
||||
index_lines.append(f" {category}: {desc}")
|
||||
else:
|
||||
index_lines.append(f" {category}:")
|
||||
index_lines.append(f" skills: {names}")
|
||||
|
||||
return (
|
||||
"## Skills (mandatory)\n"
|
||||
"Before replying, scan the skills below. If one clearly matches your task, "
|
||||
"load it with skill_view(name) and follow its instructions.\n"
|
||||
"\n"
|
||||
"<available_skills>\n"
|
||||
+ "\n".join(index_lines) + "\n"
|
||||
"</available_skills>\n"
|
||||
"\n"
|
||||
"If none match, proceed normally without loading a skill."
|
||||
)
|
||||
# Skills guidance - instructs the model to check skills before technical tasks
|
||||
SKILLS_SYSTEM_PROMPT = """## Skills
|
||||
Before answering technical questions about tools, frameworks, or workflows:
|
||||
1. Check skills_categories to see if a relevant category exists
|
||||
2. If a category matches your task, use skills_list with that category
|
||||
3. If a skill matches, load it with skill_view and follow its instructions
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Context File Injection (SOUL.md, AGENTS.md, .cursorrules)
|
||||
# =============================================================================
|
||||
|
||||
# Maximum characters per context file before truncation
|
||||
CONTEXT_FILE_MAX_CHARS = 20_000
|
||||
# Truncation strategy: keep 70% from the head, 20% from the tail
|
||||
CONTEXT_TRUNCATE_HEAD_RATIO = 0.7
|
||||
CONTEXT_TRUNCATE_TAIL_RATIO = 0.2
|
||||
|
||||
|
||||
def _truncate_content(content: str, filename: str, max_chars: int = CONTEXT_FILE_MAX_CHARS) -> str:
|
||||
"""
|
||||
Truncate content if it exceeds max_chars using a head/tail strategy.
|
||||
|
||||
Keeps 70% from the start and 20% from the end, with a truncation
|
||||
marker in the middle so the model knows content was cut.
|
||||
"""
|
||||
if len(content) <= max_chars:
|
||||
return content
|
||||
|
||||
head_chars = int(max_chars * CONTEXT_TRUNCATE_HEAD_RATIO)
|
||||
tail_chars = int(max_chars * CONTEXT_TRUNCATE_TAIL_RATIO)
|
||||
head = content[:head_chars]
|
||||
tail = content[-tail_chars:]
|
||||
|
||||
marker = f"\n\n[...truncated {filename}: kept {head_chars}+{tail_chars} of {len(content)} chars. Use file tools to read the full file.]\n\n"
|
||||
return head + marker + tail
|
||||
|
||||
|
||||
def build_context_files_prompt(cwd: str = None) -> str:
|
||||
"""
|
||||
Discover and load context files (SOUL.md, AGENTS.md, .cursorrules)
|
||||
for injection into the system prompt.
|
||||
|
||||
Discovery rules:
|
||||
- AGENTS.md: Recursively search from cwd (only if top-level exists).
|
||||
Each file becomes a ## section with its relative path.
|
||||
- .cursorrules: Check cwd for .cursorrules file and .cursor/rules/*.mdc
|
||||
- SOUL.md: Check cwd first, then ~/.hermes/SOUL.md as global fallback
|
||||
|
||||
Args:
|
||||
cwd: Working directory to search from. Defaults to os.getcwd().
|
||||
|
||||
Returns:
|
||||
str: The context files prompt section, or empty string if none found.
|
||||
"""
|
||||
import os
|
||||
import glob as glob_mod
|
||||
from pathlib import Path
|
||||
|
||||
if cwd is None:
|
||||
cwd = os.getcwd()
|
||||
|
||||
cwd_path = Path(cwd).resolve()
|
||||
sections = []
|
||||
|
||||
# ----- AGENTS.md (hierarchical, recursive) -----
|
||||
top_level_agents = None
|
||||
for name in ["AGENTS.md", "agents.md"]:
|
||||
candidate = cwd_path / name
|
||||
if candidate.exists():
|
||||
top_level_agents = candidate
|
||||
break
|
||||
|
||||
if top_level_agents:
|
||||
# Recursively find all AGENTS.md files (case-insensitive)
|
||||
agents_files = []
|
||||
for root, dirs, files in os.walk(cwd_path):
|
||||
# Skip hidden directories and common non-project dirs
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.') and d not in ('node_modules', '__pycache__', 'venv', '.venv')]
|
||||
for f in files:
|
||||
if f.lower() == "agents.md":
|
||||
agents_files.append(Path(root) / f)
|
||||
|
||||
# Sort by path depth (top-level first, then deeper)
|
||||
agents_files.sort(key=lambda p: len(p.parts))
|
||||
|
||||
total_agents_content = ""
|
||||
for agents_path in agents_files:
|
||||
try:
|
||||
content = agents_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
rel_path = agents_path.relative_to(cwd_path)
|
||||
total_agents_content += f"## {rel_path}\n\n{content}\n\n"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if total_agents_content:
|
||||
total_agents_content = _truncate_content(total_agents_content, "AGENTS.md")
|
||||
sections.append(total_agents_content)
|
||||
|
||||
# ----- .cursorrules -----
|
||||
cursorrules_content = ""
|
||||
|
||||
# Check for .cursorrules file
|
||||
cursorrules_file = cwd_path / ".cursorrules"
|
||||
if cursorrules_file.exists():
|
||||
try:
|
||||
content = cursorrules_file.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
cursorrules_content += f"## .cursorrules\n\n{content}\n\n"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check for .cursor/rules/*.mdc files
|
||||
cursor_rules_dir = cwd_path / ".cursor" / "rules"
|
||||
if cursor_rules_dir.exists() and cursor_rules_dir.is_dir():
|
||||
mdc_files = sorted(cursor_rules_dir.glob("*.mdc"))
|
||||
for mdc_file in mdc_files:
|
||||
try:
|
||||
content = mdc_file.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
cursorrules_content += f"## .cursor/rules/{mdc_file.name}\n\n{content}\n\n"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if cursorrules_content:
|
||||
cursorrules_content = _truncate_content(cursorrules_content, ".cursorrules")
|
||||
sections.append(cursorrules_content)
|
||||
|
||||
# ----- SOUL.md (cwd first, then ~/.hermes/ fallback) -----
|
||||
soul_content = ""
|
||||
soul_path = None
|
||||
|
||||
for name in ["SOUL.md", "soul.md"]:
|
||||
candidate = cwd_path / name
|
||||
if candidate.exists():
|
||||
soul_path = candidate
|
||||
break
|
||||
|
||||
if not soul_path:
|
||||
# Global fallback
|
||||
global_soul = Path.home() / ".hermes" / "SOUL.md"
|
||||
if global_soul.exists():
|
||||
soul_path = global_soul
|
||||
|
||||
if soul_path:
|
||||
try:
|
||||
content = soul_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
content = _truncate_content(content, "SOUL.md")
|
||||
soul_content = f"## SOUL.md\n\nIf SOUL.md is present, embody its persona and tone. Avoid stiff, generic replies; follow its guidance unless higher-priority instructions override it.\n\n{content}"
|
||||
sections.append(soul_content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ----- Assemble -----
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return "# Project Context\n\nThe following project context files have been loaded and should be followed:\n\n" + "\n".join(sections)
|
||||
|
||||
|
||||
def _build_tool_preview(tool_name: str, args: dict, max_len: int = 40) -> str:
|
||||
"""
|
||||
Build a short preview of a tool call's primary argument for display.
|
||||
|
||||
Returns a truncated string showing the most informative argument,
|
||||
or None if no meaningful preview is available.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool being called
|
||||
args: The tool call arguments dict
|
||||
max_len: Maximum preview length before truncation
|
||||
|
||||
Returns:
|
||||
str or None: Short preview string, or None
|
||||
"""
|
||||
# Map tool names to their primary argument key(s)
|
||||
primary_args = {
|
||||
"terminal": "command",
|
||||
"web_search": "query",
|
||||
"web_extract": "urls",
|
||||
"read_file": "path",
|
||||
"write_file": "path",
|
||||
"patch": "path",
|
||||
"search": "pattern",
|
||||
"browser_navigate": "url",
|
||||
"browser_click": "ref",
|
||||
"browser_type": "text",
|
||||
"image_generate": "prompt",
|
||||
"text_to_speech": "text",
|
||||
"vision_analyze": "question",
|
||||
"mixture_of_agents": "user_prompt",
|
||||
"skill_view": "name",
|
||||
"skills_list": "category",
|
||||
"schedule_cronjob": "name",
|
||||
}
|
||||
|
||||
key = primary_args.get(tool_name)
|
||||
if not key:
|
||||
# Try common arg names as fallback
|
||||
for fallback_key in ("query", "text", "command", "path", "name", "prompt"):
|
||||
if fallback_key in args:
|
||||
key = fallback_key
|
||||
break
|
||||
|
||||
if not key or key not in args:
|
||||
return None
|
||||
|
||||
value = args[key]
|
||||
|
||||
# Handle list values (e.g., urls)
|
||||
if isinstance(value, list):
|
||||
value = value[0] if value else ""
|
||||
|
||||
preview = str(value).strip()
|
||||
if not preview:
|
||||
return None
|
||||
|
||||
# Truncate
|
||||
if len(preview) > max_len:
|
||||
preview = preview[:max_len - 3] + "..."
|
||||
|
||||
return preview
|
||||
Skills contain vetted, up-to-date instructions for specific tools and workflows."""
|
||||
|
||||
|
||||
class KawaiiSpinner:
|
||||
@@ -1012,7 +604,6 @@ class AIAgent:
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
platform: str = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -1043,8 +634,6 @@ class AIAgent:
|
||||
prefill_messages (List[Dict]): Messages to prepend to conversation history as prefilled context.
|
||||
Useful for injecting a few-shot example or priming the model's response style.
|
||||
Example: [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hello!"}]
|
||||
platform (str): The interface platform the user is on (e.g. "cli", "telegram", "discord", "whatsapp").
|
||||
Used to inject platform-specific formatting hints into the system prompt.
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
@@ -1053,12 +642,9 @@ class AIAgent:
|
||||
self.verbose_logging = verbose_logging
|
||||
self.quiet_mode = quiet_mode
|
||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||
self.platform = platform # "cli", "telegram", "discord", "whatsapp", etc.
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
self.log_prefix = f"{log_prefix} " if log_prefix else ""
|
||||
# Store effective base URL for feature detection (prompt caching, reasoning, etc.)
|
||||
# When no base_url is provided, the client defaults to OpenRouter, so reflect that here.
|
||||
self.base_url = base_url or "https://openrouter.ai/api/v1"
|
||||
self.base_url = base_url or "" # Store for OpenRouter detection
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self._last_reported_tool = None # Track for "new tool" mode
|
||||
|
||||
@@ -1081,14 +667,6 @@ class AIAgent:
|
||||
self.reasoning_config = reasoning_config # None = use default (xhigh for OpenRouter)
|
||||
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
|
||||
|
||||
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
|
||||
# Reduces input costs by ~75% on multi-turn conversations by caching the
|
||||
# conversation prefix. Uses system_and_3 strategy (4 breakpoints).
|
||||
is_openrouter = "openrouter" in self.base_url.lower()
|
||||
is_claude = "claude" in self.model.lower()
|
||||
self._use_prompt_caching = is_openrouter and is_claude
|
||||
self._cache_ttl = "5m" # Default 5-minute TTL (1.25x write cost)
|
||||
|
||||
# Configure logging
|
||||
if self.verbose_logging:
|
||||
logging.basicConfig(
|
||||
@@ -1194,10 +772,6 @@ class AIAgent:
|
||||
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
|
||||
print(f"🔒 Ephemeral system prompt: '{prompt_preview}' (not saved to trajectories)")
|
||||
|
||||
# Show prompt caching status
|
||||
if self._use_prompt_caching and not self.quiet_mode:
|
||||
print(f"💾 Prompt caching: ENABLED (Claude via OpenRouter, {self._cache_ttl} TTL)")
|
||||
|
||||
# Session logging setup - auto-save conversation trajectories for debugging
|
||||
self.session_start = datetime.now()
|
||||
if session_id:
|
||||
@@ -1376,6 +950,10 @@ class AIAgent:
|
||||
return f"{face} 🎨 creating '{prompt}'... {time_str}"
|
||||
|
||||
# Skills - use large pool for variety
|
||||
elif tool_name == "skills_categories":
|
||||
face = random.choice(self.KAWAII_SKILL)
|
||||
return f"{face} 📚 listing categories... {time_str}"
|
||||
|
||||
elif tool_name == "skills_list":
|
||||
category = args.get("category", "skills")
|
||||
face = random.choice(self.KAWAII_SKILL)
|
||||
@@ -1386,65 +964,19 @@ class AIAgent:
|
||||
face = random.choice(self.KAWAII_SKILL)
|
||||
return f"{face} 📖 loading {name}... {time_str}"
|
||||
|
||||
# File tools
|
||||
elif tool_name == "read_file":
|
||||
path = args.get("path", "file")
|
||||
if len(path) > 30:
|
||||
path = "..." + path[-27:]
|
||||
face = random.choice(self.KAWAII_READ)
|
||||
return f"{face} 📖 reading \"{path}\" {time_str}"
|
||||
|
||||
elif tool_name == "write_file":
|
||||
path = args.get("path", "file")
|
||||
if len(path) > 30:
|
||||
path = "..." + path[-27:]
|
||||
face = random.choice(self.KAWAII_CREATE)
|
||||
return f"{face} ✍️ writing \"{path}\" {time_str}"
|
||||
|
||||
elif tool_name == "patch":
|
||||
path = args.get("path", "file")
|
||||
if path and len(path) > 30:
|
||||
path = "..." + path[-27:]
|
||||
face = random.choice(self.KAWAII_CREATE)
|
||||
return f"{face} 🔧 patching \"{path}\" {time_str}"
|
||||
|
||||
elif tool_name == "search":
|
||||
pattern = args.get("pattern", "")
|
||||
if len(pattern) > 25:
|
||||
pattern = pattern[:22] + "..."
|
||||
face = random.choice(self.KAWAII_SEARCH)
|
||||
return f"{face} 🔎 searching \"{pattern}\" {time_str}"
|
||||
|
||||
# TTS
|
||||
elif tool_name == "text_to_speech":
|
||||
text = args.get("text", "")
|
||||
if len(text) > 25:
|
||||
text = text[:22] + "..."
|
||||
face = random.choice(self.KAWAII_CREATE)
|
||||
return f"{face} 🔊 speaking \"{text}\" {time_str}"
|
||||
|
||||
# Vision tools
|
||||
elif tool_name == "vision_analyze":
|
||||
question = args.get("question", "")
|
||||
if len(question) > 25:
|
||||
question = question[:22] + "..."
|
||||
face = random.choice(self.KAWAII_BROWSER)
|
||||
return f"{face} 👁️✨ analyzing \"{question}\" {time_str}"
|
||||
return f"{face} 👁️✨ analyzing image... {time_str}"
|
||||
|
||||
# Mixture of agents
|
||||
elif tool_name == "mixture_of_agents":
|
||||
prompt = args.get("user_prompt", "")
|
||||
if len(prompt) > 25:
|
||||
prompt = prompt[:22] + "..."
|
||||
face = random.choice(self.KAWAII_THINK)
|
||||
return f"{face} 🧠💭 deep thinking \"{prompt}\" {time_str}"
|
||||
return f"{face} 🧠💭 thinking REALLY hard... {time_str}"
|
||||
|
||||
# Default fallback - random generic kawaii with primary arg preview
|
||||
# Default fallback - random generic kawaii
|
||||
else:
|
||||
face = random.choice(self.KAWAII_GENERIC)
|
||||
preview = _build_tool_preview(tool_name, args)
|
||||
if preview:
|
||||
return f"{face} ⚡ {tool_name}... \"{preview}\" {time_str}"
|
||||
return f"{face} ⚡ {tool_name}... {time_str}"
|
||||
|
||||
def _has_content_after_think_block(self, content: str) -> bool:
|
||||
@@ -1913,9 +1445,6 @@ class AIAgent:
|
||||
Call this from another thread (e.g., input handler, message receiver)
|
||||
to gracefully stop the agent and process a new message.
|
||||
|
||||
Also signals long-running tool executions (e.g. terminal commands)
|
||||
to terminate early, so the agent can respond immediately.
|
||||
|
||||
Args:
|
||||
message: Optional new message that triggered the interrupt.
|
||||
If provided, the agent will include this in its response context.
|
||||
@@ -1932,8 +1461,6 @@ class AIAgent:
|
||||
"""
|
||||
self._interrupt_requested = True
|
||||
self._interrupt_message = message
|
||||
# Signal the terminal tool to kill any running subprocess immediately
|
||||
_set_terminal_interrupt(True)
|
||||
if not self.quiet_mode:
|
||||
print(f"\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else ""))
|
||||
|
||||
@@ -1941,7 +1468,6 @@ class AIAgent:
|
||||
"""Clear any pending interrupt request."""
|
||||
self._interrupt_requested = False
|
||||
self._interrupt_message = None
|
||||
_set_terminal_interrupt(False)
|
||||
|
||||
@property
|
||||
def is_interrupted(self) -> bool:
|
||||
@@ -1994,46 +1520,20 @@ class AIAgent:
|
||||
if not self.quiet_mode:
|
||||
print(f"💬 Starting conversation: '{user_message[:60]}{'...' if len(user_message) > 60 else ''}'")
|
||||
|
||||
# ── Build the full system prompt ──
|
||||
# Layers (in order):
|
||||
# 1. Default agent identity (always present)
|
||||
# 2. User / gateway system prompt (if provided)
|
||||
# 3. Skills guidance (if skills tools are loaded)
|
||||
# 4. Context files (SOUL.md, AGENTS.md, .cursorrules)
|
||||
# 5. Current date & time
|
||||
# 6. Platform-specific formatting hint
|
||||
prompt_parts = [DEFAULT_AGENT_IDENTITY]
|
||||
|
||||
# Layer in the caller-supplied system prompt (explicit > ephemeral).
|
||||
caller_prompt = system_message if system_message is not None else self.ephemeral_system_prompt
|
||||
if caller_prompt:
|
||||
prompt_parts.append(caller_prompt)
|
||||
|
||||
# Auto-include skills guidance if skills tools are available.
|
||||
has_skills_tools = any(name in self.valid_tool_names for name in ['skills_list', 'skill_view'])
|
||||
skills_prompt = build_skills_system_prompt() if has_skills_tools else ""
|
||||
if skills_prompt:
|
||||
prompt_parts.append(skills_prompt)
|
||||
|
||||
# Auto-include context files (SOUL.md, AGENTS.md, .cursorrules).
|
||||
context_files_prompt = build_context_files_prompt()
|
||||
if context_files_prompt:
|
||||
prompt_parts.append(context_files_prompt)
|
||||
|
||||
# Current local date and time so the model is never confused about
|
||||
# what day/time it is (LLM training cutoffs can otherwise mislead it).
|
||||
now = datetime.now()
|
||||
prompt_parts.append(
|
||||
f"Current local date and time: {now.strftime('%A, %B %d, %Y %I:%M %p')}"
|
||||
)
|
||||
|
||||
# Platform-specific formatting hint (no markdown on WhatsApp, etc.).
|
||||
platform_key = (self.platform or "").lower().strip()
|
||||
if platform_key in PLATFORM_HINTS:
|
||||
prompt_parts.append(PLATFORM_HINTS[platform_key])
|
||||
|
||||
active_system_prompt = "\n\n".join(prompt_parts)
|
||||
|
||||
# Determine which system prompt to use for API calls (ephemeral)
|
||||
# Priority: explicit system_message > ephemeral_system_prompt > None
|
||||
base_system_prompt = system_message if system_message is not None else self.ephemeral_system_prompt
|
||||
|
||||
# Auto-include skills guidance if skills tools are available
|
||||
has_skills_tools = any(name in self.valid_tool_names for name in ['skills_list', 'skills_categories', 'skill_view'])
|
||||
if has_skills_tools:
|
||||
if base_system_prompt:
|
||||
active_system_prompt = f"{base_system_prompt}\n\n{SKILLS_SYSTEM_PROMPT}"
|
||||
else:
|
||||
active_system_prompt = SKILLS_SYSTEM_PROMPT
|
||||
else:
|
||||
active_system_prompt = base_system_prompt
|
||||
|
||||
# Main conversation loop
|
||||
api_call_count = 0
|
||||
final_response = None
|
||||
@@ -2080,13 +1580,16 @@ class AIAgent:
|
||||
if active_system_prompt:
|
||||
# Insert system message at the beginning
|
||||
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
|
||||
|
||||
# Apply Anthropic prompt caching for Claude models via OpenRouter.
|
||||
# Auto-detected: if model name contains "claude" and base_url is OpenRouter,
|
||||
# inject cache_control breakpoints (system + last 3 messages) to reduce
|
||||
# input token costs by ~75% on multi-turn conversations.
|
||||
if self._use_prompt_caching:
|
||||
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
|
||||
|
||||
if os.getenv("HERMES_DEBUG_OPENAI_REQUEST") == "1":
|
||||
meta = {
|
||||
"model": self.model,
|
||||
"base_url": self.base_url,
|
||||
"messages": api_messages,
|
||||
"tools": self.tools if self.tools else None,
|
||||
}
|
||||
print("\n=== HERMES_DEBUG_OPENAI_REQUEST ===", flush=True)
|
||||
print(json.dumps(meta, ensure_ascii=False, indent=2)[:200_000], flush=True)
|
||||
|
||||
# Calculate approximate request size for logging
|
||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||
@@ -2100,12 +1603,13 @@ class AIAgent:
|
||||
print(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)")
|
||||
print(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}")
|
||||
else:
|
||||
# Animated thinking spinner in quiet mode
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner.start()
|
||||
# Animated thinking spinner in quiet mode (disable for wrappers/non-TTY usage)
|
||||
if os.getenv("HERMES_DISABLE_SPINNER") != "1":
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner.start()
|
||||
|
||||
# Log request details if verbose
|
||||
if self.verbose_logging:
|
||||
@@ -2165,6 +1669,14 @@ class AIAgent:
|
||||
api_kwargs["extra_body"] = extra_body
|
||||
|
||||
response = self.client.chat.completions.create(**api_kwargs)
|
||||
|
||||
if os.getenv("HERMES_DEBUG_OPENAI_RESPONSE") == "1":
|
||||
try:
|
||||
dumped = response.model_dump()
|
||||
except Exception:
|
||||
dumped = getattr(response, "__dict__", {"repr": repr(response)})
|
||||
print("\n=== HERMES_DEBUG_OPENAI_RESPONSE: ChatCompletion (raw) ===", flush=True)
|
||||
print(json.dumps(dumped, ensure_ascii=False, indent=2), flush=True)
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
@@ -2317,16 +1829,6 @@ class AIAgent:
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Token usage: prompt={usage_dict['prompt_tokens']:,}, completion={usage_dict['completion_tokens']:,}, total={usage_dict['total_tokens']:,}")
|
||||
|
||||
# Log cache hit stats when prompt caching is active
|
||||
if self._use_prompt_caching:
|
||||
details = getattr(response.usage, 'prompt_tokens_details', None)
|
||||
cached = getattr(details, 'cached_tokens', 0) or 0 if details else 0
|
||||
written = getattr(details, 'cache_write_tokens', 0) or 0 if details else 0
|
||||
prompt = usage_dict["prompt_tokens"]
|
||||
hit_pct = (cached / prompt * 100) if prompt > 0 else 0
|
||||
if not self.quiet_mode:
|
||||
print(f"{self.log_prefix} 💾 Cache: {cached:,}/{prompt:,} tokens ({hit_pct:.0f}% hit, {written:,} written)")
|
||||
|
||||
break # Success, exit retry loop
|
||||
|
||||
@@ -2640,8 +2142,12 @@ class AIAgent:
|
||||
# Fire progress callback if registered (for messaging platforms)
|
||||
if self.tool_progress_callback:
|
||||
try:
|
||||
# Build a short preview of the primary argument
|
||||
preview = _build_tool_preview(function_name, function_args)
|
||||
# Build preview for terminal commands
|
||||
if function_name == "terminal":
|
||||
cmd = function_args.get("command", "")
|
||||
preview = cmd[:50] + "..." if len(cmd) > 50 else cmd
|
||||
else:
|
||||
preview = None
|
||||
self.tool_progress_callback(function_name, preview)
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
@@ -2649,7 +2155,7 @@ class AIAgent:
|
||||
tool_start_time = time.time()
|
||||
|
||||
# Execute the tool - with animated spinner in quiet mode
|
||||
if self.quiet_mode:
|
||||
if self.quiet_mode and os.getenv("HERMES_DISABLE_SPINNER") != "1":
|
||||
# Tool-specific spinner animations
|
||||
tool_spinners = {
|
||||
'web_search': ('arrows', ['🔍', '🌐', '📡', '🔎']),
|
||||
@@ -2663,6 +2169,7 @@ class AIAgent:
|
||||
'image_generate': ('sparkle', ['🎨', '✨', '🖼️', '🌟']),
|
||||
'skill_view': ('star', ['📚', '📖', '🎓', '✨']),
|
||||
'skills_list': ('pulse', ['📋', '📝', '📑', '📜']),
|
||||
'skills_categories': ('pulse', ['📂', '🗂️', '📁', '🏷️']),
|
||||
'moa_query': ('brain', ['🧠', '💭', '🤔', '💡']),
|
||||
'analyze_image': ('sparkle', ['👁️', '🔍', '📷', '✨']),
|
||||
}
|
||||
@@ -2678,6 +2185,9 @@ class AIAgent:
|
||||
tool_duration = time.time() - tool_start_time
|
||||
cute_msg = self._get_cute_tool_message(function_name, function_args, tool_duration)
|
||||
spinner.stop(cute_msg)
|
||||
elif self.quiet_mode:
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
else:
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
@@ -2700,21 +2210,6 @@ class AIAgent:
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
|
||||
# Check for interrupt between tool calls - skip remaining
|
||||
# tools so the agent can respond to the user immediately
|
||||
if self._interrupt_requested and i < len(assistant_message.tool_calls):
|
||||
remaining = len(assistant_message.tool_calls) - i
|
||||
print(f"{self.log_prefix}⚡ Interrupt: skipping {remaining} remaining tool call(s)")
|
||||
# Add placeholder results for skipped tool calls so the
|
||||
# message sequence stays valid (assistant tool_calls need matching tool results)
|
||||
for skipped_tc in assistant_message.tool_calls[i:]:
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": "[Tool execution skipped - user sent a new message]",
|
||||
"tool_call_id": skipped_tc.id
|
||||
})
|
||||
break
|
||||
|
||||
# Delay between tool calls
|
||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
||||
time.sleep(self.tool_delay)
|
||||
@@ -3161,4 +2656,11 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
import fire # type: ignore
|
||||
except ModuleNotFoundError as exc:
|
||||
raise SystemExit(
|
||||
"Missing optional dependency 'fire'. Install hermes-agent with its CLI extras or add `fire` "
|
||||
f"to your environment. Original error: {exc}"
|
||||
) from exc
|
||||
fire.Fire(main)
|
||||
|
||||
@@ -262,25 +262,6 @@ function Test-Ripgrep {
|
||||
return $true # Don't fail - ripgrep is optional
|
||||
}
|
||||
|
||||
function Test-Ffmpeg {
|
||||
Write-Info "Checking ffmpeg (optional, for TTS voice messages)..."
|
||||
|
||||
if (Get-Command ffmpeg -ErrorAction SilentlyContinue) {
|
||||
$version = ffmpeg -version 2>&1 | Select-Object -First 1
|
||||
Write-Success "ffmpeg found"
|
||||
$script:HasFfmpeg = $true
|
||||
return $true
|
||||
}
|
||||
|
||||
Write-Warn "ffmpeg not found (TTS voice bubbles on Telegram will send as audio files instead)"
|
||||
Write-Info " Install with: winget install ffmpeg"
|
||||
Write-Info " Or: choco install ffmpeg"
|
||||
Write-Info " Or download from: https://ffmpeg.org/download.html"
|
||||
|
||||
$script:HasFfmpeg = $false
|
||||
return $true # Don't fail - ffmpeg is optional
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Installation
|
||||
# ============================================================================
|
||||
@@ -586,7 +567,6 @@ function Main {
|
||||
if (-not (Test-Git)) { exit 1 }
|
||||
Test-Node # Optional, doesn't fail
|
||||
Test-Ripgrep # Optional, doesn't fail
|
||||
Test-Ffmpeg # Optional, doesn't fail
|
||||
|
||||
Install-Repository
|
||||
Install-Venv
|
||||
|
||||
@@ -413,45 +413,6 @@ check_ripgrep() {
|
||||
# Don't exit - ripgrep is optional (grep fallback exists)
|
||||
}
|
||||
|
||||
check_ffmpeg() {
|
||||
log_info "Checking ffmpeg (optional, for TTS voice messages)..."
|
||||
|
||||
if command -v ffmpeg &> /dev/null; then
|
||||
local ffmpeg_version=$(ffmpeg -version 2>/dev/null | head -1 | awk '{print $3}')
|
||||
log_success "ffmpeg found: $ffmpeg_version"
|
||||
HAS_FFMPEG=true
|
||||
return
|
||||
fi
|
||||
|
||||
log_warn "ffmpeg not found (TTS voice bubbles on Telegram will send as audio files instead)"
|
||||
log_info "To install ffmpeg (optional):"
|
||||
|
||||
case "$OS" in
|
||||
linux)
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian)
|
||||
log_info " sudo apt install ffmpeg"
|
||||
;;
|
||||
fedora)
|
||||
log_info " sudo dnf install ffmpeg"
|
||||
;;
|
||||
arch)
|
||||
log_info " sudo pacman -S ffmpeg"
|
||||
;;
|
||||
*)
|
||||
log_info " https://ffmpeg.org/download.html"
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
macos)
|
||||
log_info " brew install ffmpeg"
|
||||
;;
|
||||
esac
|
||||
|
||||
HAS_FFMPEG=false
|
||||
# Don't exit - ffmpeg is optional
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Installation
|
||||
# ============================================================================
|
||||
@@ -746,7 +707,6 @@ main() {
|
||||
check_git
|
||||
check_node
|
||||
check_ripgrep
|
||||
check_ffmpeg
|
||||
|
||||
clone_repo
|
||||
setup_venv
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Kill all running Modal apps (sandboxes, deployments, etc.)
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/kill_modal.sh # Stop swe-rex (the sandbox app)
|
||||
# bash scripts/kill_modal.sh --all # Stop ALL Modal apps
|
||||
|
||||
set -uo pipefail
|
||||
|
||||
echo "Fetching Modal app list..."
|
||||
APP_LIST=$(modal app list 2>/dev/null)
|
||||
|
||||
if [[ "${1:-}" == "--all" ]]; then
|
||||
echo "Stopping ALL Modal apps..."
|
||||
echo "$APP_LIST" | grep -oE 'ap-[A-Za-z0-9]+' | sort -u | while read app_id; do
|
||||
echo " Stopping $app_id"
|
||||
modal app stop "$app_id" 2>/dev/null || true
|
||||
done
|
||||
else
|
||||
echo "Stopping swe-rex sandboxes..."
|
||||
APPS=$(echo "$APP_LIST" | grep 'swe-rex' | grep -oE 'ap-[A-Za-z0-9]+' || true)
|
||||
if [[ -z "$APPS" ]]; then
|
||||
echo " No swe-rex apps found."
|
||||
else
|
||||
echo "$APPS" | while read app_id; do
|
||||
echo " Stopping $app_id"
|
||||
modal app stop "$app_id" 2>/dev/null || true
|
||||
done
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Current swe-rex status:"
|
||||
modal app list 2>/dev/null | grep -E 'State|swe-rex' || echo " (none)"
|
||||
Executable
+62
@@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Launch a local llama.cpp OpenAI-compatible server running GLM-4.7-Flash (GGUF).
|
||||
#
|
||||
# Requires:
|
||||
# - `llama-server` installed (e.g. `brew install llama.cpp`)
|
||||
#
|
||||
# Default settings are chosen to avoid clashing with Atropos sandbox_server
|
||||
# (which commonly uses port 8080 in local dev).
|
||||
#
|
||||
# Usage:
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh
|
||||
#
|
||||
# Override defaults:
|
||||
# LLAMA_CPP_HOST=127.0.0.1 LLAMA_CPP_PORT=8082 \
|
||||
# LLAMA_CPP_HF_REPO=ggml-org/GLM-4.7-Flash-GGUF \
|
||||
# LLAMA_CPP_HF_FILE=GLM-4.7-Flash-Q4_K.gguf \
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh
|
||||
|
||||
HOST="${LLAMA_CPP_HOST:-127.0.0.1}"
|
||||
PORT="${LLAMA_CPP_PORT:-8080}"
|
||||
HF_REPO="${LLAMA_CPP_HF_REPO:-ggml-org/GLM-4.7-Flash-GGUF}"
|
||||
HF_FILE="${LLAMA_CPP_HF_FILE:-GLM-4.7-Flash-Q4_K.gguf}"
|
||||
ALIAS="${LLAMA_CPP_ALIAS:-glm-4.7-flash}"
|
||||
|
||||
if ! command -v llama-server >/dev/null 2>&1; then
|
||||
echo "Error: llama-server not found in PATH."
|
||||
echo "Install via Homebrew: brew install llama.cpp"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Launching llama.cpp server..."
|
||||
echo " host: $HOST"
|
||||
echo " port: $PORT"
|
||||
echo " repo: $HF_REPO"
|
||||
echo " file: $HF_FILE"
|
||||
echo " alias: $ALIAS"
|
||||
echo
|
||||
echo "Suggested env vars for Hermes/Atropos integration:"
|
||||
echo " export ATROPOS_SERVER_BASE_URL=http://${HOST}:${PORT}"
|
||||
echo " export ATROPOS_SERVER_MODEL=${ALIAS}"
|
||||
echo " export ATROPOS_SERVER_API_KEY=local"
|
||||
echo
|
||||
|
||||
if command -v lsof >/dev/null 2>&1; then
|
||||
if lsof -nP -iTCP:"$PORT" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
echo "Error: port $PORT is already in use."
|
||||
echo "Pick a different port, e.g.:"
|
||||
echo " LLAMA_CPP_PORT=8082 Hermes-Agent/scripts/launch_llama_cpp_glm47_flash.sh"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
exec llama-server \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--hf-repo "$HF_REPO" \
|
||||
--hf-file "$HF_FILE" \
|
||||
--alias "$ALIAS" \
|
||||
-c 32768 \
|
||||
-n -1
|
||||
Executable
+70
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
# Launch a local llama.cpp OpenAI-compatible server running Hermes 4.3 36B (GGUF).
|
||||
#
|
||||
# Requires:
|
||||
# - `llama-server` installed (e.g. `brew install llama.cpp`)
|
||||
#
|
||||
# Note: Port choice can conflict with other local dev servers. If 8080 is already
|
||||
# in use, override via `LLAMA_CPP_PORT=...`.
|
||||
#
|
||||
# Usage:
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
#
|
||||
# Override defaults:
|
||||
# LLAMA_CPP_HOST=127.0.0.1 LLAMA_CPP_PORT=8082 \
|
||||
# LLAMA_CPP_HF_REPO=NousResearch/Hermes-4.3-36B-GGUF \
|
||||
# LLAMA_CPP_HF_FILE=hermes-4_3_36b-Q4_K_M.gguf \
|
||||
# LLAMA_CPP_ALIAS=hermes-4-36b \
|
||||
# LLAMA_CPP_PARALLEL=4 LLAMA_CPP_THREADS_HTTP=4 \
|
||||
# Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh
|
||||
|
||||
HOST="${LLAMA_CPP_HOST:-127.0.0.1}"
|
||||
PORT="${LLAMA_CPP_PORT:-8080}"
|
||||
HF_REPO="${LLAMA_CPP_HF_REPO:-NousResearch/Hermes-4.3-36B-GGUF}"
|
||||
HF_FILE="${LLAMA_CPP_HF_FILE:-hermes-4_3_36b-Q4_K_M.gguf}"
|
||||
ALIAS="${LLAMA_CPP_ALIAS:-hermes-4-36b}"
|
||||
PARALLEL="${LLAMA_CPP_PARALLEL:-4}"
|
||||
THREADS_HTTP="${LLAMA_CPP_THREADS_HTTP:-4}"
|
||||
|
||||
if ! command -v llama-server >/dev/null 2>&1; then
|
||||
echo "Error: llama-server not found in PATH."
|
||||
echo "Install via Homebrew: brew install llama.cpp"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Launching llama.cpp server..."
|
||||
echo " host: $HOST"
|
||||
echo " port: $PORT"
|
||||
echo " repo: $HF_REPO"
|
||||
echo " file: $HF_FILE"
|
||||
echo " alias: $ALIAS"
|
||||
echo " slots: $PARALLEL"
|
||||
echo
|
||||
echo "Suggested env vars for Hermes/Atropos integration:"
|
||||
echo " export ATROPOS_SERVER_BASE_URL=http://${HOST}:${PORT}"
|
||||
echo " export ATROPOS_SERVER_MODEL=${ALIAS}"
|
||||
echo " export ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B"
|
||||
echo " export ATROPOS_SERVER_API_KEY=local"
|
||||
echo
|
||||
|
||||
if command -v lsof >/dev/null 2>&1; then
|
||||
if lsof -nP -iTCP:"$PORT" -sTCP:LISTEN >/dev/null 2>&1; then
|
||||
echo "Error: port $PORT is already in use."
|
||||
echo "Pick a different port, e.g.:"
|
||||
echo " LLAMA_CPP_PORT=8082 Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
exec llama-server \
|
||||
--host "$HOST" \
|
||||
--port "$PORT" \
|
||||
--hf-repo "$HF_REPO" \
|
||||
--hf-file "$HF_FILE" \
|
||||
--alias "$ALIAS" \
|
||||
--parallel "$PARALLEL" \
|
||||
--threads-http "$THREADS_HTTP" \
|
||||
-c 32768 \
|
||||
-n -1
|
||||
@@ -1,3 +0,0 @@
|
||||
---
|
||||
description: Diagram creation skills for generating visual diagrams, flowcharts, architecture diagrams, and illustrations using tools like Excalidraw.
|
||||
---
|
||||
@@ -1,191 +0,0 @@
|
||||
---
|
||||
name: excalidraw
|
||||
description: Create hand-drawn style diagrams using Excalidraw JSON format. Generate .excalidraw files for architecture diagrams, flowcharts, sequence diagrams, concept maps, and more. Files can be opened at excalidraw.com or uploaded for shareable links.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
tags: [Excalidraw, Diagrams, Flowcharts, Architecture, Visualization, JSON]
|
||||
dependencies: []
|
||||
related_skills: []
|
||||
---
|
||||
|
||||
# Excalidraw Diagram Skill
|
||||
|
||||
Create diagrams by writing standard Excalidraw element JSON and saving as `.excalidraw` files. These files can be drag-and-dropped onto [excalidraw.com](https://excalidraw.com) for viewing and editing. No accounts, no API keys, no rendering libraries -- just JSON.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. **Load this skill** (you already did)
|
||||
2. **Write the elements JSON** -- an array of Excalidraw element objects
|
||||
3. **Save the file** using `write_file` to create a `.excalidraw` file
|
||||
4. **Optionally upload** for a shareable link using `scripts/upload.py` via `terminal`
|
||||
|
||||
### Saving a Diagram
|
||||
|
||||
Wrap your elements array in the standard `.excalidraw` envelope and save with `write_file`:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "excalidraw",
|
||||
"version": 2,
|
||||
"source": "hermes-agent",
|
||||
"elements": [ ...your elements array here... ],
|
||||
"appState": {
|
||||
"viewBackgroundColor": "#ffffff"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Save to any path, e.g. `~/diagrams/my_diagram.excalidraw`.
|
||||
|
||||
### Uploading for a Shareable Link
|
||||
|
||||
Run the upload script (located in this skill's `scripts/` directory) via terminal:
|
||||
|
||||
```bash
|
||||
python skills/diagramming/excalidraw/scripts/upload.py ~/diagrams/my_diagram.excalidraw
|
||||
```
|
||||
|
||||
This uploads to excalidraw.com (no account needed) and prints a shareable URL. Requires the `cryptography` pip package (`pip install cryptography`).
|
||||
|
||||
---
|
||||
|
||||
## Element Format Reference
|
||||
|
||||
### Required Fields (all elements)
|
||||
`type`, `id` (unique string), `x`, `y`, `width`, `height`
|
||||
|
||||
### Defaults (skip these -- they're applied automatically)
|
||||
- `strokeColor`: `"#1e1e1e"`
|
||||
- `backgroundColor`: `"transparent"`
|
||||
- `fillStyle`: `"solid"`
|
||||
- `strokeWidth`: `2`
|
||||
- `roughness`: `1` (hand-drawn look)
|
||||
- `opacity`: `100`
|
||||
|
||||
Canvas background is white.
|
||||
|
||||
### Element Types
|
||||
|
||||
**Rectangle**:
|
||||
```json
|
||||
{ "type": "rectangle", "id": "r1", "x": 100, "y": 100, "width": 200, "height": 100 }
|
||||
```
|
||||
- `roundness: { "type": 3 }` for rounded corners
|
||||
- `backgroundColor: "#a5d8ff"`, `fillStyle: "solid"` for filled
|
||||
|
||||
**Ellipse**:
|
||||
```json
|
||||
{ "type": "ellipse", "id": "e1", "x": 100, "y": 100, "width": 150, "height": 150 }
|
||||
```
|
||||
|
||||
**Diamond**:
|
||||
```json
|
||||
{ "type": "diamond", "id": "d1", "x": 100, "y": 100, "width": 150, "height": 150 }
|
||||
```
|
||||
|
||||
**Labeled shape (container binding)** -- create a text element bound to the shape:
|
||||
|
||||
> **WARNING:** Do NOT use `"label": { "text": "..." }` on shapes. This is NOT a valid
|
||||
> Excalidraw property and will be silently ignored, producing blank shapes. You MUST
|
||||
> use the container binding approach below.
|
||||
|
||||
The shape needs `boundElements` listing the text, and the text needs `containerId` pointing back:
|
||||
```json
|
||||
{ "type": "rectangle", "id": "r1", "x": 100, "y": 100, "width": 200, "height": 80,
|
||||
"roundness": { "type": 3 }, "backgroundColor": "#a5d8ff", "fillStyle": "solid",
|
||||
"boundElements": [{ "id": "t_r1", "type": "text" }] },
|
||||
{ "type": "text", "id": "t_r1", "x": 105, "y": 110, "width": 190, "height": 25,
|
||||
"text": "Hello", "fontSize": 20, "fontFamily": 1, "strokeColor": "#1e1e1e",
|
||||
"textAlign": "center", "verticalAlign": "middle",
|
||||
"containerId": "r1", "originalText": "Hello", "autoResize": true }
|
||||
```
|
||||
- Works on rectangle, ellipse, diamond
|
||||
- Text is auto-centered by Excalidraw when `containerId` is set
|
||||
- The text `x`/`y`/`width`/`height` are approximate -- Excalidraw recalculates them on load
|
||||
- `originalText` should match `text`
|
||||
- Always include `fontFamily: 1` (Virgil/hand-drawn font)
|
||||
|
||||
**Labeled arrow** -- same container binding approach:
|
||||
```json
|
||||
{ "type": "arrow", "id": "a1", "x": 300, "y": 150, "width": 200, "height": 0,
|
||||
"points": [[0,0],[200,0]], "endArrowhead": "arrow",
|
||||
"boundElements": [{ "id": "t_a1", "type": "text" }] },
|
||||
{ "type": "text", "id": "t_a1", "x": 370, "y": 130, "width": 60, "height": 20,
|
||||
"text": "connects", "fontSize": 16, "fontFamily": 1, "strokeColor": "#1e1e1e",
|
||||
"textAlign": "center", "verticalAlign": "middle",
|
||||
"containerId": "a1", "originalText": "connects", "autoResize": true }
|
||||
```
|
||||
|
||||
**Standalone text** (titles and annotations only -- no container):
|
||||
```json
|
||||
{ "type": "text", "id": "t1", "x": 150, "y": 138, "text": "Hello", "fontSize": 20,
|
||||
"fontFamily": 1, "strokeColor": "#1e1e1e", "originalText": "Hello", "autoResize": true }
|
||||
```
|
||||
- `x` is the LEFT edge. To center at position `cx`: `x = cx - (text.length * fontSize * 0.5) / 2`
|
||||
- Do NOT rely on `textAlign` or `width` for positioning
|
||||
|
||||
**Arrow**:
|
||||
```json
|
||||
{ "type": "arrow", "id": "a1", "x": 300, "y": 150, "width": 200, "height": 0,
|
||||
"points": [[0,0],[200,0]], "endArrowhead": "arrow" }
|
||||
```
|
||||
- `points`: `[dx, dy]` offsets from element `x`, `y`
|
||||
- `endArrowhead`: `null` | `"arrow"` | `"bar"` | `"dot"` | `"triangle"`
|
||||
- `strokeStyle`: `"solid"` (default) | `"dashed"` | `"dotted"`
|
||||
|
||||
### Arrow Bindings (connect arrows to shapes)
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "arrow", "id": "a1", "x": 300, "y": 150, "width": 150, "height": 0,
|
||||
"points": [[0,0],[150,0]], "endArrowhead": "arrow",
|
||||
"startBinding": { "elementId": "r1", "fixedPoint": [1, 0.5] },
|
||||
"endBinding": { "elementId": "r2", "fixedPoint": [0, 0.5] }
|
||||
}
|
||||
```
|
||||
|
||||
`fixedPoint` coordinates: `top=[0.5,0]`, `bottom=[0.5,1]`, `left=[0,0.5]`, `right=[1,0.5]`
|
||||
|
||||
### Drawing Order (z-order)
|
||||
- Array order = z-order (first = back, last = front)
|
||||
- Emit progressively: background zones → shape → its bound text → its arrows → next shape
|
||||
- BAD: all rectangles, then all texts, then all arrows
|
||||
- GOOD: bg_zone → shape1 → text_for_shape1 → arrow1 → arrow_label_text → shape2 → text_for_shape2 → ...
|
||||
- Always place the bound text element immediately after its container shape
|
||||
|
||||
### Sizing Guidelines
|
||||
|
||||
**Font sizes:**
|
||||
- Minimum `fontSize`: **16** for body text, labels, descriptions
|
||||
- Minimum `fontSize`: **20** for titles and headings
|
||||
- Minimum `fontSize`: **14** for secondary annotations only (sparingly)
|
||||
- NEVER use `fontSize` below 14
|
||||
|
||||
**Element sizes:**
|
||||
- Minimum shape size: 120x60 for labeled rectangles/ellipses
|
||||
- Leave 20-30px gaps between elements minimum
|
||||
- Prefer fewer, larger elements over many tiny ones
|
||||
|
||||
### Color Palette
|
||||
|
||||
See `references/colors.md` for full color tables. Quick reference:
|
||||
|
||||
| Use | Fill Color | Hex |
|
||||
|-----|-----------|-----|
|
||||
| Primary / Input | Light Blue | `#a5d8ff` |
|
||||
| Success / Output | Light Green | `#b2f2bb` |
|
||||
| Warning / External | Light Orange | `#ffd8a8` |
|
||||
| Processing / Special | Light Purple | `#d0bfff` |
|
||||
| Error / Critical | Light Red | `#ffc9c9` |
|
||||
| Notes / Decisions | Light Yellow | `#fff3bf` |
|
||||
| Storage / Data | Light Teal | `#c3fae8` |
|
||||
|
||||
### Tips
|
||||
- Use the color palette consistently across the diagram
|
||||
- **Text contrast is CRITICAL** -- never use light gray on white backgrounds. Minimum text color on white: `#757575`
|
||||
- Do NOT use emoji in text -- they don't render in Excalidraw's font
|
||||
- For dark mode diagrams, see `references/dark-mode.md`
|
||||
- For larger examples, see `references/examples.md`
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user