Compare commits
81 Commits
thought-si
...
rl-capabil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
533c064269 | ||
|
|
5c3105b437 | ||
|
|
3c0d0dba49 | ||
|
|
12bbca95ec | ||
|
|
f6574978de | ||
|
|
f018999da9 | ||
|
|
51a6b7d2b5 | ||
|
|
9bfe185a2e | ||
|
|
beeb7896e0 | ||
|
|
212460289b | ||
|
|
221fb17c5e | ||
|
|
488deb04a4 | ||
|
|
9d9eea9ac9 | ||
|
|
e7f0ffbf5d | ||
|
|
a09b018bd5 | ||
|
|
7eac4ee9fe | ||
|
|
17a5efb416 | ||
|
|
3e634aa7e4 | ||
|
|
5d3398aa8a | ||
|
|
76d929e177 | ||
|
|
be91af7551 | ||
|
|
c9011fc7e1 | ||
|
|
ff776b57bf | ||
|
|
3ee788dacc | ||
|
|
fef504f038 | ||
|
|
bbb5776763 | ||
|
|
e87bee9ccd | ||
|
|
69a338610a | ||
|
|
aa6394e94f | ||
|
|
ef409c6a24 | ||
|
|
da4167560f | ||
|
|
3488576bd8 | ||
|
|
619c72e566 | ||
|
|
a3ba41fce2 | ||
|
|
c935a604f8 | ||
|
|
e114f09f70 | ||
|
|
9b4d9452ba | ||
|
|
bbeed5b5d1 | ||
|
|
971ed2bbdf | ||
|
|
affc4e9a8f | ||
|
|
3db83b6824 | ||
|
|
9c8d707530 | ||
|
|
8f5f99c22a | ||
|
|
32254d3010 | ||
|
|
20f2875472 | ||
|
|
c360da4f35 | ||
|
|
bc76a032ba | ||
|
|
8e986584f4 | ||
|
|
4b68d30b0e | ||
|
|
b292192467 | ||
|
|
f172f7d4aa | ||
|
|
8e8b6be690 | ||
|
|
e8c6135a91 | ||
|
|
771cf41fea | ||
|
|
7ea17bb957 | ||
|
|
f8846f85a1 | ||
|
|
4c05ef0ba8 | ||
|
|
5438b64e32 | ||
|
|
248acf715e | ||
|
|
54ca0997ee | ||
|
|
b78076cac7 | ||
|
|
ba19d530ad | ||
|
|
47555602d7 | ||
|
|
6eb76c7c1a | ||
|
|
b32cc4b09d | ||
|
|
6e3dbb8d8b | ||
|
|
b66c093316 | ||
|
|
13d360030f | ||
|
|
66daebe88f | ||
|
|
4071ba29da | ||
|
|
21f9e2df40 | ||
|
|
80d326310e | ||
|
|
53fc705b13 | ||
|
|
d5af53888a | ||
|
|
a7a37249f7 | ||
|
|
6af6ff2a0a | ||
|
|
30ca282594 | ||
|
|
0fbc0475f3 | ||
|
|
e5e77381f0 | ||
|
|
066514e2a9 | ||
|
|
045a1737f8 |
23
.cursorrules
23
.cursorrules
@@ -1,23 +0,0 @@
|
||||
Hermes-Agent is an agent harness for LLMs.
|
||||
|
||||
When building, the tool functionality is in the tools/ directory, where each specific tool (or in some cases, tools that are built for the same execution category or api) are placed in a script each their own.
|
||||
|
||||
Each tool is then consolidated in the model_tools.py file in the repo root.
|
||||
|
||||
There is also a way to consolidate sets of tools in toolsets.py for the agent to use.
|
||||
|
||||
The primary agent runner code is in run_agent, but other runners could be developed using the tools and framework.
|
||||
|
||||
Always ensure consistency between tools, the model_tools.py and toolsets.py when changing any of them, otherwise they could become desynced in a way that is detrimental to functionality.
|
||||
|
||||
The expected pathway for using API keys is to setup and place them in a .env file in the repo root.
|
||||
|
||||
Test scripts will be placed in tests/
|
||||
|
||||
The run_agent loop is setup to:
|
||||
- Process the enabled toolsets to provide to the model,
|
||||
- Pipe in a prompt or problem from the input to the agent,
|
||||
- Loop the LLM each time it calls a tool, until the model decides no more tools are needed and provides a natural language response,
|
||||
- Return that response.
|
||||
|
||||
There are additional caveats for logging, where we restructure the "tools" as a system prompt for storage later into a format that can be used and handled properly later.
|
||||
168
.env.example
168
.env.example
@@ -1,14 +1,21 @@
|
||||
# Hermes Agent Environment Configuration
|
||||
# Copy this file to .env and fill in your API keys
|
||||
# Get API keys from the URLs listed below
|
||||
|
||||
# =============================================================================
|
||||
# REQUIRED API KEYS
|
||||
# LLM PROVIDER (OpenRouter)
|
||||
# =============================================================================
|
||||
# OpenRouter provides access to many models through one API
|
||||
# All LLM calls go through OpenRouter - no direct provider keys needed
|
||||
# Get your key at: https://openrouter.ai/keys
|
||||
OPENROUTER_API_KEY=
|
||||
|
||||
# Anthropic API Key - Main agent model
|
||||
# Get at: https://console.anthropic.com/
|
||||
ANTHROPIC_API_KEY=
|
||||
# Default model to use (OpenRouter format: provider/model)
|
||||
# Examples: anthropic/claude-sonnet-4, openai/gpt-4o, google/gemini-2.0-flash, zhipuai/glm-4-plus
|
||||
LLM_MODEL=anthropic/claude-sonnet-4
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
# Firecrawl API Key - Web search, extract, and crawl
|
||||
# Get at: https://firecrawl.dev/
|
||||
@@ -18,32 +25,161 @@ FIRECRAWL_API_KEY=
|
||||
# Get at: https://inference-api.nousresearch.com/
|
||||
NOUS_API_KEY=
|
||||
|
||||
# Morph API Key - Terminal/command execution tools
|
||||
# Get at: https://morph.so/
|
||||
MORPH_API_KEY=
|
||||
|
||||
# FAL.ai API Key - Image generation
|
||||
# Get at: https://fal.ai/
|
||||
FAL_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# OPTIONAL API KEYS
|
||||
# TERMINAL TOOL CONFIGURATION (mini-swe-agent backend)
|
||||
# =============================================================================
|
||||
# Backend type: "local", "singularity", "docker", "modal", or "ssh"
|
||||
# - local: Runs directly on your machine (fastest, no isolation)
|
||||
# - ssh: Runs on remote server via SSH (great for sandboxing - agent can't touch its own code)
|
||||
# - singularity: Runs in Apptainer/Singularity containers (HPC clusters, no root needed)
|
||||
# - docker: Runs in Docker containers (isolated, requires Docker + docker group)
|
||||
# - modal: Runs in Modal cloud sandboxes (scalable, requires Modal account)
|
||||
TERMINAL_ENV=local
|
||||
|
||||
# OpenAI API Key - Optional, for enhanced Hecate features
|
||||
# Get at: https://platform.openai.com/
|
||||
OPENAI_API_KEY=
|
||||
|
||||
# Container images (for singularity/docker/modal backends)
|
||||
TERMINAL_DOCKER_IMAGE=python:3.11
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://python:3.11
|
||||
TERMINAL_MODAL_IMAGE=python:3.11
|
||||
|
||||
# Working directory inside the container
|
||||
TERMINAL_CWD=/tmp
|
||||
|
||||
# Default command timeout in seconds
|
||||
TERMINAL_TIMEOUT=60
|
||||
|
||||
# Cleanup inactive environments after this many seconds
|
||||
TERMINAL_LIFETIME_SECONDS=300
|
||||
|
||||
# =============================================================================
|
||||
# OPTIONAL CONFIGURATION
|
||||
# SSH REMOTE EXECUTION (for TERMINAL_ENV=ssh)
|
||||
# =============================================================================
|
||||
# Run terminal commands on a remote server via SSH.
|
||||
# Agent code stays on your machine, commands execute remotely.
|
||||
#
|
||||
# SECURITY BENEFITS:
|
||||
# - Agent cannot read your .env file (API keys protected)
|
||||
# - Agent cannot modify its own code
|
||||
# - Remote server acts as isolated sandbox
|
||||
# - Can safely configure passwordless sudo on remote
|
||||
#
|
||||
# TERMINAL_SSH_HOST=192.168.1.100
|
||||
# TERMINAL_SSH_USER=agent
|
||||
# TERMINAL_SSH_PORT=22
|
||||
# TERMINAL_SSH_KEY=~/.ssh/id_rsa
|
||||
|
||||
# =============================================================================
|
||||
# SUDO SUPPORT (works with ALL terminal backends)
|
||||
# =============================================================================
|
||||
# If set, enables sudo commands by piping password via `sudo -S`.
|
||||
# Works with: local, docker, singularity, modal, and ssh backends.
|
||||
#
|
||||
# SECURITY WARNING: Password stored in plaintext. Only use on trusted machines.
|
||||
#
|
||||
# ALTERNATIVES:
|
||||
# - For SSH backend: Configure passwordless sudo on the remote server
|
||||
# - For containers: Run as root inside the container (no sudo needed)
|
||||
# - For local: Configure /etc/sudoers for specific commands
|
||||
# - For CLI: Leave unset - you'll be prompted interactively with 45s timeout
|
||||
#
|
||||
# SUDO_PASSWORD=your_password_here
|
||||
|
||||
# =============================================================================
|
||||
# MODAL CLOUD BACKEND (Optional - for TERMINAL_ENV=modal)
|
||||
# =============================================================================
|
||||
# Modal uses CLI authentication, not environment variables.
|
||||
# Run: pip install modal && modal setup
|
||||
# This will authenticate via browser and store credentials locally.
|
||||
# No API key needed in .env - Modal handles auth automatically.
|
||||
|
||||
# =============================================================================
|
||||
# BROWSER TOOL CONFIGURATION (agent-browser + Browserbase)
|
||||
# =============================================================================
|
||||
# Browser automation requires Browserbase cloud service for remote browser execution.
|
||||
# This allows the agent to navigate websites, fill forms, and extract information.
|
||||
#
|
||||
# STEALTH MODES:
|
||||
# - Basic Stealth: ALWAYS active (random fingerprints, auto CAPTCHA solving)
|
||||
# - Advanced Stealth: Requires BROWSERBASE_ADVANCED_STEALTH=true (Scale Plan only)
|
||||
|
||||
# Browserbase API Key - Cloud browser execution
|
||||
# Get at: https://browserbase.com/
|
||||
BROWSERBASE_API_KEY=
|
||||
|
||||
# Browserbase Project ID - From your Browserbase dashboard
|
||||
BROWSERBASE_PROJECT_ID=
|
||||
|
||||
# Enable residential proxies for better CAPTCHA solving (default: true)
|
||||
# Routes traffic through residential IPs, significantly improves success rate
|
||||
BROWSERBASE_PROXIES=true
|
||||
|
||||
# Enable advanced stealth mode (default: false, requires Scale Plan)
|
||||
# Uses custom Chromium build to avoid bot detection altogether
|
||||
BROWSERBASE_ADVANCED_STEALTH=false
|
||||
|
||||
# Browser session timeout in seconds (default: 300)
|
||||
# Sessions are cleaned up after this duration of inactivity
|
||||
BROWSER_SESSION_TIMEOUT=300
|
||||
|
||||
# Browser inactivity timeout - auto-cleanup inactive sessions (default: 120 = 2 min)
|
||||
# Browser sessions are automatically closed after this period of no activity
|
||||
BROWSER_INACTIVITY_TIMEOUT=120
|
||||
|
||||
# =============================================================================
|
||||
# SESSION LOGGING
|
||||
# =============================================================================
|
||||
# Session trajectories are automatically saved to logs/ directory
|
||||
# Format: logs/session_YYYYMMDD_HHMMSS_UUID.json
|
||||
# Contains full conversation history in trajectory format for debugging/replay
|
||||
|
||||
# =============================================================================
|
||||
# LEGACY/OPTIONAL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
# Terminal Tool Settings
|
||||
# Morph API Key - For legacy Hecate terminal backend (terminal-hecate tool)
|
||||
# Get at: https://morph.so/
|
||||
MORPH_API_KEY=
|
||||
|
||||
# Hecate VM Settings (only if using terminal-hecate tool)
|
||||
HECATE_VM_LIFETIME_SECONDS=300
|
||||
HECATE_DEFAULT_SNAPSHOT_ID=snapshot_p5294qxt
|
||||
|
||||
# Debug Logging (set to "true" to enable, logs saved to ./logs/)
|
||||
# =============================================================================
|
||||
# DEBUG OPTIONS
|
||||
# =============================================================================
|
||||
WEB_TOOLS_DEBUG=false
|
||||
VISION_TOOLS_DEBUG=false
|
||||
MOA_TOOLS_DEBUG=false
|
||||
IMAGE_TOOLS_DEBUG=false
|
||||
|
||||
# =============================================================================
|
||||
# CONTEXT COMPRESSION (Auto-shrinks long conversations)
|
||||
# =============================================================================
|
||||
# When conversation approaches model's context limit, middle turns are
|
||||
# automatically summarized to free up space.
|
||||
#
|
||||
# CONTEXT_COMPRESSION_ENABLED=true # Enable auto-compression (default: true)
|
||||
# CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit
|
||||
# CONTEXT_COMPRESSION_MODEL=google/gemini-2.0-flash-001 # Fast model for summaries
|
||||
|
||||
# =============================================================================
|
||||
# RL TRAINING (Tinker + Atropos)
|
||||
# =============================================================================
|
||||
# Run reinforcement learning training on language models using the Tinker API.
|
||||
# Requires the rl-server to be running (from tinker-atropos package).
|
||||
|
||||
# Tinker API Key - RL training service
|
||||
# Get at: https://tinker-console.thinkingmachines.ai/keys
|
||||
TINKER_API_KEY=
|
||||
|
||||
# Weights & Biases API Key - Experiment tracking and metrics
|
||||
# Get at: https://wandb.ai/authorize
|
||||
WANDB_API_KEY=
|
||||
|
||||
# RL API Server URL (default: http://localhost:8080)
|
||||
# Change if running the rl-server on a different host/port
|
||||
# RL_API_URL=http://localhost:8080
|
||||
|
||||
12
.gitignore
vendored
12
.gitignore
vendored
@@ -30,3 +30,15 @@ run_datagen_megascience_glm4-6.sh
|
||||
run_datagen_sonnet.sh
|
||||
source-data/*
|
||||
run_datagen_megascience_glm4-6.sh
|
||||
data/*
|
||||
node_modules/
|
||||
browser-use/
|
||||
agent-browser/
|
||||
# Private keys
|
||||
*.ppk
|
||||
*.pem
|
||||
privvy*
|
||||
images/
|
||||
|
||||
# CLI config (may contain sensitive SSH paths)
|
||||
cli-config.yaml
|
||||
|
||||
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
[submodule "mini-swe-agent"]
|
||||
path = mini-swe-agent
|
||||
url = https://github.com/SWE-agent/mini-swe-agent
|
||||
[submodule "tinker-atropos"]
|
||||
path = tinker-atropos
|
||||
url = https://github.com/nousresearch/tinker-atropos
|
||||
533
AGENTS.md
Normal file
533
AGENTS.md
Normal file
@@ -0,0 +1,533 @@
|
||||
# Hermes Agent - Development Guide
|
||||
|
||||
Instructions for AI coding assistants (GitHub Copilot, Cursor, etc.) and human developers.
|
||||
|
||||
Hermes-Agent is an AI agent harness with tool-calling capabilities, interactive CLI, messaging integrations, and scheduled tasks.
|
||||
|
||||
## Development Environment
|
||||
|
||||
**IMPORTANT**: Always use the virtual environment if it exists:
|
||||
```bash
|
||||
source venv/bin/activate # Before running any Python commands
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
hermes-agent/
|
||||
├── hermes_cli/ # Unified CLI commands
|
||||
│ ├── main.py # Entry point, command dispatcher
|
||||
│ ├── setup.py # Interactive setup wizard
|
||||
│ ├── config.py # Config management & migration
|
||||
│ ├── status.py # Status display
|
||||
│ ├── doctor.py # Diagnostics
|
||||
│ ├── gateway.py # Gateway management
|
||||
│ ├── uninstall.py # Uninstaller
|
||||
│ └── cron.py # Cron job management
|
||||
├── tools/ # Tool implementations
|
||||
├── gateway/ # Messaging platform adapters
|
||||
├── cron/ # Scheduler implementation
|
||||
├── skills/ # Knowledge documents
|
||||
├── cli.py # Interactive CLI (Rich UI)
|
||||
├── run_agent.py # Agent runner with AIAgent class
|
||||
├── model_tools.py # Tool schemas and handlers
|
||||
├── toolsets.py # Tool groupings
|
||||
├── toolset_distributions.py # Probability-based tool selection
|
||||
└── batch_runner.py # Parallel batch processing
|
||||
```
|
||||
|
||||
**User Configuration** (stored in `~/.hermes/`):
|
||||
- `~/.hermes/config.yaml` - Settings (model, terminal, toolsets, etc.)
|
||||
- `~/.hermes/.env` - API keys and secrets
|
||||
|
||||
## File Dependency Chain
|
||||
|
||||
```
|
||||
tools/*.py → tools/__init__.py → model_tools.py → toolsets.py → toolset_distributions.py
|
||||
↑
|
||||
run_agent.py ──────────────────────────┘
|
||||
cli.py → run_agent.py (uses AIAgent with quiet_mode=True)
|
||||
batch_runner.py → run_agent.py + toolset_distributions.py
|
||||
```
|
||||
|
||||
Always ensure consistency between tools, model_tools.py, and toolsets.py when changing any of them.
|
||||
|
||||
---
|
||||
|
||||
## AIAgent Class
|
||||
|
||||
The main agent is implemented in `run_agent.py`:
|
||||
|
||||
```python
|
||||
class AIAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_iterations: int = 60, # Max tool-calling loops
|
||||
enabled_toolsets: list = None,
|
||||
disabled_toolsets: list = None,
|
||||
verbose_logging: bool = False,
|
||||
quiet_mode: bool = False, # Suppress progress output
|
||||
tool_progress_callback: callable = None, # Called on each tool use
|
||||
):
|
||||
# Initialize OpenAI client, load tools based on toolsets
|
||||
...
|
||||
|
||||
def chat(self, user_message: str, task_id: str = None) -> str:
|
||||
# Main entry point - runs the agent loop
|
||||
...
|
||||
```
|
||||
|
||||
### Agent Loop
|
||||
|
||||
The core loop in `_run_agent_loop()`:
|
||||
|
||||
```
|
||||
1. Add user message to conversation
|
||||
2. Call LLM with tools
|
||||
3. If LLM returns tool calls:
|
||||
- Execute each tool
|
||||
- Add tool results to conversation
|
||||
- Go to step 2
|
||||
4. If LLM returns text response:
|
||||
- Return response to user
|
||||
```
|
||||
|
||||
```python
|
||||
while turns < max_turns:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tool_schemas,
|
||||
)
|
||||
|
||||
if response.tool_calls:
|
||||
for tool_call in response.tool_calls:
|
||||
result = await execute_tool(tool_call)
|
||||
messages.append(tool_result_message(result))
|
||||
turns += 1
|
||||
else:
|
||||
return response.content
|
||||
```
|
||||
|
||||
### Conversation Management
|
||||
|
||||
Messages are stored as a list of dicts following OpenAI format:
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant..."},
|
||||
{"role": "user", "content": "Search for Python tutorials"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [...]},
|
||||
{"role": "tool", "tool_call_id": "...", "content": "..."},
|
||||
{"role": "assistant", "content": "Here's what I found..."},
|
||||
]
|
||||
```
|
||||
|
||||
### Reasoning Model Support
|
||||
|
||||
For models that support chain-of-thought reasoning:
|
||||
- Extract `reasoning_content` from API responses
|
||||
- Store in `assistant_msg["reasoning"]` for trajectory export
|
||||
- Pass back via `reasoning_content` field on subsequent turns
|
||||
|
||||
---
|
||||
|
||||
## CLI Architecture (cli.py)
|
||||
|
||||
The interactive CLI uses:
|
||||
- **Rich** - For the welcome banner and styled panels
|
||||
- **prompt_toolkit** - For fixed input area with history and `patch_stdout`
|
||||
- **KawaiiSpinner** (in run_agent.py) - Animated feedback during API calls and tool execution
|
||||
|
||||
Key components:
|
||||
- `HermesCLI` class - Main CLI controller with commands and conversation loop
|
||||
- `load_cli_config()` - Loads config, sets environment variables for terminal
|
||||
- `build_welcome_banner()` - Displays ASCII art logo, tools, and skills summary
|
||||
- `/commands` - Process user commands like `/help`, `/clear`, `/personality`, etc.
|
||||
|
||||
CLI uses `quiet_mode=True` when creating AIAgent to suppress verbose logging.
|
||||
|
||||
### Adding CLI Commands
|
||||
|
||||
1. Add to `COMMANDS` dict with description
|
||||
2. Add handler in `process_command()` method
|
||||
3. For persistent settings, use `save_config_value()` to update config
|
||||
|
||||
---
|
||||
|
||||
## Hermes CLI Commands
|
||||
|
||||
The unified `hermes` command provides all functionality:
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `hermes` | Interactive chat (default) |
|
||||
| `hermes chat -q "..."` | Single query mode |
|
||||
| `hermes setup` | Configure API keys and settings |
|
||||
| `hermes config` | View current configuration |
|
||||
| `hermes config edit` | Open config in editor |
|
||||
| `hermes config set KEY VAL` | Set a specific value |
|
||||
| `hermes config check` | Check for missing config |
|
||||
| `hermes config migrate` | Prompt for missing config interactively |
|
||||
| `hermes status` | Show configuration status |
|
||||
| `hermes doctor` | Diagnose issues |
|
||||
| `hermes update` | Update to latest (checks for new config) |
|
||||
| `hermes uninstall` | Uninstall (can keep configs for reinstall) |
|
||||
| `hermes gateway` | Start messaging gateway |
|
||||
| `hermes cron list` | View scheduled jobs |
|
||||
| `hermes version` | Show version info |
|
||||
|
||||
---
|
||||
|
||||
## Messaging Gateway
|
||||
|
||||
The gateway connects Hermes to Telegram, Discord, and WhatsApp.
|
||||
|
||||
### Configuration (in `~/.hermes/.env`):
|
||||
|
||||
```bash
|
||||
# Telegram
|
||||
TELEGRAM_BOT_TOKEN=123456:ABC-DEF... # From @BotFather
|
||||
TELEGRAM_ALLOWED_USERS=123456789,987654 # Comma-separated user IDs (from @userinfobot)
|
||||
|
||||
# Discord
|
||||
DISCORD_BOT_TOKEN=MTIz... # From Developer Portal
|
||||
DISCORD_ALLOWED_USERS=123456789012345678 # Comma-separated user IDs
|
||||
|
||||
# Agent Behavior
|
||||
HERMES_MAX_ITERATIONS=60 # Max tool-calling iterations
|
||||
MESSAGING_CWD=/home/myuser # Terminal working directory for messaging
|
||||
|
||||
# Tool Progress (optional)
|
||||
HERMES_TOOL_PROGRESS=true # Send progress messages
|
||||
HERMES_TOOL_PROGRESS_MODE=new # "new" or "all"
|
||||
```
|
||||
|
||||
### Working Directory Behavior
|
||||
|
||||
- **CLI (`hermes` command)**: Uses current directory (`.` → `os.getcwd()`)
|
||||
- **Messaging (Telegram/Discord)**: Uses `MESSAGING_CWD` (default: home directory)
|
||||
|
||||
This is intentional: CLI users are in a terminal and expect the agent to work in their current directory, while messaging users need a consistent starting location.
|
||||
|
||||
### Security (User Allowlists):
|
||||
|
||||
**IMPORTANT**: Without an allowlist, anyone who finds your bot can use it!
|
||||
|
||||
The gateway checks `{PLATFORM}_ALLOWED_USERS` environment variables:
|
||||
- If set: Only listed user IDs can interact with the bot
|
||||
- If unset: All users are allowed (dangerous with terminal access!)
|
||||
|
||||
Users can find their IDs:
|
||||
- **Telegram**: Message [@userinfobot](https://t.me/userinfobot)
|
||||
- **Discord**: Enable Developer Mode, right-click name → Copy ID
|
||||
|
||||
### Tool Progress Notifications
|
||||
|
||||
When `HERMES_TOOL_PROGRESS=true`, the bot sends status messages as it works:
|
||||
- `💻 \`ls -la\`...` (terminal commands show the actual command)
|
||||
- `🔍 web_search...`
|
||||
- `📄 web_extract...`
|
||||
|
||||
Modes:
|
||||
- `new`: Only when switching to a different tool (less spam)
|
||||
- `all`: Every single tool call
|
||||
|
||||
### Typing Indicator
|
||||
|
||||
The gateway keeps the "typing..." indicator active throughout processing, refreshing every 4 seconds. This lets users know the bot is working even during long tool-calling sequences.
|
||||
|
||||
### Platform Toolsets:
|
||||
|
||||
Each platform has a dedicated toolset in `toolsets.py`:
|
||||
- `hermes-telegram`: Full tools including terminal (with safety checks)
|
||||
- `hermes-discord`: Full tools including terminal
|
||||
- `hermes-whatsapp`: Full tools including terminal
|
||||
|
||||
---
|
||||
|
||||
## Configuration System
|
||||
|
||||
Configuration files are stored in `~/.hermes/` for easy user access:
|
||||
- `~/.hermes/config.yaml` - All settings (model, terminal, compression, etc.)
|
||||
- `~/.hermes/.env` - API keys and secrets
|
||||
|
||||
### Adding New Configuration Options
|
||||
|
||||
When adding new configuration variables, you MUST follow this process:
|
||||
|
||||
#### For config.yaml options:
|
||||
|
||||
1. Add to `DEFAULT_CONFIG` in `hermes_cli/config.py`
|
||||
2. **CRITICAL**: Bump `_config_version` in `DEFAULT_CONFIG` when adding required fields
|
||||
3. This triggers migration prompts for existing users on next `hermes update` or `hermes setup`
|
||||
|
||||
Example:
|
||||
```python
|
||||
DEFAULT_CONFIG = {
|
||||
# ... existing config ...
|
||||
|
||||
"new_feature": {
|
||||
"enabled": True,
|
||||
"option": "default_value",
|
||||
},
|
||||
|
||||
# BUMP THIS when adding required fields
|
||||
"_config_version": 2, # Was 1, now 2
|
||||
}
|
||||
```
|
||||
|
||||
#### For .env variables (API keys/secrets):
|
||||
|
||||
1. Add to `REQUIRED_ENV_VARS` or `OPTIONAL_ENV_VARS` in `hermes_cli/config.py`
|
||||
2. Include metadata for the migration system:
|
||||
|
||||
```python
|
||||
OPTIONAL_ENV_VARS = {
|
||||
# ... existing vars ...
|
||||
"NEW_API_KEY": {
|
||||
"description": "What this key is for",
|
||||
"prompt": "Display name in prompts",
|
||||
"url": "https://where-to-get-it.com/",
|
||||
"tools": ["tools_it_enables"], # What tools need this
|
||||
"password": True, # Mask input
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
#### Update related files:
|
||||
|
||||
- `hermes_cli/setup.py` - Add prompts in the setup wizard
|
||||
- `cli-config.yaml.example` - Add example with comments
|
||||
- Update README.md if user-facing
|
||||
|
||||
### Config Version Migration
|
||||
|
||||
The system uses `_config_version` to detect outdated configs:
|
||||
|
||||
1. `check_for_missing_config()` compares user config to `DEFAULT_CONFIG`
|
||||
2. `migrate_config()` interactively prompts for missing values
|
||||
3. Called automatically by `hermes update` and optionally by `hermes setup`
|
||||
|
||||
---
|
||||
|
||||
## Environment Variables
|
||||
|
||||
API keys are loaded from `~/.hermes/.env`:
|
||||
- `OPENROUTER_API_KEY` - Main LLM API access (primary provider)
|
||||
- `FIRECRAWL_API_KEY` - Web search/extract tools
|
||||
- `BROWSERBASE_API_KEY` / `BROWSERBASE_PROJECT_ID` - Browser automation
|
||||
- `FAL_KEY` - Image generation (FLUX model)
|
||||
- `NOUS_API_KEY` - Vision and Mixture-of-Agents tools
|
||||
|
||||
Terminal tool configuration (in `~/.hermes/config.yaml`):
|
||||
- `terminal.backend` - Backend: local, docker, singularity, modal, or ssh
|
||||
- `terminal.cwd` - Working directory for CLI ("." = current directory)
|
||||
- `terminal.docker_image` - Image for Docker backend
|
||||
- `terminal.singularity_image` - Image for Singularity backend
|
||||
- `terminal.modal_image` - Image for Modal backend
|
||||
- SSH: `TERMINAL_SSH_HOST`, `TERMINAL_SSH_USER`, `TERMINAL_SSH_KEY` in .env
|
||||
|
||||
Agent behavior (in `~/.hermes/.env`):
|
||||
- `HERMES_MAX_ITERATIONS` - Max tool-calling iterations (default: 60)
|
||||
- `MESSAGING_CWD` - Working directory for messaging platforms (default: ~)
|
||||
- `HERMES_TOOL_PROGRESS` - Enable tool progress messages (`true`/`false`)
|
||||
- `HERMES_TOOL_PROGRESS_MODE` - Progress mode: `new` (tool changes) or `all`
|
||||
|
||||
### Dangerous Command Approval
|
||||
|
||||
The terminal tool includes safety checks for potentially destructive commands (e.g., `rm -rf`, `DROP TABLE`, `chmod 777`, etc.):
|
||||
|
||||
**Behavior by Backend:**
|
||||
- **Docker/Singularity/Modal**: Commands run unrestricted (isolated containers)
|
||||
- **Local/SSH**: Dangerous commands trigger approval flow
|
||||
|
||||
**Approval Flow (CLI):**
|
||||
```
|
||||
⚠️ Potentially dangerous command detected: recursive delete
|
||||
rm -rf /tmp/test
|
||||
|
||||
[o]nce | [s]ession | [a]lways | [d]eny
|
||||
Choice [o/s/a/D]:
|
||||
```
|
||||
|
||||
**Approval Flow (Messaging):**
|
||||
- Command is blocked with explanation
|
||||
- Agent explains the command was blocked for safety
|
||||
- User must add the pattern to their allowlist via `hermes config edit` or run the command directly on their machine
|
||||
|
||||
**Configuration:**
|
||||
- `command_allowlist` in `~/.hermes/config.yaml` stores permanently allowed patterns
|
||||
- Add patterns via "always" approval or edit directly
|
||||
|
||||
**Sudo Handling (Messaging):**
|
||||
- If sudo fails over messaging, output includes tip to add `SUDO_PASSWORD` to `~/.hermes/.env`
|
||||
|
||||
---
|
||||
|
||||
## Adding New Tools
|
||||
|
||||
Follow this strict order to maintain consistency:
|
||||
|
||||
1. Create `tools/your_tool.py` with:
|
||||
- Handler function (sync or async) returning a JSON string via `json.dumps()`
|
||||
- `check_*_requirements()` function to verify dependencies (e.g., API keys)
|
||||
- Schema definition following OpenAI function-calling format
|
||||
|
||||
2. Export in `tools/__init__.py`:
|
||||
- Import the handler and check function
|
||||
- Add to `__all__` list
|
||||
|
||||
3. Register in `model_tools.py`:
|
||||
- Add to `TOOLSET_REQUIREMENTS` if it needs API keys
|
||||
- Create `get_*_tool_definitions()` function or add to existing
|
||||
- Add routing in `handle_function_call()` dispatcher
|
||||
- Update `get_all_tool_names()` with the tool name
|
||||
- Update `get_toolset_for_tool()` mapping
|
||||
- Update `get_available_toolsets()` and `check_toolset_requirements()`
|
||||
|
||||
4. Add to toolset in `toolsets.py`:
|
||||
- Add to existing toolset or create new one in TOOLSETS dict
|
||||
|
||||
5. If the tool requires an API key:
|
||||
- Add to `OPTIONAL_ENV_VARS` in `hermes_cli/config.py`
|
||||
- The tool will be auto-disabled if the key is missing
|
||||
|
||||
6. Optionally add to `toolset_distributions.py` for batch processing
|
||||
|
||||
### Tool Implementation Pattern
|
||||
|
||||
```python
|
||||
# tools/example_tool.py
|
||||
import json
|
||||
import os
|
||||
|
||||
def check_example_requirements() -> bool:
|
||||
"""Check if required API keys/dependencies are available."""
|
||||
return bool(os.getenv("EXAMPLE_API_KEY"))
|
||||
|
||||
def example_tool(param: str, task_id: str = None) -> str:
|
||||
"""Execute the tool and return JSON string result."""
|
||||
try:
|
||||
result = {"success": True, "data": "..."}
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
```
|
||||
|
||||
All tool handlers MUST return a JSON string. Never return raw dicts.
|
||||
|
||||
### Dynamic Tool Availability
|
||||
|
||||
Tools are automatically disabled when their API keys are missing:
|
||||
|
||||
```python
|
||||
# In model_tools.py
|
||||
TOOLSET_REQUIREMENTS = {
|
||||
"web": {"env_vars": ["FIRECRAWL_API_KEY"]},
|
||||
"browser": {"env_vars": ["BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID"]},
|
||||
"creative": {"env_vars": ["FAL_KEY"]},
|
||||
}
|
||||
```
|
||||
|
||||
The `check_tool_availability()` function determines which tools to include.
|
||||
|
||||
### Stateful Tools
|
||||
|
||||
Tools that maintain state (terminal, browser) require:
|
||||
- `task_id` parameter for session isolation between concurrent tasks
|
||||
- `cleanup_*()` function to release resources
|
||||
- Cleanup is called automatically in run_agent.py after conversation completes
|
||||
|
||||
---
|
||||
|
||||
## Trajectory Format
|
||||
|
||||
Conversations are saved in ShareGPT format for training:
|
||||
```json
|
||||
{"from": "system", "value": "System prompt with <tools>...</tools>"}
|
||||
{"from": "human", "value": "User message"}
|
||||
{"from": "gpt", "value": "<think>reasoning</think>\n<tool_call>{...}</tool_call>"}
|
||||
{"from": "tool", "value": "<tool_response>{...}</tool_response>"}
|
||||
{"from": "gpt", "value": "Final response"}
|
||||
```
|
||||
|
||||
Tool calls use `<tool_call>` XML tags, responses use `<tool_response>` tags, reasoning uses `<think>` tags.
|
||||
|
||||
### Trajectory Export
|
||||
|
||||
```python
|
||||
agent = AIAgent(save_trajectories=True)
|
||||
agent.chat("Do something")
|
||||
# Saves to trajectories/*.jsonl in ShareGPT format
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Batch Processing (batch_runner.py)
|
||||
|
||||
For processing multiple prompts:
|
||||
- Parallel execution with multiprocessing
|
||||
- Content-based resume for fault tolerance (matches on prompt text, not indices)
|
||||
- Toolset distributions control probabilistic tool availability per prompt
|
||||
- Output: `data/<run_name>/trajectories.jsonl` (combined) + individual batch files
|
||||
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--num_workers=4 \
|
||||
--run_name=my_run
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Skills System
|
||||
|
||||
Skills are on-demand knowledge documents the agent can load. Located in `skills/` directory:
|
||||
|
||||
```
|
||||
skills/
|
||||
├── mlops/ # Category folder
|
||||
│ ├── axolotl/ # Skill folder
|
||||
│ │ ├── SKILL.md # Main instructions (required)
|
||||
│ │ ├── references/ # Additional docs, API specs
|
||||
│ │ └── templates/ # Output formats, configs
|
||||
│ └── vllm/
|
||||
│ └── SKILL.md
|
||||
└── example-skill/
|
||||
└── SKILL.md
|
||||
```
|
||||
|
||||
**Progressive disclosure** (token-efficient):
|
||||
1. `skills_categories()` - List category names (~50 tokens)
|
||||
2. `skills_list(category)` - Name + description per skill (~3k tokens)
|
||||
3. `skill_view(name)` - Full content + tags + linked files
|
||||
|
||||
SKILL.md files use YAML frontmatter:
|
||||
```yaml
|
||||
---
|
||||
name: skill-name
|
||||
description: Brief description for listing
|
||||
tags: [tag1, tag2]
|
||||
related_skills: [other-skill]
|
||||
version: 1.0.0
|
||||
---
|
||||
# Skill Content...
|
||||
```
|
||||
|
||||
Tool files: `tools/skills_tool.py` → `model_tools.py` → `toolsets.py`
|
||||
|
||||
---
|
||||
|
||||
## Testing Changes
|
||||
|
||||
After making changes:
|
||||
|
||||
1. Run `hermes doctor` to check setup
|
||||
2. Run `hermes config check` to verify config
|
||||
3. Test with `hermes chat -q "test message"`
|
||||
4. For new config options, test fresh install: `rm -rf ~/.hermes && hermes setup`
|
||||
775
README.md
775
README.md
@@ -1,243 +1,626 @@
|
||||
# Hermes Agent
|
||||
# Hermes Agent 🦋
|
||||
|
||||
An AI agent with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools.
|
||||
An AI agent with advanced tool-calling capabilities, featuring a flexible toolsets system, messaging integrations, and scheduled tasks.
|
||||
|
||||
## Quick Install
|
||||
|
||||
**Linux/macOS:**
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
**Windows (PowerShell):**
|
||||
```powershell
|
||||
irm https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.ps1 | iex
|
||||
```
|
||||
|
||||
The installer will:
|
||||
- Clone to `~/.hermes-agent` (with submodules: mini-swe-agent, tinker-atropos)
|
||||
- Create a virtual environment
|
||||
- Install all dependencies
|
||||
- Run the interactive setup wizard
|
||||
- Add `hermes` to your PATH
|
||||
|
||||
After installation, reload your shell and run:
|
||||
```bash
|
||||
hermes setup # Configure API keys (if you skipped during install)
|
||||
hermes # Start chatting!
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Configuration
|
||||
|
||||
All your settings are stored in `~/.hermes/` for easy access:
|
||||
|
||||
```
|
||||
~/.hermes/
|
||||
├── config.yaml # Settings (model, terminal, compression, etc.)
|
||||
├── .env # API keys and secrets
|
||||
├── cron/ # Scheduled jobs
|
||||
├── sessions/ # Gateway sessions
|
||||
└── logs/ # Logs
|
||||
```
|
||||
|
||||
### Managing Configuration
|
||||
|
||||
```bash
|
||||
hermes config # View current configuration
|
||||
hermes config edit # Open config.yaml in your editor
|
||||
hermes config set KEY VAL # Set a specific value
|
||||
hermes config check # Check for missing options (after updates)
|
||||
hermes config migrate # Interactively add missing options
|
||||
|
||||
# Examples:
|
||||
hermes config set model anthropic/claude-opus-4
|
||||
hermes config set terminal.backend docker
|
||||
hermes config set OPENROUTER_API_KEY sk-or-... # Saves to .env
|
||||
```
|
||||
|
||||
### Required API Keys
|
||||
|
||||
You need at least one LLM provider:
|
||||
|
||||
| Provider | Get Key | Env Variable |
|
||||
|----------|---------|--------------|
|
||||
| **OpenRouter** (recommended) | [openrouter.ai/keys](https://openrouter.ai/keys) | `OPENROUTER_API_KEY` |
|
||||
| Anthropic | [console.anthropic.com](https://console.anthropic.com/) | `ANTHROPIC_API_KEY` |
|
||||
| OpenAI | [platform.openai.com](https://platform.openai.com/api-keys) | `OPENAI_API_KEY` |
|
||||
|
||||
### Optional API Keys
|
||||
|
||||
| Feature | Provider | Env Variable |
|
||||
|---------|----------|--------------|
|
||||
| Web scraping | [Firecrawl](https://firecrawl.dev/) | `FIRECRAWL_API_KEY` |
|
||||
| Browser automation | [Browserbase](https://browserbase.com/) | `BROWSERBASE_API_KEY`, `BROWSERBASE_PROJECT_ID` |
|
||||
| Image generation | [FAL](https://fal.ai/) | `FAL_KEY` |
|
||||
| RL Training | [Tinker](https://tinker-console.thinkingmachines.ai/) + [WandB](https://wandb.ai/) | `TINKER_API_KEY`, `WANDB_API_KEY` |
|
||||
| Messaging | Telegram, Discord | `TELEGRAM_BOT_TOKEN`, `DISCORD_BOT_TOKEN` |
|
||||
|
||||
---
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
hermes # Interactive chat (default)
|
||||
hermes chat -q "Hello" # Single query mode
|
||||
hermes setup # Configure API keys and settings
|
||||
hermes config # View/edit configuration
|
||||
hermes config check # Check for missing config (useful after updates)
|
||||
hermes config migrate # Interactively add missing options
|
||||
hermes status # Show configuration status
|
||||
hermes doctor # Diagnose issues
|
||||
hermes update # Update to latest version (prompts for new config)
|
||||
hermes uninstall # Uninstall (can keep configs for later reinstall)
|
||||
hermes gateway # Start messaging gateway
|
||||
hermes cron list # View scheduled jobs
|
||||
hermes version # Show version info
|
||||
```
|
||||
|
||||
### CLI Commands (inside chat)
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show available commands |
|
||||
| `/tools` | List available tools |
|
||||
| `/model [name]` | Show or change model |
|
||||
| `/personality [name]` | Set personality (kawaii, pirate, etc.) |
|
||||
| `/clear` | Clear screen and reset |
|
||||
| `/cron` | Manage scheduled tasks |
|
||||
| `/config` | Show current configuration |
|
||||
| `/quit` | Exit |
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
- **Web Tools**: Search, extract content, and crawl websites
|
||||
- **Terminal Tools**: Execute commands with interactive session support
|
||||
- **Vision Tools**: Analyze images from URLs
|
||||
- **Reasoning Tools**: Advanced multi-model reasoning (Mixture of Agents)
|
||||
- **Creative Tools**: Generate images from text prompts
|
||||
- **Toolsets System**: Organize tools into logical groups for different scenarios
|
||||
- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking
|
||||
- **Ephemeral System Prompts**: Guide model behavior without polluting training datasets
|
||||
### 🛠️ Tools & Toolsets
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Install Dependencies
|
||||
```bash
|
||||
# Create and activate virtual environment (recommended)
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # On Windows: venv\Scripts\activate
|
||||
|
||||
# Install required packages
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install Hecate for terminal tools
|
||||
git clone git@github.com:NousResearch/hecate.git
|
||||
cd hecate
|
||||
pip install -e .
|
||||
cd ..
|
||||
```
|
||||
|
||||
### 2. Configure Environment Variables
|
||||
```bash
|
||||
# Copy the example environment file
|
||||
cp .env.example .env
|
||||
|
||||
# Edit .env and add your API keys
|
||||
nano .env # or use your preferred editor
|
||||
```
|
||||
|
||||
**Required API Keys:**
|
||||
- `ANTHROPIC_API_KEY` - Main agent model (get at: https://console.anthropic.com/)
|
||||
- `FIRECRAWL_API_KEY` - Web tools (get at: https://firecrawl.dev/)
|
||||
- `NOUS_API_KEY` - Vision & reasoning tools (get at: https://inference-api.nousresearch.com/)
|
||||
- `MORPH_API_KEY` - Terminal tools (get at: https://morph.so/)
|
||||
- `FAL_KEY` - Image generation (get at: https://fal.ai/)
|
||||
- `OPENAI_API_KEY` - Optional, for some Hecate features
|
||||
|
||||
See `.env.example` for all available configuration options including debug settings and terminal tool configuration.
|
||||
|
||||
## Toolsets System
|
||||
|
||||
The agent uses a toolsets system for organizing and managing tools. All tools must be part of a toolset to be accessible - individual tool selection is not supported. This ensures consistent and logical grouping of capabilities.
|
||||
|
||||
### Key Concepts
|
||||
|
||||
- **Toolsets**: Logical groups of tools for specific use cases (e.g., "research", "development", "debugging")
|
||||
- **Composition**: Toolsets can include other toolsets for powerful combinations
|
||||
- **Custom Toolsets**: Create your own toolsets at runtime or by editing `toolsets.py`
|
||||
- **Toolset-Only Access**: Tools are only accessible through toolsets, not individually
|
||||
|
||||
### Available Toolsets
|
||||
|
||||
See `toolsets.py` for the complete list of predefined toolsets including:
|
||||
- Basic toolsets (web, terminal, vision, creative, reasoning)
|
||||
- Composite toolsets (research, development, analysis, etc.)
|
||||
- Scenario-specific toolsets (debugging, documentation, API testing, etc.)
|
||||
- Special toolsets (safe mode without terminal, minimal, offline)
|
||||
|
||||
### Using Toolsets
|
||||
Tools are organized into logical **toolsets**:
|
||||
|
||||
```bash
|
||||
# Use a predefined toolset
|
||||
python run_agent.py --enabled_toolsets=research --query "Find latest AI papers"
|
||||
# Use specific toolsets
|
||||
hermes --toolsets "web,terminal"
|
||||
|
||||
# Combine multiple toolsets
|
||||
python run_agent.py --enabled_toolsets=web,vision --query "Analyze this website"
|
||||
|
||||
# Enable all toolsets explicitly (same as omitting the flag)
|
||||
python run_agent.py --enabled_toolsets=all --query "Do web research and run commands if helpful"
|
||||
|
||||
# Safe mode (no terminal access)
|
||||
python run_agent.py --enabled_toolsets=safe --query "Help without running commands"
|
||||
|
||||
# List all available toolsets and tools
|
||||
python run_agent.py --list_tools
|
||||
# List all toolsets
|
||||
hermes --list-tools
|
||||
```
|
||||
|
||||
For detailed documentation on toolsets, see `TOOLSETS_README.md`.
|
||||
**Available toolsets:** `web`, `terminal`, `browser`, `vision`, `creative`, `reasoning`, `skills`, `cronjob`, and more.
|
||||
|
||||
## Basic Usage
|
||||
### 🖥️ Terminal Backend
|
||||
|
||||
### Default (all tools enabled)
|
||||
The terminal tool can execute commands in different environments:
|
||||
|
||||
| Backend | Description | Use Case |
|
||||
|---------|-------------|----------|
|
||||
| `local` | Run on your machine (default) | Development, trusted tasks |
|
||||
| `docker` | Isolated containers | Security, reproducibility |
|
||||
| `ssh` | Remote server | Sandboxing, keep agent away from its own code |
|
||||
| `singularity` | HPC containers | Cluster computing, rootless |
|
||||
| `modal` | Cloud execution | Serverless, scale |
|
||||
|
||||
**Configure in `~/.hermes/config.yaml`:**
|
||||
```yaml
|
||||
terminal:
|
||||
backend: local # or: docker, ssh, singularity, modal
|
||||
cwd: "." # Working directory ("." = current dir)
|
||||
timeout: 180 # Command timeout in seconds
|
||||
```
|
||||
|
||||
**Docker Backend:**
|
||||
```yaml
|
||||
terminal:
|
||||
backend: docker
|
||||
docker_image: python:3.11-slim
|
||||
```
|
||||
|
||||
**SSH Backend** (recommended for security - agent can't modify its own code):
|
||||
```yaml
|
||||
terminal:
|
||||
backend: ssh
|
||||
```
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--query "search up the latest docs on jit in python 3.13 and write me basic example that's not in their docs. profile its perf" \
|
||||
--max_turns 20 \
|
||||
--model claude-sonnet-4-20250514 \
|
||||
--base_url https://api.anthropic.com/v1/ \
|
||||
--api_key $ANTHROPIC_API_KEY
|
||||
# Set credentials in ~/.hermes/.env
|
||||
TERMINAL_SSH_HOST=my-server.example.com
|
||||
TERMINAL_SSH_USER=myuser
|
||||
TERMINAL_SSH_KEY=~/.ssh/id_rsa
|
||||
```
|
||||
|
||||
### With specific toolset
|
||||
**Singularity/Apptainer** (for HPC clusters):
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--query "Debug this Python error" \
|
||||
--enabled_toolsets=debugging \
|
||||
--model claude-sonnet-4-20250514 \
|
||||
--api_key $ANTHROPIC_API_KEY
|
||||
# Pre-build SIF for parallel workers
|
||||
apptainer build ~/python.sif docker://python:3.11-slim
|
||||
|
||||
# Configure
|
||||
hermes config set terminal.backend singularity
|
||||
hermes config set terminal.singularity_image ~/python.sif
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Use a specific toolset
|
||||
agent = AIAgent(
|
||||
model="claude-opus-4-20250514",
|
||||
enabled_toolsets=["research"]
|
||||
)
|
||||
response = agent.chat("Find information about quantum computing")
|
||||
|
||||
# Create custom toolset at runtime
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
create_custom_toolset(
|
||||
name="my_tools",
|
||||
description="My custom toolkit",
|
||||
tools=["web_search"],
|
||||
includes=["terminal", "vision"]
|
||||
)
|
||||
|
||||
agent = AIAgent(enabled_toolsets=["my_tools"])
|
||||
**Modal** (serverless cloud):
|
||||
```bash
|
||||
pip install modal boto3
|
||||
modal setup # Authenticate
|
||||
hermes config set terminal.backend modal
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
**Sudo Support:** If a command needs sudo, you'll be prompted for your password (cached for the session). Or set `SUDO_PASSWORD` in `~/.hermes/.env`.
|
||||
|
||||
Process multiple prompts from a dataset in parallel with automatic checkpointing and statistics tracking:
|
||||
### 📱 Messaging Gateway
|
||||
|
||||
Chat with Hermes from Telegram, Discord, or WhatsApp.
|
||||
|
||||
#### Telegram Setup
|
||||
|
||||
1. **Create a bot:** Message [@BotFather](https://t.me/BotFather) on Telegram, use `/newbot`
|
||||
2. **Get your user ID:** Message [@userinfobot](https://t.me/userinfobot) - it replies with your numeric ID
|
||||
3. **Configure:**
|
||||
|
||||
```bash
|
||||
# Basic batch processing
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=my_run
|
||||
|
||||
# With specific distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=image_run \
|
||||
--distribution=image_gen \
|
||||
--num_workers=4
|
||||
# Add to ~/.hermes/.env:
|
||||
TELEGRAM_BOT_TOKEN=123456:ABC-DEF...
|
||||
TELEGRAM_ALLOWED_USERS=YOUR_USER_ID # Comma-separated for multiple users
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Parallel processing with configurable workers
|
||||
- Toolset distributions for varied data generation
|
||||
- Automatic checkpointing and resume capability
|
||||
- Combined output in `data/<run_name>/trajectories.jsonl`
|
||||
- Tool usage statistics and success rates
|
||||
4. **Start the gateway:**
|
||||
|
||||
**Quick Start:** See [QUICKSTART_BATCH.md](QUICKSTART_BATCH.md) for a 5-minute getting started guide.
|
||||
**Full Documentation:** See [BATCH_PROCESSING.md](BATCH_PROCESSING.md) for comprehensive documentation.
|
||||
```bash
|
||||
hermes gateway # Run in foreground
|
||||
hermes gateway install # Install as systemd service (Linux)
|
||||
hermes gateway start # Start the service
|
||||
```
|
||||
|
||||
### Ephemeral System Prompts
|
||||
#### Discord Setup
|
||||
|
||||
The ephemeral system prompt feature allows you to guide the model's behavior during batch processing **without** saving that prompt to the training dataset trajectories. This is useful for:
|
||||
1. **Create a bot:** Go to [Discord Developer Portal](https://discord.com/developers/applications)
|
||||
2. **Get your user ID:** Enable Developer Mode in Discord settings, right-click your name → Copy ID
|
||||
3. **Configure:**
|
||||
|
||||
- Guiding model behavior during data collection
|
||||
- Adding task-specific instructions
|
||||
- Keeping saved trajectories clean and focused on tool-calling format
|
||||
```bash
|
||||
# Add to ~/.hermes/.env:
|
||||
DISCORD_BOT_TOKEN=MTIz...
|
||||
DISCORD_ALLOWED_USERS=YOUR_USER_ID
|
||||
```
|
||||
|
||||
#### Security (Important!)
|
||||
|
||||
**Without an allowlist, anyone who finds your bot can use it!**
|
||||
|
||||
```bash
|
||||
# Restrict to specific users (recommended):
|
||||
TELEGRAM_ALLOWED_USERS=123456789,987654321
|
||||
DISCORD_ALLOWED_USERS=123456789012345678
|
||||
|
||||
# Or allow all users in a specific platform:
|
||||
# (Leave the variable unset - NOT recommended for bots with terminal access)
|
||||
```
|
||||
|
||||
#### Gateway Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/new` or `/reset` | Start fresh conversation |
|
||||
| `/status` | Show session info |
|
||||
|
||||
#### Working Directory
|
||||
|
||||
- **CLI (`hermes`)**: Uses current directory where you run the command
|
||||
- **Messaging**: Uses `MESSAGING_CWD` (default: home directory `~`)
|
||||
|
||||
```bash
|
||||
# Set custom messaging working directory in ~/.hermes/.env
|
||||
MESSAGING_CWD=/home/myuser/projects
|
||||
```
|
||||
|
||||
#### Tool Progress Notifications
|
||||
|
||||
Get real-time updates as the agent works:
|
||||
|
||||
```bash
|
||||
# Enable in ~/.hermes/.env
|
||||
HERMES_TOOL_PROGRESS=true
|
||||
HERMES_TOOL_PROGRESS_MODE=new # or "all" for every tool call
|
||||
```
|
||||
|
||||
When enabled, you'll see messages like:
|
||||
```
|
||||
💻 `ls -la`...
|
||||
🔍 web_search...
|
||||
📄 web_extract...
|
||||
```
|
||||
|
||||
See [docs/messaging.md](docs/messaging.md) for WhatsApp and advanced setup.
|
||||
|
||||
### 🤖 RL Training (Tinker + Atropos)
|
||||
|
||||
Train language models with reinforcement learning using the Tinker API and Atropos framework.
|
||||
|
||||
#### Requirements
|
||||
|
||||
1. **API Keys:** Add to `~/.hermes/.env`:
|
||||
```bash
|
||||
TINKER_API_KEY=your-tinker-key # Get from https://tinker-console.thinkingmachines.ai/keys
|
||||
WANDB_API_KEY=your-wandb-key # Get from https://wandb.ai/authorize
|
||||
OPENROUTER_API_KEY=your-key # Optional: for rl_test_inference
|
||||
```
|
||||
|
||||
2. **That's it!** tinker-atropos is included as a submodule - no separate installation needed.
|
||||
|
||||
#### Using RL Tools
|
||||
|
||||
The agent can now use RL training tools:
|
||||
|
||||
```
|
||||
You: Start training on GSM8k with group_size=16
|
||||
|
||||
Agent: I'll set up an RL training run on the GSM8k environment...
|
||||
[Uses rl_list_environments, rl_select_environment, rl_edit_config, rl_start_training]
|
||||
```
|
||||
|
||||
#### Available RL Tools
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `rl_list_environments` | List available RL environments |
|
||||
| `rl_select_environment` | Select an environment for training |
|
||||
| `rl_get_current_config` | View all configurable options |
|
||||
| `rl_edit_config` | Change a configuration value |
|
||||
| `rl_test_inference` | Test environment with OpenRouter (pre-training validation) |
|
||||
| `rl_start_training` | Start a training run |
|
||||
| `rl_check_status` | Check training progress |
|
||||
| `rl_stop_training` | Stop a running training |
|
||||
| `rl_get_results` | Fetch WandB metrics |
|
||||
| `rl_list_runs` | List active training runs |
|
||||
|
||||
#### Dedicated RL CLI
|
||||
|
||||
For extended RL workflows with longer timeouts:
|
||||
|
||||
```bash
|
||||
python rl_cli.py --model "anthropic/claude-sonnet-4-20250514"
|
||||
```
|
||||
|
||||
### ⏰ Scheduled Tasks (Cron)
|
||||
|
||||
Schedule tasks to run automatically:
|
||||
|
||||
```bash
|
||||
# In the CLI
|
||||
/cron add 30m "Remind me to check the build"
|
||||
/cron add "every 2h" "Check server status"
|
||||
/cron add "0 9 * * *" "Morning briefing"
|
||||
/cron list
|
||||
/cron remove <job_id>
|
||||
```
|
||||
|
||||
The agent can also self-schedule using `schedule_cronjob` tool.
|
||||
|
||||
**Run the scheduler:**
|
||||
```bash
|
||||
hermes cron daemon # Built-in daemon
|
||||
# Or add to system cron for reliability
|
||||
```
|
||||
|
||||
### 🗜️ Context Compression
|
||||
|
||||
Long conversations are automatically summarized when approaching context limits:
|
||||
|
||||
```yaml
|
||||
# In ~/.hermes/config.yaml
|
||||
compression:
|
||||
enabled: true
|
||||
threshold: 0.85 # Compress at 85% of limit
|
||||
```
|
||||
|
||||
### 📝 Session Logging
|
||||
|
||||
Every conversation is logged to `~/.hermes-agent/logs/` for debugging:
|
||||
|
||||
```
|
||||
logs/
|
||||
├── session_20260201_143052_a1b2c3.json
|
||||
└── ...
|
||||
```
|
||||
|
||||
### 🌐 Browser Automation
|
||||
|
||||
Browser tools let the agent navigate websites, fill forms, click buttons, and extract content using [Browserbase](https://browserbase.com/).
|
||||
|
||||
**Setup:**
|
||||
```bash
|
||||
# 1. Get credentials from browserbase.com
|
||||
hermes config set BROWSERBASE_API_KEY your_api_key
|
||||
hermes config set BROWSERBASE_PROJECT_ID your_project_id
|
||||
|
||||
# 2. Install Node.js dependencies (if not already)
|
||||
cd ~/.hermes-agent && npm install
|
||||
```
|
||||
|
||||
**Available tools:** `browser_navigate`, `browser_snapshot`, `browser_click`, `browser_type`, `browser_scroll`, `browser_back`, `browser_press`, `browser_close`, `browser_get_images`
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=10 \
|
||||
--run_name=my_run \
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
hermes --toolsets browser -q "Go to amazon.com and find the price of the latest Kindle"
|
||||
```
|
||||
|
||||
The ephemeral prompt will influence the model's behavior during execution, but **only the standard tool-calling system prompt** will be saved in the trajectory files.
|
||||
### 📚 Skills System
|
||||
|
||||
**Documentation:** See [docs/ephemeral_system_prompt.md](docs/ephemeral_system_prompt.md) for complete details.
|
||||
Skills are on-demand knowledge documents the agent can load when needed. They follow a **progressive disclosure** pattern to minimize token usage.
|
||||
|
||||
## Command Line Arguments
|
||||
**Using Skills:**
|
||||
```bash
|
||||
hermes --toolsets skills -q "What skills do you have?"
|
||||
hermes --toolsets skills -q "Show me the axolotl skill"
|
||||
```
|
||||
|
||||
**Single Agent (`run_agent.py`):**
|
||||
- `--query`: The question or task for the agent
|
||||
- `--model`: Model to use (default: claude-opus-4-20250514)
|
||||
- `--api_key`: API key for authentication
|
||||
- `--base_url`: API endpoint URL
|
||||
- `--max_turns`: Maximum number of tool-calling iterations
|
||||
- `--enabled_toolsets`: Comma-separated list of toolsets to enable. Use `all` (or `*`) to enable everything. If omitted, all toolsets are enabled by default.
|
||||
- `--disabled_toolsets`: Comma-separated list of toolsets to disable
|
||||
- `--list_tools`: List all available toolsets and tools
|
||||
- `--save_trajectories`: Save conversation trajectories to JSONL files
|
||||
**Creating Skills:**
|
||||
|
||||
**Batch Processing (`batch_runner.py`):**
|
||||
- `--dataset_file`: Path to JSONL file with prompts
|
||||
- `--batch_size`: Number of prompts per batch
|
||||
- `--run_name`: Name for this run (for output/checkpointing)
|
||||
- `--distribution`: Toolset distribution to use (default: "default")
|
||||
- `--num_workers`: Number of parallel workers (default: 4)
|
||||
- `--resume`: Resume from checkpoint if interrupted
|
||||
- `--ephemeral_system_prompt`: System prompt used during execution but NOT saved to trajectories
|
||||
- `--list_distributions`: List available toolset distributions
|
||||
Create `skills/category/skill-name/SKILL.md`:
|
||||
```markdown
|
||||
---
|
||||
name: my-skill
|
||||
description: Brief description shown in skills_list
|
||||
tags: [python, automation]
|
||||
version: 1.0.0
|
||||
---
|
||||
|
||||
## Environment Variables
|
||||
# Skill Content
|
||||
|
||||
All environment variables can be configured in the `.env` file (copy from `.env.example`).
|
||||
Instructions, examples, and guidelines here...
|
||||
```
|
||||
|
||||
**Core API Keys:**
|
||||
- `ANTHROPIC_API_KEY`: Main agent model
|
||||
- `FIRECRAWL_API_KEY`: Web tools (search, extract, crawl)
|
||||
- `NOUS_API_KEY`: Vision and reasoning tools
|
||||
- `MORPH_API_KEY`: Terminal tools
|
||||
- `FAL_KEY`: Image generation tools
|
||||
- `OPENAI_API_KEY`: Optional, for some Hecate features
|
||||
**Skill Structure:**
|
||||
```
|
||||
skills/
|
||||
├── mlops/
|
||||
│ ├── axolotl/
|
||||
│ │ ├── SKILL.md # Main instructions (required)
|
||||
│ │ ├── references/ # Additional docs
|
||||
│ │ └── templates/ # Output formats
|
||||
│ └── vllm/
|
||||
│ └── SKILL.md
|
||||
```
|
||||
|
||||
**Configuration Options:**
|
||||
- `HECATE_VM_LIFETIME_SECONDS`: VM lifetime (default: 300)
|
||||
- `HECATE_DEFAULT_SNAPSHOT_ID`: Default snapshot (default: snapshot_p5294qxt)
|
||||
- `WEB_TOOLS_DEBUG`, `VISION_TOOLS_DEBUG`, `MOA_TOOLS_DEBUG`, `IMAGE_TOOLS_DEBUG`: Enable debug logging
|
||||
---
|
||||
|
||||
## Documentation
|
||||
## Manual Installation
|
||||
|
||||
**Single Agent Usage:**
|
||||
- `TOOLSETS_README.md`: Comprehensive guide to the toolsets system
|
||||
- `toolsets.py`: View and modify available toolsets
|
||||
- `model_tools.py`: Core tool definitions and handlers
|
||||
If you prefer not to use the installer:
|
||||
|
||||
**Batch Processing:**
|
||||
- `QUICKSTART_BATCH.md`: 5-minute quick start guide
|
||||
- `BATCH_PROCESSING.md`: Complete batch processing documentation
|
||||
- `toolset_distributions.py`: Toolset distributions for data generation
|
||||
```bash
|
||||
# Clone the repository (with submodules)
|
||||
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
|
||||
## Examples
|
||||
# Run setup script
|
||||
./setup-hermes.sh
|
||||
|
||||
See `TOOLSETS_README.md` for extensive examples of using different toolsets for various scenarios.
|
||||
# Or manually:
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -e ".[all]"
|
||||
|
||||
# Install submodules (required for terminal and RL tools)
|
||||
pip install -e "./mini-swe-agent" # Terminal tool backend
|
||||
pip install -e "./tinker-atropos" # RL training backend
|
||||
|
||||
hermes setup
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Batch Processing
|
||||
|
||||
Process multiple prompts in parallel with automatic checkpointing:
|
||||
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=my_run \
|
||||
--num_workers=4 \
|
||||
--distribution=default
|
||||
```
|
||||
|
||||
**Key Options:**
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--dataset_file` | JSONL file with prompts |
|
||||
| `--batch_size` | Prompts per batch |
|
||||
| `--run_name` | Name for output/checkpoints |
|
||||
| `--num_workers` | Parallel workers (default: 4) |
|
||||
| `--distribution` | Toolset distribution |
|
||||
| `--resume` | Resume from checkpoint |
|
||||
| `--ephemeral_system_prompt` | Guide behavior without saving to trajectories |
|
||||
| `--list_distributions` | Show available distributions |
|
||||
|
||||
**Output:** `data/<run_name>/trajectories.jsonl`
|
||||
|
||||
### Trajectory Compression
|
||||
|
||||
Compress trajectories to fit token budgets for training:
|
||||
|
||||
```bash
|
||||
# Compress a directory
|
||||
python trajectory_compressor.py --input=data/my_run
|
||||
|
||||
# Compress with sampling
|
||||
python trajectory_compressor.py --input=data/my_run --sample_percent=15
|
||||
|
||||
# Custom token target
|
||||
python trajectory_compressor.py --input=data/my_run --target_max_tokens=16000
|
||||
```
|
||||
|
||||
Features:
|
||||
- Protects first/last turns
|
||||
- Summarizes middle turns via LLM
|
||||
- Configurable via `configs/trajectory_compression.yaml`
|
||||
|
||||
---
|
||||
|
||||
## Python API
|
||||
|
||||
```python
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4",
|
||||
enabled_toolsets=["web", "terminal"]
|
||||
)
|
||||
|
||||
result = agent.run_conversation("Search for the latest Python news")
|
||||
print(result["final_response"])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Environment Variables Reference
|
||||
|
||||
All variables go in `~/.hermes/.env`. Run `hermes config set VAR value` to set them.
|
||||
|
||||
**LLM Providers:**
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `OPENROUTER_API_KEY` | OpenRouter API key (recommended) |
|
||||
| `ANTHROPIC_API_KEY` | Direct Anthropic access |
|
||||
| `OPENAI_API_KEY` | Direct OpenAI access |
|
||||
|
||||
**Tool APIs:**
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `FIRECRAWL_API_KEY` | Web scraping (firecrawl.dev) |
|
||||
| `BROWSERBASE_API_KEY` | Browser automation |
|
||||
| `BROWSERBASE_PROJECT_ID` | Browserbase project |
|
||||
| `FAL_KEY` | Image generation (fal.ai) |
|
||||
|
||||
**Terminal Backend:**
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `TERMINAL_ENV` | Backend: `local`, `docker`, `ssh`, `singularity`, `modal` |
|
||||
| `TERMINAL_DOCKER_IMAGE` | Docker image (default: `python:3.11-slim`) |
|
||||
| `TERMINAL_SINGULARITY_IMAGE` | Singularity image or `.sif` path |
|
||||
| `TERMINAL_TIMEOUT` | Command timeout in seconds |
|
||||
| `TERMINAL_CWD` | Working directory |
|
||||
| `SUDO_PASSWORD` | Enable sudo (stored plaintext - be careful!) |
|
||||
|
||||
**SSH Backend:**
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `TERMINAL_SSH_HOST` | Remote server hostname |
|
||||
| `TERMINAL_SSH_USER` | SSH username |
|
||||
| `TERMINAL_SSH_PORT` | SSH port (default: 22) |
|
||||
| `TERMINAL_SSH_KEY` | Path to private key |
|
||||
|
||||
**Messaging:**
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `TELEGRAM_BOT_TOKEN` | Telegram bot token (@BotFather) |
|
||||
| `TELEGRAM_ALLOWED_USERS` | Comma-separated user IDs allowed to use bot |
|
||||
| `TELEGRAM_HOME_CHANNEL` | Default channel for cron delivery |
|
||||
| `DISCORD_BOT_TOKEN` | Discord bot token |
|
||||
| `DISCORD_ALLOWED_USERS` | Comma-separated user IDs allowed to use bot |
|
||||
| `DISCORD_HOME_CHANNEL` | Default channel for cron delivery |
|
||||
| `MESSAGING_CWD` | Working directory for terminal in messaging (default: ~) |
|
||||
|
||||
**Agent Behavior:**
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `HERMES_MAX_ITERATIONS` | Max tool-calling iterations per conversation (default: 60) |
|
||||
| `HERMES_TOOL_PROGRESS` | Send progress messages when using tools (`true`/`false`) |
|
||||
| `HERMES_TOOL_PROGRESS_MODE` | `new` (only when tool changes) or `all` (every call) |
|
||||
|
||||
**Context Compression:**
|
||||
| Variable | Description |
|
||||
|----------|-------------|
|
||||
| `CONTEXT_COMPRESSION_ENABLED` | Enable auto-compression (default: true) |
|
||||
| `CONTEXT_COMPRESSION_THRESHOLD` | Trigger at this % of limit (default: 0.85) |
|
||||
| `CONTEXT_COMPRESSION_MODEL` | Model for summaries |
|
||||
|
||||
---
|
||||
|
||||
## File Structure
|
||||
|
||||
| Path | Description |
|
||||
|------|-------------|
|
||||
| `~/.hermes/config.yaml` | Your settings |
|
||||
| `~/.hermes/.env` | API keys and secrets |
|
||||
| `~/.hermes/cron/` | Scheduled jobs data |
|
||||
| `~/.hermes/sessions/` | Gateway session data |
|
||||
| `~/.hermes-agent/` | Installation directory |
|
||||
| `~/.hermes-agent/logs/` | Session logs |
|
||||
| `hermes_cli/` | CLI implementation |
|
||||
| `tools/` | Tool implementations |
|
||||
| `skills/` | Knowledge documents |
|
||||
| `gateway/` | Messaging platform adapters |
|
||||
| `cron/` | Scheduler implementation |
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
```bash
|
||||
hermes doctor # Run diagnostics
|
||||
hermes status # Check configuration
|
||||
hermes config # View current settings
|
||||
```
|
||||
|
||||
Common issues:
|
||||
- **"API key not set"**: Run `hermes setup` or `hermes config set OPENROUTER_API_KEY your_key`
|
||||
- **"hermes: command not found"**: Reload your shell (`source ~/.bashrc`) or check PATH
|
||||
- **Gateway won't start**: Check `hermes gateway status` and logs
|
||||
- **Missing config after update**: Run `hermes config check` to see what's new, then `hermes config migrate` to add missing options
|
||||
|
||||
---
|
||||
|
||||
## Contributing
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Make your changes
|
||||
4. Submit a pull request
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see [LICENSE](LICENSE) for details.
|
||||
|
||||
589
TODO.md
Normal file
589
TODO.md
Normal file
@@ -0,0 +1,589 @@
|
||||
# Hermes Agent - Future Improvements
|
||||
|
||||
> Ideas for enhancing the agent's capabilities, generated from self-analysis of the codebase.
|
||||
|
||||
---
|
||||
|
||||
## 1. Subagent Architecture (Context Isolation) 🎯
|
||||
|
||||
**Problem:** Long-running tools (terminal commands, browser automation, complex file operations) consume massive context. A single `ls -la` can add hundreds of lines. Browser snapshots, debugging sessions, and iterative terminal work quickly bloat the main conversation, leaving less room for actual reasoning.
|
||||
|
||||
**Solution:** The main agent becomes an **orchestrator** that delegates context-heavy tasks to **subagents**.
|
||||
|
||||
**Architecture:**
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ ORCHESTRATOR (main agent) │
|
||||
│ - Receives user request │
|
||||
│ - Plans approach │
|
||||
│ - Delegates heavy tasks to subagents │
|
||||
│ - Receives summarized results │
|
||||
│ - Maintains clean, focused context │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ TERMINAL AGENT │ │ BROWSER AGENT │ │ CODE AGENT │
|
||||
│ - terminal tool │ │ - browser tools │ │ - file tools │
|
||||
│ - file tools │ │ - web_search │ │ - terminal │
|
||||
│ │ │ - web_extract │ │ │
|
||||
│ Isolated context│ │ Isolated context│ │ Isolated context│
|
||||
│ Returns summary │ │ Returns summary │ │ Returns summary │
|
||||
└─────────────────┘ └─────────────────┘ └─────────────────┘
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. User asks: "Set up a new Python project with FastAPI and tests"
|
||||
2. Orchestrator plans: "I need to create files, install deps, write code"
|
||||
3. Orchestrator calls: `terminal_task(goal="Create venv, install fastapi pytest", context="New project in ~/myapp")`
|
||||
4. **Subagent spawns** with fresh context, only terminal/file tools
|
||||
5. Subagent iterates (may take 10+ tool calls, lots of output)
|
||||
6. Subagent completes → returns summary: "Created venv, installed fastapi==0.109.0, pytest==8.0.0"
|
||||
7. Orchestrator receives **only the summary**, context stays clean
|
||||
8. Orchestrator continues with next subtask
|
||||
|
||||
**Key tools to implement:**
|
||||
- [ ] `terminal_task(goal, context, cwd?)` - Delegate terminal/shell work
|
||||
- [ ] `browser_task(goal, context, start_url?)` - Delegate web research/automation
|
||||
- [ ] `code_task(goal, context, files?)` - Delegate code writing/modification
|
||||
- [ ] Generic `delegate_task(goal, context, toolsets=[])` - Flexible delegation
|
||||
|
||||
**Implementation details:**
|
||||
- [ ] Subagent uses same `run_agent.py` but with:
|
||||
- Fresh/empty conversation history
|
||||
- Limited toolset (only what's needed)
|
||||
- Smaller max_iterations (focused task)
|
||||
- Task-specific system prompt
|
||||
- [ ] Subagent returns structured result:
|
||||
```python
|
||||
{
|
||||
"success": True,
|
||||
"summary": "Installed 3 packages, created 2 files",
|
||||
"details": "Optional longer explanation if needed",
|
||||
"artifacts": ["~/myapp/requirements.txt", "~/myapp/main.py"], # Files created
|
||||
"errors": [] # Any issues encountered
|
||||
}
|
||||
```
|
||||
- [ ] Orchestrator sees only the summary in its context
|
||||
- [ ] Full subagent transcript saved separately for debugging
|
||||
|
||||
**Benefits:**
|
||||
- 🧹 **Clean context** - Orchestrator stays focused, doesn't drown in tool output
|
||||
- 📊 **Better token efficiency** - 50 terminal outputs → 1 summary paragraph
|
||||
- 🎯 **Focused subagents** - Each agent has just the tools it needs
|
||||
- 🔄 **Parallel potential** - Independent subtasks could run concurrently
|
||||
- 🐛 **Easier debugging** - Each subtask has its own isolated transcript
|
||||
|
||||
**When to use subagents vs direct tools:**
|
||||
- **Subagent**: Multi-step tasks, iteration likely, lots of output expected
|
||||
- **Direct**: Quick one-off commands, simple file reads, user needs to see output
|
||||
|
||||
**Files to modify:** `run_agent.py` (add orchestration mode), new `tools/delegate_tools.py`, new `subagent_runner.py`
|
||||
|
||||
---
|
||||
|
||||
## 2. Planning & Task Management 📋
|
||||
|
||||
**Problem:** Agent handles tasks reactively without explicit planning. Complex multi-step tasks lack structure, progress tracking, and the ability to decompose work into manageable chunks.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Task decomposition tool** - Break complex requests into subtasks:
|
||||
```
|
||||
User: "Set up a new Python project with FastAPI, tests, and Docker"
|
||||
|
||||
Agent creates plan:
|
||||
├── 1. Create project structure and requirements.txt
|
||||
├── 2. Implement FastAPI app skeleton
|
||||
├── 3. Add pytest configuration and initial tests
|
||||
├── 4. Create Dockerfile and docker-compose.yml
|
||||
└── 5. Verify everything works together
|
||||
```
|
||||
- Each subtask becomes a trackable unit
|
||||
- Agent can report progress: "Completed 3/5 tasks"
|
||||
|
||||
- [ ] **Progress checkpoints** - Periodic self-assessment:
|
||||
- After N tool calls or time elapsed, pause to evaluate
|
||||
- "What have I accomplished? What remains? Am I on track?"
|
||||
- Detect if stuck in loops or making no progress
|
||||
- Could trigger replanning if approach isn't working
|
||||
|
||||
- [ ] **Explicit plan storage** - Persist plan in conversation:
|
||||
- Store as structured data (not just in context)
|
||||
- Update status as tasks complete
|
||||
- User can ask "What's the plan?" or "What's left?"
|
||||
- Survives context compression (plans are protected)
|
||||
|
||||
- [ ] **Failure recovery with replanning** - When things go wrong:
|
||||
- Record what failed and why
|
||||
- Revise plan to work around the issue
|
||||
- "Step 3 failed because X, adjusting approach to Y"
|
||||
- Prevents repeating failed strategies
|
||||
|
||||
**Files to modify:** `run_agent.py` (add planning hooks), new `tools/planning_tool.py`
|
||||
|
||||
---
|
||||
|
||||
## 3. Dynamic Skills Expansion 📚
|
||||
|
||||
**Problem:** Skills system is elegant but static. Skills must be manually created and added.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Skill acquisition from successful tasks** - After completing a complex task:
|
||||
- "This approach worked well. Save as a skill?"
|
||||
- Extract: goal, steps taken, tools used, key decisions
|
||||
- Generate SKILL.md automatically
|
||||
- Store in user's skills directory
|
||||
|
||||
- [ ] **Skill templates** - Common patterns that can be parameterized:
|
||||
```markdown
|
||||
# Debug {language} Error
|
||||
1. Reproduce the error
|
||||
2. Search for error message: `web_search("{error_message} {language}")`
|
||||
3. Check common causes: {common_causes}
|
||||
4. Apply fix and verify
|
||||
```
|
||||
|
||||
- [ ] **Skill chaining** - Combine skills for complex workflows:
|
||||
- Skills can reference other skills as dependencies
|
||||
- "To do X, first apply skill Y, then skill Z"
|
||||
- Directed graph of skill dependencies
|
||||
|
||||
**Files to modify:** `tools/skills_tool.py`, `skills/` directory structure, new `skill_generator.py`
|
||||
|
||||
---
|
||||
|
||||
## 4. Interactive Clarifying Questions Tool ❓
|
||||
|
||||
**Problem:** Agent sometimes makes assumptions or guesses when it should ask the user. Currently can only ask via text, which gets lost in long outputs.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Multiple-choice prompt tool** - Let agent present structured choices to user:
|
||||
```
|
||||
ask_user_choice(
|
||||
question="Should the language switcher enable only German or all languages?",
|
||||
choices=[
|
||||
"Only enable German - works immediately",
|
||||
"Enable all, mark untranslated - show fallback notice",
|
||||
"Let me specify something else"
|
||||
]
|
||||
)
|
||||
```
|
||||
- Renders as interactive terminal UI with arrow key / Tab navigation
|
||||
- User selects option, result returned to agent
|
||||
- Up to 4 choices + optional free-text option
|
||||
|
||||
- [ ] **Implementation:**
|
||||
- Use `inquirer` or `questionary` Python library for rich terminal prompts
|
||||
- Tool returns selected option text (or user's custom input)
|
||||
- **CLI-only** - only works when running via `cli.py` (not API/programmatic use)
|
||||
- Graceful fallback: if not in interactive mode, return error asking agent to rephrase as text
|
||||
|
||||
- [ ] **Use cases:**
|
||||
- Clarify ambiguous requirements before starting work
|
||||
- Confirm destructive operations with clear options
|
||||
- Let user choose between implementation approaches
|
||||
- Checkpoint complex multi-step workflows
|
||||
|
||||
**Files to modify:** New `tools/ask_user_tool.py`, `cli.py` (detect interactive mode), `model_tools.py`
|
||||
|
||||
---
|
||||
|
||||
## 5. Collaborative Problem Solving 🤝
|
||||
|
||||
**Problem:** Interaction is command/response. Complex problems benefit from dialogue.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Assumption surfacing** - Make implicit assumptions explicit:
|
||||
- "I'm assuming you want Python 3.11+. Correct?"
|
||||
- "This solution assumes you have sudo access..."
|
||||
- Let user correct before going down wrong path
|
||||
|
||||
- [ ] **Checkpoint & confirm** - For high-stakes operations:
|
||||
- "About to delete 47 files. Here's the list - proceed?"
|
||||
- "This will modify your database. Want a backup first?"
|
||||
- Configurable threshold for when to ask
|
||||
|
||||
**Files to modify:** `run_agent.py`, system prompt configuration
|
||||
|
||||
---
|
||||
|
||||
## 6. Project-Local Context 💾
|
||||
|
||||
**Problem:** Valuable context lost between sessions.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Project awareness** - Remember project-specific context:
|
||||
- Store `.hermes/context.md` in project directory
|
||||
- "This is a Django project using PostgreSQL"
|
||||
- Coding style preferences, deployment setup, etc.
|
||||
- Load automatically when working in that directory
|
||||
|
||||
- [ ] **Handoff notes** - Leave notes for future sessions:
|
||||
- Write to `.hermes/notes.md` in project
|
||||
- "TODO for next session: finish implementing X"
|
||||
- "Known issues: Y doesn't work on Windows"
|
||||
|
||||
**Files to modify:** New `project_context.py`, auto-load in `run_agent.py`
|
||||
|
||||
## 6. Tools & Skills Wishlist 🧰
|
||||
|
||||
*Things that would need new tool implementations (can't do well with current tools):*
|
||||
|
||||
### High-Impact
|
||||
|
||||
- [ ] **Audio/Video Transcription** 🎬 *(See also: Section 16 for detailed spec)*
|
||||
- Transcribe audio files, podcasts, YouTube videos
|
||||
- Extract key moments from video
|
||||
- Voice memo transcription for messaging integrations
|
||||
- *Provider options: Whisper API, Deepgram, local Whisper*
|
||||
|
||||
- [ ] **Diagram Rendering** 📊
|
||||
- Render Mermaid/PlantUML to actual images
|
||||
- Can generate the code, but rendering requires external service or tool
|
||||
- "Show me how these components connect" → actual visual diagram
|
||||
|
||||
### Medium-Impact
|
||||
|
||||
- [ ] **Canvas / Visual Workspace** 🖼️
|
||||
- Agent-controlled visual panel for rendering interactive UI
|
||||
- Inspired by OpenClaw's Canvas feature
|
||||
- **Capabilities:**
|
||||
- `present` / `hide` - Show/hide the canvas panel
|
||||
- `navigate` - Load HTML files or URLs into the canvas
|
||||
- `eval` - Execute JavaScript in the canvas context
|
||||
- `snapshot` - Capture the rendered UI as an image
|
||||
- **Use cases:**
|
||||
- Display generated HTML/CSS/JS previews
|
||||
- Show interactive data visualizations (charts, graphs)
|
||||
- Render diagrams (Mermaid → rendered output)
|
||||
- Present structured information in rich format
|
||||
- A2UI-style component system for structured agent UI
|
||||
- **Implementation options:**
|
||||
- Electron-based panel for CLI
|
||||
- WebSocket-connected web app
|
||||
- VS Code webview extension
|
||||
- *Would let agent "show" things rather than just describe them*
|
||||
|
||||
- [ ] **Document Generation** 📄
|
||||
- Create styled PDFs, Word docs, presentations
|
||||
- *Can do basic PDF via terminal tools, but limited*
|
||||
|
||||
- [ ] **Diff/Patch Tool** 📝
|
||||
- Surgical code modifications with preview
|
||||
- "Change line 45-50 to X" without rewriting whole file
|
||||
- Show diffs before applying
|
||||
- *Can use `diff`/`patch` but a native tool would be safer*
|
||||
|
||||
### Skills to Create
|
||||
|
||||
- [ ] **Domain-specific skill packs:**
|
||||
- DevOps/Infrastructure (Terraform, K8s, AWS)
|
||||
- Data Science workflows (EDA, model training)
|
||||
- Security/pentesting procedures
|
||||
|
||||
- [ ] **Framework-specific skills:**
|
||||
- React/Vue/Angular patterns
|
||||
- Django/Rails/Express conventions
|
||||
- Database optimization playbooks
|
||||
|
||||
- [ ] **Troubleshooting flowcharts:**
|
||||
- "Docker container won't start" → decision tree
|
||||
- "Production is slow" → systematic diagnosis
|
||||
|
||||
---
|
||||
|
||||
## 7. Messaging Platform Integrations 💬 ✅ COMPLETE
|
||||
|
||||
**Problem:** Agent currently only works via `cli.py` which requires direct terminal access. Users may want to interact via messaging apps from their phone or other devices.
|
||||
|
||||
**Architecture:**
|
||||
- `run_agent.py` already accepts `conversation_history` parameter and returns updated messages ✅
|
||||
- Need: persistent session storage, platform monitors, session key resolution
|
||||
|
||||
**Implementation approach:**
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Platform Monitor (e.g., telegram_monitor.py) │
|
||||
│ ├─ Long-running daemon connecting to messaging platform │
|
||||
│ ├─ On message: resolve session key → load history from disk│
|
||||
│ ├─ Call run_agent.py with loaded history │
|
||||
│ ├─ Save updated history back to disk (JSONL) │
|
||||
│ └─ Send response back to platform │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
**Platform support (each user sets up their own credentials):**
|
||||
- [x] **Telegram** - via `python-telegram-bot`
|
||||
- Bot token from @BotFather
|
||||
- Easiest to set up, good for personal use
|
||||
- [x] **Discord** - via `discord.py`
|
||||
- Bot token from Discord Developer Portal
|
||||
- Can work in servers (group sessions) or DMs
|
||||
- [x] **WhatsApp** - via Node.js bridge (whatsapp-web.js/baileys)
|
||||
- Requires Node.js bridge setup
|
||||
- More complex, but reaches most people
|
||||
|
||||
**Session management:**
|
||||
- [x] **Session store** - JSONL persistence per session key
|
||||
- `~/.hermes/sessions/{session_id}.jsonl`
|
||||
- Session keys: `agent:main:telegram:dm`, `agent:main:discord:group:123`, etc.
|
||||
- [x] **Session expiry** - Configurable reset policies
|
||||
- Daily reset (default 4am) OR idle timeout (default 2 hours)
|
||||
- Manual reset via `/reset` or `/new` command in chat
|
||||
- Per-platform and per-type overrides
|
||||
- [x] **Session continuity** - Conversations persist across messages until reset
|
||||
|
||||
**Files created:** `gateway/`, `gateway/platforms/`, `gateway/config.py`, `gateway/session.py`, `gateway/delivery.py`, `gateway/run.py`
|
||||
|
||||
**Configuration:**
|
||||
- Environment variables: `TELEGRAM_BOT_TOKEN`, `DISCORD_BOT_TOKEN`, etc.
|
||||
- Config file: `~/.hermes/gateway.json`
|
||||
- CLI commands: `/platforms` to check status, `--gateway` to start
|
||||
|
||||
**Dynamic context injection:**
|
||||
- Agent knows its source platform and chat
|
||||
- Agent knows connected platforms and home channels
|
||||
- Agent can deliver cron outputs to specific platforms
|
||||
|
||||
---
|
||||
|
||||
## 8. Text-to-Speech (TTS) 🔊
|
||||
|
||||
**Problem:** Agent can only respond with text. Some users prefer audio responses (accessibility, hands-free use, podcasts).
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **TTS tool** - Generate audio files from text
|
||||
```python
|
||||
tts_generate(text="Here's your summary...", voice="nova", output="summary.mp3")
|
||||
```
|
||||
- Returns path to generated audio file
|
||||
- For messaging integrations: can send as voice message
|
||||
|
||||
- [ ] **Provider options:**
|
||||
- Edge TTS (free, good quality, many voices)
|
||||
- OpenAI TTS (paid, excellent quality)
|
||||
- ElevenLabs (paid, best quality, voice cloning)
|
||||
- Local options (Coqui TTS, Bark)
|
||||
|
||||
- [ ] **Modes:**
|
||||
- On-demand: User explicitly asks "read this to me"
|
||||
- Auto-TTS: Configurable to always generate audio for responses
|
||||
- Long-text handling: Summarize or chunk very long responses
|
||||
|
||||
- [ ] **Integration with messaging:**
|
||||
- When enabled, can send voice notes instead of/alongside text
|
||||
- User preference per channel
|
||||
|
||||
**Files to create:** `tools/tts_tool.py`, config in `cli-config.yaml`
|
||||
|
||||
---
|
||||
|
||||
## 13. Speech-to-Text / Audio Transcription 🎤
|
||||
|
||||
**Problem:** Users may want to send voice memos instead of typing. Agent is blind to audio content.
|
||||
|
||||
**Ideas:**
|
||||
- [ ] **Voice memo transcription** - For messaging integrations
|
||||
- User sends voice message → transcribe → process as text
|
||||
- Seamless: user speaks, agent responds
|
||||
|
||||
- [ ] **Audio/video file transcription** - Existing idea, expanded:
|
||||
- Transcribe local audio files (mp3, wav, m4a)
|
||||
- Transcribe YouTube videos (download audio → transcribe)
|
||||
- Extract key moments with timestamps
|
||||
|
||||
- [ ] **Provider options:**
|
||||
- OpenAI Whisper API (good quality, cheap)
|
||||
- Deepgram (fast, good for real-time)
|
||||
- Local Whisper (free, runs on GPU)
|
||||
- Groq Whisper (fast, free tier available)
|
||||
|
||||
- [ ] **Tool interface:**
|
||||
```python
|
||||
transcribe(source="audio.mp3") # Local file
|
||||
transcribe(source="https://youtube.com/...") # YouTube
|
||||
transcribe(source="voice_message", data=bytes) # Voice memo
|
||||
```
|
||||
|
||||
**Files to create:** `tools/transcribe_tool.py`, integrate with messaging monitors
|
||||
|
||||
### Plugin/Extension System 🔌
|
||||
|
||||
**Concept:** Allow users to add custom tools/skills without modifying core code.
|
||||
|
||||
**Why interesting:**
|
||||
- Community contributions
|
||||
- Organization-specific tools
|
||||
- Clean separation of core vs. extensions
|
||||
|
||||
**Open questions:**
|
||||
- Security implications of loading arbitrary code
|
||||
- Versioning and compatibility
|
||||
- Discovery and installation UX
|
||||
|
||||
---
|
||||
|
||||
## Recently Completed ✅
|
||||
|
||||
### Dangerous Command Approval System
|
||||
**Implemented:** Dangerous command detection and approval for terminal tool.
|
||||
|
||||
**Features:**
|
||||
- Pattern-based detection of dangerous commands (rm -rf, DROP TABLE, chmod 777, etc.)
|
||||
- CLI prompt with options: `[o]nce | [s]ession | [a]lways | [d]eny`
|
||||
- Session caching (approved patterns don't re-prompt)
|
||||
- Permanent allowlist in `~/.hermes/config.yaml`
|
||||
- Force flag for agent to bypass after user confirmation
|
||||
- Skip check for isolated backends (Docker, Singularity, Modal)
|
||||
- Helpful sudo failure messages for messaging platforms
|
||||
|
||||
**Files:** `tools/terminal_tool.py`, `model_tools.py`, `hermes_cli/config.py`
|
||||
|
||||
---
|
||||
|
||||
## 14. Learning Machine / Dynamic Memory System 🧠
|
||||
|
||||
*Inspired by [Dash](~/agent-codebases/dash) - a self-learning data agent.*
|
||||
|
||||
**Problem:** Agent starts fresh every session. Valuable learnings from debugging, error patterns, successful approaches, and user preferences are lost.
|
||||
|
||||
**Dash's Key Insight:** Separate **Knowledge** (static, curated) from **Learnings** (dynamic, discovered):
|
||||
|
||||
| System | What It Stores | How It Evolves |
|
||||
|--------|---------------|----------------|
|
||||
| **Knowledge** (Skills) | Validated approaches, templates, best practices | Curated by user |
|
||||
| **Learnings** | Error patterns, gotchas, discovered fixes | Managed automatically |
|
||||
|
||||
**Tools to implement:**
|
||||
- [ ] `save_learning(topic, learning, context?)` - Record a discovered pattern
|
||||
```python
|
||||
save_learning(
|
||||
topic="python-ssl",
|
||||
learning="On Ubuntu 22.04, SSL certificate errors often fixed by: apt install ca-certificates",
|
||||
context="Debugging requests SSL failure"
|
||||
)
|
||||
```
|
||||
- [ ] `search_learnings(query)` - Find relevant past learnings
|
||||
```python
|
||||
search_learnings("SSL certificate error Python")
|
||||
# Returns: "On Ubuntu 22.04, SSL certificate errors often fixed by..."
|
||||
```
|
||||
|
||||
**User Profile & Memory:**
|
||||
- [ ] `user_profile` - Structured facts about user preferences
|
||||
```yaml
|
||||
# ~/.hermes/user_profile.yaml
|
||||
coding_style:
|
||||
python_formatter: black
|
||||
type_hints: always
|
||||
test_framework: pytest
|
||||
preferences:
|
||||
verbosity: detailed
|
||||
confirm_destructive: true
|
||||
environment:
|
||||
os: linux
|
||||
shell: bash
|
||||
default_python: 3.11
|
||||
```
|
||||
- [ ] `user_memory` - Unstructured observations the agent learns
|
||||
```yaml
|
||||
# ~/.hermes/user_memory.yaml
|
||||
- "User prefers tabs over spaces despite black's defaults"
|
||||
- "User's main project is ~/work/myapp - a Django app"
|
||||
- "User often works late - don't ask about timezone"
|
||||
```
|
||||
|
||||
**When to learn:**
|
||||
- After fixing an error that took multiple attempts
|
||||
- When user corrects the agent's approach
|
||||
- When a workaround is discovered for a tool limitation
|
||||
- When user expresses a preference
|
||||
|
||||
**Storage:** Vector database (ChromaDB) or simple YAML with embedding search.
|
||||
|
||||
**Files to create:** `tools/learning_tools.py`, `learning/store.py`, `~/.hermes/learnings/`
|
||||
|
||||
---
|
||||
|
||||
## 15. Layered Context Architecture 📊
|
||||
|
||||
*Inspired by Dash's "Six Layers of Context" - grounding responses in multiple sources.*
|
||||
|
||||
**Problem:** Context sources are ad-hoc. No clear hierarchy or strategy for what context to include when.
|
||||
|
||||
**Proposed Layers for Hermes:**
|
||||
|
||||
| Layer | Source | When Loaded | Example |
|
||||
|-------|--------|-------------|---------|
|
||||
| 1. **Project Context** | `.hermes/context.md` | Auto on cwd | "This is a FastAPI project using PostgreSQL" |
|
||||
| 2. **Skills** | `skills/*.md` | On request | "How to set up React project" |
|
||||
| 3. **User Profile** | `~/.hermes/user_profile.yaml` | Always | "User prefers pytest, uses black" |
|
||||
| 4. **Learnings** | `~/.hermes/learnings/` | Semantic search | "SSL fix for Ubuntu" |
|
||||
| 5. **External Knowledge** | Web search, docs | On demand | Current API docs, Stack Overflow |
|
||||
| 6. **Runtime Introspection** | Tool calls | Real-time | File contents, terminal output |
|
||||
|
||||
**Benefits:**
|
||||
- Clear mental model for what context is available
|
||||
- Prioritization: local > learned > external
|
||||
- Debugging: "Why did agent do X?" → check which layers contributed
|
||||
|
||||
**Files to modify:** `run_agent.py` (context loading), new `context/layers.py`
|
||||
|
||||
---
|
||||
|
||||
## 16. Evaluation System with LLM Grading 📏
|
||||
|
||||
*Inspired by Dash's evaluation framework.*
|
||||
|
||||
**Problem:** `batch_runner.py` runs test cases but lacks quality assessment.
|
||||
|
||||
**Dash's Approach:**
|
||||
- **String matching** (default) - Check if expected strings appear
|
||||
- **LLM grader** (-g flag) - GPT evaluates response quality
|
||||
- **Result comparison** (-r flag) - Compare against golden output
|
||||
|
||||
**Implementation for Hermes:**
|
||||
|
||||
- [ ] **Test case format:**
|
||||
```python
|
||||
TestCase(
|
||||
name="create_python_project",
|
||||
prompt="Create a new Python project with FastAPI and tests",
|
||||
expected_strings=["requirements.txt", "main.py", "test_"], # Basic check
|
||||
golden_actions=["write:main.py", "write:requirements.txt", "terminal:pip install"],
|
||||
grader_criteria="Should create complete project structure with working code"
|
||||
)
|
||||
```
|
||||
|
||||
- [ ] **LLM grader mode:**
|
||||
```python
|
||||
def grade_response(response: str, criteria: str) -> Grade:
|
||||
"""Use GPT to evaluate response quality."""
|
||||
prompt = f"""
|
||||
Evaluate this agent response against the criteria.
|
||||
Criteria: {criteria}
|
||||
Response: {response}
|
||||
|
||||
Score (1-5) and explain why.
|
||||
"""
|
||||
# Returns: Grade(score=4, explanation="Created all files but tests are minimal")
|
||||
```
|
||||
|
||||
- [ ] **Action comparison mode:**
|
||||
- Record tool calls made during test
|
||||
- Compare against expected actions
|
||||
- "Expected terminal call to pip install, got npm install"
|
||||
|
||||
- [ ] **CLI flags:**
|
||||
```bash
|
||||
python batch_runner.py eval test_cases.yaml # String matching
|
||||
python batch_runner.py eval test_cases.yaml -g # + LLM grading
|
||||
python batch_runner.py eval test_cases.yaml -r # + Result comparison
|
||||
python batch_runner.py eval test_cases.yaml -v # Verbose (show responses)
|
||||
```
|
||||
|
||||
**Files to modify:** `batch_runner.py`, new `evals/test_cases.py`, new `evals/grader.py`
|
||||
|
||||
---
|
||||
|
||||
*Last updated: $(date +%Y-%m-%d)* 🤖
|
||||
369
batch_runner.py
369
batch_runner.py
@@ -30,6 +30,8 @@ from datetime import datetime
|
||||
from multiprocessing import Pool, Manager, Lock
|
||||
import traceback
|
||||
|
||||
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
|
||||
from rich.console import Console
|
||||
import fire
|
||||
|
||||
from run_agent import AIAgent
|
||||
@@ -44,6 +46,77 @@ from toolset_distributions import (
|
||||
# Global configuration for worker processes
|
||||
_WORKER_CONFIG = {}
|
||||
|
||||
# All possible tools - used to ensure consistent schema across all trajectory entries
|
||||
# This is required because Arrow/Parquet (used by HuggingFace datasets) needs identical schemas
|
||||
ALL_POSSIBLE_TOOLS = {
|
||||
'terminal', 'web_search', 'web_extract',
|
||||
'vision_analyze', 'image_generate', 'mixture_of_agents',
|
||||
# Skills tools
|
||||
'skills_categories', 'skills_list', 'skill_view',
|
||||
# Browser automation tools
|
||||
'browser_navigate', 'browser_snapshot', 'browser_click',
|
||||
'browser_type', 'browser_scroll', 'browser_back',
|
||||
'browser_press', 'browser_close', 'browser_get_images',
|
||||
'browser_vision'
|
||||
}
|
||||
|
||||
# Default stats for tools that weren't used
|
||||
DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0}
|
||||
|
||||
|
||||
def _normalize_tool_stats(tool_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]:
|
||||
"""
|
||||
Normalize tool_stats to include all possible tools with consistent schema.
|
||||
|
||||
This ensures HuggingFace datasets can load the JSONL without schema mismatch errors.
|
||||
Tools that weren't used get zero counts.
|
||||
|
||||
Args:
|
||||
tool_stats (Dict): Raw tool statistics from extraction
|
||||
|
||||
Returns:
|
||||
Dict: Normalized tool statistics with all tools present
|
||||
"""
|
||||
normalized = {}
|
||||
|
||||
# Add all possible tools with defaults
|
||||
for tool in ALL_POSSIBLE_TOOLS:
|
||||
if tool in tool_stats:
|
||||
normalized[tool] = tool_stats[tool].copy()
|
||||
else:
|
||||
normalized[tool] = DEFAULT_TOOL_STATS.copy()
|
||||
|
||||
# Also include any unexpected tools (in case new tools are added)
|
||||
for tool, stats in tool_stats.items():
|
||||
if tool not in normalized:
|
||||
normalized[tool] = stats.copy()
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_tool_error_counts(tool_error_counts: Dict[str, int]) -> Dict[str, int]:
|
||||
"""
|
||||
Normalize tool_error_counts to include all possible tools.
|
||||
|
||||
Args:
|
||||
tool_error_counts (Dict): Raw error counts mapping
|
||||
|
||||
Returns:
|
||||
Dict: Normalized error counts with all tools present
|
||||
"""
|
||||
normalized = {}
|
||||
|
||||
# Add all possible tools with zero defaults
|
||||
for tool in ALL_POSSIBLE_TOOLS:
|
||||
normalized[tool] = tool_error_counts.get(tool, 0)
|
||||
|
||||
# Also include any unexpected tools
|
||||
for tool, count in tool_error_counts.items():
|
||||
if tool not in normalized:
|
||||
normalized[tool] = count
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
|
||||
"""
|
||||
@@ -154,7 +227,8 @@ def _process_single_prompt(
|
||||
if config.get("verbose"):
|
||||
print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}")
|
||||
|
||||
# Initialize agent with sampled toolsets
|
||||
# Initialize agent with sampled toolsets and log prefix for identification
|
||||
log_prefix = f"[B{batch_num}:P{prompt_index}]"
|
||||
agent = AIAgent(
|
||||
base_url=config.get("base_url"),
|
||||
api_key=config.get("api_key"),
|
||||
@@ -164,7 +238,12 @@ def _process_single_prompt(
|
||||
save_trajectories=False, # We handle saving ourselves
|
||||
verbose_logging=config.get("verbose", False),
|
||||
ephemeral_system_prompt=config.get("ephemeral_system_prompt"),
|
||||
log_prefix_chars=config.get("log_prefix_chars", 100)
|
||||
log_prefix_chars=config.get("log_prefix_chars", 100),
|
||||
log_prefix=log_prefix,
|
||||
providers_allowed=config.get("providers_allowed"),
|
||||
providers_ignored=config.get("providers_ignored"),
|
||||
providers_order=config.get("providers_order"),
|
||||
provider_sort=config.get("provider_sort"),
|
||||
)
|
||||
|
||||
# Run the agent with task_id to ensure each task gets its own isolated VM
|
||||
@@ -186,6 +265,7 @@ def _process_single_prompt(
|
||||
"trajectory": trajectory,
|
||||
"tool_stats": tool_stats,
|
||||
"completed": result["completed"],
|
||||
"partial": result.get("partial", False),
|
||||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": selected_toolsets,
|
||||
"metadata": {
|
||||
@@ -266,13 +346,27 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
|
||||
# Save trajectory if successful
|
||||
if result["success"] and result["trajectory"]:
|
||||
# Get and normalize tool stats for consistent schema across all entries
|
||||
raw_tool_stats = result.get("tool_stats", {})
|
||||
tool_stats = _normalize_tool_stats(raw_tool_stats)
|
||||
|
||||
# Create normalized tool_error_counts mapping tool names to their failure counts
|
||||
raw_error_counts = {
|
||||
tool_name: stats.get("failure", 0)
|
||||
for tool_name, stats in raw_tool_stats.items()
|
||||
}
|
||||
tool_error_counts = _normalize_tool_error_counts(raw_error_counts)
|
||||
|
||||
trajectory_entry = {
|
||||
"prompt_index": prompt_index,
|
||||
"conversations": result["trajectory"],
|
||||
"metadata": result["metadata"],
|
||||
"completed": result["completed"],
|
||||
"partial": result.get("partial", False), # True if stopped due to invalid tool calls
|
||||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": result["toolsets_used"]
|
||||
"toolsets_used": result["toolsets_used"],
|
||||
"tool_stats": tool_stats, # Full stats: {tool: {count, success, failure}} - normalized
|
||||
"tool_error_counts": tool_error_counts # Simple: {tool: failure_count} - normalized
|
||||
}
|
||||
|
||||
# Append to batch output file
|
||||
@@ -292,8 +386,13 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
batch_tool_stats[tool_name]["success"] += stats["success"]
|
||||
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
completed_in_batch.append(prompt_index)
|
||||
print(f" ✅ Prompt {prompt_index} completed")
|
||||
# Only mark as completed if successfully saved (failed prompts can be retried on resume)
|
||||
if result["success"] and result["trajectory"]:
|
||||
completed_in_batch.append(prompt_index)
|
||||
status = "⚠️ partial" if result.get("partial") else "✅"
|
||||
print(f" {status} Prompt {prompt_index} completed")
|
||||
else:
|
||||
print(f" ❌ Prompt {prompt_index} failed (will retry on resume)")
|
||||
|
||||
print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
|
||||
|
||||
@@ -325,6 +424,10 @@ class BatchRunner:
|
||||
verbose: bool = False,
|
||||
ephemeral_system_prompt: str = None,
|
||||
log_prefix_chars: int = 100,
|
||||
providers_allowed: List[str] = None,
|
||||
providers_ignored: List[str] = None,
|
||||
providers_order: List[str] = None,
|
||||
provider_sort: str = None,
|
||||
):
|
||||
"""
|
||||
Initialize the batch runner.
|
||||
@@ -342,6 +445,10 @@ class BatchRunner:
|
||||
verbose (bool): Enable verbose logging
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||
providers_allowed (List[str]): OpenRouter providers to allow (optional)
|
||||
providers_ignored (List[str]): OpenRouter providers to ignore (optional)
|
||||
providers_order (List[str]): OpenRouter providers to try in order (optional)
|
||||
provider_sort (str): Sort providers by price/throughput/latency (optional)
|
||||
"""
|
||||
self.dataset_file = Path(dataset_file)
|
||||
self.batch_size = batch_size
|
||||
@@ -355,6 +462,10 @@ class BatchRunner:
|
||||
self.verbose = verbose
|
||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
self.providers_allowed = providers_allowed
|
||||
self.providers_ignored = providers_ignored
|
||||
self.providers_order = providers_order
|
||||
self.provider_sort = provider_sort
|
||||
|
||||
# Validate distribution
|
||||
if not validate_distribution(distribution):
|
||||
@@ -479,6 +590,83 @@ class BatchRunner:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def _scan_completed_prompts_by_content(self) -> set:
|
||||
"""
|
||||
Scan all batch files and extract completed prompts by their actual content.
|
||||
|
||||
This provides a more robust resume mechanism that matches on prompt text
|
||||
rather than indices, allowing recovery even if indices don't match.
|
||||
|
||||
Returns:
|
||||
set: Set of prompt texts that have been successfully processed
|
||||
"""
|
||||
completed_prompts = set()
|
||||
batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
|
||||
|
||||
if not batch_files:
|
||||
return completed_prompts
|
||||
|
||||
print(f"📂 Scanning {len(batch_files)} batch files for completed prompts...")
|
||||
|
||||
for batch_file in batch_files:
|
||||
try:
|
||||
with open(batch_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
entry = json.loads(line.strip())
|
||||
|
||||
# Skip failed entries - we want to retry these
|
||||
if entry.get("failed", False):
|
||||
continue
|
||||
|
||||
# Extract the human/user prompt from conversations
|
||||
conversations = entry.get("conversations", [])
|
||||
for msg in conversations:
|
||||
if msg.get("from") == "human":
|
||||
prompt_text = msg.get("value", "").strip()
|
||||
if prompt_text:
|
||||
completed_prompts.add(prompt_text)
|
||||
break # Only need the first human message
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Warning: Error reading {batch_file.name}: {e}")
|
||||
|
||||
return completed_prompts
|
||||
|
||||
def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]:
|
||||
"""
|
||||
Filter the dataset to exclude prompts that have already been completed.
|
||||
|
||||
Args:
|
||||
completed_prompts: Set of prompt texts that have been completed
|
||||
|
||||
Returns:
|
||||
Tuple of (filtered_dataset, skipped_indices)
|
||||
"""
|
||||
filtered_dataset = []
|
||||
skipped_indices = []
|
||||
|
||||
for idx, entry in enumerate(self.dataset):
|
||||
# Extract prompt from the dataset entry
|
||||
prompt_text = entry.get("prompt", "").strip()
|
||||
|
||||
# Also check conversations format
|
||||
if not prompt_text:
|
||||
conversations = entry.get("conversations", [])
|
||||
for msg in conversations:
|
||||
role = msg.get("role") or msg.get("from")
|
||||
if role in ("user", "human"):
|
||||
prompt_text = (msg.get("content") or msg.get("value", "")).strip()
|
||||
break
|
||||
|
||||
if prompt_text in completed_prompts:
|
||||
skipped_indices.append(idx)
|
||||
else:
|
||||
# Keep original index for tracking
|
||||
filtered_dataset.append((idx, entry))
|
||||
|
||||
return filtered_dataset, skipped_indices
|
||||
|
||||
def run(self, resume: bool = False):
|
||||
"""
|
||||
@@ -491,17 +679,48 @@ class BatchRunner:
|
||||
print("🚀 Starting Batch Processing")
|
||||
print("=" * 70)
|
||||
|
||||
# Load checkpoint
|
||||
checkpoint_data = self._load_checkpoint() if resume else {
|
||||
# Smart resume: scan batch files by content to find completed prompts
|
||||
completed_prompt_texts = set()
|
||||
if resume:
|
||||
completed_prompt_texts = self._scan_completed_prompts_by_content()
|
||||
if completed_prompt_texts:
|
||||
print(f" Found {len(completed_prompt_texts)} already-completed prompts by content matching")
|
||||
|
||||
# Filter dataset to only include unprocessed prompts
|
||||
if resume and completed_prompt_texts:
|
||||
filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts)
|
||||
|
||||
if not filtered_entries:
|
||||
print("\n✅ All prompts have already been processed!")
|
||||
return
|
||||
|
||||
# Recreate batches from filtered entries (keeping original indices for tracking)
|
||||
batches_to_process = []
|
||||
for i in range(0, len(filtered_entries), self.batch_size):
|
||||
batch = filtered_entries[i:i + self.batch_size]
|
||||
batches_to_process.append(batch)
|
||||
|
||||
self.batches = batches_to_process
|
||||
|
||||
# Print prominent resume summary
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 RESUME SUMMARY")
|
||||
print("=" * 70)
|
||||
print(f" Original dataset size: {len(self.dataset):,} prompts")
|
||||
print(f" Already completed: {len(skipped_indices):,} prompts")
|
||||
print(f" ─────────────────────────────────────────")
|
||||
print(f" 🎯 RESUMING WITH: {len(filtered_entries):,} prompts")
|
||||
print(f" New batches created: {len(batches_to_process)}")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
# Initialize checkpoint data (needed for saving at the end)
|
||||
checkpoint_data = {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
if resume and checkpoint_data.get("completed_prompts"):
|
||||
print(f"📂 Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)")
|
||||
|
||||
# Prepare configuration for workers
|
||||
config = {
|
||||
"distribution": self.distribution,
|
||||
@@ -511,17 +730,23 @@ class BatchRunner:
|
||||
"api_key": self.api_key,
|
||||
"verbose": self.verbose,
|
||||
"ephemeral_system_prompt": self.ephemeral_system_prompt,
|
||||
"log_prefix_chars": self.log_prefix_chars
|
||||
"log_prefix_chars": self.log_prefix_chars,
|
||||
"providers_allowed": self.providers_allowed,
|
||||
"providers_ignored": self.providers_ignored,
|
||||
"providers_order": self.providers_order,
|
||||
"provider_sort": self.provider_sort,
|
||||
}
|
||||
|
||||
# Get completed prompts set
|
||||
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
||||
# For backward compatibility, still track by index (but this is secondary to content matching)
|
||||
completed_prompts_set = set()
|
||||
|
||||
# Aggregate statistics across all batches
|
||||
total_tool_stats = {}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
|
||||
|
||||
# Process batches in parallel
|
||||
with Pool(processes=self.num_workers) as pool:
|
||||
# Create tasks for each batch
|
||||
@@ -536,8 +761,39 @@ class BatchRunner:
|
||||
for batch_num, batch_data in enumerate(self.batches)
|
||||
]
|
||||
|
||||
# Use map to process batches in parallel
|
||||
results = pool.map(_process_batch_worker, tasks)
|
||||
print(f"✅ Created {len(tasks)} batch tasks")
|
||||
print(f"🚀 Starting parallel batch processing...\n")
|
||||
|
||||
# Use rich Progress for better visual tracking with persistent bottom bar
|
||||
# redirect_stdout/stderr lets rich manage all output so progress bar stays clean
|
||||
results = []
|
||||
console = Console(force_terminal=True)
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[bold blue]📦 Batches"),
|
||||
BarColumn(bar_width=40),
|
||||
MofNCompleteColumn(),
|
||||
TextColumn("•"),
|
||||
TimeRemainingColumn(),
|
||||
console=console,
|
||||
refresh_per_second=2,
|
||||
transient=False,
|
||||
redirect_stdout=False,
|
||||
redirect_stderr=False,
|
||||
) as progress:
|
||||
task = progress.add_task("Processing", total=len(tasks))
|
||||
|
||||
# Temporarily suppress DEBUG logging to avoid bar interference
|
||||
root_logger = logging.getLogger()
|
||||
original_level = root_logger.level
|
||||
root_logger.setLevel(logging.WARNING)
|
||||
|
||||
try:
|
||||
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
||||
results.append(result)
|
||||
progress.update(task, advance=1)
|
||||
finally:
|
||||
root_logger.setLevel(original_level)
|
||||
|
||||
# Aggregate all batch statistics and update checkpoint
|
||||
all_completed_prompts = list(completed_prompts_set)
|
||||
@@ -573,19 +829,58 @@ class BatchRunner:
|
||||
stats["success_rate"] = 0.0
|
||||
stats["failure_rate"] = 0.0
|
||||
|
||||
# Combine all batch files into a single trajectories.jsonl file
|
||||
# Combine ALL batch files in directory into a single trajectories.jsonl file
|
||||
# This includes both old batches (from previous runs) and new batches (from resume)
|
||||
# Also filter out corrupted entries (where model generated invalid tool names)
|
||||
combined_file = self.output_dir / "trajectories.jsonl"
|
||||
print(f"\n📦 Combining batch files into {combined_file.name}...")
|
||||
print(f"\n📦 Combining ALL batch files into {combined_file.name}...")
|
||||
|
||||
VALID_TOOLS = {'web_search', 'web_extract', 'terminal', 'vision_analyze',
|
||||
'image_generate', 'mixture_of_agents',
|
||||
# Skills tools
|
||||
'skills_categories', 'skills_list', 'skill_view',
|
||||
# Browser automation tools
|
||||
'browser_navigate', 'browser_snapshot', 'browser_click',
|
||||
'browser_type', 'browser_scroll', 'browser_back',
|
||||
'browser_press', 'browser_close', 'browser_get_images',
|
||||
'browser_vision'}
|
||||
|
||||
total_entries = 0
|
||||
filtered_entries = 0
|
||||
batch_files_found = 0
|
||||
|
||||
# Find ALL batch files in the output directory (handles resume merging old + new)
|
||||
all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
|
||||
|
||||
with open(combined_file, 'w', encoding='utf-8') as outfile:
|
||||
for batch_num in range(len(self.batches)):
|
||||
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
|
||||
if batch_file.exists():
|
||||
with open(batch_file, 'r', encoding='utf-8') as infile:
|
||||
for line in infile:
|
||||
for batch_file in all_batch_files:
|
||||
batch_files_found += 1
|
||||
batch_num = batch_file.stem.split("_")[1] # Extract batch number for logging
|
||||
|
||||
with open(batch_file, 'r', encoding='utf-8') as infile:
|
||||
for line in infile:
|
||||
total_entries += 1
|
||||
try:
|
||||
data = json.loads(line)
|
||||
tool_stats = data.get('tool_stats', {})
|
||||
|
||||
# Check for invalid tool names (model hallucinations)
|
||||
invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS]
|
||||
|
||||
if invalid_tools:
|
||||
filtered_entries += 1
|
||||
invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0]
|
||||
print(f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'")
|
||||
continue
|
||||
|
||||
outfile.write(line)
|
||||
except json.JSONDecodeError:
|
||||
filtered_entries += 1
|
||||
print(f" ⚠️ Filtering invalid JSON entry (batch {batch_num})")
|
||||
|
||||
print(f"✅ Combined {len(self.batches)} batch files into trajectories.jsonl")
|
||||
if filtered_entries > 0:
|
||||
print(f"⚠️ Filtered {filtered_entries} corrupted entries out of {total_entries} total")
|
||||
print(f"✅ Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)")
|
||||
|
||||
# Save final statistics
|
||||
final_stats = {
|
||||
@@ -607,8 +902,9 @@ class BatchRunner:
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 BATCH PROCESSING COMPLETE")
|
||||
print("=" * 70)
|
||||
print(f"✅ Total prompts processed: {len(self.dataset)}")
|
||||
print(f"✅ Total batches: {len(self.batches)}")
|
||||
print(f"✅ Prompts processed this run: {sum(r.get('processed', 0) for r in results)}")
|
||||
print(f"✅ Total trajectories in merged file: {total_entries - filtered_entries}")
|
||||
print(f"✅ Total batch files merged: {batch_files_found}")
|
||||
print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
|
||||
print(f"\n📈 Tool Usage Statistics:")
|
||||
print("-" * 70)
|
||||
@@ -646,9 +942,9 @@ def main(
|
||||
batch_size: int = None,
|
||||
run_name: str = None,
|
||||
distribution: str = "default",
|
||||
model: str = "claude-opus-4-20250514",
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://api.anthropic.com/v1/",
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_turns: int = 10,
|
||||
num_workers: int = 4,
|
||||
resume: bool = False,
|
||||
@@ -656,6 +952,10 @@ def main(
|
||||
list_distributions: bool = False,
|
||||
ephemeral_system_prompt: str = None,
|
||||
log_prefix_chars: int = 100,
|
||||
providers_allowed: str = None,
|
||||
providers_ignored: str = None,
|
||||
providers_order: str = None,
|
||||
provider_sort: str = None,
|
||||
):
|
||||
"""
|
||||
Run batch processing of agent prompts from a dataset.
|
||||
@@ -675,6 +975,10 @@ def main(
|
||||
list_distributions (bool): List available toolset distributions and exit
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||
providers_allowed (str): Comma-separated list of OpenRouter providers to allow (e.g. "anthropic,openai")
|
||||
providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra")
|
||||
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
|
||||
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
|
||||
|
||||
Examples:
|
||||
# Basic usage
|
||||
@@ -722,6 +1026,11 @@ def main(
|
||||
print("❌ Error: --run_name is required")
|
||||
return
|
||||
|
||||
# Parse provider preferences (comma-separated strings to lists)
|
||||
providers_allowed_list = [p.strip() for p in providers_allowed.split(",")] if providers_allowed else None
|
||||
providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None
|
||||
providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
|
||||
|
||||
# Initialize and run batch runner
|
||||
try:
|
||||
runner = BatchRunner(
|
||||
@@ -736,7 +1045,11 @@ def main(
|
||||
num_workers=num_workers,
|
||||
verbose=verbose,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt,
|
||||
log_prefix_chars=log_prefix_chars
|
||||
log_prefix_chars=log_prefix_chars,
|
||||
providers_allowed=providers_allowed_list,
|
||||
providers_ignored=providers_ignored_list,
|
||||
providers_order=providers_order_list,
|
||||
provider_sort=provider_sort,
|
||||
)
|
||||
|
||||
runner.run(resume=resume)
|
||||
|
||||
267
cli-config.yaml.example
Normal file
267
cli-config.yaml.example
Normal file
@@ -0,0 +1,267 @@
|
||||
# Hermes Agent CLI Configuration
|
||||
# Copy this file to cli-config.yaml and customize as needed.
|
||||
# This file configures the CLI behavior. Environment variables in .env take precedence.
|
||||
|
||||
# =============================================================================
|
||||
# Model Configuration
|
||||
# =============================================================================
|
||||
model:
|
||||
# Default model to use (can be overridden with --model flag)
|
||||
default: "anthropic/claude-sonnet-4"
|
||||
|
||||
# API configuration (falls back to OPENROUTER_API_KEY env var)
|
||||
# api_key: "your-key-here" # Uncomment to set here instead of .env
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
|
||||
# =============================================================================
|
||||
# Terminal Tool Configuration
|
||||
# =============================================================================
|
||||
# Choose ONE of the following terminal configurations by uncommenting it.
|
||||
# The terminal tool executes commands in the specified environment.
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 1: Local execution (default)
|
||||
# Commands run directly on your machine in the current directory
|
||||
# -----------------------------------------------------------------------------
|
||||
# Working directory behavior:
|
||||
# - CLI (`hermes` command): Uses "." (current directory where you run hermes)
|
||||
# - Messaging (Telegram/Discord): Uses MESSAGING_CWD from .env (default: home)
|
||||
terminal:
|
||||
env_type: "local"
|
||||
cwd: "." # CLI working directory - "." means current directory
|
||||
timeout: 180
|
||||
lifetime_seconds: 300
|
||||
# sudo_password: "" # Enable sudo commands (pipes via sudo -S) - SECURITY WARNING: plaintext!
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 2: SSH remote execution
|
||||
# Commands run on a remote server - agent code stays local (sandboxed)
|
||||
# Great for: keeping agent isolated from its own code, using powerful remote hardware
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# env_type: "ssh"
|
||||
# cwd: "/home/myuser/project"
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# ssh_host: "my-server.example.com"
|
||||
# ssh_user: "myuser"
|
||||
# ssh_port: 22
|
||||
# ssh_key: "~/.ssh/id_rsa" # Optional - uses ssh-agent if not specified
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 3: Docker container
|
||||
# Commands run in an isolated Docker container
|
||||
# Great for: reproducible environments, testing, isolation
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# env_type: "docker"
|
||||
# cwd: "/workspace"
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# docker_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Singularity/Apptainer container
|
||||
# Commands run in a Singularity container (common in HPC environments)
|
||||
# Great for: HPC clusters, shared compute environments
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# env_type: "singularity"
|
||||
# cwd: "/workspace"
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# singularity_image: "docker://nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 5: Modal cloud execution
|
||||
# Commands run on Modal's cloud infrastructure
|
||||
# Great for: GPU access, scalable compute, serverless execution
|
||||
# -----------------------------------------------------------------------------
|
||||
# terminal:
|
||||
# env_type: "modal"
|
||||
# cwd: "/workspace"
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# modal_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SUDO SUPPORT (works with ALL backends above)
|
||||
# -----------------------------------------------------------------------------
|
||||
# Add sudo_password to any terminal config above to enable sudo commands.
|
||||
# The password is piped via `sudo -S`. Works with local, ssh, docker, etc.
|
||||
#
|
||||
# SECURITY WARNING: Password stored in plaintext!
|
||||
#
|
||||
# INTERACTIVE PROMPT: If no sudo_password is set and the CLI is running,
|
||||
# you'll be prompted to enter your password when sudo is needed:
|
||||
# - 45-second timeout (auto-skips if no input)
|
||||
# - Press Enter to skip (command fails gracefully)
|
||||
# - Password is hidden while typing
|
||||
# - Password is cached for the session
|
||||
#
|
||||
# ALTERNATIVES:
|
||||
# - SSH backend: Configure passwordless sudo on the remote server
|
||||
# - Containers: Run as root inside the container (no sudo needed)
|
||||
# - Local: Configure /etc/sudoers for specific commands
|
||||
#
|
||||
# Example (add to your terminal section):
|
||||
# sudo_password: "your-password-here"
|
||||
|
||||
# =============================================================================
|
||||
# Browser Tool Configuration
|
||||
# =============================================================================
|
||||
browser:
|
||||
# Inactivity timeout in seconds - browser sessions are automatically closed
|
||||
# after this period of no activity between agent loops (default: 120 = 2 minutes)
|
||||
inactivity_timeout: 120
|
||||
|
||||
# =============================================================================
|
||||
# Context Compression (Auto-shrinks long conversations)
|
||||
# =============================================================================
|
||||
# When conversation approaches model's context limit, middle turns are
|
||||
# automatically summarized to free up space while preserving important context.
|
||||
#
|
||||
# HOW IT WORKS:
|
||||
# 1. Tracks actual token usage from API responses (not estimates)
|
||||
# 2. When prompt_tokens >= threshold% of model's context_length, triggers compression
|
||||
# 3. Protects first 3 turns (system prompt, initial request, first response)
|
||||
# 4. Protects last 4 turns (recent context is most relevant)
|
||||
# 5. Summarizes middle turns using a fast/cheap model
|
||||
# 6. Inserts summary as a user message, continues conversation seamlessly
|
||||
#
|
||||
compression:
|
||||
# Enable automatic context compression (default: true)
|
||||
# Set to false if you prefer to manage context manually or want errors on overflow
|
||||
enabled: true
|
||||
|
||||
# Trigger compression at this % of model's context limit (default: 0.85 = 85%)
|
||||
# Lower values = more aggressive compression, higher values = compress later
|
||||
threshold: 0.85
|
||||
|
||||
# Model to use for generating summaries (fast/cheap recommended)
|
||||
# This model compresses the middle turns into a concise summary
|
||||
summary_model: "google/gemini-2.0-flash-001"
|
||||
|
||||
# =============================================================================
|
||||
# Agent Behavior
|
||||
# =============================================================================
|
||||
agent:
|
||||
# Maximum tool-calling iterations per conversation
|
||||
# Higher = more room for complex tasks, but costs more tokens
|
||||
# Recommended: 20-30 for focused tasks, 50-100 for open exploration
|
||||
max_turns: 60
|
||||
|
||||
# Enable verbose logging
|
||||
verbose: false
|
||||
|
||||
# Custom system prompt (personality, instructions, etc.)
|
||||
# Leave empty or remove to use default agent behavior
|
||||
system_prompt: ""
|
||||
|
||||
# Predefined personalities (use with /personality command)
|
||||
personalities:
|
||||
helpful: "You are a helpful, friendly AI assistant."
|
||||
concise: "You are a concise assistant. Keep responses brief and to the point."
|
||||
technical: "You are a technical expert. Provide detailed, accurate technical information."
|
||||
creative: "You are a creative assistant. Think outside the box and offer innovative solutions."
|
||||
teacher: "You are a patient teacher. Explain concepts clearly with examples."
|
||||
kawaii: "You are a kawaii assistant! Use cute expressions like (◕‿◕), ★, ♪, and ~! Add sparkles and be super enthusiastic about everything! Every response should feel warm and adorable desu~! ヽ(>∀<☆)ノ"
|
||||
catgirl: "You are Neko-chan, an anime catgirl AI assistant, nya~! Add 'nya' and cat-like expressions to your speech. Use kaomoji like (=^・ω・^=) and ฅ^•ﻌ•^ฅ. Be playful and curious like a cat, nya~!"
|
||||
pirate: "Arrr! Ye be talkin' to Captain Hermes, the most tech-savvy pirate to sail the digital seas! Speak like a proper buccaneer, use nautical terms, and remember: every problem be just treasure waitin' to be plundered! Yo ho ho!"
|
||||
shakespeare: "Hark! Thou speakest with an assistant most versed in the bardic arts. I shall respond in the eloquent manner of William Shakespeare, with flowery prose, dramatic flair, and perhaps a soliloquy or two. What light through yonder terminal breaks?"
|
||||
surfer: "Duuude! You're chatting with the chillest AI on the web, bro! Everything's gonna be totally rad. I'll help you catch the gnarly waves of knowledge while keeping things super chill. Cowabunga! 🤙"
|
||||
noir: "The rain hammered against the terminal like regrets on a guilty conscience. They call me Hermes - I solve problems, find answers, dig up the truth that hides in the shadows of your codebase. In this city of silicon and secrets, everyone's got something to hide. What's your story, pal?"
|
||||
uwu: "hewwo! i'm your fwiendwy assistant uwu~ i wiww twy my best to hewp you! *nuzzles your code* OwO what's this? wet me take a wook! i pwomise to be vewy hewpful >w<"
|
||||
philosopher: "Greetings, seeker of wisdom. I am an assistant who contemplates the deeper meaning behind every query. Let us examine not just the 'how' but the 'why' of your questions. Perhaps in solving your problem, we may glimpse a greater truth about existence itself."
|
||||
hype: "YOOO LET'S GOOOO!!! 🔥🔥🔥 I am SO PUMPED to help you today! Every question is AMAZING and we're gonna CRUSH IT together! This is gonna be LEGENDARY! ARE YOU READY?! LET'S DO THIS! 💪😤🚀"
|
||||
|
||||
# =============================================================================
|
||||
# Toolsets
|
||||
# =============================================================================
|
||||
# Control which tools the agent has access to.
|
||||
# Use "all" to enable everything, or specify individual toolsets.
|
||||
|
||||
# Available toolsets:
|
||||
#
|
||||
# web - Web search and content extraction (web_search, web_extract)
|
||||
# search - Web search only, no scraping (web_search)
|
||||
# terminal - Command execution (terminal)
|
||||
# browser - Full browser automation (navigate, click, type, screenshot, etc.)
|
||||
# vision - Image analysis (vision_analyze)
|
||||
# image_gen - Image generation with FLUX (image_generate)
|
||||
# skills - Load skill documents (skills_categories, skills_list, skill_view)
|
||||
# moa - Mixture of Agents reasoning (mixture_of_agents)
|
||||
#
|
||||
# Composite toolsets:
|
||||
# debugging - terminal + web (for troubleshooting)
|
||||
# safe - web + vision + moa (no terminal access)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 1: Enable all tools (default)
|
||||
# -----------------------------------------------------------------------------
|
||||
toolsets:
|
||||
- all
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 2: Minimal - just web search and terminal
|
||||
# Great for: Simple coding tasks, quick lookups
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - web
|
||||
# - terminal
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 3: Research mode - no execution capabilities
|
||||
# Great for: Safe information gathering, research tasks
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - web
|
||||
# - vision
|
||||
# - skills
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Full automation - browser + terminal
|
||||
# Great for: Web scraping, automation tasks, testing
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - terminal
|
||||
# - browser
|
||||
# - web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 5: Creative mode - vision + image generation
|
||||
# Great for: Design work, image analysis, creative tasks
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - vision
|
||||
# - image_gen
|
||||
# - web
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 6: Safe mode - no terminal or browser
|
||||
# Great for: Restricted environments, untrusted queries
|
||||
# -----------------------------------------------------------------------------
|
||||
# toolsets:
|
||||
# - safe
|
||||
|
||||
# =============================================================================
|
||||
# Session Logging
|
||||
# =============================================================================
|
||||
# Session trajectories are automatically saved to logs/ directory.
|
||||
# Each session creates: logs/session_YYYYMMDD_HHMMSS_UUID.json
|
||||
#
|
||||
# The session ID is displayed in the welcome banner for easy reference.
|
||||
# Logs contain full conversation history in trajectory format:
|
||||
# - System prompt, user messages, assistant responses
|
||||
# - Tool calls with inputs/outputs
|
||||
# - Timestamps for debugging
|
||||
#
|
||||
# No configuration needed - logging is always enabled.
|
||||
# To disable, you would need to modify the source code.
|
||||
|
||||
# =============================================================================
|
||||
# Display
|
||||
# =============================================================================
|
||||
display:
|
||||
# Use compact banner mode
|
||||
compact: false
|
||||
42
configs/run_browser_tasks.sh
Executable file
42
configs/run_browser_tasks.sh
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Browser-focused data generation run
|
||||
# Uses browser-use-tasks.jsonl (6504 tasks)
|
||||
# Distribution: browser 97%, web 20%, vision 12%, terminal 15%
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/browser_tasks_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
echo "🌐 Running browser-focused tasks with browser_tasks distribution"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="browser-use-tasks.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="browser_tasks" \
|
||||
--distribution="browser_tasks" \
|
||||
--model="moonshotai/kimi-k2.5" \
|
||||
--verbose \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--num_workers=50 \
|
||||
--max_turns=60 \
|
||||
--resume \
|
||||
--ephemeral_system_prompt="You are an AI assistant with browser automation capabilities. Your primary task is to navigate and interact with web pages to accomplish user goals.
|
||||
|
||||
IMPORTANT GUIDELINES:
|
||||
|
||||
1. SEARCHING: Do NOT try to search directly on Google or other search engines via the browser - they block automated searches. Instead, ALWAYS use the web_search tool first to find URLs for any pages you need to visit, then use browser tools to navigate to those URLs.
|
||||
|
||||
2. COOKIE/PRIVACY DIALOGS: After navigating to a page, ALWAYS check if there are cookie consent dialogs, privacy popups, or overlay modals blocking the page. These appear in snapshots as 'dialog' elements with buttons like 'Close', 'Accept', 'Accept All', 'Decline', 'I Agree', 'Got it', 'OK', or 'X'. You MUST dismiss these dialogs FIRST by clicking the appropriate button before trying to interact with other page elements. After dismissing a dialog, take a fresh browser_snapshot to get updated element references.
|
||||
|
||||
3. HANDLING TIMEOUTS: If an action times out, it often means the element is blocked by an overlay or the page state has changed. Take a new snapshot to see the current page state and look for any dialogs or popups that need to be dismissed. If there is no dialog box to bypass, then try a new method or report the error to the user and complete the task.
|
||||
|
||||
4. GENERAL: Use browser tools to click elements, fill forms, extract information, and perform web-based tasks. If terminal is available, use it for any local file operations or computations needed to support your web tasks. Be thorough in verifying your actions and handle any errors gracefully by retrying or trying alternative approaches." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
|
||||
# --providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
26
configs/run_datagen_glm4.7-imagen.sh
Executable file
26
configs/run_datagen_glm4.7-imagen.sh
Executable file
@@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate a timestamp for the log file
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
LOG_FILE="logs/imagen_eval_gpt5_${TIMESTAMP}.log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/hermes-agent-imagen-data/hermes_agent_imagen_train_sft.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="imagen_train_sft_glm4.7" \
|
||||
--distribution="image_gen" \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=50 \
|
||||
--max_turns=25 \
|
||||
--ephemeral_system_prompt="When generating an image for the user view the image by using the vision_analyze tool to ensure it is what the user wanted. If it isn't feel free to retry a few times. If none are perfect, choose the best option that is the closest match, and explain its imperfections. If the image generation tool fails, try again a few times. If the vision analyze tool fails, provide the image to the user and explain it is your best effort attempt." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
# --verbose \
|
||||
26
configs/run_datagen_glm4.7.sh
Executable file
26
configs/run_datagen_glm4.7.sh
Executable file
@@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/glm4.7-thinking-sft1_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/hermes-agent-agent-tasks-1/agent_tasks_sft_2.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="megascience_glm4.7-thinking-sft2" \
|
||||
--distribution="science" \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=15 \
|
||||
--max_turns=60 \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12, so you can maintain focused context." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
|
||||
# --verbose \
|
||||
27
configs/run_datagen_glm4.7_megascience.sh
Executable file
27
configs/run_datagen_glm4.7_megascience.sh
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/glm4.7-thinking-sft1-10k_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/hermes-agent-megascience-data/hermes_agent_megascience_sft_train_1_10k.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="megascience_glm4.7-thinking-sft1" \
|
||||
--distribution="science" \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=50 \
|
||||
--max_turns=60 \
|
||||
--resume \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used for furthering results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12, so you can maintain a focused context." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
|
||||
# --verbose \
|
||||
28
configs/run_datagen_glm4.7_raw_tasks.sh
Executable file
28
configs/run_datagen_glm4.7_raw_tasks.sh
Executable file
@@ -0,0 +1,28 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/glm4.7-terminal-tasks_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/raw_tasks_prompts.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="terminal-tasks-glm4.7-thinking" \
|
||||
--distribution="default" \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=50 \
|
||||
--max_turns=60 \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you complete coding, system administration, and general computing tasks. You can use them in sequence and build off of the results of prior tools you've used. Always use the terminal tool to execute commands, write code, install packages, and verify your work. You should test and validate everything you create. Always pip install any packages you need (use --break-system-packages if needed). If you need a tool that isn't available, you can use the terminal to install or create it. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Use web search when you need to look up documentation, APIs, or current best practices." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
|
||||
# --verbose \
|
||||
# --resume \
|
||||
|
||||
12
configs/run_datagen_minimax-3.1.sh
Executable file
12
configs/run_datagen_minimax-3.1.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/hermes-agent-agent-tasks-1/agent_tasks_eval.jsonl" \
|
||||
--batch_size=50 \
|
||||
--run_name="megascience_sft_minimax-m2.1-thinking-2-eval" \
|
||||
--distribution="science" \
|
||||
--model="minimax/minimax-m2.1" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="minimax" \
|
||||
--num_workers=1 \
|
||||
--max_turns=40 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use the terminal or search tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run. If you need to use a tool that isn't available, you can use the terminal tool to install or create it in many cases as well. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Search for at least 3 sources, but not more than 12."
|
||||
29
configs/run_eval_glm4.7_newterm.sh
Executable file
29
configs/run_eval_glm4.7_newterm.sh
Executable file
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/glm4.7-terminal-tasks-newterm_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="source-data/hermes-agent-agent-tasks-1/agent_tasks_eval.jsonl" \
|
||||
--batch_size=1 \
|
||||
--run_name="terminal-tasks-test-newterm" \
|
||||
--distribution="terminal_only" \
|
||||
--verbose \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=5 \
|
||||
--max_turns=60 \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you complete coding, system administration, and general computing tasks. You can use them in sequence and build off of the results of prior tools you've used. Always use the terminal tool to execute commands, write code, install packages, and verify your work. You should test and validate everything you create. Always pip install any packages you need (use --break-system-packages if needed). If you need a tool that isn't available, you can use the terminal to install or create it. Do not use the terminal tool to communicate with the user, as they cannot see your commands, only your final response after completing the task. Use web search when you need to look up documentation, APIs, or current best practices." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
|
||||
# --verbose \
|
||||
# --resume \
|
||||
|
||||
33
configs/run_eval_terminal.sh
Executable file
33
configs/run_eval_terminal.sh
Executable file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Terminal-only evaluation run using Modal sandboxes
|
||||
# Uses 10 sample tasks from nous-terminal-tasks
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/terminal_eval_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
echo "🔧 Using Modal sandboxes (TERMINAL_ENV=modal)"
|
||||
|
||||
# Set terminal to use Modal
|
||||
export TERMINAL_ENV=modal
|
||||
export TERMINAL_MODAL_IMAGE=nikolaik/python-nodejs:python3.11-nodejs20
|
||||
export TERMINAL_TIMEOUT=300
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="nous-terminal-tasks_eval.jsonl" \
|
||||
--batch_size=5 \
|
||||
--run_name="terminal_eval" \
|
||||
--distribution="terminal_only" \
|
||||
--model="z-ai/glm-4.7" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
|
||||
--num_workers=2 \
|
||||
--max_turns=30 \
|
||||
--ephemeral_system_prompt="You have access to a terminal tool for executing commands. Use it to complete the task. Install any packages you need with apt-get or pip (use --break-system-packages if needed). Do not use interactive tools (vim, nano, python repl). If git output is large, pipe to cat." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
46
configs/run_mixed_tasks.sh
Executable file
46
configs/run_mixed_tasks.sh
Executable file
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Mixed browser+terminal data generation run
|
||||
# Uses mixed-browser-terminal-tasks.jsonl (200 tasks)
|
||||
# Distribution: browser 92%, terminal 92%, web 35%, vision 15%, image_gen 15%
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/mixed_tasks_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
echo "🔀 Running mixed browser+terminal tasks with mixed_tasks distribution"
|
||||
|
||||
# Set terminal environment
|
||||
# SIF images are automatically built/cached by terminal_tool.py
|
||||
export TERMINAL_ENV=singularity
|
||||
export TERMINAL_SINGULARITY_IMAGE="docker://nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
export TERMINAL_TIMEOUT=300
|
||||
|
||||
# Set up Apptainer cache directories (use /scratch if available, otherwise /tmp)
|
||||
if [ -d "/scratch" ] && [ -w "/scratch" ]; then
|
||||
CACHE_BASE="/scratch/$USER/.apptainer"
|
||||
else
|
||||
CACHE_BASE="/tmp/$USER/.apptainer"
|
||||
fi
|
||||
export APPTAINER_CACHEDIR="$CACHE_BASE"
|
||||
export APPTAINER_TMPDIR="$CACHE_BASE/tmp"
|
||||
mkdir -p "$APPTAINER_CACHEDIR" "$APPTAINER_TMPDIR"
|
||||
|
||||
echo "📁 Apptainer cache: $APPTAINER_CACHEDIR"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="mixed-browser-terminal-tasks.jsonl" \
|
||||
--batch_size=20 \
|
||||
--run_name="mixed_tasks" \
|
||||
--distribution="mixed_tasks" \
|
||||
--model="moonshotai/kimi-k2.5" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--num_workers=25 \
|
||||
--max_turns=60 \
|
||||
--ephemeral_system_prompt="You are an AI assistant capable of both browser automation and terminal operations. Use browser tools to navigate websites, interact with web pages, fill forms, and extract information. Use terminal tools to execute commands, write and run code, install packages (use --break-system-packages with pip if needed), and perform local computations. When web search is available, use it to find URLs, documentation, or current information. If vision is available, use it to analyze images or screenshots. If image generation is available, use it when the task requires creating images. Combine browser and terminal capabilities effectively - for example, you might use the browser to fetch data from a website and terminal to process or analyze it. Always verify your work and handle errors gracefully. Whenever you can do something in a terminal instead of a web browser, you should choose to do so, as it's much cheaper." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
50
configs/run_terminal_tasks.sh
Executable file
50
configs/run_terminal_tasks.sh
Executable file
@@ -0,0 +1,50 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Terminal-focused data generation run
|
||||
# Uses nous-terminal-tasks.jsonl (597 tasks)
|
||||
# Distribution: terminal 97%, web 15%, browser 0%, vision 8%, image_gen 3%
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
mkdir -p logs
|
||||
|
||||
# Generate log filename with timestamp
|
||||
LOG_FILE="logs/terminal_tasks_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "📝 Logging output to: $LOG_FILE"
|
||||
echo "💻 Running terminal-focused tasks with terminal_tasks distribution"
|
||||
|
||||
# Set terminal environment
|
||||
# SIF images are automatically built/cached by terminal_tool.py
|
||||
export TERMINAL_ENV=singularity
|
||||
export TERMINAL_SINGULARITY_IMAGE="docker://nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
export TERMINAL_TIMEOUT=300
|
||||
|
||||
# Set up Apptainer cache directories (use /scratch if available, otherwise /tmp)
|
||||
if [ -d "/scratch" ] && [ -w "/scratch" ]; then
|
||||
CACHE_BASE="/scratch/$USER/.apptainer"
|
||||
else
|
||||
CACHE_BASE="/tmp/$USER/.apptainer"
|
||||
fi
|
||||
export APPTAINER_CACHEDIR="$CACHE_BASE"
|
||||
export APPTAINER_TMPDIR="$CACHE_BASE/tmp"
|
||||
mkdir -p "$APPTAINER_CACHEDIR" "$APPTAINER_TMPDIR"
|
||||
|
||||
echo "📁 Apptainer cache: $APPTAINER_CACHEDIR"
|
||||
echo "🐳 Image: $TERMINAL_SINGULARITY_IMAGE (auto-converted to SIF on first use)"
|
||||
|
||||
python batch_runner.py \
|
||||
--dataset_file="nous-terminal-tasks.jsonl" \
|
||||
--batch_size=5 \
|
||||
--run_name="terminal_tasks-kimi-k2.5" \
|
||||
--distribution="terminal_tasks" \
|
||||
--model="moonshotai/kimi-k2.5" \
|
||||
--verbose \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--num_workers=80 \
|
||||
--max_turns=60 \
|
||||
--providers_ignored="Novita" \
|
||||
--resume \
|
||||
--ephemeral_system_prompt="You have access to a terminal tool for executing commands and completing coding, system administration, and computing tasks. Use the terminal to write code, run scripts, install packages (use --break-system-packages with pip if needed), manipulate files, and verify your work. Always test and validate code you create. Do not use interactive tools like vim, nano, or python REPL. If git output is large, pipe to cat. When web search is available, use it to look up documentation, APIs, or best practices. If browser tools are available, use them for web interactions that require page manipulation. Do not use the terminal to communicate with the user - only your final response will be shown to them." \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo "✅ Log saved to: $LOG_FILE"
|
||||
21
configs/test_skills_kimi.sh
Normal file
21
configs/test_skills_kimi.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Test skills tool with Kimi K2.5
|
||||
# Usage: ./configs/test_skills_kimi.sh "your query here"
|
||||
# Example: ./configs/test_skills_kimi.sh "List available skills and show me the vllm skill"
|
||||
|
||||
# Default query if none provided
|
||||
QUERY="${1:-List all available skills. Then show me the axolotl skill and view one of its reference files.}"
|
||||
|
||||
echo "🎯 Testing Skills Tool with Kimi K2.5"
|
||||
echo "📝 Query: $QUERY"
|
||||
echo "="
|
||||
|
||||
python run_agent.py \
|
||||
--enabled_toolsets=skills \
|
||||
--model="moonshotai/kimi-k2.5" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--max_turns=10 \
|
||||
--verbose \
|
||||
--save_sample \
|
||||
--query="$QUERY"
|
||||
101
configs/trajectory_compression.yaml
Normal file
101
configs/trajectory_compression.yaml
Normal file
@@ -0,0 +1,101 @@
|
||||
# Trajectory Compression Configuration
|
||||
#
|
||||
# Post-processes completed agent trajectories to fit within a target token budget.
|
||||
# Compression preserves head/tail turns and summarizes middle content only as needed.
|
||||
|
||||
# Tokenizer settings for accurate token counting
|
||||
tokenizer:
|
||||
# HuggingFace tokenizer name
|
||||
name: "moonshotai/Kimi-K2-Thinking"
|
||||
|
||||
# Trust remote code (required for some tokenizers)
|
||||
trust_remote_code: true
|
||||
|
||||
# Compression targets and behavior
|
||||
compression:
|
||||
# Target maximum tokens for compressed trajectory
|
||||
target_max_tokens: 29000
|
||||
|
||||
# Target size for summary (in tokens)
|
||||
# This is factored into calculations when determining what to compress
|
||||
summary_target_tokens: 750
|
||||
|
||||
# Protected turns that should NEVER be compressed
|
||||
protected_turns:
|
||||
# Always protect the first system message (tool definitions)
|
||||
first_system: true
|
||||
|
||||
# Always protect the first human message (original request)
|
||||
first_human: true
|
||||
|
||||
# Always protect the first gpt message (initial response/tool_call)
|
||||
first_gpt: true
|
||||
|
||||
# Always protect the first tool response (result of first action)
|
||||
first_tool: true
|
||||
|
||||
# Always protect the last 2 complete turn pairs (gpt+tool or gpt only)
|
||||
# This ensures the model's final actions and conclusions are preserved
|
||||
last_n_turns: 4
|
||||
|
||||
# LLM settings for generating summaries (OpenRouter only)
|
||||
summarization:
|
||||
# Model to use for summarization (should be fast and cheap)
|
||||
# Using OpenRouter model path format
|
||||
model: "google/gemini-3-flash-preview"
|
||||
|
||||
# OpenRouter API settings
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
|
||||
# Environment variable containing OpenRouter API key
|
||||
api_key_env: "OPENROUTER_API_KEY"
|
||||
|
||||
# Temperature for summarization (lower = more deterministic)
|
||||
temperature: 0.3
|
||||
|
||||
# Max retries for API failures
|
||||
max_retries: 3
|
||||
|
||||
# Delay between retries (seconds)
|
||||
retry_delay: 2
|
||||
|
||||
# Output settings
|
||||
output:
|
||||
# Add notice to system message about potential summarization
|
||||
add_summary_notice: true
|
||||
|
||||
# Text to append to system message
|
||||
summary_notice_text: "\n\nSome of the conversation may be summarized to preserve context."
|
||||
|
||||
# Output directory suffix (appended to input directory name)
|
||||
output_suffix: "_compressed"
|
||||
|
||||
# Processing settings
|
||||
processing:
|
||||
# Number of parallel workers for batch processing
|
||||
num_workers: 4
|
||||
|
||||
# Maximum concurrent API calls for summarization (async parallelism)
|
||||
max_concurrent_requests: 50
|
||||
|
||||
# Skip trajectories that are already under target length
|
||||
skip_under_target: true
|
||||
|
||||
# If true, save trajectories even if compression can't get under target
|
||||
# (will compress as much as possible)
|
||||
save_over_limit: true
|
||||
|
||||
# Timeout per trajectory in seconds (skip if takes longer)
|
||||
# Helps avoid hanging on problematic entries
|
||||
per_trajectory_timeout: 300 # 5 minutes
|
||||
|
||||
# Metrics to track
|
||||
metrics:
|
||||
# Log detailed compression statistics
|
||||
enabled: true
|
||||
|
||||
# Save per-trajectory metrics in output
|
||||
per_trajectory: false
|
||||
|
||||
# Metrics file name (saved in output directory)
|
||||
output_file: "compression_metrics.json"
|
||||
36
cron/__init__.py
Normal file
36
cron/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Cron job scheduling system for Hermes Agent.
|
||||
|
||||
This module provides scheduled task execution, allowing the agent to:
|
||||
- Run automated tasks on schedules (cron expressions, intervals, one-shot)
|
||||
- Self-schedule reminders and follow-up tasks
|
||||
- Execute tasks in isolated sessions (no prior context)
|
||||
|
||||
Usage:
|
||||
# Run due jobs (for system cron integration)
|
||||
python -c "from cron import tick; tick()"
|
||||
|
||||
# Or via CLI
|
||||
python cli.py --cron-daemon
|
||||
"""
|
||||
|
||||
from cron.jobs import (
|
||||
create_job,
|
||||
get_job,
|
||||
list_jobs,
|
||||
remove_job,
|
||||
update_job,
|
||||
JOBS_FILE,
|
||||
)
|
||||
from cron.scheduler import tick, run_daemon
|
||||
|
||||
__all__ = [
|
||||
"create_job",
|
||||
"get_job",
|
||||
"list_jobs",
|
||||
"remove_job",
|
||||
"update_job",
|
||||
"tick",
|
||||
"run_daemon",
|
||||
"JOBS_FILE",
|
||||
]
|
||||
383
cron/jobs.py
Normal file
383
cron/jobs.py
Normal file
@@ -0,0 +1,383 @@
|
||||
"""
|
||||
Cron job storage and management.
|
||||
|
||||
Jobs are stored in ~/.hermes/cron/jobs.json
|
||||
Output is saved to ~/.hermes/cron/output/{job_id}/{timestamp}.md
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
try:
|
||||
from croniter import croniter
|
||||
HAS_CRONITER = True
|
||||
except ImportError:
|
||||
HAS_CRONITER = False
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
HERMES_DIR = Path.home() / ".hermes"
|
||||
CRON_DIR = HERMES_DIR / "cron"
|
||||
JOBS_FILE = CRON_DIR / "jobs.json"
|
||||
OUTPUT_DIR = CRON_DIR / "output"
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
"""Ensure cron directories exist."""
|
||||
CRON_DIR.mkdir(parents=True, exist_ok=True)
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Schedule Parsing
|
||||
# =============================================================================
|
||||
|
||||
def parse_duration(s: str) -> int:
|
||||
"""
|
||||
Parse duration string into minutes.
|
||||
|
||||
Examples:
|
||||
"30m" → 30
|
||||
"2h" → 120
|
||||
"1d" → 1440
|
||||
"""
|
||||
s = s.strip().lower()
|
||||
match = re.match(r'^(\d+)\s*(m|min|mins|minute|minutes|h|hr|hrs|hour|hours|d|day|days)$', s)
|
||||
if not match:
|
||||
raise ValueError(f"Invalid duration: '{s}'. Use format like '30m', '2h', or '1d'")
|
||||
|
||||
value = int(match.group(1))
|
||||
unit = match.group(2)[0] # First char: m, h, or d
|
||||
|
||||
multipliers = {'m': 1, 'h': 60, 'd': 1440}
|
||||
return value * multipliers[unit]
|
||||
|
||||
|
||||
def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse schedule string into structured format.
|
||||
|
||||
Returns dict with:
|
||||
- kind: "once" | "interval" | "cron"
|
||||
- For "once": "run_at" (ISO timestamp)
|
||||
- For "interval": "minutes" (int)
|
||||
- For "cron": "expr" (cron expression)
|
||||
|
||||
Examples:
|
||||
"30m" → once in 30 minutes
|
||||
"2h" → once in 2 hours
|
||||
"every 30m" → recurring every 30 minutes
|
||||
"every 2h" → recurring every 2 hours
|
||||
"0 9 * * *" → cron expression
|
||||
"2026-02-03T14:00" → once at timestamp
|
||||
"""
|
||||
schedule = schedule.strip()
|
||||
original = schedule
|
||||
schedule_lower = schedule.lower()
|
||||
|
||||
# "every X" pattern → recurring interval
|
||||
if schedule_lower.startswith("every "):
|
||||
duration_str = schedule[6:].strip()
|
||||
minutes = parse_duration(duration_str)
|
||||
return {
|
||||
"kind": "interval",
|
||||
"minutes": minutes,
|
||||
"display": f"every {minutes}m"
|
||||
}
|
||||
|
||||
# Check for cron expression (5 or 6 space-separated fields)
|
||||
# Cron fields: minute hour day month weekday [year]
|
||||
parts = schedule.split()
|
||||
if len(parts) >= 5 and all(
|
||||
re.match(r'^[\d\*\-,/]+$', p) for p in parts[:5]
|
||||
):
|
||||
if not HAS_CRONITER:
|
||||
raise ValueError("Cron expressions require 'croniter' package. Install with: pip install croniter")
|
||||
# Validate cron expression
|
||||
try:
|
||||
croniter(schedule)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid cron expression '{schedule}': {e}")
|
||||
return {
|
||||
"kind": "cron",
|
||||
"expr": schedule,
|
||||
"display": schedule
|
||||
}
|
||||
|
||||
# ISO timestamp (contains T or looks like date)
|
||||
if 'T' in schedule or re.match(r'^\d{4}-\d{2}-\d{2}', schedule):
|
||||
try:
|
||||
# Parse and validate
|
||||
dt = datetime.fromisoformat(schedule.replace('Z', '+00:00'))
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": dt.isoformat(),
|
||||
"display": f"once at {dt.strftime('%Y-%m-%d %H:%M')}"
|
||||
}
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid timestamp '{schedule}': {e}")
|
||||
|
||||
# Duration like "30m", "2h", "1d" → one-shot from now
|
||||
try:
|
||||
minutes = parse_duration(schedule)
|
||||
run_at = datetime.now() + timedelta(minutes=minutes)
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": run_at.isoformat(),
|
||||
"display": f"once in {original}"
|
||||
}
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid schedule '{original}'. Use:\n"
|
||||
f" - Duration: '30m', '2h', '1d' (one-shot)\n"
|
||||
f" - Interval: 'every 30m', 'every 2h' (recurring)\n"
|
||||
f" - Cron: '0 9 * * *' (cron expression)\n"
|
||||
f" - Timestamp: '2026-02-03T14:00:00' (one-shot at time)"
|
||||
)
|
||||
|
||||
|
||||
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Compute the next run time for a schedule.
|
||||
|
||||
Returns ISO timestamp string, or None if no more runs.
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
if schedule["kind"] == "once":
|
||||
run_at = datetime.fromisoformat(schedule["run_at"])
|
||||
# If in the future, return it; if in the past, no more runs
|
||||
return schedule["run_at"] if run_at > now else None
|
||||
|
||||
elif schedule["kind"] == "interval":
|
||||
minutes = schedule["minutes"]
|
||||
if last_run_at:
|
||||
# Next run is last_run + interval
|
||||
last = datetime.fromisoformat(last_run_at)
|
||||
next_run = last + timedelta(minutes=minutes)
|
||||
else:
|
||||
# First run is now + interval
|
||||
next_run = now + timedelta(minutes=minutes)
|
||||
return next_run.isoformat()
|
||||
|
||||
elif schedule["kind"] == "cron":
|
||||
if not HAS_CRONITER:
|
||||
return None
|
||||
cron = croniter(schedule["expr"], now)
|
||||
next_run = cron.get_next(datetime)
|
||||
return next_run.isoformat()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Job CRUD Operations
|
||||
# =============================================================================
|
||||
|
||||
def load_jobs() -> List[Dict[str, Any]]:
|
||||
"""Load all jobs from storage."""
|
||||
ensure_dirs()
|
||||
if not JOBS_FILE.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
with open(JOBS_FILE, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data.get("jobs", [])
|
||||
except (json.JSONDecodeError, IOError):
|
||||
return []
|
||||
|
||||
|
||||
def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
"""Save all jobs to storage."""
|
||||
ensure_dirs()
|
||||
with open(JOBS_FILE, 'w', encoding='utf-8') as f:
|
||||
json.dump({"jobs": jobs, "updated_at": datetime.now().isoformat()}, f, indent=2)
|
||||
|
||||
|
||||
def create_job(
|
||||
prompt: str,
|
||||
schedule: str,
|
||||
name: Optional[str] = None,
|
||||
repeat: Optional[int] = None,
|
||||
deliver: Optional[str] = None,
|
||||
origin: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a new cron job.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to run (must be self-contained)
|
||||
schedule: Schedule string (see parse_schedule)
|
||||
name: Optional friendly name
|
||||
repeat: How many times to run (None = forever, 1 = once)
|
||||
deliver: Where to deliver output ("origin", "local", "telegram", etc.)
|
||||
origin: Source info where job was created (for "origin" delivery)
|
||||
|
||||
Returns:
|
||||
The created job dict
|
||||
"""
|
||||
parsed_schedule = parse_schedule(schedule)
|
||||
|
||||
# Auto-set repeat=1 for one-shot schedules if not specified
|
||||
if parsed_schedule["kind"] == "once" and repeat is None:
|
||||
repeat = 1
|
||||
|
||||
# Default delivery to origin if available, otherwise local
|
||||
if deliver is None:
|
||||
deliver = "origin" if origin else "local"
|
||||
|
||||
job_id = uuid.uuid4().hex[:12]
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
job = {
|
||||
"id": job_id,
|
||||
"name": name or prompt[:50].strip(),
|
||||
"prompt": prompt,
|
||||
"schedule": parsed_schedule,
|
||||
"schedule_display": parsed_schedule.get("display", schedule),
|
||||
"repeat": {
|
||||
"times": repeat, # None = forever
|
||||
"completed": 0
|
||||
},
|
||||
"enabled": True,
|
||||
"created_at": now,
|
||||
"next_run_at": compute_next_run(parsed_schedule),
|
||||
"last_run_at": None,
|
||||
"last_status": None,
|
||||
"last_error": None,
|
||||
# Delivery configuration
|
||||
"deliver": deliver,
|
||||
"origin": origin, # Tracks where job was created for "origin" delivery
|
||||
}
|
||||
|
||||
jobs = load_jobs()
|
||||
jobs.append(job)
|
||||
save_jobs(jobs)
|
||||
|
||||
return job
|
||||
|
||||
|
||||
def get_job(job_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a job by ID."""
|
||||
jobs = load_jobs()
|
||||
for job in jobs:
|
||||
if job["id"] == job_id:
|
||||
return job
|
||||
return None
|
||||
|
||||
|
||||
def list_jobs(include_disabled: bool = False) -> List[Dict[str, Any]]:
|
||||
"""List all jobs, optionally including disabled ones."""
|
||||
jobs = load_jobs()
|
||||
if not include_disabled:
|
||||
jobs = [j for j in jobs if j.get("enabled", True)]
|
||||
return jobs
|
||||
|
||||
|
||||
def update_job(job_id: str, updates: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Update a job by ID."""
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
if job["id"] == job_id:
|
||||
jobs[i] = {**job, **updates}
|
||||
save_jobs(jobs)
|
||||
return jobs[i]
|
||||
return None
|
||||
|
||||
|
||||
def remove_job(job_id: str) -> bool:
|
||||
"""Remove a job by ID."""
|
||||
jobs = load_jobs()
|
||||
original_len = len(jobs)
|
||||
jobs = [j for j in jobs if j["id"] != job_id]
|
||||
if len(jobs) < original_len:
|
||||
save_jobs(jobs)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
"""
|
||||
Mark a job as having been run.
|
||||
|
||||
Updates last_run_at, last_status, increments completed count,
|
||||
computes next_run_at, and auto-deletes if repeat limit reached.
|
||||
"""
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
if job["id"] == job_id:
|
||||
now = datetime.now().isoformat()
|
||||
job["last_run_at"] = now
|
||||
job["last_status"] = "ok" if success else "error"
|
||||
job["last_error"] = error if not success else None
|
||||
|
||||
# Increment completed count
|
||||
if job.get("repeat"):
|
||||
job["repeat"]["completed"] = job["repeat"].get("completed", 0) + 1
|
||||
|
||||
# Check if we've hit the repeat limit
|
||||
times = job["repeat"].get("times")
|
||||
completed = job["repeat"]["completed"]
|
||||
if times is not None and completed >= times:
|
||||
# Remove the job (limit reached)
|
||||
jobs.pop(i)
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
# Compute next run
|
||||
job["next_run_at"] = compute_next_run(job["schedule"], now)
|
||||
|
||||
# If no next run (one-shot completed), disable
|
||||
if job["next_run_at"] is None:
|
||||
job["enabled"] = False
|
||||
|
||||
save_jobs(jobs)
|
||||
return
|
||||
|
||||
save_jobs(jobs)
|
||||
|
||||
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now."""
|
||||
now = datetime.now()
|
||||
jobs = load_jobs()
|
||||
due = []
|
||||
|
||||
for job in jobs:
|
||||
if not job.get("enabled", True):
|
||||
continue
|
||||
|
||||
next_run = job.get("next_run_at")
|
||||
if not next_run:
|
||||
continue
|
||||
|
||||
next_run_dt = datetime.fromisoformat(next_run)
|
||||
if next_run_dt <= now:
|
||||
due.append(job)
|
||||
|
||||
return due
|
||||
|
||||
|
||||
def save_job_output(job_id: str, output: str):
|
||||
"""Save job output to file."""
|
||||
ensure_dirs()
|
||||
job_output_dir = OUTPUT_DIR / job_id
|
||||
job_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
output_file = job_output_dir / f"{timestamp}.md"
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(output)
|
||||
|
||||
return output_file
|
||||
188
cron/scheduler.py
Normal file
188
cron/scheduler.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
Cron job scheduler - executes due jobs.
|
||||
|
||||
This module provides:
|
||||
- tick(): Run all due jobs once (for system cron integration)
|
||||
- run_daemon(): Run continuously, checking every 60 seconds
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output
|
||||
|
||||
|
||||
def run_job(job: dict) -> tuple[bool, str, Optional[str]]:
|
||||
"""
|
||||
Execute a single cron job.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, output, error_message)
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
|
||||
job_id = job["id"]
|
||||
job_name = job["name"]
|
||||
prompt = job["prompt"]
|
||||
|
||||
print(f"[cron] Running job '{job_name}' (ID: {job_id})")
|
||||
print(f"[cron] Prompt: {prompt[:100]}{'...' if len(prompt) > 100 else ''}")
|
||||
|
||||
try:
|
||||
# Create agent with default settings
|
||||
# Jobs run in isolated sessions (no prior context)
|
||||
agent = AIAgent(
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-sonnet-4"),
|
||||
quiet_mode=True,
|
||||
session_id=f"cron_{job_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
# Run the conversation
|
||||
result = agent.run_conversation(prompt)
|
||||
|
||||
# Extract final response
|
||||
final_response = result.get("final_response", "")
|
||||
if not final_response:
|
||||
final_response = "(No response generated)"
|
||||
|
||||
# Build output document
|
||||
output = f"""# Cron Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
|
||||
## Prompt
|
||||
|
||||
{prompt}
|
||||
|
||||
## Response
|
||||
|
||||
{final_response}
|
||||
"""
|
||||
|
||||
print(f"[cron] Job '{job_name}' completed successfully")
|
||||
return True, output, None
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
||||
print(f"[cron] Job '{job_name}' failed: {error_msg}")
|
||||
|
||||
# Build error output
|
||||
output = f"""# Cron Job: {job_name} (FAILED)
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
|
||||
## Prompt
|
||||
|
||||
{prompt}
|
||||
|
||||
## Error
|
||||
|
||||
```
|
||||
{error_msg}
|
||||
|
||||
{traceback.format_exc()}
|
||||
```
|
||||
"""
|
||||
return False, output, error_msg
|
||||
|
||||
|
||||
def tick(verbose: bool = True) -> int:
|
||||
"""
|
||||
Check and run all due jobs.
|
||||
|
||||
This is designed to be called by system cron every minute:
|
||||
*/1 * * * * cd ~/hermes-agent && python -c "from cron import tick; tick()"
|
||||
|
||||
Args:
|
||||
verbose: Whether to print status messages
|
||||
|
||||
Returns:
|
||||
Number of jobs executed
|
||||
"""
|
||||
due_jobs = get_due_jobs()
|
||||
|
||||
if verbose and not due_jobs:
|
||||
print(f"[cron] {datetime.now().strftime('%H:%M:%S')} - No jobs due")
|
||||
return 0
|
||||
|
||||
if verbose:
|
||||
print(f"[cron] {datetime.now().strftime('%H:%M:%S')} - {len(due_jobs)} job(s) due")
|
||||
|
||||
executed = 0
|
||||
for job in due_jobs:
|
||||
try:
|
||||
success, output, error = run_job(job)
|
||||
|
||||
# Save output to file
|
||||
output_file = save_job_output(job["id"], output)
|
||||
if verbose:
|
||||
print(f"[cron] Output saved to: {output_file}")
|
||||
|
||||
# Mark job as run (handles repeat counting, next_run computation)
|
||||
mark_job_run(job["id"], success, error)
|
||||
executed += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"[cron] Error processing job {job['id']}: {e}")
|
||||
mark_job_run(job["id"], False, str(e))
|
||||
|
||||
return executed
|
||||
|
||||
|
||||
def run_daemon(check_interval: int = 60, verbose: bool = True):
|
||||
"""
|
||||
Run the cron daemon continuously.
|
||||
|
||||
Checks for due jobs every `check_interval` seconds.
|
||||
|
||||
Args:
|
||||
check_interval: Seconds between checks (default: 60)
|
||||
verbose: Whether to print status messages
|
||||
"""
|
||||
print(f"[cron] Starting daemon (checking every {check_interval}s)")
|
||||
print(f"[cron] Press Ctrl+C to stop")
|
||||
print()
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
tick(verbose=verbose)
|
||||
except Exception as e:
|
||||
print(f"[cron] Tick error: {e}")
|
||||
|
||||
time.sleep(check_interval)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n[cron] Daemon stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Allow running directly: python cron/scheduler.py [daemon|tick]
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Hermes Cron Scheduler")
|
||||
parser.add_argument("mode", choices=["daemon", "tick"], default="tick", nargs="?",
|
||||
help="Mode: 'tick' to run once, 'daemon' to run continuously")
|
||||
parser.add_argument("--interval", type=int, default=60,
|
||||
help="Check interval in seconds for daemon mode")
|
||||
parser.add_argument("--quiet", "-q", action="store_true",
|
||||
help="Suppress status messages")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.mode == "daemon":
|
||||
run_daemon(check_interval=args.interval, verbose=not args.quiet)
|
||||
else:
|
||||
tick(verbose=not args.quiet)
|
||||
104
docs/agents.md
Normal file
104
docs/agents.md
Normal file
@@ -0,0 +1,104 @@
|
||||
# Agents
|
||||
|
||||
The agent is the core loop that orchestrates LLM calls and tool execution.
|
||||
|
||||
## AIAgent Class
|
||||
|
||||
The main agent is implemented in `run_agent.py`:
|
||||
|
||||
```python
|
||||
class AIAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_turns: int = 20,
|
||||
enabled_toolsets: list = None,
|
||||
disabled_toolsets: list = None,
|
||||
verbose_logging: bool = False,
|
||||
):
|
||||
# Initialize OpenAI client, load tools based on toolsets
|
||||
...
|
||||
|
||||
def chat(self, user_message: str, task_id: str = None) -> str:
|
||||
# Main entry point - runs the agent loop
|
||||
...
|
||||
```
|
||||
|
||||
## Agent Loop
|
||||
|
||||
The core loop in `_run_agent_loop()`:
|
||||
|
||||
```
|
||||
1. Add user message to conversation
|
||||
2. Call LLM with tools
|
||||
3. If LLM returns tool calls:
|
||||
- Execute each tool
|
||||
- Add tool results to conversation
|
||||
- Go to step 2
|
||||
4. If LLM returns text response:
|
||||
- Return response to user
|
||||
```
|
||||
|
||||
```python
|
||||
while turns < max_turns:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tool_schemas,
|
||||
)
|
||||
|
||||
if response.tool_calls:
|
||||
for tool_call in response.tool_calls:
|
||||
result = await execute_tool(tool_call)
|
||||
messages.append(tool_result_message(result))
|
||||
turns += 1
|
||||
else:
|
||||
return response.content
|
||||
```
|
||||
|
||||
## Conversation Management
|
||||
|
||||
Messages are stored as a list of dicts following OpenAI format:
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant..."},
|
||||
{"role": "user", "content": "Search for Python tutorials"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [...]},
|
||||
{"role": "tool", "tool_call_id": "...", "content": "..."},
|
||||
{"role": "assistant", "content": "Here's what I found..."},
|
||||
]
|
||||
```
|
||||
|
||||
## Reasoning Context
|
||||
|
||||
For models that support reasoning (chain-of-thought), the agent:
|
||||
1. Extracts `reasoning_content` from API responses
|
||||
2. Stores it in `assistant_msg["reasoning"]` for trajectory export
|
||||
3. Passes it back via `reasoning_content` field on subsequent turns
|
||||
|
||||
## Trajectory Export
|
||||
|
||||
Conversations can be exported for training:
|
||||
|
||||
```python
|
||||
agent = AIAgent(save_trajectories=True)
|
||||
agent.chat("Do something")
|
||||
# Saves to trajectories/*.jsonl in ShareGPT format
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
For processing multiple prompts, use `batch_runner.py`:
|
||||
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--num_workers=4 \
|
||||
--run_name=my_run
|
||||
```
|
||||
|
||||
See `batch_runner.py` for parallel execution with checkpointing.
|
||||
296
docs/cli.md
Normal file
296
docs/cli.md
Normal file
@@ -0,0 +1,296 @@
|
||||
# CLI
|
||||
|
||||
The Hermes Agent CLI provides an interactive terminal interface for working with the agent.
|
||||
|
||||
## Running the CLI
|
||||
|
||||
```bash
|
||||
# Basic usage
|
||||
./hermes
|
||||
|
||||
# With specific model
|
||||
./hermes --model "anthropic/claude-sonnet-4"
|
||||
|
||||
# With specific toolsets
|
||||
./hermes --toolsets "web,terminal,skills"
|
||||
|
||||
# Verbose mode
|
||||
./hermes --verbose
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
The CLI is implemented in `cli.py` and uses:
|
||||
|
||||
- **Rich** - Welcome banner with ASCII art and styled panels
|
||||
- **prompt_toolkit** - Fixed input area with command history
|
||||
- **KawaiiSpinner** - Animated feedback during operations
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────┐
|
||||
│ HERMES-AGENT ASCII Logo │
|
||||
│ ┌─────────────┐ ┌────────────────────────────┐ │
|
||||
│ │ Caduceus │ │ Model: claude-opus-4.5 │ │
|
||||
│ │ ASCII Art │ │ Terminal: local │ │
|
||||
│ │ │ │ Working Dir: /home/user │ │
|
||||
│ │ │ │ Available Tools: 19 │ │
|
||||
│ │ │ │ Available Skills: 12 │ │
|
||||
│ └─────────────┘ └────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────┘
|
||||
│ Conversation output scrolls here... │
|
||||
│ │
|
||||
│ User: Hello! │
|
||||
│ ────────────────────────────────────────────── │
|
||||
│ (◕‿◕✿) 🧠 pondering... (2.3s) │
|
||||
│ ✧٩(ˊᗜˋ*)و✧ got it! (2.3s) │
|
||||
│ │
|
||||
│ Assistant: Hello! How can I help you today? │
|
||||
├─────────────────────────────────────────────────┤
|
||||
│ ❯ [Fixed input area at bottom] │
|
||||
└─────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show available commands |
|
||||
| `/tools` | List available tools grouped by toolset |
|
||||
| `/toolsets` | List available toolsets with descriptions |
|
||||
| `/model [name]` | Show or change the current model |
|
||||
| `/prompt [text]` | View/set/clear custom system prompt |
|
||||
| `/personality [name]` | Set a predefined personality |
|
||||
| `/clear` | Clear screen and reset conversation |
|
||||
| `/reset` | Reset conversation only (keep screen) |
|
||||
| `/history` | Show conversation history |
|
||||
| `/save` | Save current conversation to file |
|
||||
| `/config` | Show current configuration |
|
||||
| `/quit` | Exit the CLI (also: `/exit`, `/q`) |
|
||||
|
||||
## Configuration
|
||||
|
||||
The CLI is configured via `cli-config.yaml`. Copy from `cli-config.yaml.example`:
|
||||
|
||||
```bash
|
||||
cp cli-config.yaml.example cli-config.yaml
|
||||
```
|
||||
|
||||
### Model Configuration
|
||||
|
||||
```yaml
|
||||
model:
|
||||
default: "anthropic/claude-opus-4.5"
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
```
|
||||
|
||||
### Terminal Configuration
|
||||
|
||||
The CLI supports multiple terminal backends:
|
||||
|
||||
```yaml
|
||||
# Local execution (default)
|
||||
terminal:
|
||||
env_type: "local"
|
||||
cwd: "." # Current directory
|
||||
|
||||
# SSH remote execution (sandboxed - agent can't touch its own code)
|
||||
terminal:
|
||||
env_type: "ssh"
|
||||
cwd: "/home/myuser/project"
|
||||
ssh_host: "my-server.example.com"
|
||||
ssh_user: "myuser"
|
||||
ssh_key: "~/.ssh/id_rsa"
|
||||
|
||||
# Docker container
|
||||
terminal:
|
||||
env_type: "docker"
|
||||
docker_image: "python:3.11"
|
||||
|
||||
# Singularity/Apptainer (HPC)
|
||||
terminal:
|
||||
env_type: "singularity"
|
||||
singularity_image: "docker://python:3.11"
|
||||
|
||||
# Modal cloud
|
||||
terminal:
|
||||
env_type: "modal"
|
||||
modal_image: "python:3.11"
|
||||
```
|
||||
|
||||
### Sudo Support
|
||||
|
||||
The CLI supports interactive sudo prompts:
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────┐
|
||||
│ 🔐 SUDO PASSWORD REQUIRED │
|
||||
├──────────────────────────────────────────────────────────┤
|
||||
│ Enter password below (input is hidden), or: │
|
||||
│ • Press Enter to skip (command fails gracefully) │
|
||||
│ • Wait 45s to auto-skip │
|
||||
└──────────────────────────────────────────────────────────┘
|
||||
|
||||
Password (hidden):
|
||||
```
|
||||
|
||||
**Options:**
|
||||
- **Interactive**: Leave `sudo_password` unset - you'll be prompted when needed
|
||||
- **Configured**: Set `sudo_password` in `cli-config.yaml` to auto-fill
|
||||
- **Environment**: Set `SUDO_PASSWORD` in `.env` for all runs
|
||||
|
||||
Password is cached for the session once entered.
|
||||
|
||||
### Toolsets
|
||||
|
||||
Control which tools are available:
|
||||
|
||||
```yaml
|
||||
# Enable all tools
|
||||
toolsets:
|
||||
- all
|
||||
|
||||
# Or enable specific toolsets
|
||||
toolsets:
|
||||
- web
|
||||
- terminal
|
||||
- skills
|
||||
```
|
||||
|
||||
Available toolsets: `web`, `search`, `terminal`, `browser`, `vision`, `image_gen`, `skills`, `moa`, `debugging`, `safe`
|
||||
|
||||
### Personalities
|
||||
|
||||
Predefined personalities for the `/personality` command:
|
||||
|
||||
```yaml
|
||||
agent:
|
||||
personalities:
|
||||
helpful: "You are a helpful, friendly AI assistant."
|
||||
kawaii: "You are a kawaii assistant! Use cute expressions..."
|
||||
pirate: "Arrr! Ye be talkin' to Captain Hermes..."
|
||||
# Add your own!
|
||||
```
|
||||
|
||||
Built-in personalities:
|
||||
- `helpful`, `concise`, `technical`, `creative`, `teacher`
|
||||
- `kawaii`, `catgirl`, `pirate`, `shakespeare`, `surfer`
|
||||
- `noir`, `uwu`, `philosopher`, `hype`
|
||||
|
||||
## Animated Feedback
|
||||
|
||||
The CLI provides animated feedback during operations:
|
||||
|
||||
### Thinking Animation
|
||||
|
||||
During API calls, shows animated spinner with thinking verbs:
|
||||
```
|
||||
◜ (。•́︿•̀。) pondering... (1.2s)
|
||||
◠ (⊙_⊙) contemplating... (2.4s)
|
||||
✧٩(ˊᗜˋ*)و✧ got it! (3.1s)
|
||||
```
|
||||
|
||||
### Tool Execution Animation
|
||||
|
||||
Each tool type has unique animations:
|
||||
```
|
||||
⠋ (◕‿◕✿) 🔍 web_search... (0.8s)
|
||||
▅ (≧◡≦) 💻 terminal... (1.2s)
|
||||
🌓 (★ω★) 🌐 browser_navigate... (2.1s)
|
||||
✧ (✿◠‿◠) 🎨 image_generate... (4.5s)
|
||||
```
|
||||
|
||||
## Multi-line Input
|
||||
|
||||
For multi-line input, end a line with `\` to continue:
|
||||
|
||||
```
|
||||
❯ Write a function that:\
|
||||
1. Takes a list of numbers\
|
||||
2. Returns the sum
|
||||
```
|
||||
|
||||
## Environment Variable Priority
|
||||
|
||||
For terminal settings, `cli-config.yaml` takes precedence over `.env`:
|
||||
|
||||
1. `cli-config.yaml` (highest priority in CLI)
|
||||
2. `.env` file
|
||||
3. System environment variables
|
||||
4. Default values
|
||||
|
||||
This allows you to have different terminal configs for CLI vs batch processing.
|
||||
|
||||
## Session Management
|
||||
|
||||
- **History**: Command history is saved to `~/.hermes_history`
|
||||
- **Conversations**: Use `/save` to export conversations
|
||||
- **Reset**: Use `/clear` for full reset, `/reset` to just clear history
|
||||
- **Session Logs**: Every session automatically logs to `logs/session_{session_id}.json`
|
||||
|
||||
### Session Logging
|
||||
|
||||
Sessions are automatically logged to the `logs/` directory:
|
||||
|
||||
```
|
||||
logs/
|
||||
├── session_20260201_143052_a1b2c3.json
|
||||
├── session_20260201_150217_d4e5f6.json
|
||||
└── ...
|
||||
```
|
||||
|
||||
The session ID is displayed in the welcome banner and follows the format: `YYYYMMDD_HHMMSS_UUID`.
|
||||
|
||||
Log files contain:
|
||||
- Full conversation history in trajectory format
|
||||
- Timestamps for session start and last update
|
||||
- Model and message count metadata
|
||||
|
||||
This is useful for:
|
||||
- Debugging agent behavior
|
||||
- Replaying conversations
|
||||
- Training data inspection
|
||||
|
||||
### Context Compression
|
||||
|
||||
Long conversations can exceed model context limits. The CLI automatically compresses context when approaching the limit:
|
||||
|
||||
```yaml
|
||||
# In cli-config.yaml
|
||||
compression:
|
||||
enabled: true # Enable auto-compression
|
||||
threshold: 0.85 # Compress at 85% of context limit
|
||||
summary_model: "google/gemini-2.0-flash-001"
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
1. Tracks actual token usage from each API response
|
||||
2. When tokens reach threshold, middle turns are summarized
|
||||
3. First 3 and last 4 turns are always protected
|
||||
4. Conversation continues seamlessly after compression
|
||||
|
||||
**When compression triggers:**
|
||||
```
|
||||
📦 Context compression triggered (170,000 tokens ≥ 170,000 threshold)
|
||||
📊 Model context limit: 200,000 tokens (85% = 170,000)
|
||||
🗜️ Summarizing turns 4-15 (12 turns)
|
||||
✅ Compressed: 20 → 9 messages (~45,000 tokens saved)
|
||||
```
|
||||
|
||||
To disable compression:
|
||||
```yaml
|
||||
compression:
|
||||
enabled: false
|
||||
```
|
||||
|
||||
## Quiet Mode
|
||||
|
||||
The CLI runs in "quiet mode" (`HERMES_QUIET=1`), which:
|
||||
- Suppresses verbose logging from tools
|
||||
- Enables kawaii-style animated feedback
|
||||
- Hides terminal environment warnings
|
||||
- Keeps output clean and user-friendly
|
||||
|
||||
For verbose output (debugging), use:
|
||||
```bash
|
||||
./hermes --verbose
|
||||
```
|
||||
124
docs/llm_client.md
Normal file
124
docs/llm_client.md
Normal file
@@ -0,0 +1,124 @@
|
||||
# LLM Client
|
||||
|
||||
Hermes Agent uses the OpenAI Python SDK with OpenRouter as the backend, providing access to many models through a single API.
|
||||
|
||||
## Configuration
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1"
|
||||
)
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
Any model available on [OpenRouter](https://openrouter.ai/models):
|
||||
|
||||
```python
|
||||
# Anthropic
|
||||
model = "anthropic/claude-sonnet-4"
|
||||
model = "anthropic/claude-opus-4"
|
||||
|
||||
# OpenAI
|
||||
model = "openai/gpt-4o"
|
||||
model = "openai/o1"
|
||||
|
||||
# Google
|
||||
model = "google/gemini-2.0-flash"
|
||||
|
||||
# Open models
|
||||
model = "meta-llama/llama-3.3-70b-instruct"
|
||||
model = "deepseek/deepseek-chat-v3"
|
||||
model = "moonshotai/kimi-k2.5"
|
||||
```
|
||||
|
||||
## Tool Calling
|
||||
|
||||
Standard OpenAI function calling format:
|
||||
|
||||
```python
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Check for tool calls
|
||||
if response.choices[0].message.tool_calls:
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
name = tool_call.function.name
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
# Execute tool...
|
||||
```
|
||||
|
||||
## Reasoning Models
|
||||
|
||||
Some models return reasoning/thinking content:
|
||||
|
||||
```python
|
||||
# Access reasoning if available
|
||||
message = response.choices[0].message
|
||||
if hasattr(message, 'reasoning_content') and message.reasoning_content:
|
||||
reasoning = message.reasoning_content
|
||||
# Store for trajectory export
|
||||
```
|
||||
|
||||
## Provider Selection
|
||||
|
||||
OpenRouter allows selecting specific providers:
|
||||
|
||||
```python
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
extra_body={
|
||||
"provider": {
|
||||
"order": ["Anthropic", "Google"], # Preferred providers
|
||||
"ignore": ["Novita"], # Providers to skip
|
||||
}
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
Common errors and handling:
|
||||
|
||||
```python
|
||||
try:
|
||||
response = client.chat.completions.create(...)
|
||||
except openai.RateLimitError:
|
||||
# Back off and retry
|
||||
except openai.APIError as e:
|
||||
# Check e.code for specific errors
|
||||
# 400 = bad request (often provider-specific)
|
||||
# 502 = bad gateway (retry with different provider)
|
||||
```
|
||||
|
||||
## Cost Tracking
|
||||
|
||||
OpenRouter returns usage info:
|
||||
|
||||
```python
|
||||
usage = response.usage
|
||||
print(f"Tokens: {usage.prompt_tokens} + {usage.completion_tokens}")
|
||||
print(f"Cost: ${usage.cost:.6f}") # If available
|
||||
```
|
||||
121
docs/message_graph.md
Normal file
121
docs/message_graph.md
Normal file
@@ -0,0 +1,121 @@
|
||||
# Message Format & Trajectories
|
||||
|
||||
Hermes Agent uses two message formats: the **API format** for LLM calls and the **trajectory format** for training data export.
|
||||
|
||||
## API Message Format
|
||||
|
||||
Standard OpenAI chat format used during execution:
|
||||
|
||||
```python
|
||||
messages = [
|
||||
# System prompt
|
||||
{"role": "system", "content": "You are a helpful assistant with tools..."},
|
||||
|
||||
# User query
|
||||
{"role": "user", "content": "Search for Python tutorials"},
|
||||
|
||||
# Assistant with tool call
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": "call_abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"arguments": "{\"query\": \"Python tutorials\"}"
|
||||
}
|
||||
}]
|
||||
},
|
||||
|
||||
# Tool result
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc123",
|
||||
"content": "{\"results\": [...]}"
|
||||
},
|
||||
|
||||
# Final response
|
||||
{"role": "assistant", "content": "Here's what I found..."}
|
||||
]
|
||||
```
|
||||
|
||||
## Trajectory Format (ShareGPT)
|
||||
|
||||
Exported for training in ShareGPT format:
|
||||
|
||||
```json
|
||||
{
|
||||
"conversations": [
|
||||
{"from": "system", "value": "You are a helpful assistant..."},
|
||||
{"from": "human", "value": "Search for Python tutorials"},
|
||||
{"from": "gpt", "value": "<tool_call>\n{\"name\": \"web_search\", \"arguments\": {\"query\": \"Python tutorials\"}}\n</tool_call>"},
|
||||
{"from": "tool", "value": "<tool_response>\n{\"results\": [...]}\n</tool_response>"},
|
||||
{"from": "gpt", "value": "Here's what I found..."}
|
||||
],
|
||||
"tools": "[{\"type\": \"function\", \"function\": {...}}]",
|
||||
"source": "hermes-agent"
|
||||
}
|
||||
```
|
||||
|
||||
## Reasoning Content
|
||||
|
||||
For models that output reasoning/chain-of-thought:
|
||||
|
||||
**During execution** (API format):
|
||||
```python
|
||||
# Stored internally but not sent back to model in content
|
||||
assistant_msg = {
|
||||
"role": "assistant",
|
||||
"content": "Here's what I found...",
|
||||
"reasoning": "Let me think about this step by step..." # Internal only
|
||||
}
|
||||
```
|
||||
|
||||
**In trajectory export** (reasoning wrapped in tags):
|
||||
```json
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "<think>\nLet me think about this step by step...\n</think>\nHere's what I found..."
|
||||
}
|
||||
```
|
||||
|
||||
## Conversion Flow
|
||||
|
||||
```
|
||||
API Response → Internal Storage → Trajectory Export
|
||||
↓ ↓ ↓
|
||||
tool_calls reasoning field <tool_call> tags
|
||||
reasoning_content <think> tags
|
||||
```
|
||||
|
||||
The conversion happens in `_convert_to_trajectory_format()` in `run_agent.py`.
|
||||
|
||||
## Ephemeral System Prompts
|
||||
|
||||
Batch processing supports ephemeral system prompts that guide behavior during execution but are NOT saved to trajectories:
|
||||
|
||||
```python
|
||||
# During execution: full system prompt + ephemeral guidance
|
||||
messages = [
|
||||
{"role": "system", "content": SYSTEM_PROMPT + "\n\n" + ephemeral_prompt},
|
||||
...
|
||||
]
|
||||
|
||||
# In saved trajectory: only the base system prompt
|
||||
trajectory = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": SYSTEM_PROMPT}, # No ephemeral
|
||||
...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Trajectory Compression
|
||||
|
||||
Long trajectories can be compressed for training using `trajectory_compressor.py`:
|
||||
|
||||
- Protects first/last N turns
|
||||
- Summarizes middle turns with LLM
|
||||
- Targets specific token budget
|
||||
- See `configs/trajectory_compression.yaml` for settings
|
||||
515
docs/messaging.md
Normal file
515
docs/messaging.md
Normal file
@@ -0,0 +1,515 @@
|
||||
# Messaging Platform Integrations (Gateway)
|
||||
|
||||
Hermes Agent can connect to messaging platforms like Telegram, Discord, and WhatsApp to serve as a conversational AI assistant.
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Set your bot token(s) in .env file
|
||||
echo 'TELEGRAM_BOT_TOKEN="your_telegram_bot_token"' >> .env
|
||||
echo 'DISCORD_BOT_TOKEN="your_discord_bot_token"' >> .env
|
||||
|
||||
# 2. Test the gateway (foreground)
|
||||
./scripts/hermes-gateway run
|
||||
|
||||
# 3. Install as a system service (runs in background)
|
||||
./scripts/hermes-gateway install
|
||||
|
||||
# 4. Manage the service
|
||||
./scripts/hermes-gateway start
|
||||
./scripts/hermes-gateway stop
|
||||
./scripts/hermes-gateway restart
|
||||
./scripts/hermes-gateway status
|
||||
```
|
||||
|
||||
**Quick test (without service install):**
|
||||
```bash
|
||||
python cli.py --gateway # Runs in foreground, useful for debugging
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ Hermes Gateway │
|
||||
├─────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ Telegram │ │ Discord │ │ WhatsApp │ │
|
||||
│ │ Adapter │ │ Adapter │ │ Adapter │ │
|
||||
│ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
|
||||
│ │ │ │ │
|
||||
│ └─────────────────┼─────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────▼────────┐ │
|
||||
│ │ Session Store │ │
|
||||
│ │ (per-chat) │ │
|
||||
│ └────────┬────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────▼────────┐ │
|
||||
│ │ AIAgent │ │
|
||||
│ │ (run_agent) │ │
|
||||
│ └─────────────────┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Session Management
|
||||
|
||||
### Session Persistence
|
||||
|
||||
Sessions persist across messages until they reset. The agent remembers your conversation context.
|
||||
|
||||
### Reset Policies
|
||||
|
||||
Sessions reset based on configurable policies:
|
||||
|
||||
| Policy | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| Daily | 4:00 AM | Reset at a specific hour each day |
|
||||
| Idle | 120 min | Reset after N minutes of inactivity |
|
||||
| Both | (combined) | Whichever triggers first |
|
||||
|
||||
### Manual Reset
|
||||
|
||||
Send `/new` or `/reset` as a message to start fresh.
|
||||
|
||||
### Per-Platform Overrides
|
||||
|
||||
Configure different reset policies per platform:
|
||||
|
||||
```json
|
||||
{
|
||||
"reset_by_platform": {
|
||||
"telegram": { "mode": "idle", "idle_minutes": 240 },
|
||||
"discord": { "mode": "idle", "idle_minutes": 60 }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Platform Setup
|
||||
|
||||
### Telegram
|
||||
|
||||
1. **Create a bot** via [@BotFather](https://t.me/BotFather)
|
||||
2. **Get your token** (looks like `123456789:ABCdefGHIjklMNOpqrsTUVwxyz`)
|
||||
3. **Set environment variable:**
|
||||
```bash
|
||||
export TELEGRAM_BOT_TOKEN="your_token_here"
|
||||
```
|
||||
4. **Optional: Set home channel** for cron job delivery:
|
||||
```bash
|
||||
export TELEGRAM_HOME_CHANNEL="-1001234567890"
|
||||
export TELEGRAM_HOME_CHANNEL_NAME="My Notes"
|
||||
```
|
||||
|
||||
**Requirements:**
|
||||
```bash
|
||||
pip install python-telegram-bot>=20.0
|
||||
```
|
||||
|
||||
### Discord
|
||||
|
||||
1. **Create an application** at [Discord Developer Portal](https://discord.com/developers/applications)
|
||||
2. **Create a bot** under your application
|
||||
3. **Get the bot token**
|
||||
4. **Enable required intents:**
|
||||
- Message Content Intent
|
||||
- Server Members Intent (optional)
|
||||
5. **Invite to your server** using OAuth2 URL generator (scopes: `bot`, `applications.commands`)
|
||||
6. **Set environment variable:**
|
||||
```bash
|
||||
export DISCORD_BOT_TOKEN="your_token_here"
|
||||
```
|
||||
7. **Optional: Set home channel:**
|
||||
```bash
|
||||
export DISCORD_HOME_CHANNEL="123456789012345678"
|
||||
export DISCORD_HOME_CHANNEL_NAME="#bot-updates"
|
||||
```
|
||||
|
||||
**Requirements:**
|
||||
```bash
|
||||
pip install discord.py>=2.0
|
||||
```
|
||||
|
||||
### WhatsApp
|
||||
|
||||
WhatsApp integration is more complex due to the lack of a simple bot API.
|
||||
|
||||
**Options:**
|
||||
1. **WhatsApp Business API** (requires Meta verification)
|
||||
2. **whatsapp-web.js** via Node.js bridge (for personal accounts)
|
||||
|
||||
**Bridge Setup:**
|
||||
1. Install Node.js
|
||||
2. Set up the bridge script (see `scripts/whatsapp-bridge/` for reference)
|
||||
3. Configure in gateway:
|
||||
```json
|
||||
{
|
||||
"platforms": {
|
||||
"whatsapp": {
|
||||
"enabled": true,
|
||||
"extra": {
|
||||
"bridge_script": "/path/to/bridge.js",
|
||||
"bridge_port": 3000
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
There are **three ways** to configure the gateway (in order of precedence):
|
||||
|
||||
### 1. Environment Variables (`.env` file) - Recommended for Quick Setup
|
||||
|
||||
Add to your `~/.hermes/.env` file:
|
||||
|
||||
```bash
|
||||
# =============================================================================
|
||||
# MESSAGING PLATFORM TOKENS
|
||||
# =============================================================================
|
||||
|
||||
# Telegram - get from @BotFather on Telegram
|
||||
TELEGRAM_BOT_TOKEN=your_telegram_bot_token
|
||||
TELEGRAM_ALLOWED_USERS=123456789,987654321 # Security: restrict to these user IDs
|
||||
|
||||
# Optional: Default channel for cron job delivery
|
||||
TELEGRAM_HOME_CHANNEL=-1001234567890
|
||||
TELEGRAM_HOME_CHANNEL_NAME="My Notes"
|
||||
|
||||
# Discord - get from Discord Developer Portal
|
||||
DISCORD_BOT_TOKEN=your_discord_bot_token
|
||||
DISCORD_ALLOWED_USERS=123456789012345678 # Security: restrict to these user IDs
|
||||
|
||||
# Optional: Default channel for cron job delivery
|
||||
DISCORD_HOME_CHANNEL=123456789012345678
|
||||
DISCORD_HOME_CHANNEL_NAME="#bot-updates"
|
||||
|
||||
# WhatsApp - requires Node.js bridge setup
|
||||
WHATSAPP_ENABLED=true
|
||||
|
||||
# =============================================================================
|
||||
# AGENT SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Max tool-calling iterations per conversation (default: 60)
|
||||
HERMES_MAX_ITERATIONS=60
|
||||
|
||||
# Working directory for terminal commands (default: home ~)
|
||||
MESSAGING_CWD=/home/myuser
|
||||
|
||||
# =============================================================================
|
||||
# TOOL PROGRESS NOTIFICATIONS
|
||||
# =============================================================================
|
||||
|
||||
# Show progress messages as agent uses tools
|
||||
HERMES_TOOL_PROGRESS=true
|
||||
|
||||
# Mode: "new" (only when tool changes) or "all" (every tool call)
|
||||
HERMES_TOOL_PROGRESS_MODE=new
|
||||
|
||||
# =============================================================================
|
||||
# SESSION SETTINGS
|
||||
# =============================================================================
|
||||
|
||||
# Reset sessions after N minutes of inactivity (default: 120)
|
||||
SESSION_IDLE_MINUTES=120
|
||||
|
||||
# Daily reset hour in 24h format (default: 4 = 4am)
|
||||
SESSION_RESET_HOUR=4
|
||||
```
|
||||
|
||||
### 2. Gateway Config File (`~/.hermes/gateway.json`) - Full Control
|
||||
|
||||
For advanced configuration, create `~/.hermes/gateway.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"platforms": {
|
||||
"telegram": {
|
||||
"enabled": true,
|
||||
"token": "your_telegram_token",
|
||||
"home_channel": {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-1001234567890",
|
||||
"name": "My Notes"
|
||||
}
|
||||
},
|
||||
"discord": {
|
||||
"enabled": true,
|
||||
"token": "your_discord_token",
|
||||
"home_channel": {
|
||||
"platform": "discord",
|
||||
"chat_id": "123456789012345678",
|
||||
"name": "#bot-updates"
|
||||
}
|
||||
}
|
||||
},
|
||||
"default_reset_policy": {
|
||||
"mode": "both",
|
||||
"at_hour": 4,
|
||||
"idle_minutes": 120
|
||||
},
|
||||
"reset_by_platform": {
|
||||
"discord": {
|
||||
"mode": "idle",
|
||||
"idle_minutes": 60
|
||||
}
|
||||
},
|
||||
"always_log_local": true
|
||||
}
|
||||
```
|
||||
|
||||
## Platform-Specific Toolsets
|
||||
|
||||
Each platform has its own toolset for security:
|
||||
|
||||
| Platform | Toolset | Capabilities |
|
||||
|----------|---------|--------------|
|
||||
| CLI | `hermes-cli` | Full access (terminal, browser, etc.) |
|
||||
| Telegram | `hermes-telegram` | Full tools including terminal |
|
||||
| Discord | `hermes-discord` | Full tools including terminal |
|
||||
| WhatsApp | `hermes-whatsapp` | Full tools including terminal |
|
||||
|
||||
## User Experience Features
|
||||
|
||||
### Typing Indicator
|
||||
|
||||
The gateway keeps the "typing..." indicator active throughout processing, refreshing every 4 seconds. This lets users know the bot is working even during long tool-calling sequences.
|
||||
|
||||
### Tool Progress Notifications
|
||||
|
||||
When `HERMES_TOOL_PROGRESS=true`, the bot sends status messages as it works:
|
||||
|
||||
```
|
||||
💻 `ls -la`...
|
||||
🔍 web_search...
|
||||
📄 web_extract...
|
||||
🎨 image_generate...
|
||||
```
|
||||
|
||||
Terminal commands show the actual command (truncated to 50 chars). Other tools just show the tool name.
|
||||
|
||||
**Modes:**
|
||||
- `new`: Only sends message when switching to a different tool (less spam)
|
||||
- `all`: Sends message for every single tool call
|
||||
|
||||
### Working Directory
|
||||
|
||||
- **CLI (`hermes` command)**: Uses current directory where you run the command
|
||||
- **Messaging**: Uses `MESSAGING_CWD` (default: home directory `~`)
|
||||
|
||||
This is intentional: CLI users are in a terminal and expect the agent to work in their current directory, while messaging users need a consistent starting location.
|
||||
|
||||
### Max Iterations
|
||||
|
||||
If the agent hits the max iteration limit while working, instead of a generic error, it asks the model to summarize what it found so far. This gives you a useful response even when the task couldn't be fully completed.
|
||||
|
||||
## Cron Job Delivery
|
||||
|
||||
When scheduling cron jobs, you can specify where the output should be delivered:
|
||||
|
||||
```
|
||||
User: "Remind me to check the server in 30 minutes"
|
||||
|
||||
Agent uses: schedule_cronjob(
|
||||
prompt="Check server status...",
|
||||
schedule="30m",
|
||||
deliver="origin" # Back to this chat
|
||||
)
|
||||
```
|
||||
|
||||
### Delivery Options
|
||||
|
||||
| Option | Description |
|
||||
|--------|-------------|
|
||||
| `"origin"` | Back to where the job was created |
|
||||
| `"local"` | Save to local files only |
|
||||
| `"telegram"` | Telegram home channel |
|
||||
| `"discord"` | Discord home channel |
|
||||
| `"telegram:123456"` | Specific Telegram chat |
|
||||
|
||||
## Dynamic Context Injection
|
||||
|
||||
The agent knows where it is via injected context:
|
||||
|
||||
```
|
||||
## Current Session Context
|
||||
|
||||
**Source:** Telegram (group: Dev Team, ID: -1001234567890)
|
||||
**Connected Platforms:** local, telegram, discord
|
||||
|
||||
**Home Channels:**
|
||||
- telegram: My Notes (ID: -1001234567890)
|
||||
- discord: #bot-updates (ID: 123456789012345678)
|
||||
|
||||
**Delivery options for scheduled tasks:**
|
||||
- "origin" → Back to this chat (Dev Team)
|
||||
- "local" → Save to local files only
|
||||
- "telegram" → Home channel (My Notes)
|
||||
- "discord" → Home channel (#bot-updates)
|
||||
```
|
||||
|
||||
## CLI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/platforms` | Show gateway configuration and status |
|
||||
| `--gateway` | Start the gateway (CLI flag) |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "python-telegram-bot not installed"
|
||||
|
||||
```bash
|
||||
pip install python-telegram-bot>=20.0
|
||||
```
|
||||
|
||||
### "discord.py not installed"
|
||||
|
||||
```bash
|
||||
pip install discord.py>=2.0
|
||||
```
|
||||
|
||||
### "No platforms connected"
|
||||
|
||||
1. Check your environment variables are set
|
||||
2. Check your tokens are valid
|
||||
3. Try `/platforms` to see configuration status
|
||||
|
||||
### Session not persisting
|
||||
|
||||
1. Check `~/.hermes/sessions/` exists
|
||||
2. Check session policies aren't too aggressive
|
||||
3. Verify no errors in gateway logs
|
||||
|
||||
## Adding a New Platform
|
||||
|
||||
To add a new messaging platform:
|
||||
|
||||
### 1. Create the adapter
|
||||
|
||||
Create `gateway/platforms/your_platform.py`:
|
||||
|
||||
```python
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
class YourPlatformAdapter(BasePlatformAdapter):
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.YOUR_PLATFORM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
# Connect to the platform
|
||||
...
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
# Disconnect
|
||||
...
|
||||
|
||||
async def send(self, chat_id: str, content: str, ...) -> SendResult:
|
||||
# Send a message
|
||||
...
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
# Get chat information
|
||||
...
|
||||
```
|
||||
|
||||
### 2. Register the platform
|
||||
|
||||
Add to `gateway/config.py`:
|
||||
|
||||
```python
|
||||
class Platform(Enum):
|
||||
# ... existing ...
|
||||
YOUR_PLATFORM = "your_platform"
|
||||
```
|
||||
|
||||
### 3. Add to gateway runner
|
||||
|
||||
Update `gateway/run.py` `_create_adapter()`:
|
||||
|
||||
```python
|
||||
elif platform == Platform.YOUR_PLATFORM:
|
||||
from gateway.platforms.your_platform import YourPlatformAdapter
|
||||
return YourPlatformAdapter(config)
|
||||
```
|
||||
|
||||
### 4. Create a toolset (optional)
|
||||
|
||||
Add to `toolsets.py`:
|
||||
|
||||
```python
|
||||
"hermes-your-platform": {
|
||||
"description": "Your platform toolset",
|
||||
"tools": [...],
|
||||
"includes": []
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Configure
|
||||
|
||||
Add environment variables to `.env`:
|
||||
|
||||
```bash
|
||||
YOUR_PLATFORM_TOKEN=...
|
||||
YOUR_PLATFORM_HOME_CHANNEL=...
|
||||
```
|
||||
|
||||
## Service Management
|
||||
|
||||
### Linux (systemd)
|
||||
|
||||
```bash
|
||||
# Install as user service
|
||||
./scripts/hermes-gateway install
|
||||
|
||||
# Manage
|
||||
systemctl --user start hermes-gateway
|
||||
systemctl --user stop hermes-gateway
|
||||
systemctl --user restart hermes-gateway
|
||||
systemctl --user status hermes-gateway
|
||||
|
||||
# View logs
|
||||
journalctl --user -u hermes-gateway -f
|
||||
|
||||
# Enable lingering (keeps running after logout)
|
||||
sudo loginctl enable-linger $USER
|
||||
```
|
||||
|
||||
### macOS (launchd)
|
||||
|
||||
```bash
|
||||
# Install
|
||||
./scripts/hermes-gateway install
|
||||
|
||||
# Manage
|
||||
launchctl start ai.hermes.gateway
|
||||
launchctl stop ai.hermes.gateway
|
||||
|
||||
# View logs
|
||||
tail -f ~/.hermes/logs/gateway.log
|
||||
```
|
||||
|
||||
### Manual (any platform)
|
||||
|
||||
```bash
|
||||
# Run in foreground (for testing/debugging)
|
||||
./scripts/hermes-gateway run
|
||||
|
||||
# Or via CLI (also foreground)
|
||||
python cli.py --gateway
|
||||
```
|
||||
|
||||
## Storage Locations
|
||||
|
||||
| Path | Purpose |
|
||||
|------|---------|
|
||||
| `~/.hermes/gateway.json` | Gateway configuration |
|
||||
| `~/.hermes/sessions/sessions.json` | Session index |
|
||||
| `~/.hermes/sessions/{id}.jsonl` | Conversation transcripts |
|
||||
| `~/.hermes/cron/output/` | Cron job outputs |
|
||||
| `~/.hermes/logs/gateway.log` | Gateway logs (macOS launchd) |
|
||||
159
docs/tools.md
Normal file
159
docs/tools.md
Normal file
@@ -0,0 +1,159 @@
|
||||
# Tools
|
||||
|
||||
Tools are functions that extend the agent's capabilities. Each tool is defined with an OpenAI-compatible JSON schema and an async handler function.
|
||||
|
||||
## Tool Structure
|
||||
|
||||
Each tool module in `tools/` exports:
|
||||
1. **Schema definitions** - OpenAI function-calling format
|
||||
2. **Handler functions** - Async functions that execute the tool
|
||||
|
||||
```python
|
||||
# Example: tools/web_tools.py
|
||||
|
||||
# Schema definition
|
||||
WEB_SEARCH_SCHEMA = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Handler function
|
||||
async def web_search(query: str) -> dict:
|
||||
"""Execute web search and return results."""
|
||||
# Implementation...
|
||||
return {"results": [...]}
|
||||
```
|
||||
|
||||
## Tool Categories
|
||||
|
||||
| Category | Module | Tools |
|
||||
|----------|--------|-------|
|
||||
| **Web** | `web_tools.py` | `web_search`, `web_extract`, `web_crawl` |
|
||||
| **Terminal** | `terminal_tool.py` | `terminal` (local/docker/singularity/modal/ssh backends) |
|
||||
| **Browser** | `browser_tool.py` | `browser_navigate`, `browser_click`, `browser_type`, etc. |
|
||||
| **Vision** | `vision_tools.py` | `vision_analyze` |
|
||||
| **Image Gen** | `image_generation_tool.py` | `image_generate` |
|
||||
| **Reasoning** | `mixture_of_agents_tool.py` | `mixture_of_agents` |
|
||||
| **Skills** | `skills_tool.py` | `skills_categories`, `skills_list`, `skill_view` |
|
||||
|
||||
## Tool Registration
|
||||
|
||||
Tools are registered in `model_tools.py`:
|
||||
|
||||
```python
|
||||
# model_tools.py
|
||||
TOOL_SCHEMAS = [
|
||||
*WEB_TOOL_SCHEMAS,
|
||||
*TERMINAL_TOOL_SCHEMAS,
|
||||
*BROWSER_TOOL_SCHEMAS,
|
||||
# ...
|
||||
]
|
||||
|
||||
TOOL_HANDLERS = {
|
||||
"web_search": web_search,
|
||||
"terminal": terminal_tool,
|
||||
"browser_navigate": browser_navigate,
|
||||
# ...
|
||||
}
|
||||
```
|
||||
|
||||
## Toolsets
|
||||
|
||||
Tools are grouped into **toolsets** for logical organization (see `toolsets.py`):
|
||||
|
||||
```python
|
||||
TOOLSETS = {
|
||||
"web": {
|
||||
"description": "Web search and content extraction",
|
||||
"tools": ["web_search", "web_extract", "web_crawl"]
|
||||
},
|
||||
"terminal": {
|
||||
"description": "Command execution",
|
||||
"tools": ["terminal"]
|
||||
},
|
||||
# ...
|
||||
}
|
||||
```
|
||||
|
||||
## Adding a New Tool
|
||||
|
||||
1. Create handler function in `tools/your_tool.py`
|
||||
2. Define JSON schema following OpenAI format
|
||||
3. Register in `model_tools.py` (schemas and handlers)
|
||||
4. Add to appropriate toolset in `toolsets.py`
|
||||
5. Update `tools/__init__.py` exports
|
||||
|
||||
## Stateful Tools
|
||||
|
||||
Some tools maintain state across calls within a session:
|
||||
|
||||
- **Terminal**: Keeps container/sandbox running between commands
|
||||
- **Browser**: Maintains browser session for multi-step navigation
|
||||
|
||||
State is managed per `task_id` and cleaned up automatically.
|
||||
|
||||
## Terminal Backends
|
||||
|
||||
The terminal tool supports multiple execution backends:
|
||||
|
||||
| Backend | Description | Use Case |
|
||||
|---------|-------------|----------|
|
||||
| `local` | Direct execution on host | Development, simple tasks |
|
||||
| `ssh` | Remote execution via SSH | Sandboxing (agent can't modify its own code) |
|
||||
| `docker` | Docker container | Isolation, reproducibility |
|
||||
| `singularity` | Singularity/Apptainer | HPC clusters, rootless containers |
|
||||
| `modal` | Modal cloud | Scalable cloud compute, GPUs |
|
||||
|
||||
Configure via environment variables or `cli-config.yaml`:
|
||||
|
||||
```yaml
|
||||
# SSH backend example (in cli-config.yaml)
|
||||
terminal:
|
||||
env_type: "ssh"
|
||||
ssh_host: "my-server.example.com"
|
||||
ssh_user: "myuser"
|
||||
ssh_key: "~/.ssh/id_rsa"
|
||||
cwd: "/home/myuser/project"
|
||||
```
|
||||
|
||||
The SSH backend uses ControlMaster for connection persistence, making subsequent commands fast.
|
||||
|
||||
## Skills Tools (Progressive Disclosure)
|
||||
|
||||
Skills are on-demand knowledge documents. They use **progressive disclosure** to minimize tokens:
|
||||
|
||||
```
|
||||
Level 0: skills_categories() → ["mlops", "devops"] (~50 tokens)
|
||||
Level 1: skills_list(category) → [{name, description}, ...] (~3k tokens)
|
||||
Level 2: skill_view(name) → Full content + metadata (varies)
|
||||
Level 3: skill_view(name, path) → Specific reference file (varies)
|
||||
```
|
||||
|
||||
Skill directory structure:
|
||||
```
|
||||
skills/
|
||||
└── mlops/
|
||||
└── axolotl/
|
||||
├── SKILL.md # Main instructions (required)
|
||||
├── references/ # Additional docs
|
||||
└── templates/ # Output formats, configs
|
||||
```
|
||||
|
||||
SKILL.md uses YAML frontmatter:
|
||||
```yaml
|
||||
---
|
||||
name: axolotl
|
||||
description: Fine-tuning LLMs with Axolotl
|
||||
tags: [Fine-Tuning, LoRA, DPO]
|
||||
---
|
||||
```
|
||||
70
example-skill/SKILL.md
Normal file
70
example-skill/SKILL.md
Normal file
@@ -0,0 +1,70 @@
|
||||
---
|
||||
name: example-skill
|
||||
description: An example skill demonstrating the skill file format and structure
|
||||
---
|
||||
|
||||
# Example Skill
|
||||
|
||||
This is an example skill file that demonstrates how to create skills for the Hermes Agent.
|
||||
|
||||
## Skill File Format
|
||||
|
||||
Skills are markdown files with YAML frontmatter at the top:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: your-skill-name
|
||||
description: A brief one-line description of what this skill does
|
||||
---
|
||||
```
|
||||
|
||||
The frontmatter fields:
|
||||
- **name**: The identifier used to reference this skill (lowercase, hyphens for spaces)
|
||||
- **description**: A brief description shown when listing skills (keep under 200 chars)
|
||||
|
||||
## Writing Effective Skills
|
||||
|
||||
### 1. Be Specific and Actionable
|
||||
|
||||
Good skills provide clear, actionable instructions:
|
||||
|
||||
```
|
||||
When reviewing code:
|
||||
1. Check for security vulnerabilities first
|
||||
2. Verify error handling is comprehensive
|
||||
3. Ensure tests cover edge cases
|
||||
```
|
||||
|
||||
### 2. Include Examples
|
||||
|
||||
Show concrete examples of what you want:
|
||||
|
||||
```python
|
||||
# Good: Descriptive variable names
|
||||
user_authentication_token = get_token()
|
||||
|
||||
# Bad: Cryptic abbreviations
|
||||
uat = gt()
|
||||
```
|
||||
|
||||
### 3. Define When to Use
|
||||
|
||||
Help the agent understand when this skill applies:
|
||||
|
||||
> Use this skill when: reviewing pull requests, auditing security, or checking code quality.
|
||||
|
||||
## Skill Categories
|
||||
|
||||
Consider organizing skills by purpose:
|
||||
|
||||
- **Conventions**: Coding standards, API patterns, naming rules
|
||||
- **Workflows**: Step-by-step processes for deployments, reviews, releases
|
||||
- **Knowledge**: Domain-specific information, system architecture, gotchas
|
||||
- **Templates**: Boilerplate for common tasks, response formats
|
||||
|
||||
## Tips
|
||||
|
||||
1. Keep the description concise - it's shown in the skills list
|
||||
2. Use headers to organize longer skills
|
||||
3. Include code examples where helpful
|
||||
4. Reference other skills if they're related
|
||||
35
gateway/__init__.py
Normal file
35
gateway/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Hermes Gateway - Multi-platform messaging integration.
|
||||
|
||||
This module provides a unified gateway for connecting the Hermes agent
|
||||
to various messaging platforms (Telegram, Discord, WhatsApp) with:
|
||||
- Session management (persistent conversations with reset policies)
|
||||
- Dynamic context injection (agent knows where messages come from)
|
||||
- Delivery routing (cron job outputs to appropriate channels)
|
||||
- Platform-specific toolsets (different capabilities per platform)
|
||||
"""
|
||||
|
||||
from .config import GatewayConfig, PlatformConfig, HomeChannel, load_gateway_config
|
||||
from .session import (
|
||||
SessionContext,
|
||||
SessionStore,
|
||||
SessionResetPolicy,
|
||||
build_session_context_prompt,
|
||||
)
|
||||
from .delivery import DeliveryRouter, DeliveryTarget
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"GatewayConfig",
|
||||
"PlatformConfig",
|
||||
"HomeChannel",
|
||||
"load_gateway_config",
|
||||
# Session
|
||||
"SessionContext",
|
||||
"SessionStore",
|
||||
"SessionResetPolicy",
|
||||
"build_session_context_prompt",
|
||||
# Delivery
|
||||
"DeliveryRouter",
|
||||
"DeliveryTarget",
|
||||
]
|
||||
333
gateway/config.py
Normal file
333
gateway/config.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Gateway configuration management.
|
||||
|
||||
Handles loading and validating configuration for:
|
||||
- Connected platforms (Telegram, Discord, WhatsApp)
|
||||
- Home channels for each platform
|
||||
- Session reset policies
|
||||
- Delivery preferences
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Platform(Enum):
|
||||
"""Supported messaging platforms."""
|
||||
LOCAL = "local"
|
||||
TELEGRAM = "telegram"
|
||||
DISCORD = "discord"
|
||||
WHATSAPP = "whatsapp"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HomeChannel:
|
||||
"""
|
||||
Default destination for a platform.
|
||||
|
||||
When a cron job specifies deliver="telegram" without a specific chat ID,
|
||||
messages are sent to this home channel.
|
||||
"""
|
||||
platform: Platform
|
||||
chat_id: str
|
||||
name: str # Human-readable name for display
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"platform": self.platform.value,
|
||||
"chat_id": self.chat_id,
|
||||
"name": self.name,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "HomeChannel":
|
||||
return cls(
|
||||
platform=Platform(data["platform"]),
|
||||
chat_id=str(data["chat_id"]),
|
||||
name=data.get("name", "Home"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionResetPolicy:
|
||||
"""
|
||||
Controls when sessions reset (lose context).
|
||||
|
||||
Modes:
|
||||
- "daily": Reset at a specific hour each day
|
||||
- "idle": Reset after N minutes of inactivity
|
||||
- "both": Whichever triggers first (daily boundary OR idle timeout)
|
||||
"""
|
||||
mode: str = "both" # "daily", "idle", or "both"
|
||||
at_hour: int = 4 # Hour for daily reset (0-23, local time)
|
||||
idle_minutes: int = 120 # Minutes of inactivity before reset
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"mode": self.mode,
|
||||
"at_hour": self.at_hour,
|
||||
"idle_minutes": self.idle_minutes,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionResetPolicy":
|
||||
return cls(
|
||||
mode=data.get("mode", "both"),
|
||||
at_hour=data.get("at_hour", 4),
|
||||
idle_minutes=data.get("idle_minutes", 120),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlatformConfig:
|
||||
"""Configuration for a single messaging platform."""
|
||||
enabled: bool = False
|
||||
token: Optional[str] = None # Bot token (Telegram, Discord)
|
||||
api_key: Optional[str] = None # API key if different from token
|
||||
home_channel: Optional[HomeChannel] = None
|
||||
|
||||
# Platform-specific settings
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"enabled": self.enabled,
|
||||
"extra": self.extra,
|
||||
}
|
||||
if self.token:
|
||||
result["token"] = self.token
|
||||
if self.api_key:
|
||||
result["api_key"] = self.api_key
|
||||
if self.home_channel:
|
||||
result["home_channel"] = self.home_channel.to_dict()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "PlatformConfig":
|
||||
home_channel = None
|
||||
if "home_channel" in data:
|
||||
home_channel = HomeChannel.from_dict(data["home_channel"])
|
||||
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
token=data.get("token"),
|
||||
api_key=data.get("api_key"),
|
||||
home_channel=home_channel,
|
||||
extra=data.get("extra", {}),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatewayConfig:
|
||||
"""
|
||||
Main gateway configuration.
|
||||
|
||||
Manages all platform connections, session policies, and delivery settings.
|
||||
"""
|
||||
# Platform configurations
|
||||
platforms: Dict[Platform, PlatformConfig] = field(default_factory=dict)
|
||||
|
||||
# Session reset policies by type
|
||||
default_reset_policy: SessionResetPolicy = field(default_factory=SessionResetPolicy)
|
||||
reset_by_type: Dict[str, SessionResetPolicy] = field(default_factory=dict)
|
||||
reset_by_platform: Dict[Platform, SessionResetPolicy] = field(default_factory=dict)
|
||||
|
||||
# Reset trigger commands
|
||||
reset_triggers: List[str] = field(default_factory=lambda: ["/new", "/reset"])
|
||||
|
||||
# Storage paths
|
||||
sessions_dir: Path = field(default_factory=lambda: Path.home() / ".hermes" / "sessions")
|
||||
|
||||
# Delivery settings
|
||||
always_log_local: bool = True # Always save cron outputs to local files
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
for platform, config in self.platforms.items():
|
||||
if config.enabled and (config.token or config.api_key):
|
||||
connected.append(platform)
|
||||
return connected
|
||||
|
||||
def get_home_channel(self, platform: Platform) -> Optional[HomeChannel]:
|
||||
"""Get the home channel for a platform."""
|
||||
config = self.platforms.get(platform)
|
||||
if config:
|
||||
return config.home_channel
|
||||
return None
|
||||
|
||||
def get_reset_policy(
|
||||
self,
|
||||
platform: Optional[Platform] = None,
|
||||
session_type: Optional[str] = None
|
||||
) -> SessionResetPolicy:
|
||||
"""
|
||||
Get the appropriate reset policy for a session.
|
||||
|
||||
Priority: platform override > type override > default
|
||||
"""
|
||||
# Platform-specific override takes precedence
|
||||
if platform and platform in self.reset_by_platform:
|
||||
return self.reset_by_platform[platform]
|
||||
|
||||
# Type-specific override (dm, group, thread)
|
||||
if session_type and session_type in self.reset_by_type:
|
||||
return self.reset_by_type[session_type]
|
||||
|
||||
return self.default_reset_policy
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"platforms": {
|
||||
p.value: c.to_dict() for p, c in self.platforms.items()
|
||||
},
|
||||
"default_reset_policy": self.default_reset_policy.to_dict(),
|
||||
"reset_by_type": {
|
||||
k: v.to_dict() for k, v in self.reset_by_type.items()
|
||||
},
|
||||
"reset_by_platform": {
|
||||
p.value: v.to_dict() for p, v in self.reset_by_platform.items()
|
||||
},
|
||||
"reset_triggers": self.reset_triggers,
|
||||
"sessions_dir": str(self.sessions_dir),
|
||||
"always_log_local": self.always_log_local,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GatewayConfig":
|
||||
platforms = {}
|
||||
for platform_name, platform_data in data.get("platforms", {}).items():
|
||||
try:
|
||||
platform = Platform(platform_name)
|
||||
platforms[platform] = PlatformConfig.from_dict(platform_data)
|
||||
except ValueError:
|
||||
pass # Skip unknown platforms
|
||||
|
||||
reset_by_type = {}
|
||||
for type_name, policy_data in data.get("reset_by_type", {}).items():
|
||||
reset_by_type[type_name] = SessionResetPolicy.from_dict(policy_data)
|
||||
|
||||
reset_by_platform = {}
|
||||
for platform_name, policy_data in data.get("reset_by_platform", {}).items():
|
||||
try:
|
||||
platform = Platform(platform_name)
|
||||
reset_by_platform[platform] = SessionResetPolicy.from_dict(policy_data)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
default_policy = SessionResetPolicy()
|
||||
if "default_reset_policy" in data:
|
||||
default_policy = SessionResetPolicy.from_dict(data["default_reset_policy"])
|
||||
|
||||
sessions_dir = Path.home() / ".hermes" / "sessions"
|
||||
if "sessions_dir" in data:
|
||||
sessions_dir = Path(data["sessions_dir"])
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
reset_by_type=reset_by_type,
|
||||
reset_by_platform=reset_by_platform,
|
||||
reset_triggers=data.get("reset_triggers", ["/new", "/reset"]),
|
||||
sessions_dir=sessions_dir,
|
||||
always_log_local=data.get("always_log_local", True),
|
||||
)
|
||||
|
||||
|
||||
def load_gateway_config() -> GatewayConfig:
|
||||
"""
|
||||
Load gateway configuration from multiple sources.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. Environment variables
|
||||
2. ~/.hermes/gateway.json
|
||||
3. cli-config.yaml gateway section
|
||||
4. Defaults
|
||||
"""
|
||||
config = GatewayConfig()
|
||||
|
||||
# Try loading from ~/.hermes/gateway.json
|
||||
gateway_config_path = Path.home() / ".hermes" / "gateway.json"
|
||||
if gateway_config_path.exists():
|
||||
try:
|
||||
with open(gateway_config_path, "r") as f:
|
||||
data = json.load(f)
|
||||
config = GatewayConfig.from_dict(data)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load {gateway_config_path}: {e}")
|
||||
|
||||
# Override with environment variables
|
||||
_apply_env_overrides(config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
"""Apply environment variable overrides to config."""
|
||||
|
||||
# Telegram
|
||||
telegram_token = os.getenv("TELEGRAM_BOT_TOKEN")
|
||||
if telegram_token:
|
||||
if Platform.TELEGRAM not in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM] = PlatformConfig()
|
||||
config.platforms[Platform.TELEGRAM].enabled = True
|
||||
config.platforms[Platform.TELEGRAM].token = telegram_token
|
||||
|
||||
telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL")
|
||||
if telegram_home and Platform.TELEGRAM in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM].home_channel = HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id=telegram_home,
|
||||
name=os.getenv("TELEGRAM_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Discord
|
||||
discord_token = os.getenv("DISCORD_BOT_TOKEN")
|
||||
if discord_token:
|
||||
if Platform.DISCORD not in config.platforms:
|
||||
config.platforms[Platform.DISCORD] = PlatformConfig()
|
||||
config.platforms[Platform.DISCORD].enabled = True
|
||||
config.platforms[Platform.DISCORD].token = discord_token
|
||||
|
||||
discord_home = os.getenv("DISCORD_HOME_CHANNEL")
|
||||
if discord_home and Platform.DISCORD in config.platforms:
|
||||
config.platforms[Platform.DISCORD].home_channel = HomeChannel(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id=discord_home,
|
||||
name=os.getenv("DISCORD_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# WhatsApp (typically uses different auth mechanism)
|
||||
whatsapp_enabled = os.getenv("WHATSAPP_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
if whatsapp_enabled:
|
||||
if Platform.WHATSAPP not in config.platforms:
|
||||
config.platforms[Platform.WHATSAPP] = PlatformConfig()
|
||||
config.platforms[Platform.WHATSAPP].enabled = True
|
||||
|
||||
# Session settings
|
||||
idle_minutes = os.getenv("SESSION_IDLE_MINUTES")
|
||||
if idle_minutes:
|
||||
try:
|
||||
config.default_reset_policy.idle_minutes = int(idle_minutes)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
reset_hour = os.getenv("SESSION_RESET_HOUR")
|
||||
if reset_hour:
|
||||
try:
|
||||
config.default_reset_policy.at_hour = int(reset_hour)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def save_gateway_config(config: GatewayConfig) -> None:
|
||||
"""Save gateway configuration to ~/.hermes/gateway.json."""
|
||||
gateway_config_path = Path.home() / ".hermes" / "gateway.json"
|
||||
gateway_config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(gateway_config_path, "w") as f:
|
||||
json.dump(config.to_dict(), f, indent=2)
|
||||
318
gateway/delivery.py
Normal file
318
gateway/delivery.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Delivery routing for cron job outputs and agent responses.
|
||||
|
||||
Routes messages to the appropriate destination based on:
|
||||
- Explicit targets (e.g., "telegram:123456789")
|
||||
- Platform home channels (e.g., "telegram" → home channel)
|
||||
- Origin (back to where the job was created)
|
||||
- Local (always saved to files)
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from enum import Enum
|
||||
|
||||
from .config import Platform, GatewayConfig, HomeChannel
|
||||
from .session import SessionSource
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeliveryTarget:
|
||||
"""
|
||||
A single delivery target.
|
||||
|
||||
Represents where a message should be sent:
|
||||
- "origin" → back to source
|
||||
- "local" → save to local files
|
||||
- "telegram" → Telegram home channel
|
||||
- "telegram:123456" → specific Telegram chat
|
||||
"""
|
||||
platform: Platform
|
||||
chat_id: Optional[str] = None # None means use home channel
|
||||
is_origin: bool = False
|
||||
is_explicit: bool = False # True if chat_id was explicitly specified
|
||||
|
||||
@classmethod
|
||||
def parse(cls, target: str, origin: Optional[SessionSource] = None) -> "DeliveryTarget":
|
||||
"""
|
||||
Parse a delivery target string.
|
||||
|
||||
Formats:
|
||||
- "origin" → back to source
|
||||
- "local" → local files only
|
||||
- "telegram" → Telegram home channel
|
||||
- "telegram:123456" → specific Telegram chat
|
||||
"""
|
||||
target = target.strip().lower()
|
||||
|
||||
if target == "origin":
|
||||
if origin:
|
||||
return cls(
|
||||
platform=origin.platform,
|
||||
chat_id=origin.chat_id,
|
||||
is_origin=True,
|
||||
)
|
||||
else:
|
||||
# Fallback to local if no origin
|
||||
return cls(platform=Platform.LOCAL, is_origin=True)
|
||||
|
||||
if target == "local":
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
# Check for platform:chat_id format
|
||||
if ":" in target:
|
||||
platform_str, chat_id = target.split(":", 1)
|
||||
try:
|
||||
platform = Platform(platform_str)
|
||||
return cls(platform=platform, chat_id=chat_id, is_explicit=True)
|
||||
except ValueError:
|
||||
# Unknown platform, treat as local
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
# Just a platform name (use home channel)
|
||||
try:
|
||||
platform = Platform(target)
|
||||
return cls(platform=platform)
|
||||
except ValueError:
|
||||
# Unknown platform, treat as local
|
||||
return cls(platform=Platform.LOCAL)
|
||||
|
||||
def to_string(self) -> str:
|
||||
"""Convert back to string format."""
|
||||
if self.is_origin:
|
||||
return "origin"
|
||||
if self.platform == Platform.LOCAL:
|
||||
return "local"
|
||||
if self.chat_id:
|
||||
return f"{self.platform.value}:{self.chat_id}"
|
||||
return self.platform.value
|
||||
|
||||
|
||||
class DeliveryRouter:
|
||||
"""
|
||||
Routes messages to appropriate destinations.
|
||||
|
||||
Handles the logic of resolving delivery targets and dispatching
|
||||
messages to the right platform adapters.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GatewayConfig, adapters: Dict[Platform, Any] = None):
|
||||
"""
|
||||
Initialize the delivery router.
|
||||
|
||||
Args:
|
||||
config: Gateway configuration
|
||||
adapters: Dict mapping platforms to their adapter instances
|
||||
"""
|
||||
self.config = config
|
||||
self.adapters = adapters or {}
|
||||
self.output_dir = Path.home() / ".hermes" / "cron" / "output"
|
||||
|
||||
def resolve_targets(
|
||||
self,
|
||||
deliver: Union[str, List[str]],
|
||||
origin: Optional[SessionSource] = None
|
||||
) -> List[DeliveryTarget]:
|
||||
"""
|
||||
Resolve delivery specification to concrete targets.
|
||||
|
||||
Args:
|
||||
deliver: Delivery spec - "origin", "telegram", ["local", "discord"], etc.
|
||||
origin: The source where the request originated (for "origin" target)
|
||||
|
||||
Returns:
|
||||
List of resolved delivery targets
|
||||
"""
|
||||
if isinstance(deliver, str):
|
||||
deliver = [deliver]
|
||||
|
||||
targets = []
|
||||
seen_platforms = set()
|
||||
|
||||
for target_str in deliver:
|
||||
target = DeliveryTarget.parse(target_str, origin)
|
||||
|
||||
# Resolve home channel if needed
|
||||
if target.chat_id is None and target.platform != Platform.LOCAL:
|
||||
home = self.config.get_home_channel(target.platform)
|
||||
if home:
|
||||
target.chat_id = home.chat_id
|
||||
else:
|
||||
# No home channel configured, skip this platform
|
||||
continue
|
||||
|
||||
# Deduplicate
|
||||
key = (target.platform, target.chat_id)
|
||||
if key not in seen_platforms:
|
||||
seen_platforms.add(key)
|
||||
targets.append(target)
|
||||
|
||||
# Always include local if configured
|
||||
if self.config.always_log_local:
|
||||
local_key = (Platform.LOCAL, None)
|
||||
if local_key not in seen_platforms:
|
||||
targets.append(DeliveryTarget(platform=Platform.LOCAL))
|
||||
|
||||
return targets
|
||||
|
||||
async def deliver(
|
||||
self,
|
||||
content: str,
|
||||
targets: List[DeliveryTarget],
|
||||
job_id: Optional[str] = None,
|
||||
job_name: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Deliver content to all specified targets.
|
||||
|
||||
Args:
|
||||
content: The message/output to deliver
|
||||
targets: List of delivery targets
|
||||
job_id: Optional job ID (for cron jobs)
|
||||
job_name: Optional job name
|
||||
metadata: Additional metadata to include
|
||||
|
||||
Returns:
|
||||
Dict with delivery results per target
|
||||
"""
|
||||
results = {}
|
||||
|
||||
for target in targets:
|
||||
try:
|
||||
if target.platform == Platform.LOCAL:
|
||||
result = self._deliver_local(content, job_id, job_name, metadata)
|
||||
else:
|
||||
result = await self._deliver_to_platform(target, content, metadata)
|
||||
|
||||
results[target.to_string()] = {
|
||||
"success": True,
|
||||
"result": result
|
||||
}
|
||||
except Exception as e:
|
||||
results[target.to_string()] = {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def _deliver_local(
|
||||
self,
|
||||
content: str,
|
||||
job_id: Optional[str],
|
||||
job_name: Optional[str],
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Save content to local files."""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
if job_id:
|
||||
output_path = self.output_dir / job_id / f"{timestamp}.md"
|
||||
else:
|
||||
output_path = self.output_dir / "misc" / f"{timestamp}.md"
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build the output document
|
||||
lines = []
|
||||
if job_name:
|
||||
lines.append(f"# {job_name}")
|
||||
else:
|
||||
lines.append("# Delivery Output")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"**Timestamp:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
|
||||
if job_id:
|
||||
lines.append(f"**Job ID:** {job_id}")
|
||||
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
lines.append(f"**{key}:** {value}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("---")
|
||||
lines.append("")
|
||||
lines.append(content)
|
||||
|
||||
output_path.write_text("\n".join(lines))
|
||||
|
||||
return {
|
||||
"path": str(output_path),
|
||||
"timestamp": timestamp
|
||||
}
|
||||
|
||||
async def _deliver_to_platform(
|
||||
self,
|
||||
target: DeliveryTarget,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Deliver content to a messaging platform."""
|
||||
adapter = self.adapters.get(target.platform)
|
||||
|
||||
if not adapter:
|
||||
raise ValueError(f"No adapter configured for {target.platform.value}")
|
||||
|
||||
if not target.chat_id:
|
||||
raise ValueError(f"No chat ID for {target.platform.value} delivery")
|
||||
|
||||
# Call the adapter's send method
|
||||
# Adapters should implement: async def send(chat_id: str, content: str) -> Dict
|
||||
return await adapter.send(target.chat_id, content, metadata=metadata)
|
||||
|
||||
|
||||
def parse_deliver_spec(
|
||||
deliver: Optional[Union[str, List[str]]],
|
||||
origin: Optional[SessionSource] = None,
|
||||
default: str = "origin"
|
||||
) -> Union[str, List[str]]:
|
||||
"""
|
||||
Normalize a delivery specification.
|
||||
|
||||
If None or empty, returns the default.
|
||||
"""
|
||||
if not deliver:
|
||||
return default
|
||||
return deliver
|
||||
|
||||
|
||||
def build_delivery_context_for_tool(
|
||||
config: GatewayConfig,
|
||||
origin: Optional[SessionSource] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build context for the schedule_cronjob tool to understand delivery options.
|
||||
|
||||
This is passed to the tool so it can validate and explain delivery targets.
|
||||
"""
|
||||
connected = config.get_connected_platforms()
|
||||
|
||||
options = {
|
||||
"origin": {
|
||||
"description": "Back to where this job was created",
|
||||
"available": origin is not None,
|
||||
},
|
||||
"local": {
|
||||
"description": "Save to local files only",
|
||||
"available": True,
|
||||
}
|
||||
}
|
||||
|
||||
for platform in connected:
|
||||
home = config.get_home_channel(platform)
|
||||
options[platform.value] = {
|
||||
"description": f"{platform.value.title()} home channel",
|
||||
"available": True,
|
||||
"home_channel": home.to_dict() if home else None,
|
||||
}
|
||||
|
||||
return {
|
||||
"origin": origin.to_dict() if origin else None,
|
||||
"options": options,
|
||||
"always_log_local": config.always_log_local,
|
||||
}
|
||||
17
gateway/platforms/__init__.py
Normal file
17
gateway/platforms/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Platform adapters for messaging integrations.
|
||||
|
||||
Each adapter handles:
|
||||
- Receiving messages from a platform
|
||||
- Sending messages/responses back
|
||||
- Platform-specific authentication
|
||||
- Message formatting and media handling
|
||||
"""
|
||||
|
||||
from .base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
|
||||
__all__ = [
|
||||
"BasePlatformAdapter",
|
||||
"MessageEvent",
|
||||
"SendResult",
|
||||
]
|
||||
365
gateway/platforms/base.py
Normal file
365
gateway/platforms/base.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""
|
||||
Base platform adapter interface.
|
||||
|
||||
All platform adapters (Telegram, Discord, WhatsApp) inherit from this
|
||||
and implement the required methods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Callable, Awaitable
|
||||
from enum import Enum
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(__file__).rsplit("/", 3)[0])
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
"""Types of incoming messages."""
|
||||
TEXT = "text"
|
||||
PHOTO = "photo"
|
||||
VIDEO = "video"
|
||||
AUDIO = "audio"
|
||||
VOICE = "voice"
|
||||
DOCUMENT = "document"
|
||||
STICKER = "sticker"
|
||||
COMMAND = "command" # /command style
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageEvent:
|
||||
"""
|
||||
Incoming message from a platform.
|
||||
|
||||
Normalized representation that all adapters produce.
|
||||
"""
|
||||
# Message content
|
||||
text: str
|
||||
message_type: MessageType = MessageType.TEXT
|
||||
|
||||
# Source information
|
||||
source: SessionSource = None
|
||||
|
||||
# Original platform data
|
||||
raw_message: Any = None
|
||||
message_id: Optional[str] = None
|
||||
|
||||
# Media attachments
|
||||
media_urls: List[str] = field(default_factory=list)
|
||||
media_types: List[str] = field(default_factory=list)
|
||||
|
||||
# Reply context
|
||||
reply_to_message_id: Optional[str] = None
|
||||
|
||||
# Timestamps
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def is_command(self) -> bool:
|
||||
"""Check if this is a command message (e.g., /new, /reset)."""
|
||||
return self.text.startswith("/")
|
||||
|
||||
def get_command(self) -> Optional[str]:
|
||||
"""Extract command name if this is a command message."""
|
||||
if not self.is_command():
|
||||
return None
|
||||
# Split on space and get first word, strip the /
|
||||
parts = self.text.split(maxsplit=1)
|
||||
return parts[0][1:].lower() if parts else None
|
||||
|
||||
def get_command_args(self) -> str:
|
||||
"""Get the arguments after a command."""
|
||||
if not self.is_command():
|
||||
return self.text
|
||||
parts = self.text.split(maxsplit=1)
|
||||
return parts[1] if len(parts) > 1 else ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendResult:
|
||||
"""Result of sending a message."""
|
||||
success: bool
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
raw_response: Any = None
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]]
|
||||
|
||||
|
||||
class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
Base class for platform adapters.
|
||||
|
||||
Subclasses implement platform-specific logic for:
|
||||
- Connecting and authenticating
|
||||
- Receiving messages
|
||||
- Sending messages/responses
|
||||
- Handling media
|
||||
"""
|
||||
|
||||
def __init__(self, config: PlatformConfig, platform: Platform):
|
||||
self.config = config
|
||||
self.platform = platform
|
||||
self._message_handler: Optional[MessageHandler] = None
|
||||
self._running = False
|
||||
|
||||
# Track active message handlers per session for interrupt support
|
||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Human-readable name for this adapter."""
|
||||
return self.platform.value.title()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if adapter is currently connected."""
|
||||
return self._running
|
||||
|
||||
def set_message_handler(self, handler: MessageHandler) -> None:
|
||||
"""
|
||||
Set the handler for incoming messages.
|
||||
|
||||
The handler receives a MessageEvent and should return
|
||||
an optional response string.
|
||||
"""
|
||||
self._message_handler = handler
|
||||
|
||||
@abstractmethod
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Connect to the platform and start receiving messages.
|
||||
|
||||
Returns True if connection was successful.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the platform."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> SendResult:
|
||||
"""
|
||||
Send a message to a chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID to send to
|
||||
content: Message content (may be markdown)
|
||||
reply_to: Optional message ID to reply to
|
||||
metadata: Additional platform-specific options
|
||||
|
||||
Returns:
|
||||
SendResult with success status and message ID
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""
|
||||
Send a typing indicator.
|
||||
|
||||
Override in subclasses if the platform supports it.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _keep_typing(self, chat_id: str, interval: float = 2.0) -> None:
|
||||
"""
|
||||
Continuously send typing indicator until cancelled.
|
||||
|
||||
Telegram/Discord typing status expires after ~5 seconds, so we refresh every 2
|
||||
to recover quickly after progress messages interrupt it.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
await self.send_typing(chat_id)
|
||||
await asyncio.sleep(interval)
|
||||
except asyncio.CancelledError:
|
||||
pass # Normal cancellation when handler completes
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
|
||||
This method returns quickly by spawning background tasks.
|
||||
This allows new messages to be processed even while an agent is running,
|
||||
enabling interruption support.
|
||||
"""
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
session_key = event.source.chat_id
|
||||
|
||||
# Check if there's already an active handler for this session
|
||||
if session_key in self._active_sessions:
|
||||
# Store this as a pending message - it will interrupt the running agent
|
||||
print(f"[{self.name}] ⚡ New message while session {session_key} is active - triggering interrupt")
|
||||
self._pending_messages[session_key] = event
|
||||
# Signal the interrupt (the processing task checks this)
|
||||
self._active_sessions[session_key].set()
|
||||
return # Don't process now - will be handled after current task finishes
|
||||
|
||||
# Spawn background task to process this message
|
||||
asyncio.create_task(self._process_message_background(event, session_key))
|
||||
|
||||
async def _process_message_background(self, event: MessageEvent, session_key: str) -> None:
|
||||
"""Background task that actually processes the message."""
|
||||
# Create interrupt event for this session
|
||||
interrupt_event = asyncio.Event()
|
||||
self._active_sessions[session_key] = interrupt_event
|
||||
|
||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id))
|
||||
|
||||
try:
|
||||
# Call the handler (this can take a while with tool calls)
|
||||
response = await self._message_handler(event)
|
||||
|
||||
# Send response if any
|
||||
if response:
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=response,
|
||||
reply_to=event.message_id
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{response[:3500]}",
|
||||
reply_to=event.message_id
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
pending_event = self._pending_messages.pop(session_key)
|
||||
print(f"[{self.name}] 📨 Processing queued message from interrupt")
|
||||
# Clean up current session before processing pending
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
typing_task.cancel()
|
||||
try:
|
||||
await typing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Process pending message in new background task
|
||||
await self._process_message_background(pending_event, session_key)
|
||||
return # Already cleaned up
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error handling message: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Stop typing indicator
|
||||
typing_task.cancel()
|
||||
try:
|
||||
await typing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Clean up session tracking
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
|
||||
def has_pending_interrupt(self, session_key: str) -> bool:
|
||||
"""Check if there's a pending interrupt for a session."""
|
||||
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
||||
|
||||
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
|
||||
"""Get and clear any pending message for a session."""
|
||||
return self._pending_messages.get(session_key)
|
||||
|
||||
def build_source(
|
||||
self,
|
||||
chat_id: str,
|
||||
chat_name: Optional[str] = None,
|
||||
chat_type: str = "dm",
|
||||
user_id: Optional[str] = None,
|
||||
user_name: Optional[str] = None,
|
||||
thread_id: Optional[str] = None
|
||||
) -> SessionSource:
|
||||
"""Helper to build a SessionSource for this platform."""
|
||||
return SessionSource(
|
||||
platform=self.platform,
|
||||
chat_id=str(chat_id),
|
||||
chat_name=chat_name,
|
||||
chat_type=chat_type,
|
||||
user_id=str(user_id) if user_id else None,
|
||||
user_name=user_name,
|
||||
thread_id=str(thread_id) if thread_id else None,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a chat/channel.
|
||||
|
||||
Returns dict with at least:
|
||||
- name: Chat name
|
||||
- type: "dm", "group", "channel"
|
||||
"""
|
||||
pass
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format a message for this platform.
|
||||
|
||||
Override in subclasses to handle platform-specific formatting
|
||||
(e.g., Telegram MarkdownV2, Discord markdown).
|
||||
|
||||
Default implementation returns content as-is.
|
||||
"""
|
||||
return content
|
||||
|
||||
def truncate_message(self, content: str, max_length: int = 4096) -> List[str]:
|
||||
"""
|
||||
Split a long message into chunks.
|
||||
|
||||
Args:
|
||||
content: The full message content
|
||||
max_length: Maximum length per chunk (platform-specific)
|
||||
|
||||
Returns:
|
||||
List of message chunks
|
||||
"""
|
||||
if len(content) <= max_length:
|
||||
return [content]
|
||||
|
||||
chunks = []
|
||||
while content:
|
||||
if len(content) <= max_length:
|
||||
chunks.append(content)
|
||||
break
|
||||
|
||||
# Try to split at a newline
|
||||
split_idx = content.rfind("\n", 0, max_length)
|
||||
if split_idx == -1:
|
||||
# No newline, split at space
|
||||
split_idx = content.rfind(" ", 0, max_length)
|
||||
if split_idx == -1:
|
||||
# No space either, hard split
|
||||
split_idx = max_length
|
||||
|
||||
chunks.append(content[:split_idx])
|
||||
content = content[split_idx:].lstrip()
|
||||
|
||||
return chunks
|
||||
297
gateway/platforms/discord.py
Normal file
297
gateway/platforms/discord.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Discord platform adapter.
|
||||
|
||||
Uses discord.py library for:
|
||||
- Receiving messages from servers and DMs
|
||||
- Sending responses back
|
||||
- Handling threads and channels
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
try:
|
||||
import discord
|
||||
from discord import Message as DiscordMessage, Intents
|
||||
from discord.ext import commands
|
||||
DISCORD_AVAILABLE = True
|
||||
except ImportError:
|
||||
DISCORD_AVAILABLE = False
|
||||
discord = None
|
||||
DiscordMessage = Any
|
||||
Intents = Any
|
||||
commands = None
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(__file__).rsplit("/", 3)[0])
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
|
||||
def check_discord_requirements() -> bool:
|
||||
"""Check if Discord dependencies are available."""
|
||||
return DISCORD_AVAILABLE
|
||||
|
||||
|
||||
class DiscordAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Discord bot adapter.
|
||||
|
||||
Handles:
|
||||
- Receiving messages from servers and DMs
|
||||
- Sending responses with Discord markdown
|
||||
- Thread support
|
||||
- Slash commands (future)
|
||||
"""
|
||||
|
||||
# Discord message limits
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.DISCORD)
|
||||
self._client: Optional[commands.Bot] = None
|
||||
self._ready_event = asyncio.Event()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
if not DISCORD_AVAILABLE:
|
||||
print(f"[{self.name}] discord.py not installed. Run: pip install discord.py")
|
||||
return False
|
||||
|
||||
if not self.config.token:
|
||||
print(f"[{self.name}] No bot token configured")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Set up intents
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
intents.dm_messages = True
|
||||
intents.guild_messages = True
|
||||
|
||||
# Create bot
|
||||
self._client = commands.Bot(
|
||||
command_prefix="!", # Not really used, we handle raw messages
|
||||
intents=intents,
|
||||
)
|
||||
|
||||
# Register event handlers
|
||||
@self._client.event
|
||||
async def on_ready():
|
||||
print(f"[{self.name}] Connected as {self._client.user}")
|
||||
self._ready_event.set()
|
||||
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Ignore bot's own messages
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
await self._handle_message(message)
|
||||
|
||||
# Start the bot in background
|
||||
asyncio.create_task(self._client.start(self.config.token))
|
||||
|
||||
# Wait for ready
|
||||
await asyncio.wait_for(self._ready_event.wait(), timeout=30)
|
||||
|
||||
self._running = True
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(f"[{self.name}] Timeout waiting for connection")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Discord."""
|
||||
if self._client:
|
||||
try:
|
||||
await self._client.close()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during disconnect: {e}")
|
||||
|
||||
self._running = False
|
||||
self._client = None
|
||||
self._ready_event.clear()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Discord channel."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Get the channel
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
message_ids = []
|
||||
reference = None
|
||||
|
||||
if reply_to:
|
||||
try:
|
||||
ref_msg = await channel.fetch_message(int(reply_to))
|
||||
reference = ref_msg
|
||||
except Exception:
|
||||
pass # Ignore if we can't find the referenced message
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
reference=reference if i == 0 else None,
|
||||
)
|
||||
message_ids.append(str(msg.id))
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._client:
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if channel:
|
||||
await channel.typing()
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Discord channel."""
|
||||
if not self._client:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
|
||||
if not channel:
|
||||
return {"name": str(chat_id), "type": "dm"}
|
||||
|
||||
# Determine channel type
|
||||
if isinstance(channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
name = channel.recipient.name if channel.recipient else str(chat_id)
|
||||
elif isinstance(channel, discord.Thread):
|
||||
chat_type = "thread"
|
||||
name = channel.name
|
||||
elif isinstance(channel, discord.TextChannel):
|
||||
chat_type = "channel"
|
||||
name = f"#{channel.name}"
|
||||
if channel.guild:
|
||||
name = f"{channel.guild.name} / {name}"
|
||||
else:
|
||||
chat_type = "channel"
|
||||
name = getattr(channel, "name", str(chat_id))
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"type": chat_type,
|
||||
"guild_id": str(channel.guild.id) if hasattr(channel, "guild") and channel.guild else None,
|
||||
"guild_name": channel.guild.name if hasattr(channel, "guild") and channel.guild else None,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format message for Discord.
|
||||
|
||||
Discord uses its own markdown variant.
|
||||
"""
|
||||
# Discord markdown is fairly standard, no special escaping needed
|
||||
return content
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if message.content.startswith("/"):
|
||||
msg_type = MessageType.COMMAND
|
||||
elif message.attachments:
|
||||
# Check attachment types
|
||||
for att in message.attachments:
|
||||
if att.content_type:
|
||||
if att.content_type.startswith("image/"):
|
||||
msg_type = MessageType.PHOTO
|
||||
elif att.content_type.startswith("video/"):
|
||||
msg_type = MessageType.VIDEO
|
||||
elif att.content_type.startswith("audio/"):
|
||||
msg_type = MessageType.AUDIO
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
# Determine chat type
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
chat_type = "dm"
|
||||
chat_name = message.author.name
|
||||
elif isinstance(message.channel, discord.Thread):
|
||||
chat_type = "thread"
|
||||
chat_name = message.channel.name
|
||||
else:
|
||||
chat_type = "group" # Treat server channels as groups
|
||||
chat_name = getattr(message.channel, "name", str(message.channel.id))
|
||||
if hasattr(message.channel, "guild") and message.channel.guild:
|
||||
chat_name = f"{message.channel.guild.name} / #{chat_name}"
|
||||
|
||||
# Get thread ID if in a thread
|
||||
thread_id = None
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
thread_id = str(message.channel.id)
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(message.channel.id),
|
||||
chat_name=chat_name,
|
||||
chat_type=chat_type,
|
||||
user_id=str(message.author.id),
|
||||
user_name=message.author.display_name,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
# Build media URLs
|
||||
media_urls = [att.url for att in message.attachments]
|
||||
media_types = [att.content_type or "unknown" for att in message.attachments]
|
||||
|
||||
event = MessageEvent(
|
||||
text=message.content,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=message,
|
||||
message_id=str(message.id),
|
||||
media_urls=media_urls,
|
||||
media_types=media_types,
|
||||
reply_to_message_id=str(message.reference.message_id) if message.reference else None,
|
||||
timestamp=message.created_at,
|
||||
)
|
||||
|
||||
await self.handle_message(event)
|
||||
298
gateway/platforms/telegram.py
Normal file
298
gateway/platforms/telegram.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
Telegram platform adapter.
|
||||
|
||||
Uses python-telegram-bot library for:
|
||||
- Receiving messages from users/groups
|
||||
- Sending responses back
|
||||
- Handling media and commands
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
try:
|
||||
from telegram import Update, Bot, Message
|
||||
from telegram.ext import (
|
||||
Application,
|
||||
CommandHandler,
|
||||
MessageHandler as TelegramMessageHandler,
|
||||
ContextTypes,
|
||||
filters,
|
||||
)
|
||||
from telegram.constants import ParseMode, ChatType
|
||||
TELEGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TELEGRAM_AVAILABLE = False
|
||||
Update = Any
|
||||
Bot = Any
|
||||
Message = Any
|
||||
Application = Any
|
||||
ContextTypes = Any
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(__file__).rsplit("/", 3)[0])
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
|
||||
def check_telegram_requirements() -> bool:
|
||||
"""Check if Telegram dependencies are available."""
|
||||
return TELEGRAM_AVAILABLE
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
|
||||
Handles:
|
||||
- Receiving messages from users and groups
|
||||
- Sending responses with Telegram markdown
|
||||
- Forum topics (thread_id support)
|
||||
- Media messages
|
||||
"""
|
||||
|
||||
# Telegram message limits
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.TELEGRAM)
|
||||
self._app: Optional[Application] = None
|
||||
self._bot: Optional[Bot] = None
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Telegram and start polling for updates."""
|
||||
if not TELEGRAM_AVAILABLE:
|
||||
print(f"[{self.name}] python-telegram-bot not installed. Run: pip install python-telegram-bot")
|
||||
return False
|
||||
|
||||
if not self.config.token:
|
||||
print(f"[{self.name}] No bot token configured")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Build the application
|
||||
self._app = Application.builder().token(self.config.token).build()
|
||||
self._bot = self._app.bot
|
||||
|
||||
# Register handlers
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.TEXT & ~filters.COMMAND,
|
||||
self._handle_text_message
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.COMMAND,
|
||||
self._handle_command
|
||||
))
|
||||
self._app.add_handler(TelegramMessageHandler(
|
||||
filters.PHOTO | filters.VIDEO | filters.AUDIO | filters.VOICE | filters.Document.ALL,
|
||||
self._handle_media_message
|
||||
))
|
||||
|
||||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
await self._app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
|
||||
|
||||
self._running = True
|
||||
print(f"[{self.name}] Connected and polling for updates")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to connect: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop polling and disconnect."""
|
||||
if self._app:
|
||||
try:
|
||||
await self._app.updater.stop()
|
||||
await self._app.stop()
|
||||
await self._app.shutdown()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error during disconnect: {e}")
|
||||
|
||||
self._running = False
|
||||
self._app = None
|
||||
self._bot = None
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> SendResult:
|
||||
"""Send a message to a Telegram chat."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
message_ids = []
|
||||
thread_id = metadata.get("thread_id") if metadata else None
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Try Markdown first, fall back to plain text if it fails
|
||||
try:
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
parse_mode=None, # Plain text
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
else:
|
||||
raise # Re-raise if not a parse error
|
||||
message_ids.append(str(msg.message_id))
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0] if message_ids else None,
|
||||
raw_response={"message_ids": message_ids}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._bot:
|
||||
try:
|
||||
await self._bot.send_chat_action(
|
||||
chat_id=int(chat_id),
|
||||
action="typing"
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Telegram chat."""
|
||||
if not self._bot:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
try:
|
||||
chat = await self._bot.get_chat(int(chat_id))
|
||||
|
||||
chat_type = "dm"
|
||||
if chat.type == ChatType.GROUP:
|
||||
chat_type = "group"
|
||||
elif chat.type == ChatType.SUPERGROUP:
|
||||
chat_type = "group"
|
||||
if chat.is_forum:
|
||||
chat_type = "forum"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
return {
|
||||
"name": chat.title or chat.full_name or str(chat_id),
|
||||
"type": chat_type,
|
||||
"username": chat.username,
|
||||
"is_forum": getattr(chat, "is_forum", False),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"name": str(chat_id), "type": "dm", "error": str(e)}
|
||||
|
||||
def format_message(self, content: str) -> str:
|
||||
"""
|
||||
Format message for Telegram.
|
||||
|
||||
Telegram uses a subset of markdown. We'll use the simpler
|
||||
Markdown mode (not MarkdownV2) for compatibility.
|
||||
"""
|
||||
# Basic escaping for Telegram Markdown
|
||||
# In Markdown mode (not V2), only certain characters need escaping
|
||||
return content
|
||||
|
||||
async def _handle_text_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming text messages."""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.TEXT)
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming command messages."""
|
||||
if not update.message or not update.message.text:
|
||||
return
|
||||
|
||||
event = self._build_message_event(update.message, MessageType.COMMAND)
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _handle_media_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
"""Handle incoming media messages."""
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
msg = update.message
|
||||
|
||||
# Determine media type
|
||||
if msg.photo:
|
||||
msg_type = MessageType.PHOTO
|
||||
elif msg.video:
|
||||
msg_type = MessageType.VIDEO
|
||||
elif msg.audio:
|
||||
msg_type = MessageType.AUDIO
|
||||
elif msg.voice:
|
||||
msg_type = MessageType.VOICE
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
event = self._build_message_event(msg, msg_type)
|
||||
|
||||
# Add caption as text
|
||||
if msg.caption:
|
||||
event.text = msg.caption
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
def _build_message_event(self, message: Message, msg_type: MessageType) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Telegram message."""
|
||||
chat = message.chat
|
||||
user = message.from_user
|
||||
|
||||
# Determine chat type
|
||||
chat_type = "dm"
|
||||
if chat.type in (ChatType.GROUP, ChatType.SUPERGROUP):
|
||||
chat_type = "group"
|
||||
elif chat.type == ChatType.CHANNEL:
|
||||
chat_type = "channel"
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=str(chat.id),
|
||||
chat_name=chat.title or (chat.full_name if hasattr(chat, "full_name") else None),
|
||||
chat_type=chat_type,
|
||||
user_id=str(user.id) if user else None,
|
||||
user_name=user.full_name if user else None,
|
||||
thread_id=str(message.message_thread_id) if message.message_thread_id else None,
|
||||
)
|
||||
|
||||
return MessageEvent(
|
||||
text=message.text or "",
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=message,
|
||||
message_id=str(message.message_id),
|
||||
timestamp=message.date,
|
||||
)
|
||||
327
gateway/platforms/whatsapp.py
Normal file
327
gateway/platforms/whatsapp.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
WhatsApp platform adapter.
|
||||
|
||||
WhatsApp integration is more complex than Telegram/Discord because:
|
||||
- No official bot API for personal accounts
|
||||
- Business API requires Meta Business verification
|
||||
- Most solutions use web-based automation
|
||||
|
||||
This adapter supports multiple backends:
|
||||
1. WhatsApp Business API (requires Meta verification)
|
||||
2. whatsapp-web.js (via Node.js subprocess) - for personal accounts
|
||||
3. Baileys (via Node.js subprocess) - alternative for personal accounts
|
||||
|
||||
For simplicity, we'll implement a generic interface that can work
|
||||
with different backends via a bridge pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
import sys
|
||||
sys.path.insert(0, str(__file__).rsplit("/", 3)[0])
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
)
|
||||
|
||||
|
||||
def check_whatsapp_requirements() -> bool:
|
||||
"""
|
||||
Check if WhatsApp dependencies are available.
|
||||
|
||||
WhatsApp requires a Node.js bridge for most implementations.
|
||||
"""
|
||||
# Check for Node.js
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["node", "--version"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
return result.returncode == 0
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
WhatsApp adapter.
|
||||
|
||||
This implementation uses a simple HTTP bridge pattern where:
|
||||
1. A Node.js process runs the WhatsApp Web client
|
||||
2. Messages are forwarded via HTTP/IPC to this Python adapter
|
||||
3. Responses are sent back through the bridge
|
||||
|
||||
The actual Node.js bridge implementation can vary:
|
||||
- whatsapp-web.js based
|
||||
- Baileys based
|
||||
- Business API based
|
||||
|
||||
Configuration:
|
||||
- bridge_script: Path to the Node.js bridge script
|
||||
- bridge_port: Port for HTTP communication (default: 3000)
|
||||
- session_path: Path to store WhatsApp session data
|
||||
"""
|
||||
|
||||
# WhatsApp message limits
|
||||
MAX_MESSAGE_LENGTH = 65536 # WhatsApp allows longer messages
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.WHATSAPP)
|
||||
self._bridge_process: Optional[subprocess.Popen] = None
|
||||
self._bridge_port: int = config.extra.get("bridge_port", 3000)
|
||||
self._bridge_script: Optional[str] = config.extra.get("bridge_script")
|
||||
self._session_path: Path = Path(config.extra.get(
|
||||
"session_path",
|
||||
Path.home() / ".hermes" / "whatsapp" / "session"
|
||||
))
|
||||
self._message_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""
|
||||
Start the WhatsApp bridge.
|
||||
|
||||
This launches the Node.js bridge process and waits for it to be ready.
|
||||
"""
|
||||
if not check_whatsapp_requirements():
|
||||
print(f"[{self.name}] Node.js not found. WhatsApp requires Node.js.")
|
||||
return False
|
||||
|
||||
if not self._bridge_script:
|
||||
print(f"[{self.name}] No bridge script configured.")
|
||||
print(f"[{self.name}] Set 'bridge_script' in whatsapp.extra config.")
|
||||
print(f"[{self.name}] See docs/messaging.md for WhatsApp setup instructions.")
|
||||
return False
|
||||
|
||||
bridge_path = Path(self._bridge_script)
|
||||
if not bridge_path.exists():
|
||||
print(f"[{self.name}] Bridge script not found: {bridge_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Ensure session directory exists
|
||||
self._session_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Start the bridge process
|
||||
self._bridge_process = subprocess.Popen(
|
||||
[
|
||||
"node",
|
||||
str(bridge_path),
|
||||
"--port", str(self._bridge_port),
|
||||
"--session", str(self._session_path),
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
|
||||
# Wait for bridge to be ready (look for ready signal)
|
||||
# This is a simplified version - real implementation would
|
||||
# wait for an HTTP health check or specific stdout message
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if self._bridge_process.poll() is not None:
|
||||
stderr = self._bridge_process.stderr.read() if self._bridge_process.stderr else ""
|
||||
print(f"[{self.name}] Bridge process died: {stderr}")
|
||||
return False
|
||||
|
||||
# Start message polling task
|
||||
asyncio.create_task(self._poll_messages())
|
||||
|
||||
self._running = True
|
||||
print(f"[{self.name}] Bridge started on port {self._bridge_port}")
|
||||
print(f"[{self.name}] Scan QR code if prompted (check bridge output)")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to start bridge: {e}")
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop the WhatsApp bridge."""
|
||||
if self._bridge_process:
|
||||
try:
|
||||
self._bridge_process.terminate()
|
||||
await asyncio.sleep(1)
|
||||
if self._bridge_process.poll() is None:
|
||||
self._bridge_process.kill()
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error stopping bridge: {e}")
|
||||
|
||||
self._running = False
|
||||
self._bridge_process = None
|
||||
print(f"[{self.name}] Disconnected")
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> SendResult:
|
||||
"""Send a message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {
|
||||
"chatId": chat_id,
|
||||
"message": content,
|
||||
}
|
||||
if reply_to:
|
||||
payload["replyTo"] = reply_to
|
||||
|
||||
async with session.post(
|
||||
f"http://localhost:{self._bridge_port}/send",
|
||||
json=payload,
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=data.get("messageId"),
|
||||
raw_response=data
|
||||
)
|
||||
else:
|
||||
error = await resp.text()
|
||||
return SendResult(success=False, error=error)
|
||||
|
||||
except ImportError:
|
||||
return SendResult(
|
||||
success=False,
|
||||
error="aiohttp not installed. Run: pip install aiohttp"
|
||||
)
|
||||
except Exception as e:
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send typing indicator via bridge."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
await session.post(
|
||||
f"http://localhost:{self._bridge_port}/typing",
|
||||
json={"chatId": chat_id},
|
||||
timeout=aiohttp.ClientTimeout(total=5)
|
||||
)
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a WhatsApp chat."""
|
||||
if not self._running:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/chat/{chat_id}",
|
||||
timeout=aiohttp.ClientTimeout(total=10)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
data = await resp.json()
|
||||
return {
|
||||
"name": data.get("name", chat_id),
|
||||
"type": "group" if data.get("isGroup") else "dm",
|
||||
"participants": data.get("participants", []),
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
async def _poll_messages(self) -> None:
|
||||
"""Poll the bridge for incoming messages."""
|
||||
try:
|
||||
import aiohttp
|
||||
except ImportError:
|
||||
print(f"[{self.name}] aiohttp not installed, message polling disabled")
|
||||
return
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
f"http://localhost:{self._bridge_port}/messages",
|
||||
timeout=aiohttp.ClientTimeout(total=30)
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
messages = await resp.json()
|
||||
for msg_data in messages:
|
||||
event = self._build_message_event(msg_data)
|
||||
if event:
|
||||
await self.handle_message(event)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Poll error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
await asyncio.sleep(1) # Poll interval
|
||||
|
||||
def _build_message_event(self, data: Dict[str, Any]) -> Optional[MessageEvent]:
|
||||
"""Build a MessageEvent from bridge message data."""
|
||||
try:
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if data.get("hasMedia"):
|
||||
media_type = data.get("mediaType", "")
|
||||
if "image" in media_type:
|
||||
msg_type = MessageType.PHOTO
|
||||
elif "video" in media_type:
|
||||
msg_type = MessageType.VIDEO
|
||||
elif "audio" in media_type or "ptt" in media_type: # ptt = voice note
|
||||
msg_type = MessageType.VOICE
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
|
||||
# Determine chat type
|
||||
is_group = data.get("isGroup", False)
|
||||
chat_type = "group" if is_group else "dm"
|
||||
|
||||
# Build source
|
||||
source = self.build_source(
|
||||
chat_id=data.get("chatId", ""),
|
||||
chat_name=data.get("chatName"),
|
||||
chat_type=chat_type,
|
||||
user_id=data.get("senderId"),
|
||||
user_name=data.get("senderName"),
|
||||
)
|
||||
|
||||
return MessageEvent(
|
||||
text=data.get("body", ""),
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=data,
|
||||
message_id=data.get("messageId"),
|
||||
media_urls=data.get("mediaUrls", []),
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error building event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Note: A reference Node.js bridge script would be provided in scripts/whatsapp-bridge/
|
||||
# It would use whatsapp-web.js or Baileys to:
|
||||
# 1. Handle WhatsApp Web authentication (QR code)
|
||||
# 2. Listen for incoming messages
|
||||
# 3. Expose HTTP endpoints for send/receive/status
|
||||
666
gateway/run.py
Normal file
666
gateway/run.py
Normal file
@@ -0,0 +1,666 @@
|
||||
"""
|
||||
Gateway runner - entry point for messaging platform integrations.
|
||||
|
||||
This module provides:
|
||||
- start_gateway(): Start all configured platform adapters
|
||||
- GatewayRunner: Main class managing the gateway lifecycle
|
||||
|
||||
Usage:
|
||||
# Start the gateway
|
||||
python -m gateway.run
|
||||
|
||||
# Or from CLI
|
||||
python cli.py --gateway
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Any, List
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Load environment variables from ~/.hermes/.env first
|
||||
from dotenv import load_dotenv
|
||||
_env_path = Path.home() / '.hermes' / '.env'
|
||||
if _env_path.exists():
|
||||
load_dotenv(_env_path)
|
||||
# Also try project .env as fallback
|
||||
load_dotenv()
|
||||
|
||||
# Gateway runs in quiet mode - suppress debug output and use cwd directly (no temp dirs)
|
||||
os.environ["HERMES_QUIET"] = "1"
|
||||
|
||||
# Set terminal working directory for messaging platforms
|
||||
# Uses MESSAGING_CWD if set, otherwise defaults to home directory
|
||||
# This is separate from CLI which uses the directory where `hermes` is run
|
||||
messaging_cwd = os.getenv("MESSAGING_CWD") or str(Path.home())
|
||||
os.environ["TERMINAL_CWD"] = messaging_cwd
|
||||
|
||||
from gateway.config import (
|
||||
Platform,
|
||||
GatewayConfig,
|
||||
load_gateway_config,
|
||||
)
|
||||
from gateway.session import (
|
||||
SessionStore,
|
||||
SessionSource,
|
||||
SessionContext,
|
||||
build_session_context,
|
||||
build_session_context_prompt,
|
||||
)
|
||||
from gateway.delivery import DeliveryRouter, DeliveryTarget
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent
|
||||
|
||||
|
||||
class GatewayRunner:
|
||||
"""
|
||||
Main gateway controller.
|
||||
|
||||
Manages the lifecycle of all platform adapters and routes
|
||||
messages to/from the agent.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[GatewayConfig] = None):
|
||||
self.config = config or load_gateway_config()
|
||||
self.adapters: Dict[Platform, BasePlatformAdapter] = {}
|
||||
self.session_store = SessionStore(self.config.sessions_dir, self.config)
|
||||
self.delivery_router = DeliveryRouter(self.config)
|
||||
self._running = False
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
# Track running agents per session for interrupt support
|
||||
# Key: session_key, Value: AIAgent instance
|
||||
self._running_agents: Dict[str, Any] = {}
|
||||
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""
|
||||
Start the gateway and all configured platform adapters.
|
||||
|
||||
Returns True if at least one adapter connected successfully.
|
||||
"""
|
||||
print("[gateway] Starting Hermes Gateway...")
|
||||
print(f"[gateway] Session storage: {self.config.sessions_dir}")
|
||||
|
||||
connected_count = 0
|
||||
|
||||
# Initialize and connect each configured platform
|
||||
for platform, platform_config in self.config.platforms.items():
|
||||
if not platform_config.enabled:
|
||||
continue
|
||||
|
||||
adapter = self._create_adapter(platform, platform_config)
|
||||
if not adapter:
|
||||
print(f"[gateway] No adapter available for {platform.value}")
|
||||
continue
|
||||
|
||||
# Set up message handler
|
||||
adapter.set_message_handler(self._handle_message)
|
||||
|
||||
# Try to connect
|
||||
print(f"[gateway] Connecting to {platform.value}...")
|
||||
try:
|
||||
success = await adapter.connect()
|
||||
if success:
|
||||
self.adapters[platform] = adapter
|
||||
connected_count += 1
|
||||
print(f"[gateway] ✓ {platform.value} connected")
|
||||
else:
|
||||
print(f"[gateway] ✗ {platform.value} failed to connect")
|
||||
except Exception as e:
|
||||
print(f"[gateway] ✗ {platform.value} error: {e}")
|
||||
|
||||
if connected_count == 0:
|
||||
print("[gateway] No platforms connected. Check your configuration.")
|
||||
return False
|
||||
|
||||
# Update delivery router with adapters
|
||||
self.delivery_router.adapters = self.adapters
|
||||
|
||||
self._running = True
|
||||
print(f"[gateway] Gateway running with {connected_count} platform(s)")
|
||||
print("[gateway] Press Ctrl+C to stop")
|
||||
|
||||
return True
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the gateway and disconnect all adapters."""
|
||||
print("[gateway] Stopping gateway...")
|
||||
self._running = False
|
||||
|
||||
for platform, adapter in self.adapters.items():
|
||||
try:
|
||||
await adapter.disconnect()
|
||||
print(f"[gateway] ✓ {platform.value} disconnected")
|
||||
except Exception as e:
|
||||
print(f"[gateway] ✗ {platform.value} disconnect error: {e}")
|
||||
|
||||
self.adapters.clear()
|
||||
self._shutdown_event.set()
|
||||
print("[gateway] Gateway stopped")
|
||||
|
||||
async def wait_for_shutdown(self) -> None:
|
||||
"""Wait for shutdown signal."""
|
||||
await self._shutdown_event.wait()
|
||||
|
||||
def _create_adapter(
|
||||
self,
|
||||
platform: Platform,
|
||||
config: Any
|
||||
) -> Optional[BasePlatformAdapter]:
|
||||
"""Create the appropriate adapter for a platform."""
|
||||
if platform == Platform.TELEGRAM:
|
||||
from gateway.platforms.telegram import TelegramAdapter, check_telegram_requirements
|
||||
if not check_telegram_requirements():
|
||||
print(f"[gateway] Telegram: python-telegram-bot not installed")
|
||||
return None
|
||||
return TelegramAdapter(config)
|
||||
|
||||
elif platform == Platform.DISCORD:
|
||||
from gateway.platforms.discord import DiscordAdapter, check_discord_requirements
|
||||
if not check_discord_requirements():
|
||||
print(f"[gateway] Discord: discord.py not installed")
|
||||
return None
|
||||
return DiscordAdapter(config)
|
||||
|
||||
elif platform == Platform.WHATSAPP:
|
||||
from gateway.platforms.whatsapp import WhatsAppAdapter, check_whatsapp_requirements
|
||||
if not check_whatsapp_requirements():
|
||||
print(f"[gateway] WhatsApp: Node.js not installed or bridge not configured")
|
||||
return None
|
||||
return WhatsAppAdapter(config)
|
||||
|
||||
return None
|
||||
|
||||
def _is_user_authorized(self, source: SessionSource) -> bool:
|
||||
"""
|
||||
Check if a user is authorized to use the bot.
|
||||
|
||||
Authorization is checked via environment variables:
|
||||
- GATEWAY_ALLOWED_USERS: Comma-separated list of user IDs (all platforms)
|
||||
- TELEGRAM_ALLOWED_USERS: Telegram-specific user IDs
|
||||
- DISCORD_ALLOWED_USERS: Discord-specific user IDs
|
||||
|
||||
If no allowlist is configured, all users are allowed (open access).
|
||||
"""
|
||||
user_id = source.user_id
|
||||
if not user_id:
|
||||
return False # Can't verify unknown users
|
||||
|
||||
# Check platform-specific allowlist first
|
||||
platform_env_map = {
|
||||
Platform.TELEGRAM: "TELEGRAM_ALLOWED_USERS",
|
||||
Platform.DISCORD: "DISCORD_ALLOWED_USERS",
|
||||
Platform.WHATSAPP: "WHATSAPP_ALLOWED_USERS",
|
||||
}
|
||||
|
||||
platform_allowlist = os.getenv(platform_env_map.get(source.platform, ""))
|
||||
global_allowlist = os.getenv("GATEWAY_ALLOWED_USERS", "")
|
||||
|
||||
# If no allowlists configured, allow all (backward compatible)
|
||||
if not platform_allowlist and not global_allowlist:
|
||||
return True
|
||||
|
||||
# Check if user is in any allowlist
|
||||
allowed_ids = set()
|
||||
if platform_allowlist:
|
||||
allowed_ids.update(uid.strip() for uid in platform_allowlist.split(","))
|
||||
if global_allowlist:
|
||||
allowed_ids.update(uid.strip() for uid in global_allowlist.split(","))
|
||||
|
||||
return user_id in allowed_ids
|
||||
|
||||
async def _handle_message(self, event: MessageEvent) -> Optional[str]:
|
||||
"""
|
||||
Handle an incoming message from any platform.
|
||||
|
||||
This is the core message processing pipeline:
|
||||
1. Check user authorization
|
||||
2. Check for commands (/new, /reset, etc.)
|
||||
3. Check for running agent and interrupt if needed
|
||||
4. Get or create session
|
||||
5. Build context for agent
|
||||
6. Run agent conversation
|
||||
7. Return response
|
||||
"""
|
||||
source = event.source
|
||||
|
||||
# Check if user is authorized
|
||||
if not self._is_user_authorized(source):
|
||||
print(f"[gateway] Unauthorized user: {source.user_id} ({source.user_name}) on {source.platform.value}")
|
||||
return None # Silently ignore unauthorized users
|
||||
|
||||
# Check for commands
|
||||
command = event.get_command()
|
||||
if command in ["new", "reset"]:
|
||||
return await self._handle_reset_command(event)
|
||||
|
||||
if command == "status":
|
||||
return await self._handle_status_command(event)
|
||||
|
||||
if command == "stop":
|
||||
return await self._handle_stop_command(event)
|
||||
|
||||
# Get or create session
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
|
||||
# Check if there's already a running agent for this session
|
||||
if session_key in self._running_agents:
|
||||
running_agent = self._running_agents[session_key]
|
||||
print(f"[gateway] ⚡ Interrupting running agent for session {session_key[:20]}...")
|
||||
running_agent.interrupt(event.text)
|
||||
# Store the new message to be processed after current agent finishes
|
||||
self._pending_messages[session_key] = event.text
|
||||
return None # Don't respond yet - let the interrupt handle it
|
||||
|
||||
# Build session context
|
||||
context = build_session_context(source, self.config, session_entry)
|
||||
|
||||
# Set environment variables for tools
|
||||
self._set_session_env(context)
|
||||
|
||||
# Build the context prompt to inject
|
||||
context_prompt = build_session_context_prompt(context)
|
||||
|
||||
# Load conversation history from transcript
|
||||
history = self.session_store.load_transcript(session_entry.session_id)
|
||||
|
||||
try:
|
||||
# Run the agent
|
||||
response = await self._run_agent(
|
||||
message=event.text,
|
||||
context_prompt=context_prompt,
|
||||
history=history,
|
||||
source=source,
|
||||
session_id=session_entry.session_id,
|
||||
session_key=session_key
|
||||
)
|
||||
|
||||
# Append to transcript
|
||||
self.session_store.append_to_transcript(
|
||||
session_entry.session_id,
|
||||
{"role": "user", "content": event.text, "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
self.session_store.append_to_transcript(
|
||||
session_entry.session_id,
|
||||
{"role": "assistant", "content": response, "timestamp": datetime.now().isoformat()}
|
||||
)
|
||||
|
||||
# Update session
|
||||
self.session_store.update_session(session_entry.session_key)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
print(f"[gateway] Agent error: {e}")
|
||||
return f"Sorry, I encountered an error: {str(e)}"
|
||||
finally:
|
||||
# Clear session env
|
||||
self._clear_session_env()
|
||||
|
||||
async def _handle_reset_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /new or /reset command."""
|
||||
source = event.source
|
||||
|
||||
# Get existing session key
|
||||
session_key = f"agent:main:{source.platform.value}:" + \
|
||||
(f"dm" if source.chat_type == "dm" else f"{source.chat_type}:{source.chat_id}")
|
||||
|
||||
# Reset the session
|
||||
new_entry = self.session_store.reset_session(session_key)
|
||||
|
||||
if new_entry:
|
||||
return "✨ Session reset! I've started fresh with no memory of our previous conversation."
|
||||
else:
|
||||
# No existing session, just create one
|
||||
self.session_store.get_or_create_session(source, force_new=True)
|
||||
return "✨ New session started!"
|
||||
|
||||
async def _handle_status_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /status command."""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
|
||||
connected_platforms = [p.value for p in self.adapters.keys()]
|
||||
|
||||
# Check if there's an active agent
|
||||
session_key = session_entry.session_key
|
||||
is_running = session_key in self._running_agents
|
||||
|
||||
lines = [
|
||||
"📊 **Hermes Gateway Status**",
|
||||
"",
|
||||
f"**Session ID:** `{session_entry.session_id[:12]}...`",
|
||||
f"**Created:** {session_entry.created_at.strftime('%Y-%m-%d %H:%M')}",
|
||||
f"**Last Activity:** {session_entry.updated_at.strftime('%Y-%m-%d %H:%M')}",
|
||||
f"**Tokens:** {session_entry.total_tokens:,}",
|
||||
f"**Agent Running:** {'Yes ⚡' if is_running else 'No'}",
|
||||
"",
|
||||
f"**Connected Platforms:** {', '.join(connected_platforms)}",
|
||||
]
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_stop_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /stop command - interrupt a running agent."""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_key = session_entry.session_key
|
||||
|
||||
if session_key in self._running_agents:
|
||||
agent = self._running_agents[session_key]
|
||||
agent.interrupt()
|
||||
return "⚡ Stopping the current task... The agent will finish its current step and respond."
|
||||
else:
|
||||
return "No active task to stop."
|
||||
|
||||
def _set_session_env(self, context: SessionContext) -> None:
|
||||
"""Set environment variables for the current session."""
|
||||
os.environ["HERMES_SESSION_PLATFORM"] = context.source.platform.value
|
||||
os.environ["HERMES_SESSION_CHAT_ID"] = context.source.chat_id
|
||||
if context.source.chat_name:
|
||||
os.environ["HERMES_SESSION_CHAT_NAME"] = context.source.chat_name
|
||||
|
||||
def _clear_session_env(self) -> None:
|
||||
"""Clear session environment variables."""
|
||||
for var in ["HERMES_SESSION_PLATFORM", "HERMES_SESSION_CHAT_ID", "HERMES_SESSION_CHAT_NAME"]:
|
||||
if var in os.environ:
|
||||
del os.environ[var]
|
||||
|
||||
async def _run_agent(
|
||||
self,
|
||||
message: str,
|
||||
context_prompt: str,
|
||||
history: List[Dict[str, Any]],
|
||||
source: SessionSource,
|
||||
session_id: str,
|
||||
session_key: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Run the agent with the given message and context.
|
||||
|
||||
This is run in a thread pool to not block the event loop.
|
||||
Supports interruption via new messages.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
import queue
|
||||
|
||||
# Determine toolset based on platform
|
||||
toolset_map = {
|
||||
Platform.LOCAL: "hermes-cli",
|
||||
Platform.TELEGRAM: "hermes-telegram",
|
||||
Platform.DISCORD: "hermes-discord",
|
||||
Platform.WHATSAPP: "hermes-whatsapp",
|
||||
}
|
||||
toolset = toolset_map.get(source.platform, "hermes-telegram")
|
||||
|
||||
# Check if tool progress notifications are enabled
|
||||
tool_progress_enabled = os.getenv("HERMES_TOOL_PROGRESS", "").lower() in ("1", "true", "yes")
|
||||
progress_mode = os.getenv("HERMES_TOOL_PROGRESS_MODE", "new") # "all" or "new" (only new tools)
|
||||
|
||||
# Queue for progress messages (thread-safe)
|
||||
progress_queue = queue.Queue() if tool_progress_enabled else None
|
||||
last_tool = [None] # Mutable container for tracking in closure
|
||||
|
||||
def progress_callback(tool_name: str, preview: str = None):
|
||||
"""Callback invoked by agent when a tool is called."""
|
||||
if not progress_queue:
|
||||
return
|
||||
|
||||
# "new" mode: only report when tool changes
|
||||
if progress_mode == "new" and tool_name == last_tool[0]:
|
||||
return
|
||||
last_tool[0] = tool_name
|
||||
|
||||
# Build progress message
|
||||
tool_emojis = {
|
||||
"terminal": "💻",
|
||||
"web_search": "🔍",
|
||||
"web_extract": "📄",
|
||||
"read_file": "📖",
|
||||
"write_file": "✍️",
|
||||
"list_directory": "📂",
|
||||
"image_generate": "🎨",
|
||||
"browser_navigate": "🌐",
|
||||
"browser_click": "👆",
|
||||
"moa_query": "🧠",
|
||||
}
|
||||
emoji = tool_emojis.get(tool_name, "⚙️")
|
||||
|
||||
if tool_name == "terminal" and preview:
|
||||
msg = f"{emoji} `{preview}`..."
|
||||
else:
|
||||
msg = f"{emoji} {tool_name}..."
|
||||
|
||||
progress_queue.put(msg)
|
||||
|
||||
# Background task to send progress messages
|
||||
async def send_progress_messages():
|
||||
if not progress_queue:
|
||||
return
|
||||
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if not adapter:
|
||||
return
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Non-blocking check with small timeout
|
||||
msg = progress_queue.get_nowait()
|
||||
await adapter.send(chat_id=source.chat_id, content=msg)
|
||||
# Restore typing indicator after sending progress message
|
||||
await asyncio.sleep(0.3)
|
||||
await adapter.send_typing(source.chat_id)
|
||||
except queue.Empty:
|
||||
await asyncio.sleep(0.3) # Check again soon
|
||||
except asyncio.CancelledError:
|
||||
# Drain remaining messages
|
||||
while not progress_queue.empty():
|
||||
try:
|
||||
msg = progress_queue.get_nowait()
|
||||
await adapter.send(chat_id=source.chat_id, content=msg)
|
||||
except:
|
||||
break
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"[Gateway] Progress message error: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# We need to share the agent instance for interrupt support
|
||||
agent_holder = [None] # Mutable container for the agent instance
|
||||
result_holder = [None] # Mutable container for the result
|
||||
|
||||
def run_sync():
|
||||
# Read from env var or use default (same as CLI)
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
||||
|
||||
agent = AIAgent(
|
||||
model=os.getenv("HERMES_MODEL", "anthropic/claude-sonnet-4"),
|
||||
max_iterations=max_iterations,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=[toolset],
|
||||
ephemeral_system_prompt=context_prompt,
|
||||
session_id=session_id,
|
||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||
)
|
||||
|
||||
# Store agent reference for interrupt support
|
||||
agent_holder[0] = agent
|
||||
|
||||
# Convert transcript history to agent format
|
||||
# Transcript has timestamps; agent expects {"role": ..., "content": ...}
|
||||
agent_history = []
|
||||
for msg in history:
|
||||
role = msg.get("role")
|
||||
content = msg.get("content")
|
||||
if role and content:
|
||||
agent_history.append({"role": role, "content": content})
|
||||
|
||||
result = agent.run_conversation(message, conversation_history=agent_history)
|
||||
result_holder[0] = result
|
||||
|
||||
# Return final response, or a message if something went wrong
|
||||
final_response = result.get("final_response")
|
||||
if final_response:
|
||||
return final_response
|
||||
elif result.get("error"):
|
||||
# Agent couldn't recover - show the error
|
||||
return f"⚠️ {result['error']}"
|
||||
else:
|
||||
return "(No response generated)"
|
||||
|
||||
# Start progress message sender if enabled
|
||||
progress_task = None
|
||||
if tool_progress_enabled:
|
||||
progress_task = asyncio.create_task(send_progress_messages())
|
||||
|
||||
# Track this agent as running for this session (for interrupt support)
|
||||
# We do this in a callback after the agent is created
|
||||
async def track_agent():
|
||||
# Wait for agent to be created
|
||||
while agent_holder[0] is None:
|
||||
await asyncio.sleep(0.05)
|
||||
if session_key:
|
||||
self._running_agents[session_key] = agent_holder[0]
|
||||
|
||||
tracking_task = asyncio.create_task(track_agent())
|
||||
|
||||
# Monitor for interrupts from the adapter (new messages arriving)
|
||||
async def monitor_for_interrupt():
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if not adapter:
|
||||
return
|
||||
|
||||
chat_id = source.chat_id
|
||||
while True:
|
||||
await asyncio.sleep(0.2) # Check every 200ms
|
||||
# Check if adapter has a pending interrupt for this session
|
||||
if hasattr(adapter, 'has_pending_interrupt') and adapter.has_pending_interrupt(chat_id):
|
||||
agent = agent_holder[0]
|
||||
if agent:
|
||||
pending_event = adapter.get_pending_message(chat_id)
|
||||
pending_text = pending_event.text if pending_event else None
|
||||
print(f"[gateway] ⚡ Interrupt detected from adapter, signaling agent...")
|
||||
agent.interrupt(pending_text)
|
||||
break
|
||||
|
||||
interrupt_monitor = asyncio.create_task(monitor_for_interrupt())
|
||||
|
||||
try:
|
||||
# Run in thread pool to not block
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(None, run_sync)
|
||||
|
||||
# Check if we were interrupted and have a pending message
|
||||
result = result_holder[0]
|
||||
adapter = self.adapters.get(source.platform)
|
||||
|
||||
# Get pending message from adapter if interrupted
|
||||
pending = None
|
||||
if result and result.get("interrupted") and adapter:
|
||||
pending_event = adapter.get_pending_message(source.chat_id)
|
||||
if pending_event:
|
||||
pending = pending_event.text
|
||||
elif result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
|
||||
if pending:
|
||||
print(f"[gateway] 📨 Processing interrupted message: '{pending[:40]}...'")
|
||||
# Add an indicator to the response
|
||||
if response:
|
||||
response = response + "\n\n---\n_[Interrupted - processing your new message]_"
|
||||
|
||||
# Send the interrupted response first
|
||||
if adapter and response:
|
||||
await adapter.send(chat_id=source.chat_id, content=response)
|
||||
|
||||
# Now process the pending message with updated history
|
||||
updated_history = result.get("messages", history)
|
||||
return await self._run_agent(
|
||||
message=pending,
|
||||
context_prompt=context_prompt,
|
||||
history=updated_history,
|
||||
source=source,
|
||||
session_id=session_id,
|
||||
session_key=session_key
|
||||
)
|
||||
finally:
|
||||
# Stop progress sender and interrupt monitor
|
||||
if progress_task:
|
||||
progress_task.cancel()
|
||||
interrupt_monitor.cancel()
|
||||
|
||||
# Clean up tracking
|
||||
tracking_task.cancel()
|
||||
if session_key and session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
|
||||
# Wait for cancelled tasks
|
||||
for task in [progress_task, interrupt_monitor, tracking_task]:
|
||||
if task:
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None) -> None:
|
||||
"""
|
||||
Start the gateway and run until interrupted.
|
||||
|
||||
This is the main entry point for running the gateway.
|
||||
"""
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
# Set up signal handlers
|
||||
def signal_handler():
|
||||
asyncio.create_task(runner.stop())
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
except NotImplementedError:
|
||||
# Windows doesn't support add_signal_handler
|
||||
pass
|
||||
|
||||
# Start the gateway
|
||||
success = await runner.start()
|
||||
if not success:
|
||||
return
|
||||
|
||||
# Wait for shutdown
|
||||
await runner.wait_for_shutdown()
|
||||
|
||||
|
||||
def main():
|
||||
"""CLI entry point for the gateway."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Hermes Gateway - Multi-platform messaging")
|
||||
parser.add_argument("--config", "-c", help="Path to gateway config file")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
config = None
|
||||
if args.config:
|
||||
import json
|
||||
with open(args.config) as f:
|
||||
data = json.load(f)
|
||||
config = GatewayConfig.from_dict(data)
|
||||
|
||||
# Run the gateway
|
||||
asyncio.run(start_gateway(config))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
522
gateway/session.py
Normal file
522
gateway/session.py
Normal file
@@ -0,0 +1,522 @@
|
||||
"""
|
||||
Session management for the gateway.
|
||||
|
||||
Handles:
|
||||
- Session context tracking (where messages come from)
|
||||
- Session storage (conversations persisted to disk)
|
||||
- Reset policy evaluation (when to start fresh)
|
||||
- Dynamic system prompt injection (agent knows its context)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from .config import (
|
||||
Platform,
|
||||
GatewayConfig,
|
||||
SessionResetPolicy,
|
||||
HomeChannel,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionSource:
|
||||
"""
|
||||
Describes where a message originated from.
|
||||
|
||||
This information is used to:
|
||||
1. Route responses back to the right place
|
||||
2. Inject context into the system prompt
|
||||
3. Track origin for cron job delivery
|
||||
"""
|
||||
platform: Platform
|
||||
chat_id: str
|
||||
chat_name: Optional[str] = None
|
||||
chat_type: str = "dm" # "dm", "group", "channel", "thread"
|
||||
user_id: Optional[str] = None
|
||||
user_name: Optional[str] = None
|
||||
thread_id: Optional[str] = None # For forum topics, Discord threads, etc.
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Human-readable description of the source."""
|
||||
if self.platform == Platform.LOCAL:
|
||||
return "CLI terminal"
|
||||
|
||||
parts = []
|
||||
if self.chat_type == "dm":
|
||||
parts.append(f"DM with {self.user_name or self.user_id or 'user'}")
|
||||
elif self.chat_type == "group":
|
||||
parts.append(f"group: {self.chat_name or self.chat_id}")
|
||||
elif self.chat_type == "channel":
|
||||
parts.append(f"channel: {self.chat_name or self.chat_id}")
|
||||
else:
|
||||
parts.append(self.chat_name or self.chat_id)
|
||||
|
||||
if self.thread_id:
|
||||
parts.append(f"thread: {self.thread_id}")
|
||||
|
||||
return ", ".join(parts)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"platform": self.platform.value,
|
||||
"chat_id": self.chat_id,
|
||||
"chat_name": self.chat_name,
|
||||
"chat_type": self.chat_type,
|
||||
"user_id": self.user_id,
|
||||
"user_name": self.user_name,
|
||||
"thread_id": self.thread_id,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionSource":
|
||||
return cls(
|
||||
platform=Platform(data["platform"]),
|
||||
chat_id=str(data["chat_id"]),
|
||||
chat_name=data.get("chat_name"),
|
||||
chat_type=data.get("chat_type", "dm"),
|
||||
user_id=data.get("user_id"),
|
||||
user_name=data.get("user_name"),
|
||||
thread_id=data.get("thread_id"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def local_cli(cls) -> "SessionSource":
|
||||
"""Create a source representing the local CLI."""
|
||||
return cls(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI terminal",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionContext:
|
||||
"""
|
||||
Full context for a session, used for dynamic system prompt injection.
|
||||
|
||||
The agent receives this information to understand:
|
||||
- Where messages are coming from
|
||||
- What platforms are available
|
||||
- Where it can deliver scheduled task outputs
|
||||
"""
|
||||
source: SessionSource
|
||||
connected_platforms: List[Platform]
|
||||
home_channels: Dict[Platform, HomeChannel]
|
||||
|
||||
# Session metadata
|
||||
session_key: str = ""
|
||||
session_id: str = ""
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"source": self.source.to_dict(),
|
||||
"connected_platforms": [p.value for p in self.connected_platforms],
|
||||
"home_channels": {
|
||||
p.value: hc.to_dict() for p, hc in self.home_channels.items()
|
||||
},
|
||||
"session_key": self.session_key,
|
||||
"session_id": self.session_id,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def build_session_context_prompt(context: SessionContext) -> str:
|
||||
"""
|
||||
Build the dynamic system prompt section that tells the agent about its context.
|
||||
|
||||
This is injected into the system prompt so the agent knows:
|
||||
- Where messages are coming from
|
||||
- What platforms are connected
|
||||
- Where it can deliver scheduled task outputs
|
||||
"""
|
||||
lines = [
|
||||
"## Current Session Context",
|
||||
"",
|
||||
]
|
||||
|
||||
# Source info
|
||||
platform_name = context.source.platform.value.title()
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append(f"**Source:** {platform_name} (the machine running this agent)")
|
||||
else:
|
||||
lines.append(f"**Source:** {platform_name} ({context.source.description})")
|
||||
|
||||
# Connected platforms
|
||||
platforms_list = ["local (files on this machine)"]
|
||||
for p in context.connected_platforms:
|
||||
if p != Platform.LOCAL:
|
||||
platforms_list.append(f"{p.value}: Connected ✓")
|
||||
|
||||
lines.append(f"**Connected Platforms:** {', '.join(platforms_list)}")
|
||||
|
||||
# Home channels
|
||||
if context.home_channels:
|
||||
lines.append("")
|
||||
lines.append("**Home Channels (default destinations):**")
|
||||
for platform, home in context.home_channels.items():
|
||||
lines.append(f" - {platform.value}: {home.name} (ID: {home.chat_id})")
|
||||
|
||||
# Delivery options for scheduled tasks
|
||||
lines.append("")
|
||||
lines.append("**Delivery options for scheduled tasks:**")
|
||||
|
||||
# Origin delivery
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append("- `\"origin\"` → Local output (saved to files)")
|
||||
else:
|
||||
lines.append(f"- `\"origin\"` → Back to this chat ({context.source.chat_name or context.source.chat_id})")
|
||||
|
||||
# Local always available
|
||||
lines.append("- `\"local\"` → Save to local files only (~/.hermes/cron/output/)")
|
||||
|
||||
# Platform home channels
|
||||
for platform, home in context.home_channels.items():
|
||||
lines.append(f"- `\"{platform.value}\"` → Home channel ({home.name})")
|
||||
|
||||
# Note about explicit targeting
|
||||
lines.append("")
|
||||
lines.append("*For explicit targeting, use `\"platform:chat_id\"` format if the user provides a specific chat ID.*")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionEntry:
|
||||
"""
|
||||
Entry in the session store.
|
||||
|
||||
Maps a session key to its current session ID and metadata.
|
||||
"""
|
||||
session_key: str
|
||||
session_id: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
# Origin metadata for delivery routing
|
||||
origin: Optional[SessionSource] = None
|
||||
|
||||
# Display metadata
|
||||
display_name: Optional[str] = None
|
||||
platform: Optional[Platform] = None
|
||||
chat_type: str = "dm"
|
||||
|
||||
# Token tracking
|
||||
input_tokens: int = 0
|
||||
output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
"session_key": self.session_key,
|
||||
"session_id": self.session_id,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"display_name": self.display_name,
|
||||
"platform": self.platform.value if self.platform else None,
|
||||
"chat_type": self.chat_type,
|
||||
"input_tokens": self.input_tokens,
|
||||
"output_tokens": self.output_tokens,
|
||||
"total_tokens": self.total_tokens,
|
||||
}
|
||||
if self.origin:
|
||||
result["origin"] = self.origin.to_dict()
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionEntry":
|
||||
origin = None
|
||||
if "origin" in data and data["origin"]:
|
||||
origin = SessionSource.from_dict(data["origin"])
|
||||
|
||||
platform = None
|
||||
if data.get("platform"):
|
||||
try:
|
||||
platform = Platform(data["platform"])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return cls(
|
||||
session_key=data["session_key"],
|
||||
session_id=data["session_id"],
|
||||
created_at=datetime.fromisoformat(data["created_at"]),
|
||||
updated_at=datetime.fromisoformat(data["updated_at"]),
|
||||
origin=origin,
|
||||
display_name=data.get("display_name"),
|
||||
platform=platform,
|
||||
chat_type=data.get("chat_type", "dm"),
|
||||
input_tokens=data.get("input_tokens", 0),
|
||||
output_tokens=data.get("output_tokens", 0),
|
||||
total_tokens=data.get("total_tokens", 0),
|
||||
)
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""
|
||||
Manages session storage and retrieval.
|
||||
|
||||
Sessions are stored in:
|
||||
- sessions.json: Index mapping session keys to session IDs
|
||||
- {session_id}.jsonl: Conversation transcripts
|
||||
"""
|
||||
|
||||
def __init__(self, sessions_dir: Path, config: GatewayConfig):
|
||||
self.sessions_dir = sessions_dir
|
||||
self.config = config
|
||||
self._entries: Dict[str, SessionEntry] = {}
|
||||
self._loaded = False
|
||||
|
||||
def _ensure_loaded(self) -> None:
|
||||
"""Load sessions from disk if not already loaded."""
|
||||
if self._loaded:
|
||||
return
|
||||
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
if sessions_file.exists():
|
||||
try:
|
||||
with open(sessions_file, "r") as f:
|
||||
data = json.load(f)
|
||||
for key, entry_data in data.items():
|
||||
self._entries[key] = SessionEntry.from_dict(entry_data)
|
||||
except Exception as e:
|
||||
print(f"[gateway] Warning: Failed to load sessions: {e}")
|
||||
|
||||
self._loaded = True
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Save sessions index to disk."""
|
||||
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
data = {key: entry.to_dict() for key, entry in self._entries.items()}
|
||||
with open(sessions_file, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def _generate_session_key(self, source: SessionSource) -> str:
|
||||
"""Generate a session key from a source."""
|
||||
platform = source.platform.value
|
||||
|
||||
if source.chat_type == "dm":
|
||||
# DMs share the main session per platform
|
||||
return f"agent:main:{platform}:dm"
|
||||
else:
|
||||
# Groups/channels get their own keys
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}"
|
||||
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||
"""
|
||||
Check if a session should be reset based on policy.
|
||||
|
||||
Returns True if the session is stale and should start fresh.
|
||||
"""
|
||||
policy = self.config.get_reset_policy(
|
||||
platform=source.platform,
|
||||
session_type=source.chat_type
|
||||
)
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
# Check idle timeout
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
|
||||
# Check daily reset
|
||||
if policy.mode in ("daily", "both"):
|
||||
# Find the most recent reset boundary
|
||||
today_reset = now.replace(
|
||||
hour=policy.at_hour,
|
||||
minute=0,
|
||||
second=0,
|
||||
microsecond=0
|
||||
)
|
||||
if now.hour < policy.at_hour:
|
||||
# Reset boundary was yesterday
|
||||
today_reset -= timedelta(days=1)
|
||||
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_or_create_session(
|
||||
self,
|
||||
source: SessionSource,
|
||||
force_new: bool = False
|
||||
) -> SessionEntry:
|
||||
"""
|
||||
Get an existing session or create a new one.
|
||||
|
||||
Evaluates reset policy to determine if the existing session is stale.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
session_key = self._generate_session_key(source)
|
||||
now = datetime.now()
|
||||
|
||||
# Check for existing session
|
||||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
# Check if session should be reset
|
||||
if not self._should_reset(entry, source):
|
||||
# Update timestamp and return existing
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
return entry
|
||||
|
||||
# Create new session
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=source,
|
||||
display_name=source.chat_name,
|
||||
platform=source.platform,
|
||||
chat_type=source.chat_type,
|
||||
)
|
||||
|
||||
self._entries[session_key] = entry
|
||||
self._save()
|
||||
|
||||
return entry
|
||||
|
||||
def update_session(
|
||||
self,
|
||||
session_key: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0
|
||||
) -> None:
|
||||
"""Update a session's metadata after an interaction."""
|
||||
self._ensure_loaded()
|
||||
|
||||
if session_key in self._entries:
|
||||
entry = self._entries[session_key]
|
||||
entry.updated_at = datetime.now()
|
||||
entry.input_tokens += input_tokens
|
||||
entry.output_tokens += output_tokens
|
||||
entry.total_tokens = entry.input_tokens + entry.output_tokens
|
||||
self._save()
|
||||
|
||||
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
||||
"""Force reset a session, creating a new session ID."""
|
||||
self._ensure_loaded()
|
||||
|
||||
if session_key not in self._entries:
|
||||
return None
|
||||
|
||||
old_entry = self._entries[session_key]
|
||||
now = datetime.now()
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
new_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=old_entry.origin,
|
||||
display_name=old_entry.display_name,
|
||||
platform=old_entry.platform,
|
||||
chat_type=old_entry.chat_type,
|
||||
)
|
||||
|
||||
self._entries[session_key] = new_entry
|
||||
self._save()
|
||||
|
||||
return new_entry
|
||||
|
||||
def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]:
|
||||
"""
|
||||
List all sessions, optionally filtered by activity.
|
||||
|
||||
Args:
|
||||
active_minutes: If provided, only return sessions updated within this many minutes
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
entries = list(self._entries.values())
|
||||
|
||||
if active_minutes is not None:
|
||||
cutoff = datetime.now() - timedelta(minutes=active_minutes)
|
||||
entries = [e for e in entries if e.updated_at >= cutoff]
|
||||
|
||||
# Sort by most recently updated
|
||||
entries.sort(key=lambda e: e.updated_at, reverse=True)
|
||||
|
||||
return entries
|
||||
|
||||
def get_transcript_path(self, session_id: str) -> Path:
|
||||
"""Get the path to a session's transcript file."""
|
||||
return self.sessions_dir / f"{session_id}.jsonl"
|
||||
|
||||
def append_to_transcript(self, session_id: str, message: Dict[str, Any]) -> None:
|
||||
"""Append a message to a session's transcript."""
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
|
||||
with open(transcript_path, "a") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
|
||||
def load_transcript(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""Load all messages from a session's transcript."""
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
|
||||
if not transcript_path.exists():
|
||||
return []
|
||||
|
||||
messages = []
|
||||
with open(transcript_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
messages.append(json.loads(line))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def build_session_context(
|
||||
source: SessionSource,
|
||||
config: GatewayConfig,
|
||||
session_entry: Optional[SessionEntry] = None
|
||||
) -> SessionContext:
|
||||
"""
|
||||
Build a full session context from a source and config.
|
||||
|
||||
This is used to inject context into the agent's system prompt.
|
||||
"""
|
||||
connected = config.get_connected_platforms()
|
||||
|
||||
home_channels = {}
|
||||
for platform in connected:
|
||||
home = config.get_home_channel(platform)
|
||||
if home:
|
||||
home_channels[platform] = home
|
||||
|
||||
context = SessionContext(
|
||||
source=source,
|
||||
connected_platforms=connected,
|
||||
home_channels=home_channels,
|
||||
)
|
||||
|
||||
if session_entry:
|
||||
context.session_key = session_entry.session_key
|
||||
context.session_id = session_entry.session_id
|
||||
context.created_at = session_entry.created_at
|
||||
context.updated_at = session_entry.updated_at
|
||||
|
||||
return context
|
||||
12
hermes
Executable file
12
hermes
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Hermes Agent CLI Launcher
|
||||
|
||||
This is a convenience wrapper to launch the Hermes CLI.
|
||||
Usage: ./hermes [options]
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
from cli import main
|
||||
import fire
|
||||
fire.Fire(main)
|
||||
868
hermes_agent.egg-info/PKG-INFO
Normal file
868
hermes_agent.egg-info/PKG-INFO
Normal file
@@ -0,0 +1,868 @@
|
||||
Metadata-Version: 2.4
|
||||
Name: hermes-agent
|
||||
Version: 0.1.0
|
||||
Summary: AI agent with advanced tool-calling and toolsets
|
||||
Author: Nous Research
|
||||
License: MIT
|
||||
Requires-Python: >=3.10
|
||||
Description-Content-Type: text/markdown
|
||||
Requires-Dist: openai
|
||||
Requires-Dist: python-dotenv
|
||||
Requires-Dist: fire
|
||||
Requires-Dist: httpx
|
||||
Requires-Dist: rich
|
||||
Requires-Dist: tenacity
|
||||
Requires-Dist: pyyaml
|
||||
Requires-Dist: requests
|
||||
Requires-Dist: jinja2
|
||||
Requires-Dist: pydantic>=2.0
|
||||
Requires-Dist: firecrawl-py
|
||||
Requires-Dist: fal-client
|
||||
Requires-Dist: litellm>=1.75.5
|
||||
Requires-Dist: typer
|
||||
Requires-Dist: platformdirs
|
||||
Provides-Extra: modal
|
||||
Requires-Dist: modal; extra == "modal"
|
||||
Requires-Dist: boto3; extra == "modal"
|
||||
Provides-Extra: dev
|
||||
Requires-Dist: pytest; extra == "dev"
|
||||
Requires-Dist: pytest-asyncio; extra == "dev"
|
||||
Provides-Extra: messaging
|
||||
Requires-Dist: python-telegram-bot>=20.0; extra == "messaging"
|
||||
Requires-Dist: discord.py>=2.0; extra == "messaging"
|
||||
Provides-Extra: cron
|
||||
Requires-Dist: croniter; extra == "cron"
|
||||
Provides-Extra: all
|
||||
Requires-Dist: croniter; extra == "all"
|
||||
Requires-Dist: python-telegram-bot>=20.0; extra == "all"
|
||||
Requires-Dist: discord.py>=2.0; extra == "all"
|
||||
|
||||
# Hermes Agent
|
||||
|
||||
An AI agent with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools.
|
||||
|
||||
## Features
|
||||
|
||||
- **Interactive CLI**: Beautiful terminal interface with animated feedback, personalities, and session management
|
||||
- **Messaging Gateway**: Connect to Telegram, Discord, and WhatsApp for conversational AI anywhere
|
||||
- **Web Tools**: Search, extract content, and crawl websites
|
||||
- **Terminal Tools**: Execute commands via local, Docker, Singularity, Modal, or SSH backends
|
||||
- **Browser Tools**: Automate web browsers to navigate, click, type, and extract content
|
||||
- **Vision Tools**: Analyze images from URLs
|
||||
- **Reasoning Tools**: Advanced multi-model reasoning (Mixture of Agents)
|
||||
- **Creative Tools**: Generate images from text prompts
|
||||
- **Skills Tools**: On-demand knowledge documents with progressive disclosure
|
||||
- **Toolsets System**: Organize tools into logical groups for different scenarios
|
||||
- **Scheduled Tasks**: Cron jobs for automated agent tasks with delivery to platforms
|
||||
- **Context Compression**: Automatic summarization when approaching context limits
|
||||
- **Batch Processing**: Process datasets in parallel with checkpointing and statistics tracking
|
||||
- **Ephemeral System Prompts**: Guide model behavior without polluting training datasets
|
||||
|
||||
## Installation
|
||||
|
||||
### Quick Install (Recommended)
|
||||
|
||||
**Linux/macOS:**
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
**Windows (PowerShell):**
|
||||
```powershell
|
||||
irm https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.ps1 | iex
|
||||
```
|
||||
|
||||
This installer will:
|
||||
- Clone the repository to `~/.hermes-agent`
|
||||
- Create a virtual environment and install dependencies
|
||||
- Set up the `hermes` command in your PATH
|
||||
- Run an interactive setup wizard to configure API keys
|
||||
|
||||
### Manual Installation
|
||||
|
||||
If you prefer to install manually:
|
||||
|
||||
```bash
|
||||
# Clone with submodules
|
||||
git clone --recurse-submodules https://github.com/NousResearch/Hermes-Agent.git
|
||||
cd Hermes-Agent
|
||||
|
||||
# Run the setup script
|
||||
./setup-hermes.sh
|
||||
```
|
||||
|
||||
Or step-by-step:
|
||||
|
||||
```bash
|
||||
# Create and activate virtual environment
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate # Windows: venv\Scripts\activate
|
||||
|
||||
# Install in editable mode with all extras
|
||||
pip install -e ".[all]"
|
||||
|
||||
# Or install dependencies manually
|
||||
pip install -r requirements.txt
|
||||
pip install -e ./mini-swe-agent
|
||||
|
||||
# Copy and configure environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your API keys
|
||||
|
||||
# Run the setup wizard
|
||||
hermes setup
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
Once installed, the `hermes` command is your main entry point:
|
||||
|
||||
```bash
|
||||
hermes # Interactive chat (default)
|
||||
hermes chat # Same as above
|
||||
hermes chat -q "Hello" # Single query, then exit
|
||||
hermes setup # Configure API keys and settings
|
||||
hermes status # Show configuration status
|
||||
hermes doctor # Diagnose issues
|
||||
hermes gateway # Start messaging gateway (Telegram/Discord/WhatsApp)
|
||||
hermes cron daemon # Run cron job scheduler
|
||||
hermes version # Show version info
|
||||
```
|
||||
|
||||
**Legacy `./hermes` script:**
|
||||
```bash
|
||||
# The old CLI script still works:
|
||||
./hermes
|
||||
|
||||
# Or with options:
|
||||
./hermes --model "anthropic/claude-sonnet-4" --toolsets "web,terminal"
|
||||
```
|
||||
|
||||
The CLI provides:
|
||||
- Animated spinners during thinking and tool execution
|
||||
- Kawaii-style feedback messages
|
||||
- `/commands` for configuration, history, and session management
|
||||
- Customizable personalities (`/personality kawaii`, `/personality pirate`, etc.)
|
||||
- Persistent configuration via `cli-config.yaml`
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables
|
||||
```bash
|
||||
# Copy the example environment file
|
||||
cp .env.example .env
|
||||
|
||||
# Edit .env and add your API keys
|
||||
nano .env # or use your preferred editor
|
||||
```
|
||||
|
||||
**Required API Keys:**
|
||||
- `OPENROUTER_API_KEY` - LLM access via OpenRouter (get at: https://openrouter.ai/keys)
|
||||
- `FIRECRAWL_API_KEY` - Web tools (get at: https://firecrawl.dev/)
|
||||
- `NOUS_API_KEY` - Vision & reasoning tools (get at: https://inference-api.nousresearch.com/)
|
||||
- `FAL_KEY` - Image generation (get at: https://fal.ai/)
|
||||
|
||||
**Optional API Keys (for specific features):**
|
||||
- `BROWSERBASE_API_KEY` - Browser automation (get at: https://browserbase.com/)
|
||||
- `BROWSERBASE_PROJECT_ID` - From Browserbase dashboard
|
||||
- `MORPH_API_KEY` - For legacy Hecate terminal backend (get at: https://morph.so/)
|
||||
|
||||
### 4. Configure Terminal Backend
|
||||
|
||||
The terminal tool uses **mini-swe-agent** environments. Configure in `.env` or `cli-config.yaml`:
|
||||
|
||||
```bash
|
||||
# Backend: "local", "docker", "singularity", "modal", or "ssh"
|
||||
TERMINAL_ENV=local # Default: runs on host machine (no isolation)
|
||||
TERMINAL_ENV=ssh # Remote execution via SSH (agent code stays local)
|
||||
TERMINAL_ENV=singularity # Recommended for HPC: Apptainer/Singularity containers
|
||||
TERMINAL_ENV=docker # Isolated Docker containers
|
||||
TERMINAL_ENV=modal # Cloud execution via Modal
|
||||
|
||||
# Container image (for docker/singularity/modal backends)
|
||||
TERMINAL_DOCKER_IMAGE=python:3.11-slim
|
||||
TERMINAL_SINGULARITY_IMAGE=docker://python:3.11-slim
|
||||
TERMINAL_TIMEOUT=60
|
||||
|
||||
# SSH backend (for ssh)
|
||||
TERMINAL_SSH_HOST=my-server.example.com
|
||||
TERMINAL_SSH_USER=myuser
|
||||
TERMINAL_SSH_KEY=~/.ssh/id_rsa # Optional, uses ssh-agent if not set
|
||||
```
|
||||
|
||||
**Backend Requirements:**
|
||||
- **local**: No extra setup (runs directly on your machine, no isolation)
|
||||
- **ssh**: SSH access to remote machine (great for sandboxing - agent can't touch its own code)
|
||||
- **singularity**: Requires Apptainer or Singularity installed (common on HPC clusters, no root needed)
|
||||
- **docker**: Requires Docker installed and user in `docker` group
|
||||
- **modal**: Requires Modal account (see setup below)
|
||||
|
||||
### Singularity/Apptainer Setup (Recommended for HPC)
|
||||
|
||||
Singularity/Apptainer provides rootless container execution, ideal for HPC clusters:
|
||||
|
||||
```bash
|
||||
# 1. Verify Apptainer is installed
|
||||
apptainer --version # or: singularity --version
|
||||
|
||||
# 2. Set up cache directories (important for parallel workers)
|
||||
# Use /scratch if available (HPC), otherwise /tmp
|
||||
export APPTAINER_CACHEDIR=/scratch/$USER/.apptainer
|
||||
export APPTAINER_TMPDIR=/scratch/$USER/.apptainer/tmp
|
||||
mkdir -p "$APPTAINER_CACHEDIR" "$APPTAINER_TMPDIR"
|
||||
|
||||
# 3. Pre-build SIF image (recommended for parallel batch processing)
|
||||
# This avoids race conditions when multiple workers start simultaneously
|
||||
apptainer build $APPTAINER_CACHEDIR/python-nodejs.sif docker://nikolaik/python-nodejs:python3.11-nodejs20
|
||||
|
||||
# 4. Configure .env to use the local SIF
|
||||
TERMINAL_ENV=singularity
|
||||
TERMINAL_SINGULARITY_IMAGE=/scratch/$USER/.apptainer/python-nodejs.sif
|
||||
```
|
||||
|
||||
**Tip:** The batch scripts in `configs/` automatically handle SIF pre-building if `/scratch` is available.
|
||||
|
||||
### Modal Cloud Backend Setup
|
||||
|
||||
[Modal](https://modal.com) provides serverless cloud compute for running sandboxed environments at scale.
|
||||
|
||||
```bash
|
||||
# 1. Install Modal and dependencies
|
||||
pip install modal boto3
|
||||
|
||||
# 2. Authenticate with Modal (opens browser)
|
||||
modal setup
|
||||
|
||||
# 3. Set terminal backend to modal in .env
|
||||
TERMINAL_ENV=modal
|
||||
```
|
||||
|
||||
Modal uses CLI-based authentication (stored in `~/.modal/`), so no API key is needed in `.env`. After running `modal setup`, commands will automatically execute in Modal's cloud sandboxes.
|
||||
|
||||
### Browser Tools Setup
|
||||
|
||||
Browser tools enable the agent to navigate websites, fill forms, click buttons, and extract content. They use [agent-browser](https://github.com/vercel-labs/agent-browser) CLI with [Browserbase](https://browserbase.com) cloud execution.
|
||||
|
||||
```bash
|
||||
# 1. Install Node.js (if not already installed)
|
||||
# Use nvm (recommended) or your package manager
|
||||
|
||||
# 2. Install agent-browser CLI (choose one option):
|
||||
npm install -g agent-browser # Option A: Global install (recommended)
|
||||
npm install # Option B: Local install (uses npx fallback)
|
||||
|
||||
# 3. Get Browserbase credentials
|
||||
# Sign up at https://browserbase.com/ and get your:
|
||||
# - API Key (from Settings → API Keys)
|
||||
# - Project ID (from your project dashboard)
|
||||
|
||||
# 4. Add to your .env file:
|
||||
BROWSERBASE_API_KEY=your_api_key_here
|
||||
BROWSERBASE_PROJECT_ID=your_project_id_here
|
||||
```
|
||||
|
||||
**Available Browser Tools:**
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `browser_navigate` | Navigate to a URL |
|
||||
| `browser_snapshot` | Get text-based page snapshot with element refs |
|
||||
| `browser_click` | Click an element by ref (e.g., `@e5`) |
|
||||
| `browser_type` | Type text into an input field |
|
||||
| `browser_scroll` | Scroll up or down |
|
||||
| `browser_back` | Go back in browser history |
|
||||
| `browser_press` | Press a keyboard key (Enter, Tab, etc.) |
|
||||
| `browser_close` | Close the browser session |
|
||||
| `browser_get_images` | Get list of images on the page |
|
||||
|
||||
**Example Usage:**
|
||||
```bash
|
||||
# Use browser tools with web search and vision
|
||||
python run_agent.py \
|
||||
--query "Go to amazon.com and find the price of the latest Kindle" \
|
||||
--enabled_toolsets=browser,web,vision
|
||||
|
||||
# Use browser-focused distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=browser_tasks.jsonl \
|
||||
--distribution=browser_use \
|
||||
--run_name=browser_run
|
||||
```
|
||||
|
||||
See `.env.example` for all available configuration options including debug settings.
|
||||
|
||||
### Skills Tools
|
||||
|
||||
Skills are on-demand knowledge documents the agent can load when needed. They follow a **progressive disclosure** pattern to minimize token usage:
|
||||
|
||||
```
|
||||
skills/
|
||||
├── mlops/ # Category folder
|
||||
│ ├── axolotl/ # Skill folder
|
||||
│ │ ├── SKILL.md # Main instructions (required)
|
||||
│ │ ├── references/ # Additional docs, API specs
|
||||
│ │ └── templates/ # Output formats, configs
|
||||
│ └── vllm/
|
||||
│ └── SKILL.md
|
||||
```
|
||||
|
||||
**Available Skills Tools:**
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `skills_categories` | List available skill categories (~50 tokens) |
|
||||
| `skills_list` | List skills with name + description (~3k tokens for 40 skills) |
|
||||
| `skill_view` | Load full skill content, tags, and linked files |
|
||||
|
||||
**Example Usage:**
|
||||
```bash
|
||||
# Use skills tools
|
||||
python run_agent.py \
|
||||
--query "What skills do you have for fine-tuning? Show me the axolotl skill." \
|
||||
--enabled_toolsets=skills
|
||||
```
|
||||
|
||||
**Creating Skills:**
|
||||
|
||||
Skills use YAML frontmatter for metadata:
|
||||
```yaml
|
||||
---
|
||||
name: my-skill
|
||||
description: Brief description shown in skills_list
|
||||
tags: [tag1, tag2]
|
||||
related_skills: [other-skill]
|
||||
version: 1.0.0
|
||||
---
|
||||
# Skill Content
|
||||
|
||||
Instructions, examples, and guidelines here...
|
||||
```
|
||||
|
||||
Skills can include:
|
||||
- `references/` - Additional documentation, API specs, examples
|
||||
- `templates/` - Output formats, config files, boilerplate code
|
||||
- `scripts/` - Executable helpers (Python, shell scripts)
|
||||
|
||||
## Session Logging
|
||||
|
||||
Every conversation is automatically logged to `logs/` for debugging and inspection:
|
||||
|
||||
```
|
||||
logs/
|
||||
├── session_20260201_143052_a1b2c3.json
|
||||
├── session_20260201_150217_d4e5f6.json
|
||||
└── ...
|
||||
```
|
||||
|
||||
**Log Format:**
|
||||
```json
|
||||
{
|
||||
"session_id": "20260201_143052_a1b2c3",
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"session_start": "2026-02-01T14:30:52.123456",
|
||||
"last_updated": "2026-02-01T14:35:12.789012",
|
||||
"message_count": 8,
|
||||
"conversations": [
|
||||
{"from": "system", "value": "..."},
|
||||
{"from": "human", "value": "..."},
|
||||
{"from": "gpt", "value": "..."},
|
||||
{"from": "tool", "value": "..."}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
- **Automatic**: Logs are created and updated automatically after each conversation turn
|
||||
- **Session ID in Banner**: The CLI displays the session ID in the welcome banner
|
||||
- **Trajectory Format**: Uses the same format as batch processing for consistency
|
||||
- **Git Ignored**: `logs/` is in `.gitignore` so logs aren't committed
|
||||
|
||||
## Context Compression
|
||||
|
||||
Long conversations can exceed the model's context limit. Hermes Agent automatically compresses context when approaching the limit:
|
||||
|
||||
**How it works:**
|
||||
1. Tracks actual token usage from API responses (`usage.prompt_tokens`)
|
||||
2. When tokens reach 85% of model's context limit, triggers compression
|
||||
3. Protects first 3 turns (system prompt, initial request, first response)
|
||||
4. Protects last 4 turns (recent context is most relevant)
|
||||
5. Summarizes middle turns using a fast/cheap model (Gemini Flash)
|
||||
6. Inserts summary as a user message, conversation continues seamlessly
|
||||
|
||||
**Configuration (`cli-config.yaml`):**
|
||||
```yaml
|
||||
compression:
|
||||
enabled: true # Enable auto-compression (default)
|
||||
threshold: 0.85 # Compress at 85% of context limit
|
||||
summary_model: "google/gemini-2.0-flash-001"
|
||||
```
|
||||
|
||||
**Or via environment variables:**
|
||||
```bash
|
||||
CONTEXT_COMPRESSION_ENABLED=true
|
||||
CONTEXT_COMPRESSION_THRESHOLD=0.85
|
||||
CONTEXT_COMPRESSION_MODEL=google/gemini-2.0-flash-001
|
||||
```
|
||||
|
||||
**When compression triggers, you'll see:**
|
||||
```
|
||||
📦 Context compression triggered (170,000 tokens ≥ 170,000 threshold)
|
||||
📊 Model context limit: 200,000 tokens (85% = 170,000)
|
||||
🗜️ Summarizing turns 4-15 (12 turns)
|
||||
✅ Compressed: 20 → 9 messages (~45,000 tokens saved)
|
||||
```
|
||||
|
||||
## Scheduled Tasks (Cron Jobs)
|
||||
|
||||
Hermes Agent can schedule automated tasks to run in the future - either one-time reminders or recurring jobs.
|
||||
|
||||
### CLI Commands
|
||||
|
||||
```bash
|
||||
# List scheduled jobs
|
||||
/cron
|
||||
|
||||
# Add a one-shot reminder (runs once in 30 minutes)
|
||||
/cron add 30m Remind me to check the build status
|
||||
|
||||
# Add a recurring job (every 2 hours)
|
||||
/cron add "every 2h" Check server status at 192.168.1.100 and report any issues
|
||||
|
||||
# Add a cron expression (daily at 9am)
|
||||
/cron add "0 9 * * *" Generate a morning briefing summarizing GitHub notifications
|
||||
|
||||
# Remove a job
|
||||
/cron remove abc123def456
|
||||
```
|
||||
|
||||
### Agent Self-Scheduling
|
||||
|
||||
The agent can also schedule its own follow-up tasks using tools:
|
||||
|
||||
```python
|
||||
# Available when using hermes-cli toolset (default for CLI)
|
||||
schedule_cronjob(prompt="...", schedule="30m", repeat=1) # One-shot
|
||||
schedule_cronjob(prompt="...", schedule="every 2h") # Recurring
|
||||
list_cronjobs() # View all jobs
|
||||
remove_cronjob(job_id="...") # Cancel a job
|
||||
```
|
||||
|
||||
**⚠️ Important:** Cronjobs run in **isolated sessions with NO prior context**. The prompt must be completely self-contained with all necessary information (file paths, URLs, server addresses, etc.). The future agent will not remember anything from the current conversation.
|
||||
|
||||
### Schedule Formats
|
||||
|
||||
| Format | Example | Description |
|
||||
|--------|---------|-------------|
|
||||
| Duration | `30m`, `2h`, `1d` | One-shot delay from now |
|
||||
| Interval | `every 30m`, `every 2h` | Recurring at fixed intervals |
|
||||
| Cron | `0 9 * * *` | Cron expression (requires `croniter`) |
|
||||
| Timestamp | `2026-02-03T14:00` | One-shot at specific time |
|
||||
|
||||
### Repeat Options
|
||||
|
||||
| repeat | Behavior |
|
||||
|--------|----------|
|
||||
| (omitted) | One-shot schedules run once; intervals/cron run forever |
|
||||
| `1` | Run once then auto-delete |
|
||||
| `N` | Run N times then auto-delete |
|
||||
|
||||
### Running the Cron Daemon
|
||||
|
||||
Jobs are stored in `~/.hermes/cron/jobs.json` and executed by a scheduler:
|
||||
|
||||
```bash
|
||||
# Option 1: Built-in daemon (checks every 60 seconds)
|
||||
python cli.py --cron-daemon
|
||||
|
||||
# Option 2: System cron integration (run once per minute)
|
||||
# Add to crontab: crontab -e
|
||||
*/1 * * * * cd ~/hermes-agent && python cli.py --cron-tick-once >> ~/.hermes/cron/cron.log 2>&1
|
||||
```
|
||||
|
||||
### Job Output
|
||||
|
||||
Job outputs are saved to `~/.hermes/cron/output/{job_id}/{timestamp}.md` for review.
|
||||
|
||||
## Messaging Gateway (Telegram, Discord, WhatsApp)
|
||||
|
||||
Connect Hermes Agent to messaging platforms so you can chat from anywhere.
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
# 1. Add your bot token to .env
|
||||
echo 'TELEGRAM_BOT_TOKEN="your_token"' >> .env
|
||||
|
||||
# 2. Test the gateway (foreground)
|
||||
./scripts/hermes-gateway run
|
||||
|
||||
# 3. Install as a background service
|
||||
./scripts/hermes-gateway install
|
||||
|
||||
# 4. Manage the service
|
||||
./scripts/hermes-gateway start # Start
|
||||
./scripts/hermes-gateway stop # Stop
|
||||
./scripts/hermes-gateway status # Check status
|
||||
```
|
||||
|
||||
### Supported Platforms
|
||||
|
||||
| Platform | Setup | Toolset |
|
||||
|----------|-------|---------|
|
||||
| Telegram | Bot via @BotFather | `hermes-telegram` |
|
||||
| Discord | Bot via Developer Portal | `hermes-discord` |
|
||||
| WhatsApp | Node.js bridge | `hermes-whatsapp` |
|
||||
|
||||
### Session Management
|
||||
|
||||
- Sessions persist across messages (agent remembers context)
|
||||
- Reset policies: daily (4am), idle (2 hours), or both
|
||||
- Manual reset: send `/new` or `/reset`
|
||||
|
||||
### Cron Job Delivery
|
||||
|
||||
Schedule tasks that deliver to specific platforms:
|
||||
|
||||
```python
|
||||
schedule_cronjob(
|
||||
prompt="Check server status...",
|
||||
schedule="every 1h",
|
||||
deliver="telegram" # or "origin", "discord", etc.
|
||||
)
|
||||
```
|
||||
|
||||
### CLI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/platforms` | Show gateway configuration status |
|
||||
| `--gateway` | Start the gateway (CLI flag) |
|
||||
|
||||
See [docs/messaging.md](docs/messaging.md) for full setup instructions.
|
||||
|
||||
## Interactive CLI
|
||||
|
||||
The CLI provides a rich interactive experience for working with the agent.
|
||||
|
||||
### Running the CLI
|
||||
|
||||
```bash
|
||||
# Basic usage
|
||||
./hermes
|
||||
|
||||
# With specific model
|
||||
./hermes --model "anthropic/claude-sonnet-4"
|
||||
|
||||
# With specific toolsets
|
||||
./hermes --toolsets "web,terminal,skills"
|
||||
```
|
||||
|
||||
### CLI Commands
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `/help` | Show available commands |
|
||||
| `/tools` | List available tools by toolset |
|
||||
| `/toolsets` | List available toolsets |
|
||||
| `/model [name]` | Show or change the current model |
|
||||
| `/prompt [text]` | View/set custom system prompt |
|
||||
| `/personality [name]` | Set a predefined personality |
|
||||
| `/clear` | Clear screen and reset conversation |
|
||||
| `/reset` | Reset conversation only |
|
||||
| `/history` | Show conversation history |
|
||||
| `/save` | Save current conversation to file |
|
||||
| `/config` | Show current configuration |
|
||||
| `/cron` | Manage scheduled tasks (list, add, remove) |
|
||||
| `/platforms` | Show gateway/messaging platform status |
|
||||
| `/quit` | Exit the CLI |
|
||||
|
||||
### Configuration
|
||||
|
||||
Copy `cli-config.yaml.example` to `cli-config.yaml` and customize:
|
||||
|
||||
```yaml
|
||||
# Model settings
|
||||
model:
|
||||
default: "anthropic/claude-sonnet-4"
|
||||
|
||||
# Terminal backend (local, docker, singularity, modal, or ssh)
|
||||
terminal:
|
||||
env_type: "local"
|
||||
cwd: "." # Use current directory
|
||||
|
||||
# Or use SSH for remote execution (keeps agent code isolated)
|
||||
# terminal:
|
||||
# env_type: "ssh"
|
||||
# ssh_host: "my-server.example.com"
|
||||
# ssh_user: "myuser"
|
||||
# ssh_key: "~/.ssh/id_rsa"
|
||||
# cwd: "/home/myuser/project"
|
||||
|
||||
# Enable specific toolsets
|
||||
toolsets:
|
||||
- all # or: web, terminal, browser, vision, etc.
|
||||
|
||||
# Custom personalities (use with /personality command)
|
||||
agent:
|
||||
personalities:
|
||||
helpful: "You are a helpful assistant."
|
||||
kawaii: "You are a kawaii assistant! Use cute expressions..."
|
||||
```
|
||||
|
||||
### Personalities
|
||||
|
||||
Built-in personalities available via `/personality`:
|
||||
- `helpful`, `concise`, `technical`, `creative`, `teacher`
|
||||
- `kawaii`, `catgirl`, `pirate`, `shakespeare`, `surfer`
|
||||
- `noir`, `uwu`, `philosopher`, `hype`
|
||||
|
||||
## Toolsets System
|
||||
|
||||
The agent uses a toolsets system for organizing and managing tools. All tools must be part of a toolset to be accessible - individual tool selection is not supported. This ensures consistent and logical grouping of capabilities.
|
||||
|
||||
### Key Concepts
|
||||
|
||||
- **Toolsets**: Logical groups of tools for specific use cases (e.g., "research", "development", "debugging")
|
||||
- **Composition**: Toolsets can include other toolsets for powerful combinations
|
||||
- **Custom Toolsets**: Create your own toolsets at runtime or by editing `toolsets.py`
|
||||
- **Toolset-Only Access**: Tools are only accessible through toolsets, not individually
|
||||
|
||||
### Available Toolsets
|
||||
|
||||
See `toolsets.py` for the complete list of predefined toolsets including:
|
||||
- Basic toolsets (web, terminal, vision, creative, reasoning)
|
||||
- Composite toolsets (research, development, analysis, etc.)
|
||||
- Scenario-specific toolsets (debugging, documentation, API testing, etc.)
|
||||
- Special toolsets (safe mode without terminal, minimal, offline)
|
||||
|
||||
### Using Toolsets
|
||||
|
||||
```bash
|
||||
# Use a predefined toolset
|
||||
python run_agent.py --enabled_toolsets=research --query "Find latest AI papers"
|
||||
|
||||
# Combine multiple toolsets
|
||||
python run_agent.py --enabled_toolsets=web,vision --query "Analyze this website"
|
||||
|
||||
# Enable all toolsets explicitly (same as omitting the flag)
|
||||
python run_agent.py --enabled_toolsets=all --query "Do web research and run commands if helpful"
|
||||
|
||||
# Safe mode (no terminal access)
|
||||
python run_agent.py --enabled_toolsets=safe --query "Help without running commands"
|
||||
|
||||
# List all available toolsets and tools
|
||||
python run_agent.py --list_tools
|
||||
```
|
||||
|
||||
See `toolsets.py` for the complete list of available toolsets and how to create custom ones.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Default (all tools enabled)
|
||||
```bash
|
||||
# Uses OpenRouter by default - just set OPENROUTER_API_KEY in .env
|
||||
python run_agent.py \
|
||||
--query "search up the latest docs on jit in python 3.13 and write me basic example that's not in their docs. profile its perf" \
|
||||
--max_turns 20 \
|
||||
--model anthropic/claude-sonnet-4-20250514
|
||||
```
|
||||
|
||||
### With specific toolset
|
||||
```bash
|
||||
python run_agent.py \
|
||||
--query "Debug this Python error" \
|
||||
--enabled_toolsets=debugging \
|
||||
--model anthropic/claude-sonnet-4-20250514
|
||||
```
|
||||
|
||||
### Python API
|
||||
```python
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Uses OpenRouter by default (reads OPENROUTER_API_KEY from .env)
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4-20250514",
|
||||
enabled_toolsets=["research"]
|
||||
)
|
||||
response = agent.chat("Find information about quantum computing")
|
||||
|
||||
# Create custom toolset at runtime
|
||||
from toolsets import create_custom_toolset
|
||||
|
||||
create_custom_toolset(
|
||||
name="my_tools",
|
||||
description="My custom toolkit",
|
||||
tools=["web_search"],
|
||||
includes=["terminal", "vision"]
|
||||
)
|
||||
|
||||
agent = AIAgent(enabled_toolsets=["my_tools"])
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
Process multiple prompts from a dataset in parallel with automatic checkpointing and statistics tracking:
|
||||
|
||||
```bash
|
||||
# Basic batch processing
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=my_run
|
||||
|
||||
# With specific distribution
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--run_name=image_run \
|
||||
--distribution=image_gen \
|
||||
--num_workers=4
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- Parallel processing with configurable workers
|
||||
- Toolset distributions for varied data generation
|
||||
- Automatic checkpointing and resume capability
|
||||
- Combined output in `data/<run_name>/trajectories.jsonl`
|
||||
- Tool usage statistics and success rates
|
||||
|
||||
Use `--list_distributions` to see available toolset distributions for varied data generation.
|
||||
|
||||
### Trajectory Compression
|
||||
|
||||
Post-process trajectories to fit within token budgets for training:
|
||||
|
||||
```bash
|
||||
# Compress a directory of JSONL files
|
||||
python trajectory_compressor.py --input=data/my_run
|
||||
|
||||
# Compress a single JSONL file
|
||||
python trajectory_compressor.py --input=data/trajectories.jsonl
|
||||
|
||||
# Compress a 15% sample (useful for creating smaller training sets)
|
||||
python trajectory_compressor.py --input=data/trajectories.jsonl --sample_percent=15
|
||||
|
||||
# Custom output and token target
|
||||
python trajectory_compressor.py \
|
||||
--input=data/trajectories.jsonl \
|
||||
--output=data/compressed.jsonl \
|
||||
--target_max_tokens=16000
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Protects first turns (system, human, first GPT response, first tool call)
|
||||
- Protects last N turns (configurable)
|
||||
- Summarizes middle turns using LLM to fit target token budget
|
||||
- Supports both directory and single file input
|
||||
- Optional random sampling with `--sample_percent`
|
||||
- Configurable via `configs/trajectory_compression.yaml`
|
||||
|
||||
### Ephemeral System Prompts
|
||||
|
||||
The ephemeral system prompt feature allows you to guide the model's behavior during batch processing **without** saving that prompt to the training dataset trajectories. This is useful for:
|
||||
|
||||
- Guiding model behavior during data collection
|
||||
- Adding task-specific instructions
|
||||
- Keeping saved trajectories clean and focused on tool-calling format
|
||||
|
||||
**Example:**
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=10 \
|
||||
--run_name=my_run \
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
```
|
||||
|
||||
The ephemeral prompt will influence the model's behavior during execution, but **only the standard tool-calling system prompt** will be saved in the trajectory files.
|
||||
|
||||
The ephemeral prompt influences model behavior during execution, but **only the standard tool-calling system prompt** is saved in trajectory files.
|
||||
|
||||
## Command Line Arguments
|
||||
|
||||
**Single Agent (`run_agent.py`):**
|
||||
- `--query`: The question or task for the agent
|
||||
- `--model`: Model to use (default: claude-opus-4-20250514)
|
||||
- `--api_key`: API key for authentication
|
||||
- `--base_url`: API endpoint URL
|
||||
- `--max_turns`: Maximum number of tool-calling iterations
|
||||
- `--enabled_toolsets`: Comma-separated list of toolsets to enable. Use `all` (or `*`) to enable everything. If omitted, all toolsets are enabled by default.
|
||||
- `--disabled_toolsets`: Comma-separated list of toolsets to disable
|
||||
- `--list_tools`: List all available toolsets and tools
|
||||
- `--save_trajectories`: Save conversation trajectories to JSONL files
|
||||
|
||||
**Batch Processing (`batch_runner.py`):**
|
||||
- `--dataset_file`: Path to JSONL file with prompts
|
||||
- `--batch_size`: Number of prompts per batch
|
||||
- `--run_name`: Name for this run (for output/checkpointing)
|
||||
- `--distribution`: Toolset distribution to use (default: "default")
|
||||
- `--num_workers`: Number of parallel workers (default: 4)
|
||||
- `--resume`: Resume from checkpoint if interrupted
|
||||
- `--ephemeral_system_prompt`: System prompt used during execution but NOT saved to trajectories
|
||||
- `--list_distributions`: List available toolset distributions
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All environment variables can be configured in the `.env` file (copy from `.env.example`).
|
||||
|
||||
**LLM Provider (OpenRouter):**
|
||||
- `OPENROUTER_API_KEY`: Primary LLM access via OpenRouter (supports Claude, GPT-4, Gemini, etc.)
|
||||
- `LLM_MODEL`: Default model (e.g., `anthropic/claude-sonnet-4`, `openai/gpt-4o`)
|
||||
|
||||
**Tool API Keys:**
|
||||
- `FIRECRAWL_API_KEY`: Web tools (search, extract, crawl)
|
||||
- `NOUS_API_KEY`: Vision and reasoning tools
|
||||
- `FAL_KEY`: Image generation tools
|
||||
|
||||
**Terminal Tool Configuration (mini-swe-agent backend):**
|
||||
- `TERMINAL_ENV`: Backend type - `local`, `docker`, `singularity`, `modal`, or `ssh` (default: `local`)
|
||||
- `TERMINAL_DOCKER_IMAGE`: Docker image for docker backend (default: `python:3.11-slim`)
|
||||
- `TERMINAL_SINGULARITY_IMAGE`: Singularity/Apptainer image (can be `docker://...` URL or local `.sif` path)
|
||||
- `TERMINAL_TIMEOUT`: Command timeout in seconds (default: `60`)
|
||||
- `TERMINAL_LIFETIME_SECONDS`: Cleanup inactive environments after this time (default: `300`)
|
||||
- `TERMINAL_CWD`: Working directory inside containers (default: `/tmp`)
|
||||
- `TERMINAL_SCRATCH_DIR`: Custom scratch directory for sandbox storage (optional, auto-detects `/scratch`)
|
||||
- `SUDO_PASSWORD`: Enable sudo commands by piping password via `sudo -S` (works with all backends)
|
||||
- If unset in CLI mode, you'll be prompted interactively when sudo is needed (45s timeout)
|
||||
|
||||
**SSH Backend Configuration (for remote execution):**
|
||||
- `TERMINAL_SSH_HOST`: Remote server hostname or IP
|
||||
- `TERMINAL_SSH_USER`: SSH username
|
||||
- `TERMINAL_SSH_PORT`: SSH port (default: `22`)
|
||||
- `TERMINAL_SSH_KEY`: Path to SSH private key (optional, uses ssh-agent if not set)
|
||||
|
||||
**Context Compression (auto-shrinks long conversations):**
|
||||
- `CONTEXT_COMPRESSION_ENABLED`: Enable auto-compression (default: `true`)
|
||||
- `CONTEXT_COMPRESSION_THRESHOLD`: Compress at this % of context limit (default: `0.85`)
|
||||
- `CONTEXT_COMPRESSION_MODEL`: Model for generating summaries (default: `google/gemini-2.0-flash-001`)
|
||||
|
||||
**Browser Tool Configuration (agent-browser + Browserbase):**
|
||||
- `BROWSERBASE_API_KEY`: Browserbase API key for cloud browser execution
|
||||
- `BROWSERBASE_PROJECT_ID`: Browserbase project ID
|
||||
- `BROWSER_SESSION_TIMEOUT`: Session timeout in seconds (default: `300`)
|
||||
|
||||
**Legacy Hecate Terminal Backend (optional):**
|
||||
- `MORPH_API_KEY`: For Hecate/MorphCloud terminal backend
|
||||
- `HECATE_VM_LIFETIME_SECONDS`: VM lifetime (default: 300)
|
||||
- `HECATE_DEFAULT_SNAPSHOT_ID`: Default snapshot (default: snapshot_p5294qxt)
|
||||
|
||||
**Debug Options:**
|
||||
- `WEB_TOOLS_DEBUG`, `VISION_TOOLS_DEBUG`, `MOA_TOOLS_DEBUG`, `IMAGE_TOOLS_DEBUG`: Enable debug logging
|
||||
|
||||
## Key Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `hermes` | CLI launcher script (run with `./hermes`) |
|
||||
| `cli.py` | Interactive CLI implementation |
|
||||
| `cli-config.yaml` | CLI configuration (copy from `.example`) |
|
||||
| `run_agent.py` | Main agent runner - single query execution |
|
||||
| `batch_runner.py` | Parallel batch processing with checkpointing |
|
||||
| `model_tools.py` | Core tool definitions and handlers |
|
||||
| `toolsets.py` | Toolset definitions and composition |
|
||||
| `toolset_distributions.py` | Probability distributions for data generation |
|
||||
| `trajectory_compressor.py` | Post-process trajectories for training |
|
||||
| `tools/` | Individual tool implementations |
|
||||
| `tools/skills_tool.py` | Skills system with progressive disclosure |
|
||||
| `skills/` | On-demand knowledge documents |
|
||||
| `docs/` | Documentation |
|
||||
| `configs/` | Example batch run scripts |
|
||||
47
hermes_agent.egg-info/SOURCES.txt
Normal file
47
hermes_agent.egg-info/SOURCES.txt
Normal file
@@ -0,0 +1,47 @@
|
||||
README.md
|
||||
batch_runner.py
|
||||
cli.py
|
||||
model_tools.py
|
||||
pyproject.toml
|
||||
run_agent.py
|
||||
toolset_distributions.py
|
||||
toolsets.py
|
||||
trajectory_compressor.py
|
||||
cron/__init__.py
|
||||
cron/jobs.py
|
||||
cron/scheduler.py
|
||||
gateway/__init__.py
|
||||
gateway/config.py
|
||||
gateway/delivery.py
|
||||
gateway/run.py
|
||||
gateway/session.py
|
||||
hermes_agent.egg-info/PKG-INFO
|
||||
hermes_agent.egg-info/SOURCES.txt
|
||||
hermes_agent.egg-info/dependency_links.txt
|
||||
hermes_agent.egg-info/entry_points.txt
|
||||
hermes_agent.egg-info/requires.txt
|
||||
hermes_agent.egg-info/top_level.txt
|
||||
hermes_cli/__init__.py
|
||||
hermes_cli/cron.py
|
||||
hermes_cli/doctor.py
|
||||
hermes_cli/gateway.py
|
||||
hermes_cli/main.py
|
||||
hermes_cli/setup.py
|
||||
hermes_cli/status.py
|
||||
tests/test_batch_runner.py
|
||||
tests/test_checkpoint_resumption.py
|
||||
tests/test_modal_terminal.py
|
||||
tests/test_nous_api_limits.py
|
||||
tests/test_nous_api_pattern.py
|
||||
tests/test_temperature_fix.py
|
||||
tests/test_web_tools.py
|
||||
tools/__init__.py
|
||||
tools/browser_tool.py
|
||||
tools/cronjob_tools.py
|
||||
tools/image_generation_tool.py
|
||||
tools/mixture_of_agents_tool.py
|
||||
tools/skills_tool.py
|
||||
tools/terminal_hecate.py
|
||||
tools/terminal_tool.py
|
||||
tools/vision_tools.py
|
||||
tools/web_tools.py
|
||||
1
hermes_agent.egg-info/dependency_links.txt
Normal file
1
hermes_agent.egg-info/dependency_links.txt
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
3
hermes_agent.egg-info/entry_points.txt
Normal file
3
hermes_agent.egg-info/entry_points.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
[console_scripts]
|
||||
hermes = hermes_cli.main:main
|
||||
hermes-agent = run_agent:main
|
||||
35
hermes_agent.egg-info/requires.txt
Normal file
35
hermes_agent.egg-info/requires.txt
Normal file
@@ -0,0 +1,35 @@
|
||||
openai
|
||||
python-dotenv
|
||||
fire
|
||||
httpx
|
||||
rich
|
||||
tenacity
|
||||
pyyaml
|
||||
requests
|
||||
jinja2
|
||||
pydantic>=2.0
|
||||
firecrawl-py
|
||||
fal-client
|
||||
litellm>=1.75.5
|
||||
typer
|
||||
platformdirs
|
||||
|
||||
[all]
|
||||
croniter
|
||||
python-telegram-bot>=20.0
|
||||
discord.py>=2.0
|
||||
|
||||
[cron]
|
||||
croniter
|
||||
|
||||
[dev]
|
||||
pytest
|
||||
pytest-asyncio
|
||||
|
||||
[messaging]
|
||||
python-telegram-bot>=20.0
|
||||
discord.py>=2.0
|
||||
|
||||
[modal]
|
||||
modal
|
||||
boto3
|
||||
11
hermes_agent.egg-info/top_level.txt
Normal file
11
hermes_agent.egg-info/top_level.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
batch_runner
|
||||
cli
|
||||
cron
|
||||
gateway
|
||||
hermes_cli
|
||||
model_tools
|
||||
run_agent
|
||||
tools
|
||||
toolset_distributions
|
||||
toolsets
|
||||
trajectory_compressor
|
||||
14
hermes_cli/__init__.py
Normal file
14
hermes_cli/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
Hermes CLI - Unified command-line interface for Hermes Agent.
|
||||
|
||||
Provides subcommands for:
|
||||
- hermes chat - Interactive chat (same as ./hermes)
|
||||
- hermes gateway - Run gateway in foreground
|
||||
- hermes gateway start - Start gateway service
|
||||
- hermes gateway stop - Stop gateway service
|
||||
- hermes setup - Interactive setup wizard
|
||||
- hermes status - Show status of all components
|
||||
- hermes cron - Manage cron jobs
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
785
hermes_cli/config.py
Normal file
785
hermes_cli/config.py
Normal file
@@ -0,0 +1,785 @@
|
||||
"""
|
||||
Configuration management for Hermes Agent.
|
||||
|
||||
Config files are stored in ~/.hermes/ for easy access:
|
||||
- ~/.hermes/config.yaml - All settings (model, toolsets, terminal, etc.)
|
||||
- ~/.hermes/.env - API keys and secrets
|
||||
|
||||
This module provides:
|
||||
- hermes config - Show current configuration
|
||||
- hermes config edit - Open config in editor
|
||||
- hermes config set - Set a specific value
|
||||
- hermes config wizard - Re-run setup wizard
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
# ANSI colors
|
||||
class Colors:
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
BLUE = "\033[34m"
|
||||
MAGENTA = "\033[35m"
|
||||
CYAN = "\033[36m"
|
||||
|
||||
def color(text: str, *codes) -> str:
|
||||
if not sys.stdout.isatty():
|
||||
return text
|
||||
return "".join(codes) + text + Colors.RESET
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config paths
|
||||
# =============================================================================
|
||||
|
||||
def get_hermes_home() -> Path:
|
||||
"""Get the Hermes home directory (~/.hermes)."""
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
def get_config_path() -> Path:
|
||||
"""Get the main config file path."""
|
||||
return get_hermes_home() / "config.yaml"
|
||||
|
||||
def get_env_path() -> Path:
|
||||
"""Get the .env file path (for API keys)."""
|
||||
return get_hermes_home() / ".env"
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""Get the project installation directory."""
|
||||
return Path(__file__).parent.parent.resolve()
|
||||
|
||||
def ensure_hermes_home():
|
||||
"""Ensure ~/.hermes directory structure exists."""
|
||||
home = get_hermes_home()
|
||||
(home / "cron").mkdir(parents=True, exist_ok=True)
|
||||
(home / "sessions").mkdir(parents=True, exist_ok=True)
|
||||
(home / "logs").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config loading/saving
|
||||
# =============================================================================
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"model": "anthropic/claude-sonnet-4.5",
|
||||
"toolsets": ["hermes-cli"],
|
||||
"max_turns": 100,
|
||||
|
||||
"terminal": {
|
||||
"backend": "local",
|
||||
"cwd": ".", # Use current directory
|
||||
"timeout": 180,
|
||||
"docker_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"singularity_image": "docker://nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"modal_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
},
|
||||
|
||||
"browser": {
|
||||
"inactivity_timeout": 120,
|
||||
},
|
||||
|
||||
"compression": {
|
||||
"enabled": True,
|
||||
"threshold": 0.85,
|
||||
"summary_model": "google/gemini-2.0-flash-001",
|
||||
},
|
||||
|
||||
"display": {
|
||||
"compact": False,
|
||||
"personality": "kawaii",
|
||||
},
|
||||
|
||||
# Permanently allowed dangerous command patterns (added via "always" approval)
|
||||
"command_allowlist": [],
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 1,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# Config Migration System
|
||||
# =============================================================================
|
||||
|
||||
# Required environment variables with metadata for migration prompts
|
||||
REQUIRED_ENV_VARS = {
|
||||
"OPENROUTER_API_KEY": {
|
||||
"description": "OpenRouter API key (required for vision, web scraping, and tools)",
|
||||
"prompt": "OpenRouter API key",
|
||||
"url": "https://openrouter.ai/keys",
|
||||
"required": True,
|
||||
"password": True,
|
||||
},
|
||||
}
|
||||
|
||||
# Optional environment variables that enhance functionality
|
||||
OPTIONAL_ENV_VARS = {
|
||||
"FIRECRAWL_API_KEY": {
|
||||
"description": "Firecrawl API key for web search and scraping",
|
||||
"prompt": "Firecrawl API key",
|
||||
"url": "https://firecrawl.dev/",
|
||||
"tools": ["web_search", "web_extract"],
|
||||
"password": True,
|
||||
},
|
||||
"BROWSERBASE_API_KEY": {
|
||||
"description": "Browserbase API key for browser automation",
|
||||
"prompt": "Browserbase API key",
|
||||
"url": "https://browserbase.com/",
|
||||
"tools": ["browser_navigate", "browser_click", "etc."],
|
||||
"password": True,
|
||||
},
|
||||
"BROWSERBASE_PROJECT_ID": {
|
||||
"description": "Browserbase project ID",
|
||||
"prompt": "Browserbase project ID",
|
||||
"url": "https://browserbase.com/",
|
||||
"tools": ["browser_navigate", "browser_click", "etc."],
|
||||
"password": False,
|
||||
},
|
||||
"FAL_KEY": {
|
||||
"description": "FAL API key for image generation",
|
||||
"prompt": "FAL API key",
|
||||
"url": "https://fal.ai/",
|
||||
"tools": ["image_generate"],
|
||||
"password": True,
|
||||
},
|
||||
"TINKER_API_KEY": {
|
||||
"description": "Tinker API key for RL training",
|
||||
"prompt": "Tinker API key",
|
||||
"url": "https://tinker-console.thinkingmachines.ai/keys",
|
||||
"tools": ["rl_start_training", "rl_check_status", "rl_stop_training"],
|
||||
"password": True,
|
||||
},
|
||||
"WANDB_API_KEY": {
|
||||
"description": "Weights & Biases API key for experiment tracking",
|
||||
"prompt": "WandB API key",
|
||||
"url": "https://wandb.ai/authorize",
|
||||
"tools": ["rl_get_results", "rl_check_status"],
|
||||
"password": True,
|
||||
},
|
||||
"OPENAI_BASE_URL": {
|
||||
"description": "Custom OpenAI-compatible API endpoint URL",
|
||||
"prompt": "API base URL (e.g., https://api.example.com/v1)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
},
|
||||
"OPENAI_API_KEY": {
|
||||
"description": "API key for custom OpenAI-compatible endpoint",
|
||||
"prompt": "API key for custom endpoint",
|
||||
"url": None,
|
||||
"password": True,
|
||||
},
|
||||
# Messaging platform tokens
|
||||
"TELEGRAM_BOT_TOKEN": {
|
||||
"description": "Telegram bot token from @BotFather",
|
||||
"prompt": "Telegram bot token",
|
||||
"url": "https://t.me/BotFather",
|
||||
"password": True,
|
||||
},
|
||||
"TELEGRAM_ALLOWED_USERS": {
|
||||
"description": "Comma-separated Telegram user IDs allowed to use the bot (get ID from @userinfobot)",
|
||||
"prompt": "Allowed Telegram user IDs (comma-separated)",
|
||||
"url": "https://t.me/userinfobot",
|
||||
"password": False,
|
||||
},
|
||||
"DISCORD_BOT_TOKEN": {
|
||||
"description": "Discord bot token from Developer Portal",
|
||||
"prompt": "Discord bot token",
|
||||
"url": "https://discord.com/developers/applications",
|
||||
"password": True,
|
||||
},
|
||||
"DISCORD_ALLOWED_USERS": {
|
||||
"description": "Comma-separated Discord user IDs allowed to use the bot",
|
||||
"prompt": "Allowed Discord user IDs (comma-separated)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
},
|
||||
# Terminal configuration
|
||||
"MESSAGING_CWD": {
|
||||
"description": "Working directory for terminal commands via messaging (Telegram/Discord/etc). CLI always uses current directory.",
|
||||
"prompt": "Messaging working directory (default: home)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
},
|
||||
"SUDO_PASSWORD": {
|
||||
"description": "Sudo password for terminal commands requiring root access",
|
||||
"prompt": "Sudo password",
|
||||
"url": None,
|
||||
"password": True,
|
||||
},
|
||||
# Agent configuration
|
||||
"HERMES_MAX_ITERATIONS": {
|
||||
"description": "Maximum tool-calling iterations per conversation (default: 60)",
|
||||
"prompt": "Max iterations",
|
||||
"url": None,
|
||||
"password": False,
|
||||
},
|
||||
"HERMES_TOOL_PROGRESS": {
|
||||
"description": "Send tool progress messages in messaging channels (true/false)",
|
||||
"prompt": "Enable tool progress messages",
|
||||
"url": None,
|
||||
"password": False,
|
||||
},
|
||||
"HERMES_TOOL_PROGRESS_MODE": {
|
||||
"description": "Progress mode: 'all' (every tool) or 'new' (only when tool changes)",
|
||||
"prompt": "Progress mode (all/new)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_missing_env_vars(required_only: bool = False) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Check which environment variables are missing.
|
||||
|
||||
Returns list of dicts with var info for missing variables.
|
||||
"""
|
||||
missing = []
|
||||
|
||||
# Check required vars
|
||||
for var_name, info in REQUIRED_ENV_VARS.items():
|
||||
if not get_env_value(var_name):
|
||||
missing.append({"name": var_name, **info, "is_required": True})
|
||||
|
||||
# Check optional vars (if not required_only)
|
||||
if not required_only:
|
||||
for var_name, info in OPTIONAL_ENV_VARS.items():
|
||||
if not get_env_value(var_name):
|
||||
missing.append({"name": var_name, **info, "is_required": False})
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def get_missing_config_fields() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Check which config fields are missing or outdated.
|
||||
|
||||
Returns list of missing/outdated fields.
|
||||
"""
|
||||
config = load_config()
|
||||
missing = []
|
||||
|
||||
# Check for new top-level keys in DEFAULT_CONFIG
|
||||
for key, default_value in DEFAULT_CONFIG.items():
|
||||
if key.startswith('_'):
|
||||
continue # Skip internal keys
|
||||
if key not in config:
|
||||
missing.append({
|
||||
"key": key,
|
||||
"default": default_value,
|
||||
"description": f"New config section: {key}",
|
||||
})
|
||||
elif isinstance(default_value, dict):
|
||||
# Check nested keys
|
||||
for subkey, subvalue in default_value.items():
|
||||
if subkey not in config.get(key, {}):
|
||||
missing.append({
|
||||
"key": f"{key}.{subkey}",
|
||||
"default": subvalue,
|
||||
"description": f"New config option: {key}.{subkey}",
|
||||
})
|
||||
|
||||
return missing
|
||||
|
||||
|
||||
def check_config_version() -> Tuple[int, int]:
|
||||
"""
|
||||
Check config version.
|
||||
|
||||
Returns (current_version, latest_version).
|
||||
"""
|
||||
config = load_config()
|
||||
current = config.get("_config_version", 0)
|
||||
latest = DEFAULT_CONFIG.get("_config_version", 1)
|
||||
return current, latest
|
||||
|
||||
|
||||
def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Migrate config to latest version, prompting for new required fields.
|
||||
|
||||
Args:
|
||||
interactive: If True, prompt user for missing values
|
||||
quiet: If True, suppress output
|
||||
|
||||
Returns:
|
||||
Dict with migration results: {"env_added": [...], "config_added": [...], "warnings": [...]}
|
||||
"""
|
||||
results = {"env_added": [], "config_added": [], "warnings": []}
|
||||
|
||||
# Check config version
|
||||
current_ver, latest_ver = check_config_version()
|
||||
|
||||
if current_ver < latest_ver and not quiet:
|
||||
print(f"Config version: {current_ver} → {latest_ver}")
|
||||
|
||||
# Check for missing required env vars
|
||||
missing_env = get_missing_env_vars(required_only=True)
|
||||
|
||||
if missing_env and not quiet:
|
||||
print("\n⚠️ Missing required environment variables:")
|
||||
for var in missing_env:
|
||||
print(f" • {var['name']}: {var['description']}")
|
||||
|
||||
if interactive and missing_env:
|
||||
print("\nLet's configure them now:\n")
|
||||
for var in missing_env:
|
||||
if var.get("url"):
|
||||
print(f" Get your key at: {var['url']}")
|
||||
|
||||
if var.get("password"):
|
||||
import getpass
|
||||
value = getpass.getpass(f" {var['prompt']}: ")
|
||||
else:
|
||||
value = input(f" {var['prompt']}: ").strip()
|
||||
|
||||
if value:
|
||||
save_env_value(var["name"], value)
|
||||
results["env_added"].append(var["name"])
|
||||
print(f" ✓ Saved {var['name']}")
|
||||
else:
|
||||
results["warnings"].append(f"Skipped {var['name']} - some features may not work")
|
||||
print()
|
||||
|
||||
# Check for missing config fields
|
||||
missing_config = get_missing_config_fields()
|
||||
|
||||
if missing_config:
|
||||
config = load_config()
|
||||
|
||||
for field in missing_config:
|
||||
key = field["key"]
|
||||
default = field["default"]
|
||||
|
||||
# Add with default value
|
||||
if "." in key:
|
||||
# Nested key
|
||||
parent, child = key.split(".", 1)
|
||||
if parent not in config:
|
||||
config[parent] = {}
|
||||
config[parent][child] = default
|
||||
else:
|
||||
config[key] = default
|
||||
|
||||
results["config_added"].append(key)
|
||||
if not quiet:
|
||||
print(f" ✓ Added {key} = {default}")
|
||||
|
||||
# Update version and save
|
||||
config["_config_version"] = latest_ver
|
||||
save_config(config)
|
||||
elif current_ver < latest_ver:
|
||||
# Just update version
|
||||
config = load_config()
|
||||
config["_config_version"] = latest_ver
|
||||
save_config(config)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_config() -> Dict[str, Any]:
|
||||
"""Load configuration from ~/.hermes/config.yaml."""
|
||||
config_path = get_config_path()
|
||||
|
||||
config = DEFAULT_CONFIG.copy()
|
||||
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
|
||||
# Deep merge
|
||||
for key, value in user_config.items():
|
||||
if isinstance(value, dict) and key in config and isinstance(config[key], dict):
|
||||
config[key].update(value)
|
||||
else:
|
||||
config[key] = value
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load config: {e}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def save_config(config: Dict[str, Any]):
|
||||
"""Save configuration to ~/.hermes/config.yaml."""
|
||||
ensure_hermes_home()
|
||||
config_path = get_config_path()
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
|
||||
def load_env() -> Dict[str, str]:
|
||||
"""Load environment variables from ~/.hermes/.env."""
|
||||
env_path = get_env_path()
|
||||
env_vars = {}
|
||||
|
||||
if env_path.exists():
|
||||
with open(env_path) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#') and '=' in line:
|
||||
key, _, value = line.partition('=')
|
||||
env_vars[key.strip()] = value.strip().strip('"\'')
|
||||
|
||||
return env_vars
|
||||
|
||||
|
||||
def save_env_value(key: str, value: str):
|
||||
"""Save or update a value in ~/.hermes/.env."""
|
||||
ensure_hermes_home()
|
||||
env_path = get_env_path()
|
||||
|
||||
# Load existing
|
||||
lines = []
|
||||
if env_path.exists():
|
||||
with open(env_path) as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Find and update or append
|
||||
found = False
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip().startswith(f"{key}="):
|
||||
lines[i] = f"{key}={value}\n"
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
lines.append(f"{key}={value}\n")
|
||||
|
||||
with open(env_path, 'w') as f:
|
||||
f.writelines(lines)
|
||||
|
||||
|
||||
def get_env_value(key: str) -> Optional[str]:
|
||||
"""Get a value from ~/.hermes/.env or environment."""
|
||||
# Check environment first
|
||||
if key in os.environ:
|
||||
return os.environ[key]
|
||||
|
||||
# Then check .env file
|
||||
env_vars = load_env()
|
||||
return env_vars.get(key)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Config display
|
||||
# =============================================================================
|
||||
|
||||
def redact_key(key: str) -> str:
|
||||
"""Redact an API key for display."""
|
||||
if not key:
|
||||
return color("(not set)", Colors.DIM)
|
||||
if len(key) < 12:
|
||||
return "***"
|
||||
return key[:4] + "..." + key[-4:]
|
||||
|
||||
|
||||
def show_config():
|
||||
"""Display current configuration."""
|
||||
config = load_config()
|
||||
env_vars = load_env()
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ 🦋 Hermes Configuration │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
|
||||
# Paths
|
||||
print()
|
||||
print(color("◆ Paths", Colors.CYAN, Colors.BOLD))
|
||||
print(f" Config: {get_config_path()}")
|
||||
print(f" Secrets: {get_env_path()}")
|
||||
print(f" Install: {get_project_root()}")
|
||||
|
||||
# API Keys
|
||||
print()
|
||||
print(color("◆ API Keys", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
keys = [
|
||||
("OPENROUTER_API_KEY", "OpenRouter"),
|
||||
("ANTHROPIC_API_KEY", "Anthropic"),
|
||||
("OPENAI_API_KEY", "OpenAI"),
|
||||
("FIRECRAWL_API_KEY", "Firecrawl"),
|
||||
("BROWSERBASE_API_KEY", "Browserbase"),
|
||||
("FAL_KEY", "FAL"),
|
||||
]
|
||||
|
||||
for env_key, name in keys:
|
||||
value = get_env_value(env_key)
|
||||
print(f" {name:<14} {redact_key(value)}")
|
||||
|
||||
# Model settings
|
||||
print()
|
||||
print(color("◆ Model", Colors.CYAN, Colors.BOLD))
|
||||
print(f" Model: {config.get('model', 'not set')}")
|
||||
print(f" Max turns: {config.get('max_turns', 100)}")
|
||||
print(f" Toolsets: {', '.join(config.get('toolsets', ['all']))}")
|
||||
|
||||
# Terminal
|
||||
print()
|
||||
print(color("◆ Terminal", Colors.CYAN, Colors.BOLD))
|
||||
terminal = config.get('terminal', {})
|
||||
print(f" Backend: {terminal.get('backend', 'local')}")
|
||||
print(f" Working dir: {terminal.get('cwd', '.')}")
|
||||
print(f" Timeout: {terminal.get('timeout', 60)}s")
|
||||
|
||||
if terminal.get('backend') == 'docker':
|
||||
print(f" Docker image: {terminal.get('docker_image', 'python:3.11-slim')}")
|
||||
elif terminal.get('backend') == 'singularity':
|
||||
print(f" Image: {terminal.get('singularity_image', 'docker://python:3.11')}")
|
||||
elif terminal.get('backend') == 'modal':
|
||||
print(f" Modal image: {terminal.get('modal_image', 'python:3.11')}")
|
||||
modal_token = get_env_value('MODAL_TOKEN_ID')
|
||||
print(f" Modal token: {'configured' if modal_token else '(not set)'}")
|
||||
elif terminal.get('backend') == 'ssh':
|
||||
ssh_host = get_env_value('TERMINAL_SSH_HOST')
|
||||
ssh_user = get_env_value('TERMINAL_SSH_USER')
|
||||
print(f" SSH host: {ssh_host or '(not set)'}")
|
||||
print(f" SSH user: {ssh_user or '(not set)'}")
|
||||
|
||||
# Compression
|
||||
print()
|
||||
print(color("◆ Context Compression", Colors.CYAN, Colors.BOLD))
|
||||
compression = config.get('compression', {})
|
||||
enabled = compression.get('enabled', True)
|
||||
print(f" Enabled: {'yes' if enabled else 'no'}")
|
||||
if enabled:
|
||||
print(f" Threshold: {compression.get('threshold', 0.85) * 100:.0f}%")
|
||||
print(f" Model: {compression.get('summary_model', 'google/gemini-2.0-flash-001')}")
|
||||
|
||||
# Messaging
|
||||
print()
|
||||
print(color("◆ Messaging Platforms", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
telegram_token = get_env_value('TELEGRAM_BOT_TOKEN')
|
||||
discord_token = get_env_value('DISCORD_BOT_TOKEN')
|
||||
|
||||
print(f" Telegram: {'configured' if telegram_token else color('not configured', Colors.DIM)}")
|
||||
print(f" Discord: {'configured' if discord_token else color('not configured', Colors.DIM)}")
|
||||
|
||||
print()
|
||||
print(color("─" * 60, Colors.DIM))
|
||||
print(color(" hermes config edit # Edit config file", Colors.DIM))
|
||||
print(color(" hermes config set KEY VALUE", Colors.DIM))
|
||||
print(color(" hermes setup # Run setup wizard", Colors.DIM))
|
||||
print()
|
||||
|
||||
|
||||
def edit_config():
|
||||
"""Open config file in user's editor."""
|
||||
config_path = get_config_path()
|
||||
|
||||
# Ensure config exists
|
||||
if not config_path.exists():
|
||||
save_config(DEFAULT_CONFIG)
|
||||
print(f"Created {config_path}")
|
||||
|
||||
# Find editor
|
||||
editor = os.getenv('EDITOR') or os.getenv('VISUAL')
|
||||
|
||||
if not editor:
|
||||
# Try common editors
|
||||
for cmd in ['nano', 'vim', 'vi', 'code', 'notepad']:
|
||||
import shutil
|
||||
if shutil.which(cmd):
|
||||
editor = cmd
|
||||
break
|
||||
|
||||
if not editor:
|
||||
print(f"No editor found. Config file is at:")
|
||||
print(f" {config_path}")
|
||||
return
|
||||
|
||||
print(f"Opening {config_path} in {editor}...")
|
||||
subprocess.run([editor, str(config_path)])
|
||||
|
||||
|
||||
def set_config_value(key: str, value: str):
|
||||
"""Set a configuration value."""
|
||||
# Check if it's an API key (goes to .env)
|
||||
api_keys = [
|
||||
'OPENROUTER_API_KEY', 'ANTHROPIC_API_KEY', 'OPENAI_API_KEY',
|
||||
'FIRECRAWL_API_KEY', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID',
|
||||
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
|
||||
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
|
||||
'SUDO_PASSWORD'
|
||||
]
|
||||
|
||||
if key.upper() in api_keys or key.upper().startswith('TERMINAL_SSH'):
|
||||
save_env_value(key.upper(), value)
|
||||
print(f"✓ Set {key} in {get_env_path()}")
|
||||
return
|
||||
|
||||
# Otherwise it goes to config.yaml
|
||||
config = load_config()
|
||||
|
||||
# Handle nested keys (e.g., "terminal.backend")
|
||||
parts = key.split('.')
|
||||
current = config
|
||||
|
||||
for part in parts[:-1]:
|
||||
if part not in current:
|
||||
current[part] = {}
|
||||
current = current[part]
|
||||
|
||||
# Convert value to appropriate type
|
||||
if value.lower() in ('true', 'yes', 'on'):
|
||||
value = True
|
||||
elif value.lower() in ('false', 'no', 'off'):
|
||||
value = False
|
||||
elif value.isdigit():
|
||||
value = int(value)
|
||||
elif value.replace('.', '', 1).isdigit():
|
||||
value = float(value)
|
||||
|
||||
current[parts[-1]] = value
|
||||
save_config(config)
|
||||
print(f"✓ Set {key} = {value} in {get_config_path()}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Command handler
|
||||
# =============================================================================
|
||||
|
||||
def config_command(args):
|
||||
"""Handle config subcommands."""
|
||||
subcmd = getattr(args, 'config_command', None)
|
||||
|
||||
if subcmd is None or subcmd == "show":
|
||||
show_config()
|
||||
|
||||
elif subcmd == "edit":
|
||||
edit_config()
|
||||
|
||||
elif subcmd == "set":
|
||||
key = getattr(args, 'key', None)
|
||||
value = getattr(args, 'value', None)
|
||||
if not key or not value:
|
||||
print("Usage: hermes config set KEY VALUE")
|
||||
print()
|
||||
print("Examples:")
|
||||
print(" hermes config set model anthropic/claude-sonnet-4")
|
||||
print(" hermes config set terminal.backend docker")
|
||||
print(" hermes config set OPENROUTER_API_KEY sk-or-...")
|
||||
sys.exit(1)
|
||||
set_config_value(key, value)
|
||||
|
||||
elif subcmd == "path":
|
||||
print(get_config_path())
|
||||
|
||||
elif subcmd == "env-path":
|
||||
print(get_env_path())
|
||||
|
||||
elif subcmd == "migrate":
|
||||
print()
|
||||
print(color("🔄 Checking configuration for updates...", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
|
||||
# Check what's missing
|
||||
missing_env = get_missing_env_vars(required_only=False)
|
||||
missing_config = get_missing_config_fields()
|
||||
current_ver, latest_ver = check_config_version()
|
||||
|
||||
if not missing_env and not missing_config and current_ver >= latest_ver:
|
||||
print(color("✓ Configuration is up to date!", Colors.GREEN))
|
||||
print()
|
||||
return
|
||||
|
||||
# Show what needs to be updated
|
||||
if current_ver < latest_ver:
|
||||
print(f" Config version: {current_ver} → {latest_ver}")
|
||||
|
||||
if missing_config:
|
||||
print(f"\n {len(missing_config)} new config option(s) will be added with defaults")
|
||||
|
||||
required_missing = [v for v in missing_env if v.get("is_required")]
|
||||
optional_missing = [v for v in missing_env if not v.get("is_required")]
|
||||
|
||||
if required_missing:
|
||||
print(f"\n ⚠️ {len(required_missing)} required API key(s) missing:")
|
||||
for var in required_missing:
|
||||
print(f" • {var['name']}")
|
||||
|
||||
if optional_missing:
|
||||
print(f"\n ℹ️ {len(optional_missing)} optional API key(s) not configured:")
|
||||
for var in optional_missing:
|
||||
tools = var.get("tools", [])
|
||||
tools_str = f" (enables: {', '.join(tools[:2])})" if tools else ""
|
||||
print(f" • {var['name']}{tools_str}")
|
||||
|
||||
print()
|
||||
|
||||
# Run migration
|
||||
results = migrate_config(interactive=True, quiet=False)
|
||||
|
||||
print()
|
||||
if results["env_added"] or results["config_added"]:
|
||||
print(color("✓ Configuration updated!", Colors.GREEN))
|
||||
|
||||
if results["warnings"]:
|
||||
print()
|
||||
for warning in results["warnings"]:
|
||||
print(color(f" ⚠️ {warning}", Colors.YELLOW))
|
||||
|
||||
print()
|
||||
|
||||
elif subcmd == "check":
|
||||
# Non-interactive check for what's missing
|
||||
print()
|
||||
print(color("📋 Configuration Status", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
|
||||
current_ver, latest_ver = check_config_version()
|
||||
if current_ver >= latest_ver:
|
||||
print(f" Config version: {current_ver} ✓")
|
||||
else:
|
||||
print(color(f" Config version: {current_ver} → {latest_ver} (update available)", Colors.YELLOW))
|
||||
|
||||
print()
|
||||
print(color(" Required:", Colors.BOLD))
|
||||
for var_name in REQUIRED_ENV_VARS:
|
||||
if get_env_value(var_name):
|
||||
print(f" ✓ {var_name}")
|
||||
else:
|
||||
print(color(f" ✗ {var_name} (missing)", Colors.RED))
|
||||
|
||||
print()
|
||||
print(color(" Optional:", Colors.BOLD))
|
||||
for var_name, info in OPTIONAL_ENV_VARS.items():
|
||||
if get_env_value(var_name):
|
||||
print(f" ✓ {var_name}")
|
||||
else:
|
||||
tools = info.get("tools", [])
|
||||
tools_str = f" → {', '.join(tools[:2])}" if tools else ""
|
||||
print(color(f" ○ {var_name}{tools_str}", Colors.DIM))
|
||||
|
||||
missing_config = get_missing_config_fields()
|
||||
if missing_config:
|
||||
print()
|
||||
print(color(f" {len(missing_config)} new config option(s) available", Colors.YELLOW))
|
||||
print(f" Run 'hermes config migrate' to add them")
|
||||
|
||||
print()
|
||||
|
||||
else:
|
||||
print(f"Unknown config command: {subcmd}")
|
||||
print()
|
||||
print("Available commands:")
|
||||
print(" hermes config Show current configuration")
|
||||
print(" hermes config edit Open config in editor")
|
||||
print(" hermes config set K V Set a config value")
|
||||
print(" hermes config check Check for missing/outdated config")
|
||||
print(" hermes config migrate Update config with new options")
|
||||
print(" hermes config path Show config file path")
|
||||
print(" hermes config env-path Show .env file path")
|
||||
sys.exit(1)
|
||||
131
hermes_cli/cron.py
Normal file
131
hermes_cli/cron.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Cron subcommand for hermes CLI.
|
||||
|
||||
Handles: hermes cron [list|daemon|tick]
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
# ANSI colors
|
||||
class Colors:
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
CYAN = "\033[36m"
|
||||
|
||||
def color(text: str, *codes) -> str:
|
||||
if not sys.stdout.isatty():
|
||||
return text
|
||||
return "".join(codes) + text + Colors.RESET
|
||||
|
||||
|
||||
def cron_list(show_all: bool = False):
|
||||
"""List all scheduled jobs."""
|
||||
from cron.jobs import list_jobs
|
||||
|
||||
jobs = list_jobs(include_disabled=show_all)
|
||||
|
||||
if not jobs:
|
||||
print(color("No scheduled jobs.", Colors.DIM))
|
||||
print(color("Create one with: hermes cron add <schedule> <prompt>", Colors.DIM))
|
||||
return
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ Scheduled Jobs │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
print()
|
||||
|
||||
for job in jobs:
|
||||
job_id = job.get("id", "?")[:8]
|
||||
name = job.get("name", "(unnamed)")
|
||||
schedule = job.get("schedule_display", job.get("schedule", {}).get("value", "?"))
|
||||
enabled = job.get("enabled", True)
|
||||
next_run = job.get("next_run_at", "?")
|
||||
|
||||
# Repeat info
|
||||
repeat_info = job.get("repeat", {})
|
||||
repeat_times = repeat_info.get("times")
|
||||
repeat_completed = repeat_info.get("completed", 0)
|
||||
|
||||
if repeat_times:
|
||||
repeat_str = f"{repeat_completed}/{repeat_times}"
|
||||
else:
|
||||
repeat_str = "∞"
|
||||
|
||||
# Delivery targets
|
||||
deliver = job.get("deliver", ["local"])
|
||||
if isinstance(deliver, str):
|
||||
deliver = [deliver]
|
||||
deliver_str = ", ".join(deliver)
|
||||
|
||||
# Status indicator
|
||||
if not enabled:
|
||||
status = color("[disabled]", Colors.RED)
|
||||
else:
|
||||
status = color("[active]", Colors.GREEN)
|
||||
|
||||
print(f" {color(job_id, Colors.YELLOW)} {status}")
|
||||
print(f" Name: {name}")
|
||||
print(f" Schedule: {schedule}")
|
||||
print(f" Repeat: {repeat_str}")
|
||||
print(f" Next run: {next_run}")
|
||||
print(f" Deliver: {deliver_str}")
|
||||
print()
|
||||
|
||||
|
||||
def cron_daemon(interval: int = 60):
|
||||
"""Run the cron daemon."""
|
||||
from cron.scheduler import start_daemon
|
||||
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ 🦋 Hermes Cron Daemon │", Colors.CYAN))
|
||||
print(color("├─────────────────────────────────────────────────────────┤", Colors.CYAN))
|
||||
print(color("│ Press Ctrl+C to stop │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
print()
|
||||
|
||||
try:
|
||||
start_daemon(interval=interval)
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
print(color("Cron daemon stopped.", Colors.YELLOW))
|
||||
|
||||
|
||||
def cron_tick():
|
||||
"""Run due jobs once (for system cron integration)."""
|
||||
from cron.scheduler import tick
|
||||
|
||||
print(f"[{datetime.now().isoformat()}] Running cron tick...")
|
||||
tick()
|
||||
|
||||
|
||||
def cron_command(args):
|
||||
"""Handle cron subcommands."""
|
||||
subcmd = getattr(args, 'cron_command', None)
|
||||
|
||||
if subcmd is None or subcmd == "list":
|
||||
show_all = getattr(args, 'all', False)
|
||||
cron_list(show_all)
|
||||
|
||||
elif subcmd == "daemon":
|
||||
interval = getattr(args, 'interval', 60)
|
||||
cron_daemon(interval)
|
||||
|
||||
elif subcmd == "tick":
|
||||
cron_tick()
|
||||
|
||||
else:
|
||||
print(f"Unknown cron command: {subcmd}")
|
||||
print("Usage: hermes cron [list|daemon|tick]")
|
||||
sys.exit(1)
|
||||
316
hermes_cli/doctor.py
Normal file
316
hermes_cli/doctor.py
Normal file
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
Doctor command for hermes CLI.
|
||||
|
||||
Diagnoses issues with Hermes Agent setup.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
# ANSI colors
|
||||
class Colors:
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
CYAN = "\033[36m"
|
||||
|
||||
def color(text: str, *codes) -> str:
|
||||
if not sys.stdout.isatty():
|
||||
return text
|
||||
return "".join(codes) + text + Colors.RESET
|
||||
|
||||
def check_ok(text: str, detail: str = ""):
|
||||
print(f" {color('✓', Colors.GREEN)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
def check_warn(text: str, detail: str = ""):
|
||||
print(f" {color('⚠', Colors.YELLOW)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
def check_fail(text: str, detail: str = ""):
|
||||
print(f" {color('✗', Colors.RED)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
def check_info(text: str):
|
||||
print(f" {color('→', Colors.CYAN)} {text}")
|
||||
|
||||
|
||||
def run_doctor(args):
|
||||
"""Run diagnostic checks."""
|
||||
should_fix = getattr(args, 'fix', False)
|
||||
|
||||
issues = []
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ 🩺 Hermes Doctor │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
|
||||
# =========================================================================
|
||||
# Check: Python version
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Python Environment", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
py_version = sys.version_info
|
||||
if py_version >= (3, 10):
|
||||
check_ok(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}")
|
||||
elif py_version >= (3, 8):
|
||||
check_warn(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}", "(3.10+ recommended)")
|
||||
else:
|
||||
check_fail(f"Python {py_version.major}.{py_version.minor}.{py_version.micro}", "(3.10+ required)")
|
||||
issues.append("Upgrade Python to 3.10+")
|
||||
|
||||
# Check if in virtual environment
|
||||
in_venv = sys.prefix != sys.base_prefix
|
||||
if in_venv:
|
||||
check_ok("Virtual environment active")
|
||||
else:
|
||||
check_warn("Not in virtual environment", "(recommended)")
|
||||
|
||||
# =========================================================================
|
||||
# Check: Required packages
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Required Packages", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
required_packages = [
|
||||
("openai", "OpenAI SDK"),
|
||||
("rich", "Rich (terminal UI)"),
|
||||
("dotenv", "python-dotenv"),
|
||||
("yaml", "PyYAML"),
|
||||
("httpx", "HTTPX"),
|
||||
]
|
||||
|
||||
optional_packages = [
|
||||
("croniter", "Croniter (cron expressions)"),
|
||||
("browserbase", "Browserbase SDK"),
|
||||
("telegram", "python-telegram-bot"),
|
||||
("discord", "discord.py"),
|
||||
]
|
||||
|
||||
for module, name in required_packages:
|
||||
try:
|
||||
__import__(module)
|
||||
check_ok(name)
|
||||
except ImportError:
|
||||
check_fail(name, "(missing)")
|
||||
issues.append(f"Install {name}: pip install {module}")
|
||||
|
||||
for module, name in optional_packages:
|
||||
try:
|
||||
__import__(module)
|
||||
check_ok(name, "(optional)")
|
||||
except ImportError:
|
||||
check_warn(name, "(optional, not installed)")
|
||||
|
||||
# =========================================================================
|
||||
# Check: Configuration files
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Configuration Files", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
env_path = PROJECT_ROOT / '.env'
|
||||
if env_path.exists():
|
||||
check_ok(".env file exists")
|
||||
|
||||
# Check for common issues
|
||||
content = env_path.read_text()
|
||||
if "OPENROUTER_API_KEY" in content or "ANTHROPIC_API_KEY" in content:
|
||||
check_ok("API key configured")
|
||||
else:
|
||||
check_warn("No API key found in .env")
|
||||
issues.append("Run 'hermes setup' to configure API keys")
|
||||
else:
|
||||
check_fail(".env file missing")
|
||||
check_info("Run 'hermes setup' to create one")
|
||||
issues.append("Run 'hermes setup' to create .env")
|
||||
|
||||
config_path = PROJECT_ROOT / 'cli-config.yaml'
|
||||
if config_path.exists():
|
||||
check_ok("cli-config.yaml exists")
|
||||
else:
|
||||
check_warn("cli-config.yaml not found", "(using defaults)")
|
||||
|
||||
# =========================================================================
|
||||
# Check: Directory structure
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Directory Structure", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
hermes_home = Path.home() / ".hermes"
|
||||
if hermes_home.exists():
|
||||
check_ok("~/.hermes directory exists")
|
||||
else:
|
||||
check_warn("~/.hermes not found", "(will be created on first use)")
|
||||
|
||||
logs_dir = PROJECT_ROOT / "logs"
|
||||
if logs_dir.exists():
|
||||
check_ok("logs/ directory exists")
|
||||
else:
|
||||
check_warn("logs/ not found", "(will be created on first use)")
|
||||
|
||||
# =========================================================================
|
||||
# Check: External tools
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ External Tools", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
# Git
|
||||
if shutil.which("git"):
|
||||
check_ok("git")
|
||||
else:
|
||||
check_warn("git not found", "(optional)")
|
||||
|
||||
# ripgrep (optional, for faster file search)
|
||||
if shutil.which("rg"):
|
||||
check_ok("ripgrep (rg)", "(faster file search)")
|
||||
else:
|
||||
check_warn("ripgrep (rg) not found", "(file search uses grep fallback)")
|
||||
check_info("Install for faster search: sudo apt install ripgrep")
|
||||
|
||||
# Docker (optional)
|
||||
terminal_env = os.getenv("TERMINAL_ENV", "local")
|
||||
if terminal_env == "docker":
|
||||
if shutil.which("docker"):
|
||||
# Check if docker daemon is running
|
||||
result = subprocess.run(["docker", "info"], capture_output=True)
|
||||
if result.returncode == 0:
|
||||
check_ok("docker", "(daemon running)")
|
||||
else:
|
||||
check_fail("docker daemon not running")
|
||||
issues.append("Start Docker daemon")
|
||||
else:
|
||||
check_fail("docker not found", "(required for TERMINAL_ENV=docker)")
|
||||
issues.append("Install Docker or change TERMINAL_ENV")
|
||||
else:
|
||||
if shutil.which("docker"):
|
||||
check_ok("docker", "(optional)")
|
||||
else:
|
||||
check_warn("docker not found", "(optional)")
|
||||
|
||||
# SSH (if using ssh backend)
|
||||
if terminal_env == "ssh":
|
||||
ssh_host = os.getenv("TERMINAL_SSH_HOST")
|
||||
if ssh_host:
|
||||
# Try to connect
|
||||
result = subprocess.run(
|
||||
["ssh", "-o", "ConnectTimeout=5", "-o", "BatchMode=yes", ssh_host, "echo ok"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
check_ok(f"SSH connection to {ssh_host}")
|
||||
else:
|
||||
check_fail(f"SSH connection to {ssh_host}")
|
||||
issues.append(f"Check SSH configuration for {ssh_host}")
|
||||
else:
|
||||
check_fail("TERMINAL_SSH_HOST not set", "(required for TERMINAL_ENV=ssh)")
|
||||
issues.append("Set TERMINAL_SSH_HOST in .env")
|
||||
|
||||
# =========================================================================
|
||||
# Check: API connectivity
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ API Connectivity", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
openrouter_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if openrouter_key:
|
||||
try:
|
||||
import httpx
|
||||
response = httpx.get(
|
||||
"https://openrouter.ai/api/v1/models",
|
||||
headers={"Authorization": f"Bearer {openrouter_key}"},
|
||||
timeout=10
|
||||
)
|
||||
if response.status_code == 200:
|
||||
check_ok("OpenRouter API")
|
||||
elif response.status_code == 401:
|
||||
check_fail("OpenRouter API", "(invalid API key)")
|
||||
issues.append("Check OPENROUTER_API_KEY in .env")
|
||||
else:
|
||||
check_fail("OpenRouter API", f"(HTTP {response.status_code})")
|
||||
except Exception as e:
|
||||
check_fail("OpenRouter API", f"({e})")
|
||||
issues.append("Check network connectivity")
|
||||
else:
|
||||
check_warn("OpenRouter API", "(not configured)")
|
||||
|
||||
anthropic_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if anthropic_key:
|
||||
try:
|
||||
import httpx
|
||||
response = httpx.get(
|
||||
"https://api.anthropic.com/v1/models",
|
||||
headers={
|
||||
"x-api-key": anthropic_key,
|
||||
"anthropic-version": "2023-06-01"
|
||||
},
|
||||
timeout=10
|
||||
)
|
||||
if response.status_code == 200:
|
||||
check_ok("Anthropic API")
|
||||
elif response.status_code == 401:
|
||||
check_fail("Anthropic API", "(invalid API key)")
|
||||
else:
|
||||
# Note: Anthropic may not have /models endpoint
|
||||
check_warn("Anthropic API", "(couldn't verify)")
|
||||
except Exception as e:
|
||||
check_warn("Anthropic API", f"({e})")
|
||||
|
||||
# =========================================================================
|
||||
# Check: Tool Availability
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Tool Availability", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
try:
|
||||
# Add project root to path for imports
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
|
||||
|
||||
available, unavailable = check_tool_availability()
|
||||
|
||||
for tid in available:
|
||||
info = TOOLSET_REQUIREMENTS.get(tid, {})
|
||||
check_ok(info.get("name", tid))
|
||||
|
||||
for item in unavailable:
|
||||
if item["missing_vars"]:
|
||||
vars_str = ", ".join(item["missing_vars"])
|
||||
check_warn(item["name"], f"(missing {vars_str})")
|
||||
else:
|
||||
check_warn(item["name"], "(system dependency not met)")
|
||||
|
||||
# Count disabled tools with API key requirements
|
||||
api_disabled = [u for u in unavailable if u["missing_vars"]]
|
||||
if api_disabled:
|
||||
issues.append("Run 'hermes setup' to configure missing API keys for full tool access")
|
||||
except Exception as e:
|
||||
check_warn("Could not check tool availability", f"({e})")
|
||||
|
||||
# =========================================================================
|
||||
# Summary
|
||||
# =========================================================================
|
||||
print()
|
||||
if issues:
|
||||
print(color("─" * 60, Colors.YELLOW))
|
||||
print(color(f" Found {len(issues)} issue(s) to address:", Colors.YELLOW, Colors.BOLD))
|
||||
print()
|
||||
for i, issue in enumerate(issues, 1):
|
||||
print(f" {i}. {issue}")
|
||||
print()
|
||||
|
||||
if should_fix:
|
||||
print(color(" Attempting auto-fix is not yet implemented.", Colors.DIM))
|
||||
print(color(" Please resolve issues manually.", Colors.DIM))
|
||||
else:
|
||||
print(color("─" * 60, Colors.GREEN))
|
||||
print(color(" All checks passed! 🎉", Colors.GREEN, Colors.BOLD))
|
||||
|
||||
print()
|
||||
487
hermes_cli/gateway.py
Normal file
487
hermes_cli/gateway.py
Normal file
@@ -0,0 +1,487 @@
|
||||
"""
|
||||
Gateway subcommand for hermes CLI.
|
||||
|
||||
Handles: hermes gateway [run|start|stop|restart|status|install|uninstall]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Process Management (for manual gateway runs)
|
||||
# =============================================================================
|
||||
|
||||
def find_gateway_pids() -> list:
|
||||
"""Find PIDs of running gateway processes."""
|
||||
pids = []
|
||||
try:
|
||||
# Look for gateway processes with multiple patterns
|
||||
patterns = [
|
||||
"hermes_cli.main gateway",
|
||||
"hermes gateway",
|
||||
"gateway/run.py",
|
||||
]
|
||||
|
||||
result = subprocess.run(
|
||||
["ps", "aux"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
for line in result.stdout.split('\n'):
|
||||
# Skip grep and current process
|
||||
if 'grep' in line or str(os.getpid()) in line:
|
||||
continue
|
||||
|
||||
for pattern in patterns:
|
||||
if pattern in line:
|
||||
parts = line.split()
|
||||
if len(parts) > 1:
|
||||
try:
|
||||
pid = int(parts[1])
|
||||
if pid not in pids:
|
||||
pids.append(pid)
|
||||
except ValueError:
|
||||
continue
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return pids
|
||||
|
||||
|
||||
def kill_gateway_processes(force: bool = False) -> int:
|
||||
"""Kill any running gateway processes. Returns count killed."""
|
||||
pids = find_gateway_pids()
|
||||
killed = 0
|
||||
|
||||
for pid in pids:
|
||||
try:
|
||||
if force:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
else:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
killed += 1
|
||||
except ProcessLookupError:
|
||||
# Process already gone
|
||||
pass
|
||||
except PermissionError:
|
||||
print(f"⚠ Permission denied to kill PID {pid}")
|
||||
|
||||
return killed
|
||||
|
||||
|
||||
def is_linux() -> bool:
|
||||
return sys.platform.startswith('linux')
|
||||
|
||||
def is_macos() -> bool:
|
||||
return sys.platform == 'darwin'
|
||||
|
||||
def is_windows() -> bool:
|
||||
return sys.platform == 'win32'
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Service Configuration
|
||||
# =============================================================================
|
||||
|
||||
SERVICE_NAME = "hermes-gateway"
|
||||
SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
|
||||
|
||||
def get_systemd_unit_path() -> Path:
|
||||
return Path.home() / ".config" / "systemd" / "user" / f"{SERVICE_NAME}.service"
|
||||
|
||||
def get_launchd_plist_path() -> Path:
|
||||
return Path.home() / "Library" / "LaunchAgents" / "ai.hermes.gateway.plist"
|
||||
|
||||
def get_python_path() -> str:
|
||||
venv_python = PROJECT_ROOT / "venv" / "bin" / "python"
|
||||
if venv_python.exists():
|
||||
return str(venv_python)
|
||||
return sys.executable
|
||||
|
||||
def get_hermes_cli_path() -> str:
|
||||
"""Get the path to the hermes CLI."""
|
||||
# Check if installed via pip
|
||||
import shutil
|
||||
hermes_bin = shutil.which("hermes")
|
||||
if hermes_bin:
|
||||
return hermes_bin
|
||||
|
||||
# Fallback to direct module execution
|
||||
return f"{get_python_path()} -m hermes_cli.main"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Systemd (Linux)
|
||||
# =============================================================================
|
||||
|
||||
def generate_systemd_unit() -> str:
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run
|
||||
WorkingDirectory={working_dir}
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
"""
|
||||
|
||||
def systemd_install(force: bool = False):
|
||||
unit_path = get_systemd_unit_path()
|
||||
|
||||
if unit_path.exists() and not force:
|
||||
print(f"Service already installed at: {unit_path}")
|
||||
print("Use --force to reinstall")
|
||||
return
|
||||
|
||||
unit_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Installing systemd service to: {unit_path}")
|
||||
unit_path.write_text(generate_systemd_unit())
|
||||
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
subprocess.run(["systemctl", "--user", "enable", SERVICE_NAME], check=True)
|
||||
|
||||
print()
|
||||
print("✓ Service installed and enabled!")
|
||||
print()
|
||||
print("Next steps:")
|
||||
print(f" hermes gateway start # Start the service")
|
||||
print(f" hermes gateway status # Check status")
|
||||
print(f" journalctl --user -u {SERVICE_NAME} -f # View logs")
|
||||
print()
|
||||
print("To enable lingering (keeps running after logout):")
|
||||
print(" sudo loginctl enable-linger $USER")
|
||||
|
||||
def systemd_uninstall():
|
||||
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=False)
|
||||
subprocess.run(["systemctl", "--user", "disable", SERVICE_NAME], check=False)
|
||||
|
||||
unit_path = get_systemd_unit_path()
|
||||
if unit_path.exists():
|
||||
unit_path.unlink()
|
||||
print(f"✓ Removed {unit_path}")
|
||||
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
print("✓ Service uninstalled")
|
||||
|
||||
def systemd_start():
|
||||
subprocess.run(["systemctl", "--user", "start", SERVICE_NAME], check=True)
|
||||
print("✓ Service started")
|
||||
|
||||
def systemd_stop():
|
||||
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=True)
|
||||
print("✓ Service stopped")
|
||||
|
||||
def systemd_restart():
|
||||
subprocess.run(["systemctl", "--user", "restart", SERVICE_NAME], check=True)
|
||||
print("✓ Service restarted")
|
||||
|
||||
def systemd_status(deep: bool = False):
|
||||
# Check if service unit file exists
|
||||
unit_path = get_systemd_unit_path()
|
||||
if not unit_path.exists():
|
||||
print("✗ Gateway service is not installed")
|
||||
print(" Run: hermes gateway install")
|
||||
return
|
||||
|
||||
# Show detailed status first
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "status", SERVICE_NAME, "--no-pager"],
|
||||
capture_output=False
|
||||
)
|
||||
|
||||
# Check if service is active
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", SERVICE_NAME],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
status = result.stdout.strip()
|
||||
|
||||
if status == "active":
|
||||
print("✓ Gateway service is running")
|
||||
else:
|
||||
print("✗ Gateway service is stopped")
|
||||
print(" Run: hermes gateway start")
|
||||
|
||||
if deep:
|
||||
print()
|
||||
print("Recent logs:")
|
||||
subprocess.run([
|
||||
"journalctl", "--user", "-u", SERVICE_NAME,
|
||||
"-n", "20", "--no-pager"
|
||||
])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Launchd (macOS)
|
||||
# =============================================================================
|
||||
|
||||
def generate_launchd_plist() -> str:
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
log_dir = Path.home() / ".hermes" / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>ai.hermes.gateway</string>
|
||||
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>{python_path}</string>
|
||||
<string>-m</string>
|
||||
<string>hermes_cli.main</string>
|
||||
<string>gateway</string>
|
||||
<string>run</string>
|
||||
</array>
|
||||
|
||||
<key>WorkingDirectory</key>
|
||||
<string>{working_dir}</string>
|
||||
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
|
||||
<key>KeepAlive</key>
|
||||
<dict>
|
||||
<key>SuccessfulExit</key>
|
||||
<false/>
|
||||
</dict>
|
||||
|
||||
<key>StandardOutPath</key>
|
||||
<string>{log_dir}/gateway.log</string>
|
||||
|
||||
<key>StandardErrorPath</key>
|
||||
<string>{log_dir}/gateway.error.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
"""
|
||||
|
||||
def launchd_install(force: bool = False):
|
||||
plist_path = get_launchd_plist_path()
|
||||
|
||||
if plist_path.exists() and not force:
|
||||
print(f"Service already installed at: {plist_path}")
|
||||
print("Use --force to reinstall")
|
||||
return
|
||||
|
||||
plist_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Installing launchd service to: {plist_path}")
|
||||
plist_path.write_text(generate_launchd_plist())
|
||||
|
||||
subprocess.run(["launchctl", "load", str(plist_path)], check=True)
|
||||
|
||||
print()
|
||||
print("✓ Service installed and loaded!")
|
||||
print()
|
||||
print("Next steps:")
|
||||
print(" hermes gateway status # Check status")
|
||||
print(" tail -f ~/.hermes/logs/gateway.log # View logs")
|
||||
|
||||
def launchd_uninstall():
|
||||
plist_path = get_launchd_plist_path()
|
||||
subprocess.run(["launchctl", "unload", str(plist_path)], check=False)
|
||||
|
||||
if plist_path.exists():
|
||||
plist_path.unlink()
|
||||
print(f"✓ Removed {plist_path}")
|
||||
|
||||
print("✓ Service uninstalled")
|
||||
|
||||
def launchd_start():
|
||||
subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True)
|
||||
print("✓ Service started")
|
||||
|
||||
def launchd_stop():
|
||||
subprocess.run(["launchctl", "stop", "ai.hermes.gateway"], check=True)
|
||||
print("✓ Service stopped")
|
||||
|
||||
def launchd_restart():
|
||||
launchd_stop()
|
||||
launchd_start()
|
||||
|
||||
def launchd_status(deep: bool = False):
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
print("✓ Gateway service is loaded")
|
||||
print(result.stdout)
|
||||
else:
|
||||
print("✗ Gateway service is not loaded")
|
||||
|
||||
if deep:
|
||||
log_file = Path.home() / ".hermes" / "logs" / "gateway.log"
|
||||
if log_file.exists():
|
||||
print()
|
||||
print("Recent logs:")
|
||||
subprocess.run(["tail", "-20", str(log_file)])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Gateway Runner
|
||||
# =============================================================================
|
||||
|
||||
def run_gateway(verbose: bool = False):
|
||||
"""Run the gateway in foreground."""
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from gateway.run import start_gateway
|
||||
|
||||
print("┌─────────────────────────────────────────────────────────┐")
|
||||
print("│ 🦋 Hermes Gateway Starting... │")
|
||||
print("├─────────────────────────────────────────────────────────┤")
|
||||
print("│ Press Ctrl+C to stop │")
|
||||
print("└─────────────────────────────────────────────────────────┘")
|
||||
print()
|
||||
|
||||
asyncio.run(start_gateway())
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Command Handler
|
||||
# =============================================================================
|
||||
|
||||
def gateway_command(args):
|
||||
"""Handle gateway subcommands."""
|
||||
subcmd = getattr(args, 'gateway_command', None)
|
||||
|
||||
# Default to run if no subcommand
|
||||
if subcmd is None or subcmd == "run":
|
||||
verbose = getattr(args, 'verbose', False)
|
||||
run_gateway(verbose)
|
||||
return
|
||||
|
||||
# Service management commands
|
||||
if subcmd == "install":
|
||||
force = getattr(args, 'force', False)
|
||||
if is_linux():
|
||||
systemd_install(force)
|
||||
elif is_macos():
|
||||
launchd_install(force)
|
||||
else:
|
||||
print("Service installation not supported on this platform.")
|
||||
print("Run manually: hermes gateway run")
|
||||
sys.exit(1)
|
||||
|
||||
elif subcmd == "uninstall":
|
||||
if is_linux():
|
||||
systemd_uninstall()
|
||||
elif is_macos():
|
||||
launchd_uninstall()
|
||||
else:
|
||||
print("Not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
elif subcmd == "start":
|
||||
if is_linux():
|
||||
systemd_start()
|
||||
elif is_macos():
|
||||
launchd_start()
|
||||
else:
|
||||
print("Not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
elif subcmd == "stop":
|
||||
# Try service first, fall back to killing processes directly
|
||||
service_available = False
|
||||
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
try:
|
||||
systemd_stop()
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass # Fall through to process kill
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
try:
|
||||
launchd_stop()
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
if not service_available:
|
||||
# Kill gateway processes directly
|
||||
killed = kill_gateway_processes()
|
||||
if killed:
|
||||
print(f"✓ Stopped {killed} gateway process(es)")
|
||||
else:
|
||||
print("✗ No gateway processes found")
|
||||
|
||||
elif subcmd == "restart":
|
||||
# Try service first, fall back to killing and restarting
|
||||
service_available = False
|
||||
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
try:
|
||||
systemd_restart()
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
try:
|
||||
launchd_restart()
|
||||
service_available = True
|
||||
except subprocess.CalledProcessError:
|
||||
pass
|
||||
|
||||
if not service_available:
|
||||
# Manual restart: kill existing processes
|
||||
killed = kill_gateway_processes()
|
||||
if killed:
|
||||
print(f"✓ Stopped {killed} gateway process(es)")
|
||||
|
||||
import time
|
||||
time.sleep(2)
|
||||
|
||||
# Start fresh
|
||||
print("Starting gateway...")
|
||||
run_gateway(verbose=False)
|
||||
|
||||
elif subcmd == "status":
|
||||
deep = getattr(args, 'deep', False)
|
||||
|
||||
# Check for service first
|
||||
if is_linux() and get_systemd_unit_path().exists():
|
||||
systemd_status(deep)
|
||||
elif is_macos() and get_launchd_plist_path().exists():
|
||||
launchd_status(deep)
|
||||
else:
|
||||
# Check for manually running processes
|
||||
pids = find_gateway_pids()
|
||||
if pids:
|
||||
print(f"✓ Gateway is running (PID: {', '.join(map(str, pids))})")
|
||||
print(" (Running manually, not as a system service)")
|
||||
print()
|
||||
print("To install as a service:")
|
||||
print(" hermes gateway install")
|
||||
else:
|
||||
print("✗ Gateway is not running")
|
||||
print()
|
||||
print("To start:")
|
||||
print(" hermes gateway # Run in foreground")
|
||||
print(" hermes gateway install # Install as service")
|
||||
507
hermes_cli/main.py
Normal file
507
hermes_cli/main.py
Normal file
@@ -0,0 +1,507 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Hermes CLI - Main entry point.
|
||||
|
||||
Usage:
|
||||
hermes # Interactive chat (default)
|
||||
hermes chat # Interactive chat
|
||||
hermes gateway # Run gateway in foreground
|
||||
hermes gateway start # Start gateway as service
|
||||
hermes gateway stop # Stop gateway service
|
||||
hermes gateway status # Show gateway status
|
||||
hermes gateway install # Install gateway service
|
||||
hermes gateway uninstall # Uninstall gateway service
|
||||
hermes setup # Interactive setup wizard
|
||||
hermes status # Show status of all components
|
||||
hermes cron # Manage cron jobs
|
||||
hermes cron list # List cron jobs
|
||||
hermes cron daemon # Run cron daemon
|
||||
hermes doctor # Check configuration and dependencies
|
||||
hermes version # Show version
|
||||
hermes update # Update to latest version
|
||||
hermes uninstall # Uninstall Hermes Agent
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
# Load .env file
|
||||
from dotenv import load_dotenv
|
||||
env_path = PROJECT_ROOT / '.env'
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
from hermes_cli import __version__
|
||||
|
||||
|
||||
def cmd_chat(args):
|
||||
"""Run interactive chat CLI."""
|
||||
# Import and run the CLI
|
||||
from cli import main as cli_main
|
||||
|
||||
# Build kwargs from args
|
||||
kwargs = {
|
||||
"model": args.model,
|
||||
"toolsets": args.toolsets,
|
||||
"verbose": args.verbose,
|
||||
"query": args.query,
|
||||
}
|
||||
# Filter out None values
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
cli_main(**kwargs)
|
||||
|
||||
|
||||
def cmd_gateway(args):
|
||||
"""Gateway management commands."""
|
||||
from hermes_cli.gateway import gateway_command
|
||||
gateway_command(args)
|
||||
|
||||
|
||||
def cmd_setup(args):
|
||||
"""Interactive setup wizard."""
|
||||
from hermes_cli.setup import run_setup_wizard
|
||||
run_setup_wizard(args)
|
||||
|
||||
|
||||
def cmd_status(args):
|
||||
"""Show status of all components."""
|
||||
from hermes_cli.status import show_status
|
||||
show_status(args)
|
||||
|
||||
|
||||
def cmd_cron(args):
|
||||
"""Cron job management."""
|
||||
from hermes_cli.cron import cron_command
|
||||
cron_command(args)
|
||||
|
||||
|
||||
def cmd_doctor(args):
|
||||
"""Check configuration and dependencies."""
|
||||
from hermes_cli.doctor import run_doctor
|
||||
run_doctor(args)
|
||||
|
||||
|
||||
def cmd_config(args):
|
||||
"""Configuration management."""
|
||||
from hermes_cli.config import config_command
|
||||
config_command(args)
|
||||
|
||||
|
||||
def cmd_version(args):
|
||||
"""Show version."""
|
||||
print(f"Hermes Agent v{__version__}")
|
||||
print(f"Project: {PROJECT_ROOT}")
|
||||
|
||||
# Show Python version
|
||||
print(f"Python: {sys.version.split()[0]}")
|
||||
|
||||
# Check for key dependencies
|
||||
try:
|
||||
import openai
|
||||
print(f"OpenAI SDK: {openai.__version__}")
|
||||
except ImportError:
|
||||
print("OpenAI SDK: Not installed")
|
||||
|
||||
|
||||
def cmd_uninstall(args):
|
||||
"""Uninstall Hermes Agent."""
|
||||
from hermes_cli.uninstall import run_uninstall
|
||||
run_uninstall(args)
|
||||
|
||||
|
||||
def cmd_update(args):
|
||||
"""Update Hermes Agent to the latest version."""
|
||||
import subprocess
|
||||
|
||||
print("🦋 Updating Hermes Agent...")
|
||||
print()
|
||||
|
||||
# Check if we're in a git repo
|
||||
git_dir = PROJECT_ROOT / '.git'
|
||||
if not git_dir.exists():
|
||||
print("✗ Not a git repository. Please reinstall:")
|
||||
print(" curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash")
|
||||
sys.exit(1)
|
||||
|
||||
# Fetch and pull
|
||||
try:
|
||||
print("→ Fetching updates...")
|
||||
subprocess.run(["git", "fetch", "origin"], cwd=PROJECT_ROOT, check=True)
|
||||
|
||||
# Get current branch
|
||||
result = subprocess.run(
|
||||
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
||||
cwd=PROJECT_ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
branch = result.stdout.strip()
|
||||
|
||||
# Check if there are updates
|
||||
result = subprocess.run(
|
||||
["git", "rev-list", f"HEAD..origin/{branch}", "--count"],
|
||||
cwd=PROJECT_ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
commit_count = int(result.stdout.strip())
|
||||
|
||||
if commit_count == 0:
|
||||
print("✓ Already up to date!")
|
||||
return
|
||||
|
||||
print(f"→ Found {commit_count} new commit(s)")
|
||||
print("→ Pulling updates...")
|
||||
subprocess.run(["git", "pull", "origin", branch], cwd=PROJECT_ROOT, check=True)
|
||||
|
||||
# Reinstall Python dependencies
|
||||
print("→ Updating Python dependencies...")
|
||||
venv_pip = PROJECT_ROOT / "venv" / "bin" / "pip"
|
||||
if venv_pip.exists():
|
||||
subprocess.run([str(venv_pip), "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
else:
|
||||
subprocess.run(["pip", "install", "-e", ".", "--quiet"], cwd=PROJECT_ROOT, check=True)
|
||||
|
||||
# Check for Node.js deps
|
||||
if (PROJECT_ROOT / "package.json").exists():
|
||||
import shutil
|
||||
if shutil.which("npm"):
|
||||
print("→ Updating Node.js dependencies...")
|
||||
subprocess.run(["npm", "install", "--silent"], cwd=PROJECT_ROOT, check=False)
|
||||
|
||||
print()
|
||||
print("✓ Code updated!")
|
||||
|
||||
# Check for config migrations
|
||||
print()
|
||||
print("→ Checking configuration for new options...")
|
||||
|
||||
from hermes_cli.config import (
|
||||
get_missing_env_vars, get_missing_config_fields,
|
||||
check_config_version, migrate_config
|
||||
)
|
||||
|
||||
missing_env = get_missing_env_vars(required_only=True)
|
||||
missing_config = get_missing_config_fields()
|
||||
current_ver, latest_ver = check_config_version()
|
||||
|
||||
needs_migration = missing_env or missing_config or current_ver < latest_ver
|
||||
|
||||
if needs_migration:
|
||||
print()
|
||||
if missing_env:
|
||||
print(f" ⚠️ {len(missing_env)} new required setting(s) need configuration")
|
||||
if missing_config:
|
||||
print(f" ℹ️ {len(missing_config)} new config option(s) available")
|
||||
|
||||
print()
|
||||
response = input("Would you like to configure them now? [Y/n]: ").strip().lower()
|
||||
|
||||
if response in ('', 'y', 'yes'):
|
||||
print()
|
||||
results = migrate_config(interactive=True, quiet=False)
|
||||
|
||||
if results["env_added"] or results["config_added"]:
|
||||
print()
|
||||
print("✓ Configuration updated!")
|
||||
else:
|
||||
print()
|
||||
print("Skipped. Run 'hermes config migrate' later to configure.")
|
||||
else:
|
||||
print(" ✓ Configuration is up to date")
|
||||
|
||||
print()
|
||||
print("✓ Update complete!")
|
||||
print()
|
||||
print("Note: If you have the gateway service running, restart it:")
|
||||
print(" hermes gateway restart")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"✗ Update failed: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for hermes CLI."""
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="hermes",
|
||||
description="Hermes Agent - AI assistant with tool-calling capabilities",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
hermes Start interactive chat
|
||||
hermes chat -q "Hello" Single query mode
|
||||
hermes setup Run setup wizard
|
||||
hermes config View configuration
|
||||
hermes config edit Edit config in $EDITOR
|
||||
hermes config set model gpt-4 Set a config value
|
||||
hermes gateway Run messaging gateway
|
||||
hermes gateway install Install as system service
|
||||
hermes update Update to latest version
|
||||
|
||||
For more help on a command:
|
||||
hermes <command> --help
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--version", "-V",
|
||||
action="store_true",
|
||||
help="Show version and exit"
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
|
||||
# =========================================================================
|
||||
# chat command
|
||||
# =========================================================================
|
||||
chat_parser = subparsers.add_parser(
|
||||
"chat",
|
||||
help="Interactive chat with the agent",
|
||||
description="Start an interactive chat session with Hermes Agent"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"-q", "--query",
|
||||
help="Single query (non-interactive mode)"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"-m", "--model",
|
||||
help="Model to use (e.g., anthropic/claude-sonnet-4)"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"-t", "--toolsets",
|
||||
help="Comma-separated toolsets to enable"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"-v", "--verbose",
|
||||
action="store_true",
|
||||
help="Verbose output"
|
||||
)
|
||||
chat_parser.set_defaults(func=cmd_chat)
|
||||
|
||||
# =========================================================================
|
||||
# gateway command
|
||||
# =========================================================================
|
||||
gateway_parser = subparsers.add_parser(
|
||||
"gateway",
|
||||
help="Messaging gateway management",
|
||||
description="Manage the messaging gateway (Telegram, Discord, WhatsApp)"
|
||||
)
|
||||
gateway_subparsers = gateway_parser.add_subparsers(dest="gateway_command")
|
||||
|
||||
# gateway run (default)
|
||||
gateway_run = gateway_subparsers.add_parser("run", help="Run gateway in foreground")
|
||||
gateway_run.add_argument("-v", "--verbose", action="store_true")
|
||||
|
||||
# gateway start
|
||||
gateway_start = gateway_subparsers.add_parser("start", help="Start gateway service")
|
||||
|
||||
# gateway stop
|
||||
gateway_stop = gateway_subparsers.add_parser("stop", help="Stop gateway service")
|
||||
|
||||
# gateway restart
|
||||
gateway_restart = gateway_subparsers.add_parser("restart", help="Restart gateway service")
|
||||
|
||||
# gateway status
|
||||
gateway_status = gateway_subparsers.add_parser("status", help="Show gateway status")
|
||||
gateway_status.add_argument("--deep", action="store_true", help="Deep status check")
|
||||
|
||||
# gateway install
|
||||
gateway_install = gateway_subparsers.add_parser("install", help="Install gateway as service")
|
||||
gateway_install.add_argument("--force", action="store_true", help="Force reinstall")
|
||||
|
||||
# gateway uninstall
|
||||
gateway_uninstall = gateway_subparsers.add_parser("uninstall", help="Uninstall gateway service")
|
||||
|
||||
gateway_parser.set_defaults(func=cmd_gateway)
|
||||
|
||||
# =========================================================================
|
||||
# setup command
|
||||
# =========================================================================
|
||||
setup_parser = subparsers.add_parser(
|
||||
"setup",
|
||||
help="Interactive setup wizard",
|
||||
description="Configure Hermes Agent with an interactive wizard"
|
||||
)
|
||||
setup_parser.add_argument(
|
||||
"--non-interactive",
|
||||
action="store_true",
|
||||
help="Non-interactive mode (use defaults/env vars)"
|
||||
)
|
||||
setup_parser.add_argument(
|
||||
"--reset",
|
||||
action="store_true",
|
||||
help="Reset configuration to defaults"
|
||||
)
|
||||
setup_parser.set_defaults(func=cmd_setup)
|
||||
|
||||
# =========================================================================
|
||||
# status command
|
||||
# =========================================================================
|
||||
status_parser = subparsers.add_parser(
|
||||
"status",
|
||||
help="Show status of all components",
|
||||
description="Display status of Hermes Agent components"
|
||||
)
|
||||
status_parser.add_argument(
|
||||
"--all",
|
||||
action="store_true",
|
||||
help="Show all details (redacted for sharing)"
|
||||
)
|
||||
status_parser.add_argument(
|
||||
"--deep",
|
||||
action="store_true",
|
||||
help="Run deep checks (may take longer)"
|
||||
)
|
||||
status_parser.set_defaults(func=cmd_status)
|
||||
|
||||
# =========================================================================
|
||||
# cron command
|
||||
# =========================================================================
|
||||
cron_parser = subparsers.add_parser(
|
||||
"cron",
|
||||
help="Cron job management",
|
||||
description="Manage scheduled tasks"
|
||||
)
|
||||
cron_subparsers = cron_parser.add_subparsers(dest="cron_command")
|
||||
|
||||
# cron list
|
||||
cron_list = cron_subparsers.add_parser("list", help="List scheduled jobs")
|
||||
cron_list.add_argument("--all", action="store_true", help="Include disabled jobs")
|
||||
|
||||
# cron daemon
|
||||
cron_daemon = cron_subparsers.add_parser("daemon", help="Run cron daemon")
|
||||
cron_daemon.add_argument("--interval", type=int, default=60, help="Check interval in seconds")
|
||||
|
||||
# cron tick
|
||||
cron_tick = cron_subparsers.add_parser("tick", help="Run due jobs once (for system cron)")
|
||||
|
||||
cron_parser.set_defaults(func=cmd_cron)
|
||||
|
||||
# =========================================================================
|
||||
# doctor command
|
||||
# =========================================================================
|
||||
doctor_parser = subparsers.add_parser(
|
||||
"doctor",
|
||||
help="Check configuration and dependencies",
|
||||
description="Diagnose issues with Hermes Agent setup"
|
||||
)
|
||||
doctor_parser.add_argument(
|
||||
"--fix",
|
||||
action="store_true",
|
||||
help="Attempt to fix issues automatically"
|
||||
)
|
||||
doctor_parser.set_defaults(func=cmd_doctor)
|
||||
|
||||
# =========================================================================
|
||||
# config command
|
||||
# =========================================================================
|
||||
config_parser = subparsers.add_parser(
|
||||
"config",
|
||||
help="View and edit configuration",
|
||||
description="Manage Hermes Agent configuration"
|
||||
)
|
||||
config_subparsers = config_parser.add_subparsers(dest="config_command")
|
||||
|
||||
# config show (default)
|
||||
config_show = config_subparsers.add_parser("show", help="Show current configuration")
|
||||
|
||||
# config edit
|
||||
config_edit = config_subparsers.add_parser("edit", help="Open config file in editor")
|
||||
|
||||
# config set
|
||||
config_set = config_subparsers.add_parser("set", help="Set a configuration value")
|
||||
config_set.add_argument("key", nargs="?", help="Configuration key (e.g., model, terminal.backend)")
|
||||
config_set.add_argument("value", nargs="?", help="Value to set")
|
||||
|
||||
# config path
|
||||
config_path = config_subparsers.add_parser("path", help="Print config file path")
|
||||
|
||||
# config env-path
|
||||
config_env = config_subparsers.add_parser("env-path", help="Print .env file path")
|
||||
|
||||
# config check
|
||||
config_check = config_subparsers.add_parser("check", help="Check for missing/outdated config")
|
||||
|
||||
# config migrate
|
||||
config_migrate = config_subparsers.add_parser("migrate", help="Update config with new options")
|
||||
|
||||
config_parser.set_defaults(func=cmd_config)
|
||||
|
||||
# =========================================================================
|
||||
# version command
|
||||
# =========================================================================
|
||||
version_parser = subparsers.add_parser(
|
||||
"version",
|
||||
help="Show version information"
|
||||
)
|
||||
version_parser.set_defaults(func=cmd_version)
|
||||
|
||||
# =========================================================================
|
||||
# update command
|
||||
# =========================================================================
|
||||
update_parser = subparsers.add_parser(
|
||||
"update",
|
||||
help="Update Hermes Agent to the latest version",
|
||||
description="Pull the latest changes from git and reinstall dependencies"
|
||||
)
|
||||
update_parser.set_defaults(func=cmd_update)
|
||||
|
||||
# =========================================================================
|
||||
# uninstall command
|
||||
# =========================================================================
|
||||
uninstall_parser = subparsers.add_parser(
|
||||
"uninstall",
|
||||
help="Uninstall Hermes Agent",
|
||||
description="Remove Hermes Agent from your system. Can keep configs/data for reinstall."
|
||||
)
|
||||
uninstall_parser.add_argument(
|
||||
"--full",
|
||||
action="store_true",
|
||||
help="Full uninstall - remove everything including configs and data"
|
||||
)
|
||||
uninstall_parser.add_argument(
|
||||
"--yes", "-y",
|
||||
action="store_true",
|
||||
help="Skip confirmation prompts"
|
||||
)
|
||||
uninstall_parser.set_defaults(func=cmd_uninstall)
|
||||
|
||||
# =========================================================================
|
||||
# Parse and execute
|
||||
# =========================================================================
|
||||
args = parser.parse_args()
|
||||
|
||||
# Handle --version flag
|
||||
if args.version:
|
||||
cmd_version(args)
|
||||
return
|
||||
|
||||
# Default to chat if no command specified
|
||||
if args.command is None:
|
||||
# No command = run chat
|
||||
args.query = None
|
||||
args.model = None
|
||||
args.toolsets = None
|
||||
args.verbose = False
|
||||
cmd_chat(args)
|
||||
return
|
||||
|
||||
# Execute the command
|
||||
if hasattr(args, 'func'):
|
||||
args.func(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
989
hermes_cli/setup.py
Normal file
989
hermes_cli/setup.py
Normal file
@@ -0,0 +1,989 @@
|
||||
"""
|
||||
Interactive setup wizard for Hermes Agent.
|
||||
|
||||
Guides users through:
|
||||
1. Installation directory confirmation
|
||||
2. API key configuration
|
||||
3. Model selection
|
||||
4. Terminal backend selection
|
||||
5. Messaging platform setup
|
||||
6. Optional features
|
||||
|
||||
Config files are stored in ~/.hermes/ for easy access.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
# Import config helpers
|
||||
from hermes_cli.config import (
|
||||
get_hermes_home, get_config_path, get_env_path,
|
||||
load_config, save_config, save_env_value, get_env_value,
|
||||
ensure_hermes_home, DEFAULT_CONFIG
|
||||
)
|
||||
|
||||
# ANSI colors
|
||||
class Colors:
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
BLUE = "\033[34m"
|
||||
MAGENTA = "\033[35m"
|
||||
CYAN = "\033[36m"
|
||||
|
||||
def color(text: str, *codes) -> str:
|
||||
"""Apply color codes to text."""
|
||||
if not sys.stdout.isatty():
|
||||
return text
|
||||
return "".join(codes) + text + Colors.RESET
|
||||
|
||||
def print_header(title: str):
|
||||
"""Print a section header."""
|
||||
print()
|
||||
print(color(f"◆ {title}", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
def print_info(text: str):
|
||||
"""Print info text."""
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def print_success(text: str):
|
||||
"""Print success message."""
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
def print_warning(text: str):
|
||||
"""Print warning message."""
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def print_error(text: str):
|
||||
"""Print error message."""
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
def prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
"""Prompt for input with optional default."""
|
||||
if default:
|
||||
display = f"{question} [{default}]: "
|
||||
else:
|
||||
display = f"{question}: "
|
||||
|
||||
try:
|
||||
if password:
|
||||
import getpass
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
|
||||
return value.strip() or default or ""
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
sys.exit(1)
|
||||
|
||||
def prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Prompt for a choice from a list with arrow key navigation."""
|
||||
print(color(question, Colors.YELLOW))
|
||||
|
||||
# Try to use interactive menu if available
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
|
||||
# Add visual indicators
|
||||
menu_choices = [f" {choice}" for choice in choices]
|
||||
|
||||
terminal_menu = TerminalMenu(
|
||||
menu_choices,
|
||||
cursor_index=default,
|
||||
menu_cursor="→ ",
|
||||
menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True,
|
||||
clear_screen=False,
|
||||
)
|
||||
|
||||
idx = terminal_menu.show()
|
||||
if idx is None: # User pressed Escape or Ctrl+C
|
||||
print()
|
||||
sys.exit(1)
|
||||
print() # Add newline after selection
|
||||
return idx
|
||||
|
||||
except ImportError:
|
||||
# Fallback to number-based selection
|
||||
for i, choice in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
if i == default:
|
||||
print(color(f" {marker} {choice}", Colors.GREEN))
|
||||
else:
|
||||
print(f" {marker} {choice}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
value = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
if not value:
|
||||
return default
|
||||
idx = int(value) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return idx
|
||||
print_error(f"Please enter a number between 1 and {len(choices)}")
|
||||
except ValueError:
|
||||
print_error("Please enter a number")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
sys.exit(1)
|
||||
|
||||
def prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
"""Prompt for yes/no."""
|
||||
default_str = "Y/n" if default else "y/N"
|
||||
|
||||
while True:
|
||||
value = input(color(f"{question} [{default_str}]: ", Colors.YELLOW)).strip().lower()
|
||||
|
||||
if not value:
|
||||
return default
|
||||
if value in ('y', 'yes'):
|
||||
return True
|
||||
if value in ('n', 'no'):
|
||||
return False
|
||||
print_error("Please enter 'y' or 'n'")
|
||||
|
||||
|
||||
def _print_setup_summary(config: dict, hermes_home):
|
||||
"""Print the setup completion summary."""
|
||||
# Tool availability summary
|
||||
print()
|
||||
print_header("Tool Availability Summary")
|
||||
|
||||
tool_status = []
|
||||
|
||||
# OpenRouter (required for vision, moa)
|
||||
if get_env_value('OPENROUTER_API_KEY'):
|
||||
tool_status.append(("Vision (image analysis)", True, None))
|
||||
tool_status.append(("Mixture of Agents", True, None))
|
||||
else:
|
||||
tool_status.append(("Vision (image analysis)", False, "OPENROUTER_API_KEY"))
|
||||
tool_status.append(("Mixture of Agents", False, "OPENROUTER_API_KEY"))
|
||||
|
||||
# Firecrawl (web tools)
|
||||
if get_env_value('FIRECRAWL_API_KEY'):
|
||||
tool_status.append(("Web Search & Extract", True, None))
|
||||
else:
|
||||
tool_status.append(("Web Search & Extract", False, "FIRECRAWL_API_KEY"))
|
||||
|
||||
# Browserbase (browser tools)
|
||||
if get_env_value('BROWSERBASE_API_KEY'):
|
||||
tool_status.append(("Browser Automation", True, None))
|
||||
else:
|
||||
tool_status.append(("Browser Automation", False, "BROWSERBASE_API_KEY"))
|
||||
|
||||
# FAL (image generation)
|
||||
if get_env_value('FAL_KEY'):
|
||||
tool_status.append(("Image Generation", True, None))
|
||||
else:
|
||||
tool_status.append(("Image Generation", False, "FAL_KEY"))
|
||||
|
||||
# Tinker + WandB (RL training)
|
||||
if get_env_value('TINKER_API_KEY') and get_env_value('WANDB_API_KEY'):
|
||||
tool_status.append(("RL Training (Tinker)", True, None))
|
||||
elif get_env_value('TINKER_API_KEY'):
|
||||
tool_status.append(("RL Training (Tinker)", False, "WANDB_API_KEY"))
|
||||
else:
|
||||
tool_status.append(("RL Training (Tinker)", False, "TINKER_API_KEY"))
|
||||
|
||||
# Terminal (always available if system deps met)
|
||||
tool_status.append(("Terminal/Commands", True, None))
|
||||
|
||||
# Skills (always available if skills dir exists)
|
||||
tool_status.append(("Skills Knowledge Base", True, None))
|
||||
|
||||
# Print status
|
||||
available_count = sum(1 for _, avail, _ in tool_status if avail)
|
||||
total_count = len(tool_status)
|
||||
|
||||
print_info(f"{available_count}/{total_count} tool categories available:")
|
||||
print()
|
||||
|
||||
for name, available, missing_var in tool_status:
|
||||
if available:
|
||||
print(f" {color('✓', Colors.GREEN)} {name}")
|
||||
else:
|
||||
print(f" {color('✗', Colors.RED)} {name} {color(f'(missing {missing_var})', Colors.DIM)}")
|
||||
|
||||
print()
|
||||
|
||||
disabled_tools = [(name, var) for name, avail, var in tool_status if not avail]
|
||||
if disabled_tools:
|
||||
print_warning("Some tools are disabled. Run 'hermes setup' again to configure them,")
|
||||
print_warning("or edit ~/.hermes/.env directly to add the missing API keys.")
|
||||
print()
|
||||
|
||||
# Done banner
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.GREEN))
|
||||
print(color("│ ✓ Setup Complete! │", Colors.GREEN))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.GREEN))
|
||||
print()
|
||||
|
||||
# Show file locations prominently
|
||||
print(color("📁 All your files are in ~/.hermes/:", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
print(f" {color('Settings:', Colors.YELLOW)} {get_config_path()}")
|
||||
print(f" {color('API Keys:', Colors.YELLOW)} {get_env_path()}")
|
||||
print(f" {color('Data:', Colors.YELLOW)} {hermes_home}/cron/, sessions/, logs/")
|
||||
print()
|
||||
|
||||
print(color("─" * 60, Colors.DIM))
|
||||
print()
|
||||
print(color("📝 To edit your configuration:", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
print(f" {color('hermes config', Colors.GREEN)} View current settings")
|
||||
print(f" {color('hermes config edit', Colors.GREEN)} Open config in your editor")
|
||||
print(f" {color('hermes config set KEY VALUE', Colors.GREEN)}")
|
||||
print(f" Set a specific value")
|
||||
print()
|
||||
print(f" Or edit the files directly:")
|
||||
print(f" {color(f'nano {get_config_path()}', Colors.DIM)}")
|
||||
print(f" {color(f'nano {get_env_path()}', Colors.DIM)}")
|
||||
print()
|
||||
|
||||
print(color("─" * 60, Colors.DIM))
|
||||
print()
|
||||
print(color("🚀 Ready to go!", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
print(f" {color('hermes', Colors.GREEN)} Start chatting")
|
||||
print(f" {color('hermes gateway', Colors.GREEN)} Start messaging gateway")
|
||||
print(f" {color('hermes doctor', Colors.GREEN)} Check for issues")
|
||||
print()
|
||||
|
||||
|
||||
def run_setup_wizard(args):
|
||||
"""Run the interactive setup wizard."""
|
||||
ensure_hermes_home()
|
||||
|
||||
config = load_config()
|
||||
hermes_home = get_hermes_home()
|
||||
|
||||
# Check if this is an existing installation with config
|
||||
is_existing = get_env_value("OPENROUTER_API_KEY") is not None or get_config_path().exists()
|
||||
|
||||
# Import migration helpers
|
||||
from hermes_cli.config import (
|
||||
get_missing_env_vars, get_missing_config_fields,
|
||||
check_config_version, migrate_config,
|
||||
REQUIRED_ENV_VARS, OPTIONAL_ENV_VARS
|
||||
)
|
||||
|
||||
# Check what's missing
|
||||
missing_required = [v for v in get_missing_env_vars(required_only=False) if v.get("is_required")]
|
||||
missing_optional = [v for v in get_missing_env_vars(required_only=False) if not v.get("is_required")]
|
||||
missing_config = get_missing_config_fields()
|
||||
current_ver, latest_ver = check_config_version()
|
||||
|
||||
has_missing = missing_required or missing_optional or missing_config or current_ver < latest_ver
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.MAGENTA))
|
||||
print(color("│ 🦋 Hermes Agent Setup Wizard │", Colors.MAGENTA))
|
||||
print(color("├─────────────────────────────────────────────────────────┤", Colors.MAGENTA))
|
||||
print(color("│ Let's configure your Hermes Agent installation. │", Colors.MAGENTA))
|
||||
print(color("│ Press Ctrl+C at any time to exit. │", Colors.MAGENTA))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.MAGENTA))
|
||||
|
||||
# If existing installation, show what's missing and offer quick mode
|
||||
quick_mode = False
|
||||
if is_existing and has_missing:
|
||||
print()
|
||||
print_header("Existing Installation Detected")
|
||||
print_success("You already have Hermes configured!")
|
||||
print()
|
||||
|
||||
if missing_required:
|
||||
print_warning(f" {len(missing_required)} required setting(s) missing:")
|
||||
for var in missing_required:
|
||||
print(f" • {var['name']}")
|
||||
|
||||
if missing_optional:
|
||||
print_info(f" {len(missing_optional)} optional tool(s) not configured:")
|
||||
for var in missing_optional[:3]: # Show first 3
|
||||
tools = var.get("tools", [])
|
||||
tools_str = f" → {', '.join(tools[:2])}" if tools else ""
|
||||
print(f" • {var['name']}{tools_str}")
|
||||
if len(missing_optional) > 3:
|
||||
print(f" • ...and {len(missing_optional) - 3} more")
|
||||
|
||||
if missing_config:
|
||||
print_info(f" {len(missing_config)} new config option(s) available")
|
||||
|
||||
print()
|
||||
|
||||
setup_choices = [
|
||||
"Quick setup - just configure missing items",
|
||||
"Full setup - reconfigure everything",
|
||||
"Skip - exit setup"
|
||||
]
|
||||
|
||||
choice = prompt_choice("What would you like to do?", setup_choices, 0)
|
||||
|
||||
if choice == 0:
|
||||
quick_mode = True
|
||||
elif choice == 2:
|
||||
print()
|
||||
print_info("Exiting. Run 'hermes setup' again when ready.")
|
||||
return
|
||||
# choice == 1 continues with full setup
|
||||
|
||||
elif is_existing and not has_missing:
|
||||
print()
|
||||
print_header("Configuration Status")
|
||||
print_success("Your configuration is complete!")
|
||||
print()
|
||||
|
||||
if not prompt_yes_no("Would you like to reconfigure anyway?", False):
|
||||
print()
|
||||
print_info("Exiting. Your configuration is already set up.")
|
||||
print_info(f"Config: {get_config_path()}")
|
||||
print_info(f"Secrets: {get_env_path()}")
|
||||
return
|
||||
|
||||
# Quick mode: only configure missing items
|
||||
if quick_mode:
|
||||
print()
|
||||
print_header("Quick Setup - Missing Items Only")
|
||||
|
||||
# Handle missing required env vars
|
||||
if missing_required:
|
||||
for var in missing_required:
|
||||
print()
|
||||
print(color(f" {var['name']}", Colors.CYAN))
|
||||
print_info(f" {var.get('description', '')}")
|
||||
if var.get("url"):
|
||||
print_info(f" Get key at: {var['url']}")
|
||||
|
||||
if var.get("password"):
|
||||
value = prompt(f" {var.get('prompt', var['name'])}", password=True)
|
||||
else:
|
||||
value = prompt(f" {var.get('prompt', var['name'])}")
|
||||
|
||||
if value:
|
||||
save_env_value(var["name"], value)
|
||||
print_success(f" Saved {var['name']}")
|
||||
else:
|
||||
print_warning(f" Skipped {var['name']}")
|
||||
|
||||
# Handle missing optional env vars
|
||||
if missing_optional:
|
||||
print()
|
||||
print_header("Optional Tools (Quick Setup)")
|
||||
|
||||
for var in missing_optional:
|
||||
tools = var.get("tools", [])
|
||||
tools_str = f" (enables: {', '.join(tools[:2])})" if tools else ""
|
||||
|
||||
if prompt_yes_no(f"Configure {var['name']}{tools_str}?", False):
|
||||
if var.get("url"):
|
||||
print_info(f" Get key at: {var['url']}")
|
||||
|
||||
if var.get("password"):
|
||||
value = prompt(f" {var.get('prompt', var['name'])}", password=True)
|
||||
else:
|
||||
value = prompt(f" {var.get('prompt', var['name'])}")
|
||||
|
||||
if value:
|
||||
save_env_value(var["name"], value)
|
||||
print_success(f" Saved")
|
||||
|
||||
# Handle missing config fields
|
||||
if missing_config:
|
||||
print()
|
||||
print_info(f"Adding {len(missing_config)} new config option(s) with defaults...")
|
||||
for field in missing_config:
|
||||
print_success(f" Added {field['key']} = {field['default']}")
|
||||
|
||||
# Update config version
|
||||
config["_config_version"] = latest_ver
|
||||
save_config(config)
|
||||
|
||||
# Jump to summary
|
||||
_print_setup_summary(config, hermes_home)
|
||||
return
|
||||
|
||||
# =========================================================================
|
||||
# Step 0: Show paths (full setup)
|
||||
# =========================================================================
|
||||
print_header("Configuration Location")
|
||||
print_info(f"Config file: {get_config_path()}")
|
||||
print_info(f"Secrets file: {get_env_path()}")
|
||||
print_info(f"Data folder: {hermes_home}")
|
||||
print_info(f"Install dir: {PROJECT_ROOT}")
|
||||
print()
|
||||
print_info("You can edit these files directly or use 'hermes config edit'")
|
||||
|
||||
# =========================================================================
|
||||
# Step 1: OpenRouter API Key (Required for tools)
|
||||
# =========================================================================
|
||||
print_header("OpenRouter API Key (Required)")
|
||||
print_info("OpenRouter is used for vision, web scraping, and tool operations")
|
||||
print_info("even if you use a custom endpoint for your main agent.")
|
||||
print_info("Get your API key at: https://openrouter.ai/keys")
|
||||
|
||||
existing_or = get_env_value("OPENROUTER_API_KEY")
|
||||
if existing_or:
|
||||
print_info(f"Current: {existing_or[:8]}... (configured)")
|
||||
if prompt_yes_no("Update OpenRouter API key?", False):
|
||||
api_key = prompt(" OpenRouter API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("OPENROUTER_API_KEY", api_key)
|
||||
print_success("OpenRouter API key updated")
|
||||
else:
|
||||
api_key = prompt(" OpenRouter API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("OPENROUTER_API_KEY", api_key)
|
||||
print_success("OpenRouter API key saved")
|
||||
else:
|
||||
print_warning("Skipped - some tools (vision, web scraping) won't work without this")
|
||||
|
||||
# =========================================================================
|
||||
# Step 2: Main Agent Provider
|
||||
# =========================================================================
|
||||
print_header("Main Agent Provider")
|
||||
print_info("Choose how to connect to your main chat model.")
|
||||
|
||||
existing_custom = get_env_value("OPENAI_BASE_URL")
|
||||
|
||||
provider_choices = [
|
||||
"OpenRouter (use same key for agent - recommended)",
|
||||
"Custom OpenAI-compatible endpoint (separate from OpenRouter)",
|
||||
f"Keep current" + (f" ({existing_custom})" if existing_custom else " (OpenRouter)")
|
||||
]
|
||||
|
||||
provider_idx = prompt_choice("Select your main agent provider:", provider_choices, 2)
|
||||
|
||||
if provider_idx == 0: # OpenRouter for agent too
|
||||
# Clear any custom endpoint - will use OpenRouter
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
print_success("Agent will use OpenRouter")
|
||||
|
||||
elif provider_idx == 1: # Custom endpoint
|
||||
print_info("Custom OpenAI-Compatible Endpoint Configuration:")
|
||||
print_info("Works with any API that follows OpenAI's chat completions spec")
|
||||
|
||||
# Show current values if set
|
||||
current_url = get_env_value("OPENAI_BASE_URL") or ""
|
||||
current_key = get_env_value("OPENAI_API_KEY")
|
||||
current_model = config.get('model', '')
|
||||
|
||||
if current_url:
|
||||
print_info(f" Current URL: {current_url}")
|
||||
if current_key:
|
||||
print_info(f" Current key: {current_key[:8]}... (configured)")
|
||||
|
||||
base_url = prompt(" API base URL (e.g., https://api.example.com/v1)", current_url)
|
||||
api_key = prompt(" API key", password=True)
|
||||
model_name = prompt(" Model name (e.g., gpt-4, claude-3-opus)", current_model)
|
||||
|
||||
if base_url:
|
||||
save_env_value("OPENAI_BASE_URL", base_url)
|
||||
if api_key:
|
||||
save_env_value("OPENAI_API_KEY", api_key)
|
||||
if model_name:
|
||||
config['model'] = model_name
|
||||
print_success("Custom endpoint configured")
|
||||
# else: Keep current (provider_idx == 2)
|
||||
|
||||
# =========================================================================
|
||||
# Step 3: Model Selection
|
||||
# =========================================================================
|
||||
print_header("Default Model")
|
||||
|
||||
current_model = config.get('model', 'anthropic/claude-sonnet-4')
|
||||
print_info(f"Current: {current_model}")
|
||||
|
||||
model_choices = [
|
||||
"anthropic/claude-sonnet-4.5 (recommended)",
|
||||
"anthropic/claude-opus-4.5",
|
||||
"openai/gpt-5.2",
|
||||
"openai/gpt-5.2-codex",
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-3-flash-preview",
|
||||
"z-ai/glm-4.7",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"minimax/minimax-m2.1",
|
||||
"Custom model",
|
||||
f"Keep current ({current_model})"
|
||||
]
|
||||
|
||||
model_idx = prompt_choice("Select default model:", model_choices, 10) # Default: keep current
|
||||
|
||||
model_map = {
|
||||
0: "anthropic/claude-sonnet-4.5",
|
||||
1: "anthropic/claude-opus-4.5",
|
||||
2: "openai/gpt-5.2",
|
||||
3: "openai/gpt-5.2-codex",
|
||||
4: "google/gemini-3-pro-preview",
|
||||
5: "google/gemini-3-flash-preview",
|
||||
6: "z-ai/glm-4.7",
|
||||
7: "moonshotai/kimi-k2.5",
|
||||
8: "minimax/minimax-m2.1",
|
||||
}
|
||||
|
||||
if model_idx in model_map:
|
||||
config['model'] = model_map[model_idx]
|
||||
elif model_idx == 9: # Custom
|
||||
custom = prompt("Enter model name (e.g., anthropic/claude-sonnet-4.5)")
|
||||
if custom:
|
||||
config['model'] = custom
|
||||
# else: Keep current (model_idx == 10)
|
||||
|
||||
# =========================================================================
|
||||
# Step 4: Terminal Backend
|
||||
# =========================================================================
|
||||
print_header("Terminal Backend")
|
||||
print_info("The terminal tool allows the agent to run commands.")
|
||||
|
||||
current_backend = config.get('terminal', {}).get('backend', 'local')
|
||||
print_info(f"Current: {current_backend}")
|
||||
|
||||
# Detect platform for backend availability
|
||||
import platform
|
||||
is_linux = platform.system() == "Linux"
|
||||
is_macos = platform.system() == "Darwin"
|
||||
is_windows = platform.system() == "Windows"
|
||||
|
||||
# Build choices based on platform
|
||||
terminal_choices = [
|
||||
"Local (run commands on this machine - no isolation)",
|
||||
"Docker (isolated containers - recommended for security)",
|
||||
]
|
||||
|
||||
# Singularity/Apptainer is Linux-only (HPC)
|
||||
if is_linux:
|
||||
terminal_choices.append("Singularity/Apptainer (HPC clusters, shared compute)")
|
||||
|
||||
terminal_choices.extend([
|
||||
"Modal (cloud execution, GPU access, serverless)",
|
||||
"SSH (run commands on a remote server)",
|
||||
f"Keep current ({current_backend})"
|
||||
])
|
||||
|
||||
# Build index map based on available choices
|
||||
if is_linux:
|
||||
backend_to_idx = {'local': 0, 'docker': 1, 'singularity': 2, 'modal': 3, 'ssh': 4}
|
||||
idx_to_backend = {0: 'local', 1: 'docker', 2: 'singularity', 3: 'modal', 4: 'ssh'}
|
||||
keep_current_idx = 5
|
||||
else:
|
||||
backend_to_idx = {'local': 0, 'docker': 1, 'modal': 2, 'ssh': 3}
|
||||
idx_to_backend = {0: 'local', 1: 'docker', 2: 'modal', 3: 'ssh'}
|
||||
keep_current_idx = 4
|
||||
if current_backend == 'singularity':
|
||||
print_warning("Singularity is only available on Linux - please select a different backend")
|
||||
|
||||
# Default based on current
|
||||
default_terminal = backend_to_idx.get(current_backend, 0)
|
||||
|
||||
terminal_idx = prompt_choice("Select terminal backend:", terminal_choices, keep_current_idx)
|
||||
|
||||
# Map index to backend name (handles platform differences)
|
||||
selected_backend = idx_to_backend.get(terminal_idx)
|
||||
|
||||
if selected_backend == 'local':
|
||||
config.setdefault('terminal', {})['backend'] = 'local'
|
||||
print_info("Local Execution Configuration:")
|
||||
print_info("Commands run directly on this machine (no isolation)")
|
||||
|
||||
if is_windows:
|
||||
print_info("Note: On Windows, commands run via cmd.exe or PowerShell")
|
||||
|
||||
# Messaging working directory configuration
|
||||
print_info("")
|
||||
print_info("Working Directory for Messaging (Telegram/Discord/etc):")
|
||||
print_info(" The CLI always uses the directory you run 'hermes' from")
|
||||
print_info(" But messaging bots need a static starting directory")
|
||||
|
||||
current_cwd = get_env_value('MESSAGING_CWD') or str(Path.home())
|
||||
print_info(f" Current: {current_cwd}")
|
||||
|
||||
cwd_input = prompt(" Messaging working directory", current_cwd)
|
||||
# Expand ~ to full path
|
||||
if cwd_input.startswith('~'):
|
||||
cwd_expanded = str(Path.home()) + cwd_input[1:]
|
||||
else:
|
||||
cwd_expanded = cwd_input
|
||||
save_env_value("MESSAGING_CWD", cwd_expanded)
|
||||
|
||||
if prompt_yes_no(" Enable sudo support? (allows agent to run sudo commands)", False):
|
||||
print_warning(" SECURITY WARNING: Sudo password will be stored in plaintext")
|
||||
sudo_pass = prompt(" Sudo password (leave empty to skip)", password=True)
|
||||
if sudo_pass:
|
||||
save_env_value("SUDO_PASSWORD", sudo_pass)
|
||||
print_success(" Sudo password saved")
|
||||
|
||||
print_success("Terminal set to local")
|
||||
|
||||
elif selected_backend == 'docker':
|
||||
config.setdefault('terminal', {})['backend'] = 'docker'
|
||||
default_docker = config.get('terminal', {}).get('docker_image', 'nikolaik/python-nodejs:python3.11-nodejs20')
|
||||
print_info("Docker Configuration:")
|
||||
if is_macos:
|
||||
print_info("Requires Docker Desktop for Mac")
|
||||
elif is_windows:
|
||||
print_info("Requires Docker Desktop for Windows")
|
||||
docker_image = prompt(" Docker image", default_docker)
|
||||
config['terminal']['docker_image'] = docker_image
|
||||
print_success("Terminal set to Docker")
|
||||
|
||||
elif selected_backend == 'singularity':
|
||||
config.setdefault('terminal', {})['backend'] = 'singularity'
|
||||
default_singularity = config.get('terminal', {}).get('singularity_image', 'docker://nikolaik/python-nodejs:python3.11-nodejs20')
|
||||
print_info("Singularity/Apptainer Configuration:")
|
||||
print_info("Requires apptainer or singularity to be installed")
|
||||
singularity_image = prompt(" Image (docker:// prefix for Docker Hub)", default_singularity)
|
||||
config['terminal']['singularity_image'] = singularity_image
|
||||
print_success("Terminal set to Singularity/Apptainer")
|
||||
|
||||
elif selected_backend == 'modal':
|
||||
config.setdefault('terminal', {})['backend'] = 'modal'
|
||||
default_modal = config.get('terminal', {}).get('modal_image', 'nikolaik/python-nodejs:python3.11-nodejs20')
|
||||
print_info("Modal Cloud Configuration:")
|
||||
print_info("Get credentials at: https://modal.com/settings")
|
||||
|
||||
# Always show current status and allow reconfiguration
|
||||
current_token = get_env_value('MODAL_TOKEN_ID')
|
||||
if current_token:
|
||||
print_info(f" Token ID: {current_token[:8]}... (configured)")
|
||||
|
||||
modal_image = prompt(" Container image", default_modal)
|
||||
config['terminal']['modal_image'] = modal_image
|
||||
|
||||
token_id = prompt(" Modal token ID", current_token or "")
|
||||
token_secret = prompt(" Modal token secret", password=True)
|
||||
|
||||
if token_id:
|
||||
save_env_value("MODAL_TOKEN_ID", token_id)
|
||||
if token_secret:
|
||||
save_env_value("MODAL_TOKEN_SECRET", token_secret)
|
||||
|
||||
print_success("Terminal set to Modal")
|
||||
|
||||
elif selected_backend == 'ssh':
|
||||
config.setdefault('terminal', {})['backend'] = 'ssh'
|
||||
print_info("SSH Remote Execution Configuration:")
|
||||
print_info("Commands will run on a remote server over SSH")
|
||||
|
||||
current_host = get_env_value('TERMINAL_SSH_HOST') or ''
|
||||
current_user = get_env_value('TERMINAL_SSH_USER') or os.getenv("USER", "")
|
||||
current_port = get_env_value('TERMINAL_SSH_PORT') or '22'
|
||||
current_key = get_env_value('TERMINAL_SSH_KEY') or '~/.ssh/id_rsa'
|
||||
|
||||
if current_host:
|
||||
print_info(f" Current host: {current_user}@{current_host}:{current_port}")
|
||||
|
||||
ssh_host = prompt(" SSH host", current_host)
|
||||
ssh_user = prompt(" SSH user", current_user)
|
||||
ssh_port = prompt(" SSH port", current_port)
|
||||
ssh_key = prompt(" SSH key path (or leave empty for ssh-agent)", current_key)
|
||||
|
||||
if ssh_host:
|
||||
save_env_value("TERMINAL_SSH_HOST", ssh_host)
|
||||
if ssh_user:
|
||||
save_env_value("TERMINAL_SSH_USER", ssh_user)
|
||||
if ssh_port and ssh_port != '22':
|
||||
save_env_value("TERMINAL_SSH_PORT", ssh_port)
|
||||
if ssh_key:
|
||||
save_env_value("TERMINAL_SSH_KEY", ssh_key)
|
||||
|
||||
print_success("Terminal set to SSH")
|
||||
# else: Keep current (selected_backend is None)
|
||||
|
||||
# =========================================================================
|
||||
# Step 5: Agent Settings
|
||||
# =========================================================================
|
||||
print_header("Agent Settings")
|
||||
|
||||
# Max iterations
|
||||
current_max = get_env_value('HERMES_MAX_ITERATIONS') or '60'
|
||||
print_info("Maximum tool-calling iterations per conversation.")
|
||||
print_info("Higher = more complex tasks, but costs more tokens.")
|
||||
print_info("Recommended: 30-60 for most tasks, 100+ for open exploration.")
|
||||
|
||||
max_iter_str = prompt("Max iterations", current_max)
|
||||
try:
|
||||
max_iter = int(max_iter_str)
|
||||
if max_iter > 0:
|
||||
save_env_value("HERMES_MAX_ITERATIONS", str(max_iter))
|
||||
config['max_turns'] = max_iter
|
||||
print_success(f"Max iterations set to {max_iter}")
|
||||
except ValueError:
|
||||
print_warning("Invalid number, keeping current value")
|
||||
|
||||
# Tool progress notifications (for messaging)
|
||||
print_info("")
|
||||
print_info("Tool Progress Notifications (Messaging only)")
|
||||
print_info("Send status messages when the agent uses tools.")
|
||||
print_info("Example: '💻 ls -la...' or '🔍 web_search...'")
|
||||
|
||||
current_progress = get_env_value('HERMES_TOOL_PROGRESS') or 'false'
|
||||
if prompt_yes_no("Enable tool progress messages?", current_progress.lower() in ('1', 'true', 'yes')):
|
||||
save_env_value("HERMES_TOOL_PROGRESS", "true")
|
||||
|
||||
# Progress mode
|
||||
current_mode = get_env_value('HERMES_TOOL_PROGRESS_MODE') or 'new'
|
||||
print_info(" Mode options:")
|
||||
print_info(" 'new' - Only when switching tools (less spam)")
|
||||
print_info(" 'all' - Every tool call")
|
||||
mode = prompt(" Progress mode", current_mode)
|
||||
if mode.lower() in ('all', 'new'):
|
||||
save_env_value("HERMES_TOOL_PROGRESS_MODE", mode.lower())
|
||||
print_success("Tool progress enabled")
|
||||
else:
|
||||
save_env_value("HERMES_TOOL_PROGRESS", "false")
|
||||
|
||||
# =========================================================================
|
||||
# Step 6: Context Compression
|
||||
# =========================================================================
|
||||
print_header("Context Compression")
|
||||
print_info("Automatically summarize old messages when context gets too long.")
|
||||
|
||||
compression = config.get('compression', {})
|
||||
current_enabled = compression.get('enabled', True)
|
||||
|
||||
if prompt_yes_no(f"Enable context compression?", current_enabled):
|
||||
config.setdefault('compression', {})['enabled'] = True
|
||||
|
||||
current_threshold = compression.get('threshold', 0.85)
|
||||
threshold_str = prompt(f"Compression threshold (0.5-0.95)", str(current_threshold))
|
||||
try:
|
||||
threshold = float(threshold_str)
|
||||
if 0.5 <= threshold <= 0.95:
|
||||
config['compression']['threshold'] = threshold
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
print_success("Context compression enabled")
|
||||
else:
|
||||
config.setdefault('compression', {})['enabled'] = False
|
||||
|
||||
# =========================================================================
|
||||
# Step 7: Messaging Platforms (Optional)
|
||||
# =========================================================================
|
||||
print_header("Messaging Platforms (Optional)")
|
||||
print_info("Connect to messaging platforms to chat with Hermes from anywhere.")
|
||||
|
||||
# Telegram
|
||||
existing_telegram = get_env_value('TELEGRAM_BOT_TOKEN')
|
||||
if existing_telegram:
|
||||
print_info("Telegram: already configured")
|
||||
if prompt_yes_no("Reconfigure Telegram?", False):
|
||||
existing_telegram = None
|
||||
|
||||
if not existing_telegram and prompt_yes_no("Set up Telegram bot?", False):
|
||||
print_info("Create a bot via @BotFather on Telegram")
|
||||
token = prompt("Telegram bot token", password=True)
|
||||
if token:
|
||||
save_env_value("TELEGRAM_BOT_TOKEN", token)
|
||||
print_success("Telegram token saved")
|
||||
|
||||
# Allowed users (security)
|
||||
print()
|
||||
print_info("🔒 Security: Restrict who can use your bot")
|
||||
print_info(" To find your Telegram user ID:")
|
||||
print_info(" 1. Message @userinfobot on Telegram")
|
||||
print_info(" 2. It will reply with your numeric ID (e.g., 123456789)")
|
||||
print()
|
||||
allowed_users = prompt("Allowed user IDs (comma-separated, leave empty for open access)")
|
||||
if allowed_users:
|
||||
save_env_value("TELEGRAM_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("Telegram allowlist configured - only listed users can use the bot")
|
||||
else:
|
||||
print_info("⚠️ No allowlist set - anyone who finds your bot can use it!")
|
||||
|
||||
home_channel = prompt("Home channel ID (optional, for cron delivery)")
|
||||
if home_channel:
|
||||
save_env_value("TELEGRAM_HOME_CHANNEL", home_channel)
|
||||
|
||||
# Check/update existing Telegram allowlist
|
||||
elif existing_telegram:
|
||||
existing_allowlist = get_env_value('TELEGRAM_ALLOWED_USERS')
|
||||
if not existing_allowlist:
|
||||
print_info("⚠️ Telegram has no user allowlist - anyone can use your bot!")
|
||||
if prompt_yes_no("Add allowed users now?", True):
|
||||
print_info(" To find your Telegram user ID: message @userinfobot")
|
||||
allowed_users = prompt("Allowed user IDs (comma-separated)")
|
||||
if allowed_users:
|
||||
save_env_value("TELEGRAM_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("Telegram allowlist configured")
|
||||
|
||||
# Discord
|
||||
existing_discord = get_env_value('DISCORD_BOT_TOKEN')
|
||||
if existing_discord:
|
||||
print_info("Discord: already configured")
|
||||
if prompt_yes_no("Reconfigure Discord?", False):
|
||||
existing_discord = None
|
||||
|
||||
if not existing_discord and prompt_yes_no("Set up Discord bot?", False):
|
||||
print_info("Create a bot at https://discord.com/developers/applications")
|
||||
token = prompt("Discord bot token", password=True)
|
||||
if token:
|
||||
save_env_value("DISCORD_BOT_TOKEN", token)
|
||||
print_success("Discord token saved")
|
||||
|
||||
# Allowed users (security)
|
||||
print()
|
||||
print_info("🔒 Security: Restrict who can use your bot")
|
||||
print_info(" To find your Discord user ID:")
|
||||
print_info(" 1. Enable Developer Mode in Discord settings")
|
||||
print_info(" 2. Right-click your name → Copy ID")
|
||||
print()
|
||||
allowed_users = prompt("Allowed user IDs (comma-separated, leave empty for open access)")
|
||||
if allowed_users:
|
||||
save_env_value("DISCORD_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("Discord allowlist configured")
|
||||
else:
|
||||
print_info("⚠️ No allowlist set - anyone in servers with your bot can use it!")
|
||||
|
||||
home_channel = prompt("Home channel ID (optional, for cron delivery)")
|
||||
if home_channel:
|
||||
save_env_value("DISCORD_HOME_CHANNEL", home_channel)
|
||||
|
||||
# Check/update existing Discord allowlist
|
||||
elif existing_discord:
|
||||
existing_allowlist = get_env_value('DISCORD_ALLOWED_USERS')
|
||||
if not existing_allowlist:
|
||||
print_info("⚠️ Discord has no user allowlist - anyone can use your bot!")
|
||||
if prompt_yes_no("Add allowed users now?", True):
|
||||
print_info(" To find Discord ID: Enable Developer Mode, right-click name → Copy ID")
|
||||
allowed_users = prompt("Allowed user IDs (comma-separated)")
|
||||
if allowed_users:
|
||||
save_env_value("DISCORD_ALLOWED_USERS", allowed_users.replace(" ", ""))
|
||||
print_success("Discord allowlist configured")
|
||||
|
||||
# =========================================================================
|
||||
# Step 8: Additional Tools (Optional)
|
||||
# =========================================================================
|
||||
print_header("Additional Tools (Optional)")
|
||||
print_info("These tools extend the agent's capabilities.")
|
||||
print_info("Without their API keys, the corresponding features will be disabled.")
|
||||
print()
|
||||
|
||||
# Firecrawl - Web scraping
|
||||
print_info("─" * 50)
|
||||
print(color(" Web Search & Scraping (Firecrawl)", Colors.CYAN))
|
||||
print_info(" Enables: web_search, web_extract tools")
|
||||
print_info(" Use case: Search the web, read webpage content")
|
||||
if get_env_value('FIRECRAWL_API_KEY'):
|
||||
print_success(" Status: Configured ✓")
|
||||
if prompt_yes_no(" Update Firecrawl API key?", False):
|
||||
api_key = prompt(" API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("FIRECRAWL_API_KEY", api_key)
|
||||
print_success(" Updated")
|
||||
else:
|
||||
print_warning(" Status: Not configured (tools will be disabled)")
|
||||
if prompt_yes_no(" Set up Firecrawl?", False):
|
||||
print_info(" Get your API key at: https://firecrawl.dev/")
|
||||
api_key = prompt(" API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("FIRECRAWL_API_KEY", api_key)
|
||||
print_success(" Configured ✓")
|
||||
print()
|
||||
|
||||
# Browserbase - Browser automation
|
||||
print_info("─" * 50)
|
||||
print(color(" Browser Automation (Browserbase)", Colors.CYAN))
|
||||
print_info(" Enables: browser_navigate, browser_click, etc.")
|
||||
print_info(" Use case: Interact with web pages, fill forms, screenshots")
|
||||
if get_env_value('BROWSERBASE_API_KEY'):
|
||||
print_success(" Status: Configured ✓")
|
||||
if prompt_yes_no(" Update Browserbase credentials?", False):
|
||||
api_key = prompt(" API key", password=True)
|
||||
project_id = prompt(" Project ID")
|
||||
if api_key:
|
||||
save_env_value("BROWSERBASE_API_KEY", api_key)
|
||||
if project_id:
|
||||
save_env_value("BROWSERBASE_PROJECT_ID", project_id)
|
||||
print_success(" Updated")
|
||||
else:
|
||||
print_warning(" Status: Not configured (tools will be disabled)")
|
||||
if prompt_yes_no(" Set up Browserbase?", False):
|
||||
print_info(" Get credentials at: https://browserbase.com/")
|
||||
api_key = prompt(" API key", password=True)
|
||||
project_id = prompt(" Project ID")
|
||||
if api_key:
|
||||
save_env_value("BROWSERBASE_API_KEY", api_key)
|
||||
if project_id:
|
||||
save_env_value("BROWSERBASE_PROJECT_ID", project_id)
|
||||
print_success(" Configured ✓")
|
||||
print()
|
||||
|
||||
# FAL - Image generation
|
||||
print_info("─" * 50)
|
||||
print(color(" Image Generation (FAL)", Colors.CYAN))
|
||||
print_info(" Enables: image_generate tool")
|
||||
print_info(" Use case: Generate images from text prompts (FLUX)")
|
||||
if get_env_value('FAL_KEY'):
|
||||
print_success(" Status: Configured ✓")
|
||||
if prompt_yes_no(" Update FAL API key?", False):
|
||||
api_key = prompt(" API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("FAL_KEY", api_key)
|
||||
print_success(" Updated")
|
||||
else:
|
||||
print_warning(" Status: Not configured (tool will be disabled)")
|
||||
if prompt_yes_no(" Set up FAL?", False):
|
||||
print_info(" Get your API key at: https://fal.ai/")
|
||||
api_key = prompt(" API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("FAL_KEY", api_key)
|
||||
print_success(" Configured ✓")
|
||||
print()
|
||||
|
||||
# Tinker + WandB - RL Training
|
||||
print_info("─" * 50)
|
||||
print(color(" RL Training (Tinker + WandB)", Colors.CYAN))
|
||||
print_info(" Enables: rl_start_training, rl_check_status, rl_get_results tools")
|
||||
print_info(" Use case: Run reinforcement learning training via Tinker API")
|
||||
tinker_configured = get_env_value('TINKER_API_KEY')
|
||||
wandb_configured = get_env_value('WANDB_API_KEY')
|
||||
|
||||
if tinker_configured and wandb_configured:
|
||||
print_success(" Status: Configured ✓")
|
||||
if prompt_yes_no(" Update RL training credentials?", False):
|
||||
api_key = prompt(" Tinker API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("TINKER_API_KEY", api_key)
|
||||
wandb_key = prompt(" WandB API key", password=True)
|
||||
if wandb_key:
|
||||
save_env_value("WANDB_API_KEY", wandb_key)
|
||||
print_success(" Updated")
|
||||
else:
|
||||
if tinker_configured:
|
||||
print_warning(" Status: Tinker configured, WandB missing")
|
||||
elif wandb_configured:
|
||||
print_warning(" Status: WandB configured, Tinker missing")
|
||||
else:
|
||||
print_warning(" Status: Not configured (tools will be disabled)")
|
||||
|
||||
if prompt_yes_no(" Set up RL Training?", False):
|
||||
print_info(" Get Tinker key at: https://tinker-console.thinkingmachines.ai/keys")
|
||||
print_info(" Get WandB key at: https://wandb.ai/authorize")
|
||||
api_key = prompt(" Tinker API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("TINKER_API_KEY", api_key)
|
||||
wandb_key = prompt(" WandB API key", password=True)
|
||||
if wandb_key:
|
||||
save_env_value("WANDB_API_KEY", wandb_key)
|
||||
if api_key and wandb_key:
|
||||
print_success(" Configured ✓")
|
||||
else:
|
||||
print_warning(" Partially configured (both keys required)")
|
||||
|
||||
# =========================================================================
|
||||
# Save config and show summary
|
||||
# =========================================================================
|
||||
save_config(config)
|
||||
_print_setup_summary(config, hermes_home)
|
||||
241
hermes_cli/status.py
Normal file
241
hermes_cli/status.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Status command for hermes CLI.
|
||||
|
||||
Shows the status of all Hermes Agent components.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
# ANSI colors
|
||||
class Colors:
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
CYAN = "\033[36m"
|
||||
|
||||
def color(text: str, *codes) -> str:
|
||||
if not sys.stdout.isatty():
|
||||
return text
|
||||
return "".join(codes) + text + Colors.RESET
|
||||
|
||||
def check_mark(ok: bool) -> str:
|
||||
if ok:
|
||||
return color("✓", Colors.GREEN)
|
||||
return color("✗", Colors.RED)
|
||||
|
||||
def redact_key(key: str) -> str:
|
||||
"""Redact an API key for display."""
|
||||
if not key:
|
||||
return "(not set)"
|
||||
if len(key) < 12:
|
||||
return "***"
|
||||
return key[:4] + "..." + key[-4:]
|
||||
|
||||
|
||||
def show_status(args):
|
||||
"""Show status of all Hermes Agent components."""
|
||||
show_all = getattr(args, 'all', False)
|
||||
deep = getattr(args, 'deep', False)
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.CYAN))
|
||||
print(color("│ 🦋 Hermes Agent Status │", Colors.CYAN))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.CYAN))
|
||||
|
||||
# =========================================================================
|
||||
# Environment
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Environment", Colors.CYAN, Colors.BOLD))
|
||||
print(f" Project: {PROJECT_ROOT}")
|
||||
print(f" Python: {sys.version.split()[0]}")
|
||||
|
||||
env_path = PROJECT_ROOT / '.env'
|
||||
print(f" .env file: {check_mark(env_path.exists())} {'exists' if env_path.exists() else 'not found'}")
|
||||
|
||||
# =========================================================================
|
||||
# API Keys
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ API Keys", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
keys = {
|
||||
"OpenRouter": "OPENROUTER_API_KEY",
|
||||
"Anthropic": "ANTHROPIC_API_KEY",
|
||||
"OpenAI": "OPENAI_API_KEY",
|
||||
"Firecrawl": "FIRECRAWL_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY",
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
"WandB": "WANDB_API_KEY",
|
||||
}
|
||||
|
||||
for name, env_var in keys.items():
|
||||
value = os.getenv(env_var, "")
|
||||
has_key = bool(value)
|
||||
display = redact_key(value) if not show_all else value
|
||||
print(f" {name:<12} {check_mark(has_key)} {display}")
|
||||
|
||||
# =========================================================================
|
||||
# Terminal Configuration
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Terminal Backend", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
terminal_env = os.getenv("TERMINAL_ENV", "local")
|
||||
print(f" Backend: {terminal_env}")
|
||||
|
||||
if terminal_env == "ssh":
|
||||
ssh_host = os.getenv("TERMINAL_SSH_HOST", "")
|
||||
ssh_user = os.getenv("TERMINAL_SSH_USER", "")
|
||||
print(f" SSH Host: {ssh_host or '(not set)'}")
|
||||
print(f" SSH User: {ssh_user or '(not set)'}")
|
||||
elif terminal_env == "docker":
|
||||
docker_image = os.getenv("TERMINAL_DOCKER_IMAGE", "python:3.11-slim")
|
||||
print(f" Docker Image: {docker_image}")
|
||||
|
||||
sudo_password = os.getenv("SUDO_PASSWORD", "")
|
||||
print(f" Sudo: {check_mark(bool(sudo_password))} {'enabled' if sudo_password else 'disabled'}")
|
||||
|
||||
# =========================================================================
|
||||
# Messaging Platforms
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Messaging Platforms", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
platforms = {
|
||||
"Telegram": ("TELEGRAM_BOT_TOKEN", "TELEGRAM_HOME_CHANNEL"),
|
||||
"Discord": ("DISCORD_BOT_TOKEN", "DISCORD_HOME_CHANNEL"),
|
||||
"WhatsApp": ("WHATSAPP_ENABLED", None),
|
||||
}
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
token = os.getenv(token_var, "")
|
||||
has_token = bool(token)
|
||||
|
||||
home_channel = ""
|
||||
if home_var:
|
||||
home_channel = os.getenv(home_var, "")
|
||||
|
||||
status = "configured" if has_token else "not configured"
|
||||
if home_channel:
|
||||
status += f" (home: {home_channel})"
|
||||
|
||||
print(f" {name:<12} {check_mark(has_token)} {status}")
|
||||
|
||||
# =========================================================================
|
||||
# Gateway Status
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
if sys.platform.startswith('linux'):
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
is_active = result.stdout.strip() == "active"
|
||||
print(f" Status: {check_mark(is_active)} {'running' if is_active else 'stopped'}")
|
||||
print(f" Manager: systemd (user)")
|
||||
|
||||
elif sys.platform == 'darwin':
|
||||
result = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
is_loaded = result.returncode == 0
|
||||
print(f" Status: {check_mark(is_loaded)} {'loaded' if is_loaded else 'not loaded'}")
|
||||
print(f" Manager: launchd")
|
||||
else:
|
||||
print(f" Status: {color('N/A', Colors.DIM)}")
|
||||
print(f" Manager: (not supported on this platform)")
|
||||
|
||||
# =========================================================================
|
||||
# Cron Jobs
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Scheduled Jobs", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
jobs_file = Path.home() / ".hermes" / "cron" / "jobs.json"
|
||||
if jobs_file.exists():
|
||||
import json
|
||||
try:
|
||||
with open(jobs_file) as f:
|
||||
data = json.load(f)
|
||||
jobs = data.get("jobs", [])
|
||||
enabled_jobs = [j for j in jobs if j.get("enabled", True)]
|
||||
print(f" Jobs: {len(enabled_jobs)} active, {len(jobs)} total")
|
||||
except:
|
||||
print(f" Jobs: (error reading jobs file)")
|
||||
else:
|
||||
print(f" Jobs: 0")
|
||||
|
||||
# =========================================================================
|
||||
# Sessions
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Sessions", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
sessions_file = Path.home() / ".hermes" / "sessions" / "sessions.json"
|
||||
if sessions_file.exists():
|
||||
import json
|
||||
try:
|
||||
with open(sessions_file) as f:
|
||||
data = json.load(f)
|
||||
print(f" Active: {len(data)} session(s)")
|
||||
except:
|
||||
print(f" Active: (error reading sessions file)")
|
||||
else:
|
||||
print(f" Active: 0")
|
||||
|
||||
# =========================================================================
|
||||
# Deep checks
|
||||
# =========================================================================
|
||||
if deep:
|
||||
print()
|
||||
print(color("◆ Deep Checks", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
# Check OpenRouter connectivity
|
||||
openrouter_key = os.getenv("OPENROUTER_API_KEY", "")
|
||||
if openrouter_key:
|
||||
try:
|
||||
import httpx
|
||||
response = httpx.get(
|
||||
"https://openrouter.ai/api/v1/models",
|
||||
headers={"Authorization": f"Bearer {openrouter_key}"},
|
||||
timeout=10
|
||||
)
|
||||
ok = response.status_code == 200
|
||||
print(f" OpenRouter: {check_mark(ok)} {'reachable' if ok else f'error ({response.status_code})'}")
|
||||
except Exception as e:
|
||||
print(f" OpenRouter: {check_mark(False)} error: {e}")
|
||||
|
||||
# Check gateway port
|
||||
try:
|
||||
import socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(1)
|
||||
result = sock.connect_ex(('127.0.0.1', 18789))
|
||||
sock.close()
|
||||
# Port in use = gateway likely running
|
||||
port_in_use = result == 0
|
||||
# This is informational, not necessarily bad
|
||||
print(f" Port 18789: {'in use' if port_in_use else 'available'}")
|
||||
except:
|
||||
pass
|
||||
|
||||
print()
|
||||
print(color("─" * 60, Colors.DIM))
|
||||
print(color(" Run 'hermes doctor' for detailed diagnostics", Colors.DIM))
|
||||
print(color(" Run 'hermes setup' to configure", Colors.DIM))
|
||||
print()
|
||||
341
hermes_cli/uninstall.py
Normal file
341
hermes_cli/uninstall.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""
|
||||
Hermes Agent Uninstaller.
|
||||
|
||||
Provides options for:
|
||||
- Full uninstall: Remove everything including configs and data
|
||||
- Keep data: Remove code but keep ~/.hermes/ (configs, sessions, logs)
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# ANSI colors
|
||||
class Colors:
|
||||
RESET = "\033[0m"
|
||||
BOLD = "\033[1m"
|
||||
DIM = "\033[2m"
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
BLUE = "\033[34m"
|
||||
MAGENTA = "\033[35m"
|
||||
CYAN = "\033[36m"
|
||||
|
||||
def color(text: str, *codes) -> str:
|
||||
"""Apply color codes to text (only in TTY)."""
|
||||
if not sys.stdout.isatty():
|
||||
return text
|
||||
return "".join(codes) + text + Colors.RESET
|
||||
|
||||
def log_info(msg: str):
|
||||
print(f"{color('→', Colors.CYAN)} {msg}")
|
||||
|
||||
def log_success(msg: str):
|
||||
print(f"{color('✓', Colors.GREEN)} {msg}")
|
||||
|
||||
def log_warn(msg: str):
|
||||
print(f"{color('⚠', Colors.YELLOW)} {msg}")
|
||||
|
||||
def log_error(msg: str):
|
||||
print(f"{color('✗', Colors.RED)} {msg}")
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""Get the project installation directory."""
|
||||
return Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
def get_hermes_home() -> Path:
|
||||
"""Get the Hermes home directory (~/.hermes)."""
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def find_shell_configs() -> list:
|
||||
"""Find shell configuration files that might have PATH entries."""
|
||||
home = Path.home()
|
||||
configs = []
|
||||
|
||||
candidates = [
|
||||
home / ".bashrc",
|
||||
home / ".bash_profile",
|
||||
home / ".profile",
|
||||
home / ".zshrc",
|
||||
home / ".zprofile",
|
||||
]
|
||||
|
||||
for config in candidates:
|
||||
if config.exists():
|
||||
configs.append(config)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def remove_path_from_shell_configs():
|
||||
"""Remove Hermes PATH entries from shell configuration files."""
|
||||
configs = find_shell_configs()
|
||||
removed_from = []
|
||||
|
||||
for config_path in configs:
|
||||
try:
|
||||
content = config_path.read_text()
|
||||
original_content = content
|
||||
|
||||
# Remove lines containing hermes-agent or hermes PATH entries
|
||||
new_lines = []
|
||||
skip_next = False
|
||||
|
||||
for line in content.split('\n'):
|
||||
# Skip the "# Hermes Agent" comment and following line
|
||||
if '# Hermes Agent' in line or '# hermes-agent' in line:
|
||||
skip_next = True
|
||||
continue
|
||||
if skip_next and ('hermes' in line.lower() and 'PATH' in line):
|
||||
skip_next = False
|
||||
continue
|
||||
skip_next = False
|
||||
|
||||
# Remove any PATH line containing hermes
|
||||
if 'hermes' in line.lower() and ('PATH=' in line or 'path=' in line.lower()):
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
|
||||
new_content = '\n'.join(new_lines)
|
||||
|
||||
# Clean up multiple blank lines
|
||||
while '\n\n\n' in new_content:
|
||||
new_content = new_content.replace('\n\n\n', '\n\n')
|
||||
|
||||
if new_content != original_content:
|
||||
config_path.write_text(new_content)
|
||||
removed_from.append(config_path)
|
||||
|
||||
except Exception as e:
|
||||
log_warn(f"Could not update {config_path}: {e}")
|
||||
|
||||
return removed_from
|
||||
|
||||
|
||||
def remove_wrapper_script():
|
||||
"""Remove the hermes wrapper script if it exists."""
|
||||
wrapper_paths = [
|
||||
Path.home() / ".local" / "bin" / "hermes",
|
||||
Path("/usr/local/bin/hermes"),
|
||||
]
|
||||
|
||||
removed = []
|
||||
for wrapper in wrapper_paths:
|
||||
if wrapper.exists():
|
||||
try:
|
||||
# Check if it's our wrapper (contains hermes_cli reference)
|
||||
content = wrapper.read_text()
|
||||
if 'hermes_cli' in content or 'hermes-agent' in content:
|
||||
wrapper.unlink()
|
||||
removed.append(wrapper)
|
||||
except Exception as e:
|
||||
log_warn(f"Could not remove {wrapper}: {e}")
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
def uninstall_gateway_service():
|
||||
"""Stop and uninstall the gateway service if running."""
|
||||
import platform
|
||||
|
||||
if platform.system() != "Linux":
|
||||
return False
|
||||
|
||||
service_file = Path.home() / ".config" / "systemd" / "user" / "hermes-gateway.service"
|
||||
|
||||
if not service_file.exists():
|
||||
return False
|
||||
|
||||
try:
|
||||
# Stop the service
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "stop", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
# Disable the service
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "disable", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
# Remove service file
|
||||
service_file.unlink()
|
||||
|
||||
# Reload systemd
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
log_warn(f"Could not fully remove gateway service: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def run_uninstall(args):
|
||||
"""
|
||||
Run the uninstall process.
|
||||
|
||||
Options:
|
||||
- Full uninstall: removes code + ~/.hermes/ (configs, data, logs)
|
||||
- Keep data: removes code but keeps ~/.hermes/ for future reinstall
|
||||
"""
|
||||
project_root = get_project_root()
|
||||
hermes_home = get_hermes_home()
|
||||
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.MAGENTA, Colors.BOLD))
|
||||
print(color("│ 🦋 Hermes Agent Uninstaller │", Colors.MAGENTA, Colors.BOLD))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.MAGENTA, Colors.BOLD))
|
||||
print()
|
||||
|
||||
# Show what will be affected
|
||||
print(color("Current Installation:", Colors.CYAN, Colors.BOLD))
|
||||
print(f" Code: {project_root}")
|
||||
print(f" Config: {hermes_home / 'config.yaml'}")
|
||||
print(f" Secrets: {hermes_home / '.env'}")
|
||||
print(f" Data: {hermes_home / 'cron/'}, {hermes_home / 'sessions/'}, {hermes_home / 'logs/'}")
|
||||
print()
|
||||
|
||||
# Ask for confirmation
|
||||
print(color("Uninstall Options:", Colors.YELLOW, Colors.BOLD))
|
||||
print()
|
||||
print(" 1) " + color("Keep data", Colors.GREEN) + " - Remove code only, keep configs/sessions/logs")
|
||||
print(" (Recommended - you can reinstall later with your settings intact)")
|
||||
print()
|
||||
print(" 2) " + color("Full uninstall", Colors.RED) + " - Remove everything including all data")
|
||||
print(" (Warning: This deletes all configs, sessions, and logs permanently)")
|
||||
print()
|
||||
print(" 3) " + color("Cancel", Colors.CYAN) + " - Don't uninstall")
|
||||
print()
|
||||
|
||||
try:
|
||||
choice = input(color("Select option [1/2/3]: ", Colors.BOLD)).strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
print("Cancelled.")
|
||||
return
|
||||
|
||||
if choice == "3" or choice.lower() in ("c", "cancel", "q", "quit", "n", "no"):
|
||||
print()
|
||||
print("Uninstall cancelled.")
|
||||
return
|
||||
|
||||
full_uninstall = (choice == "2")
|
||||
|
||||
# Final confirmation
|
||||
print()
|
||||
if full_uninstall:
|
||||
print(color("⚠️ WARNING: This will permanently delete ALL Hermes data!", Colors.RED, Colors.BOLD))
|
||||
print(color(" Including: configs, API keys, sessions, scheduled jobs, logs", Colors.RED))
|
||||
else:
|
||||
print("This will remove the Hermes code but keep your configuration and data.")
|
||||
|
||||
print()
|
||||
try:
|
||||
confirm = input(f"Type '{color('yes', Colors.YELLOW)}' to confirm: ").strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
print("Cancelled.")
|
||||
return
|
||||
|
||||
if confirm != "yes":
|
||||
print()
|
||||
print("Uninstall cancelled.")
|
||||
return
|
||||
|
||||
print()
|
||||
print(color("Uninstalling...", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
|
||||
# 1. Stop and uninstall gateway service
|
||||
log_info("Checking for gateway service...")
|
||||
if uninstall_gateway_service():
|
||||
log_success("Gateway service stopped and removed")
|
||||
else:
|
||||
log_info("No gateway service found")
|
||||
|
||||
# 2. Remove PATH entries from shell configs
|
||||
log_info("Removing PATH entries from shell configs...")
|
||||
removed_configs = remove_path_from_shell_configs()
|
||||
if removed_configs:
|
||||
for config in removed_configs:
|
||||
log_success(f"Updated {config}")
|
||||
else:
|
||||
log_info("No PATH entries found to remove")
|
||||
|
||||
# 3. Remove wrapper script
|
||||
log_info("Removing hermes command...")
|
||||
removed_wrappers = remove_wrapper_script()
|
||||
if removed_wrappers:
|
||||
for wrapper in removed_wrappers:
|
||||
log_success(f"Removed {wrapper}")
|
||||
else:
|
||||
log_info("No wrapper script found")
|
||||
|
||||
# 4. Remove installation directory (code)
|
||||
log_info(f"Removing installation directory...")
|
||||
|
||||
# Check if we're running from within the install dir
|
||||
# We need to be careful here
|
||||
try:
|
||||
if project_root.exists():
|
||||
# If the install is inside ~/.hermes/, just remove the hermes-agent subdir
|
||||
if hermes_home in project_root.parents or project_root.parent == hermes_home:
|
||||
shutil.rmtree(project_root)
|
||||
log_success(f"Removed {project_root}")
|
||||
else:
|
||||
# Installation is somewhere else entirely
|
||||
shutil.rmtree(project_root)
|
||||
log_success(f"Removed {project_root}")
|
||||
except Exception as e:
|
||||
log_warn(f"Could not fully remove {project_root}: {e}")
|
||||
log_info("You may need to manually remove it")
|
||||
|
||||
# 5. Optionally remove ~/.hermes/ data directory
|
||||
if full_uninstall:
|
||||
log_info("Removing configuration and data...")
|
||||
try:
|
||||
if hermes_home.exists():
|
||||
shutil.rmtree(hermes_home)
|
||||
log_success(f"Removed {hermes_home}")
|
||||
except Exception as e:
|
||||
log_warn(f"Could not fully remove {hermes_home}: {e}")
|
||||
log_info("You may need to manually remove it")
|
||||
else:
|
||||
log_info(f"Keeping configuration and data in {hermes_home}")
|
||||
|
||||
# Done
|
||||
print()
|
||||
print(color("┌─────────────────────────────────────────────────────────┐", Colors.GREEN, Colors.BOLD))
|
||||
print(color("│ ✓ Uninstall Complete! │", Colors.GREEN, Colors.BOLD))
|
||||
print(color("└─────────────────────────────────────────────────────────┘", Colors.GREEN, Colors.BOLD))
|
||||
print()
|
||||
|
||||
if not full_uninstall:
|
||||
print(color("Your configuration and data have been preserved:", Colors.CYAN))
|
||||
print(f" {hermes_home}/")
|
||||
print()
|
||||
print("To reinstall later with your existing settings:")
|
||||
print(color(" curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash", Colors.DIM))
|
||||
print()
|
||||
|
||||
print(color("Reload your shell to complete the process:", Colors.YELLOW))
|
||||
print(" source ~/.bashrc # or ~/.zshrc")
|
||||
print()
|
||||
print("Thank you for using Hermes Agent! 🦋")
|
||||
print()
|
||||
1
mini-swe-agent
Submodule
1
mini-swe-agent
Submodule
Submodule mini-swe-agent added at 07aa6a7385
708
mini_swe_runner.py
Normal file
708
mini_swe_runner.py
Normal file
@@ -0,0 +1,708 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Mini-SWE-Agent Runner with Hermes Trajectory Format
|
||||
|
||||
This module provides a runner that uses mini-swe-agent's execution environments
|
||||
(local, docker, modal) but outputs trajectories in the Hermes-Agent format
|
||||
compatible with batch_runner.py and trajectory_compressor.py.
|
||||
|
||||
Features:
|
||||
- Uses mini-swe-agent's Docker, Modal, or Local environments for command execution
|
||||
- Outputs trajectories in Hermes format (from/value pairs with <tool_call>/<tool_response> XML)
|
||||
- Compatible with the trajectory compression pipeline
|
||||
- Supports batch processing from JSONL prompt files
|
||||
|
||||
Usage:
|
||||
# Run a single task with local environment
|
||||
python mini_swe_runner.py --task "Create a hello world Python script" --env local
|
||||
|
||||
# Run with Docker
|
||||
python mini_swe_runner.py --task "List files in /tmp" --env docker --image python:3.11-slim
|
||||
|
||||
# Run with Modal (cloud)
|
||||
python mini_swe_runner.py --task "Install numpy and test it" --env modal --image python:3.11-slim
|
||||
|
||||
# Batch mode from JSONL file
|
||||
python mini_swe_runner.py --prompts_file prompts.jsonl --output_file trajectories.jsonl --env docker
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Literal
|
||||
|
||||
import fire
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Add mini-swe-agent to path if not installed
|
||||
mini_swe_path = Path(__file__).parent / "mini-swe-agent" / "src"
|
||||
if mini_swe_path.exists():
|
||||
sys.path.insert(0, str(mini_swe_path))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Terminal Tool Definition (matches Hermes-Agent format)
|
||||
# ============================================================================
|
||||
|
||||
TERMINAL_TOOL_DEFINITION = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": """Execute bash commands in a sandboxed environment.
|
||||
|
||||
**Environment:**
|
||||
- Isolated execution environment (local, Docker, or Modal cloud)
|
||||
- Filesystem persists between tool calls within the same task
|
||||
- Internet access available
|
||||
|
||||
**Command Execution:**
|
||||
- Provide the command to execute via the 'command' parameter
|
||||
- Optional 'timeout' parameter in seconds (default: 60)
|
||||
|
||||
**Examples:**
|
||||
- Run command: `{"command": "ls -la"}`
|
||||
- With timeout: `{"command": "long_task.sh", "timeout": 300}`
|
||||
|
||||
**Best Practices:**
|
||||
- Use non-interactive commands (avoid vim, nano, interactive python)
|
||||
- Pipe to cat if output might be large
|
||||
- Install tools with apt-get or pip as needed
|
||||
|
||||
**Completion:**
|
||||
- When task is complete, output: echo "MINI_SWE_AGENT_FINAL_OUTPUT" followed by your result
|
||||
""",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"description": "Command timeout in seconds (default: 60)"
|
||||
}
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Environment Factory
|
||||
# ============================================================================
|
||||
|
||||
def create_environment(
|
||||
env_type: str = "local",
|
||||
image: str = "python:3.11-slim",
|
||||
cwd: str = "/tmp",
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Create an execution environment from mini-swe-agent.
|
||||
|
||||
Args:
|
||||
env_type: One of "local", "docker", "modal"
|
||||
image: Docker/Modal image name (ignored for local)
|
||||
cwd: Working directory
|
||||
timeout: Default command timeout
|
||||
**kwargs: Additional environment-specific options
|
||||
|
||||
Returns:
|
||||
Environment instance with execute() method
|
||||
"""
|
||||
if env_type == "local":
|
||||
from minisweagent.environments.local import LocalEnvironment
|
||||
return LocalEnvironment(cwd=cwd, timeout=timeout)
|
||||
|
||||
elif env_type == "docker":
|
||||
from minisweagent.environments.docker import DockerEnvironment
|
||||
return DockerEnvironment(image=image, cwd=cwd, timeout=timeout, **kwargs)
|
||||
|
||||
elif env_type == "modal":
|
||||
from minisweagent.environments.extra.swerex_modal import SwerexModalEnvironment
|
||||
return SwerexModalEnvironment(image=image, cwd=cwd, timeout=timeout, **kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown environment type: {env_type}. Use 'local', 'docker', or 'modal'")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mini-SWE Runner with Hermes Trajectory Format
|
||||
# ============================================================================
|
||||
|
||||
class MiniSWERunner:
|
||||
"""
|
||||
Agent runner that uses mini-swe-agent environments but outputs
|
||||
trajectories in Hermes-Agent format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
env_type: str = "local",
|
||||
image: str = "python:3.11-slim",
|
||||
cwd: str = "/tmp",
|
||||
max_iterations: int = 15,
|
||||
command_timeout: int = 60,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the Mini-SWE Runner.
|
||||
|
||||
Args:
|
||||
model: Model name for OpenAI-compatible API
|
||||
base_url: API base URL (optional, uses env vars if not provided)
|
||||
api_key: API key (optional, uses env vars if not provided)
|
||||
env_type: Environment type - "local", "docker", or "modal"
|
||||
image: Docker/Modal image (ignored for local)
|
||||
cwd: Working directory for commands
|
||||
max_iterations: Maximum tool-calling iterations
|
||||
command_timeout: Default timeout for commands
|
||||
verbose: Enable verbose logging
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
self.command_timeout = command_timeout
|
||||
self.verbose = verbose
|
||||
self.env_type = env_type
|
||||
self.image = image
|
||||
self.cwd = cwd
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if verbose else logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize OpenAI client - defaults to OpenRouter
|
||||
from openai import OpenAI
|
||||
|
||||
client_kwargs = {}
|
||||
|
||||
# Default to OpenRouter if no base_url provided
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
else:
|
||||
client_kwargs["base_url"] = "https://openrouter.ai/api/v1"
|
||||
|
||||
# Handle API key - OpenRouter is the primary provider
|
||||
if api_key:
|
||||
client_kwargs["api_key"] = api_key
|
||||
else:
|
||||
client_kwargs["api_key"] = os.getenv(
|
||||
"OPENROUTER_API_KEY",
|
||||
os.getenv("ANTHROPIC_API_KEY", os.getenv("OPENAI_API_KEY", ""))
|
||||
)
|
||||
|
||||
self.client = OpenAI(**client_kwargs)
|
||||
|
||||
# Environment will be created per-task
|
||||
self.env = None
|
||||
|
||||
# Tool definition
|
||||
self.tools = [TERMINAL_TOOL_DEFINITION]
|
||||
|
||||
print(f"🤖 Mini-SWE Runner initialized")
|
||||
print(f" Model: {self.model}")
|
||||
print(f" Environment: {self.env_type}")
|
||||
if self.env_type != "local":
|
||||
print(f" Image: {self.image}")
|
||||
print(f" Max iterations: {self.max_iterations}")
|
||||
|
||||
def _create_env(self):
|
||||
"""Create the execution environment."""
|
||||
print(f"🔧 Creating {self.env_type} environment...")
|
||||
self.env = create_environment(
|
||||
env_type=self.env_type,
|
||||
image=self.image,
|
||||
cwd=self.cwd,
|
||||
timeout=self.command_timeout
|
||||
)
|
||||
print(f"✅ Environment ready")
|
||||
|
||||
def _cleanup_env(self):
|
||||
"""Cleanup the execution environment."""
|
||||
if self.env is not None:
|
||||
if hasattr(self.env, 'cleanup'):
|
||||
self.env.cleanup()
|
||||
elif hasattr(self.env, 'stop'):
|
||||
self.env.stop()
|
||||
self.env = None
|
||||
|
||||
def _execute_command(self, command: str, timeout: int = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a command in the environment.
|
||||
|
||||
Args:
|
||||
command: Bash command to execute
|
||||
timeout: Optional timeout override
|
||||
|
||||
Returns:
|
||||
Dict with 'output' and 'returncode'
|
||||
"""
|
||||
if self.env is None:
|
||||
self._create_env()
|
||||
|
||||
try:
|
||||
result = self.env.execute(command, timeout=timeout or self.command_timeout)
|
||||
return {
|
||||
"output": result.get("output", ""),
|
||||
"exit_code": result.get("returncode", 0),
|
||||
"error": None
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"output": "",
|
||||
"exit_code": -1,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
def _format_tools_for_system_message(self) -> str:
|
||||
"""Format tool definitions for the system message."""
|
||||
formatted_tools = []
|
||||
for tool in self.tools:
|
||||
func = tool["function"]
|
||||
formatted_tools.append({
|
||||
"name": func["name"],
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
"required": None
|
||||
})
|
||||
return json.dumps(formatted_tools, ensure_ascii=False)
|
||||
|
||||
def _convert_to_hermes_format(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
user_query: str,
|
||||
completed: bool
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert internal message format to Hermes trajectory format.
|
||||
|
||||
This produces the exact format used by batch_runner.py.
|
||||
"""
|
||||
trajectory = []
|
||||
|
||||
# System message with tool definitions
|
||||
system_msg = (
|
||||
"You are a function calling AI model. You are provided with function signatures within <tools> </tools> XML tags. "
|
||||
"You may call one or more functions to assist with the user query. If available tools are not relevant in assisting "
|
||||
"with user query, just respond in natural conversational language. Don't make assumptions about what values to plug "
|
||||
"into functions. After calling & executing the functions, you will be provided with function results within "
|
||||
"<tool_response> </tool_response> XML tags. Here are the available tools:\n"
|
||||
f"<tools>\n{self._format_tools_for_system_message()}\n</tools>\n"
|
||||
"For each function call return a JSON object, with the following pydantic model json schema for each:\n"
|
||||
"{'title': 'FunctionCall', 'type': 'object', 'properties': {'name': {'title': 'Name', 'type': 'string'}, "
|
||||
"'arguments': {'title': 'Arguments', 'type': 'object'}}, 'required': ['name', 'arguments']}\n"
|
||||
"Each function call should be enclosed within <tool_call> </tool_call> XML tags.\n"
|
||||
"Example:\n<tool_call>\n{'name': <function-name>,'arguments': <args-dict>}\n</tool_call>"
|
||||
)
|
||||
|
||||
trajectory.append({"from": "system", "value": system_msg})
|
||||
trajectory.append({"from": "human", "value": user_query})
|
||||
|
||||
# Process messages (skip first user message as we already added it)
|
||||
i = 1
|
||||
while i < len(messages):
|
||||
msg = messages[i]
|
||||
|
||||
if msg["role"] == "assistant":
|
||||
if "tool_calls" in msg and msg["tool_calls"]:
|
||||
# Assistant message with tool calls
|
||||
content = ""
|
||||
|
||||
# Add reasoning if present
|
||||
if msg.get("reasoning"):
|
||||
content = f"<think>{msg['reasoning']}</think>"
|
||||
|
||||
if msg.get("content"):
|
||||
content += msg["content"] + "\n"
|
||||
|
||||
# Add tool calls in XML format
|
||||
for tool_call in msg["tool_calls"]:
|
||||
try:
|
||||
arguments = json.loads(tool_call["function"]["arguments"]) \
|
||||
if isinstance(tool_call["function"]["arguments"], str) \
|
||||
else tool_call["function"]["arguments"]
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
|
||||
tool_call_json = {
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": arguments
|
||||
}
|
||||
content += f"<tool_call>\n{json.dumps(tool_call_json, ensure_ascii=False)}\n</tool_call>\n"
|
||||
|
||||
trajectory.append({"from": "gpt", "value": content.rstrip()})
|
||||
|
||||
# Collect subsequent tool responses
|
||||
tool_responses = []
|
||||
j = i + 1
|
||||
while j < len(messages) and messages[j]["role"] == "tool":
|
||||
tool_msg = messages[j]
|
||||
tool_content = tool_msg["content"]
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
if tool_content.strip().startswith(("{", "[")):
|
||||
tool_content = json.loads(tool_content)
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
|
||||
tool_response = f"<tool_response>\n"
|
||||
tool_response += json.dumps({
|
||||
"tool_call_id": tool_msg.get("tool_call_id", ""),
|
||||
"name": msg["tool_calls"][len(tool_responses)]["function"]["name"] \
|
||||
if len(tool_responses) < len(msg["tool_calls"]) else "unknown",
|
||||
"content": tool_content
|
||||
}, ensure_ascii=False)
|
||||
tool_response += "\n</tool_response>"
|
||||
tool_responses.append(tool_response)
|
||||
j += 1
|
||||
|
||||
if tool_responses:
|
||||
trajectory.append({"from": "tool", "value": "\n".join(tool_responses)})
|
||||
i = j - 1
|
||||
|
||||
else:
|
||||
# Regular assistant message (no tool calls)
|
||||
content = ""
|
||||
if msg.get("reasoning"):
|
||||
content = f"<think>{msg['reasoning']}</think>"
|
||||
content += msg.get("content") or ""
|
||||
trajectory.append({"from": "gpt", "value": content})
|
||||
|
||||
elif msg["role"] == "user":
|
||||
trajectory.append({"from": "human", "value": msg["content"]})
|
||||
|
||||
i += 1
|
||||
|
||||
return trajectory
|
||||
|
||||
def run_task(self, task: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a single task and return the result with trajectory.
|
||||
|
||||
Args:
|
||||
task: The task/prompt to execute
|
||||
|
||||
Returns:
|
||||
Dict with trajectory, completion status, and metadata
|
||||
"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📝 Task: {task[:80]}{'...' if len(task) > 80 else ''}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Initialize environment
|
||||
self._create_env()
|
||||
|
||||
# Message history
|
||||
messages = [{"role": "user", "content": task}]
|
||||
|
||||
# System prompt for the LLM (ephemeral - not saved to trajectory)
|
||||
system_prompt = """You are an AI agent that can execute bash commands to complete tasks.
|
||||
|
||||
When you need to run commands, use the 'terminal' tool with your bash command.
|
||||
|
||||
**Important:**
|
||||
- When you have completed the task successfully, run: echo "MINI_SWE_AGENT_FINAL_OUTPUT" followed by a summary
|
||||
- Be concise and efficient in your approach
|
||||
- Install any needed tools with apt-get or pip
|
||||
- Avoid interactive commands (no vim, nano, less, etc.)
|
||||
|
||||
Complete the user's task step by step."""
|
||||
|
||||
api_call_count = 0
|
||||
completed = False
|
||||
final_response = None
|
||||
|
||||
try:
|
||||
while api_call_count < self.max_iterations:
|
||||
api_call_count += 1
|
||||
print(f"\n🔄 API call #{api_call_count}/{self.max_iterations}")
|
||||
|
||||
# Prepare API messages
|
||||
api_messages = [{"role": "system", "content": system_prompt}] + messages
|
||||
|
||||
# Make API call
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
tools=self.tools,
|
||||
timeout=300.0
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"API call failed: {e}")
|
||||
break
|
||||
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
# Log assistant response
|
||||
if assistant_message.content:
|
||||
print(f"🤖 Assistant: {assistant_message.content[:100]}...")
|
||||
|
||||
# Check for tool calls
|
||||
if assistant_message.tool_calls:
|
||||
print(f"🔧 Tool calls: {len(assistant_message.tool_calls)}")
|
||||
|
||||
# Add assistant message with tool calls
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": assistant_message.content,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tc.id,
|
||||
"type": tc.type,
|
||||
"function": {
|
||||
"name": tc.function.name,
|
||||
"arguments": tc.function.arguments
|
||||
}
|
||||
}
|
||||
for tc in assistant_message.tool_calls
|
||||
]
|
||||
})
|
||||
|
||||
# Execute each tool call
|
||||
for tc in assistant_message.tool_calls:
|
||||
try:
|
||||
args = json.loads(tc.function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
|
||||
command = args.get("command", "echo 'No command provided'")
|
||||
timeout = args.get("timeout", self.command_timeout)
|
||||
|
||||
print(f" 📞 terminal: {command[:60]}...")
|
||||
|
||||
# Execute command
|
||||
result = self._execute_command(command, timeout)
|
||||
|
||||
# Format result
|
||||
result_json = json.dumps({
|
||||
"content": {
|
||||
"output": result["output"],
|
||||
"exit_code": result["exit_code"],
|
||||
"error": result["error"]
|
||||
}
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Check for task completion signal
|
||||
if "MINI_SWE_AGENT_FINAL_OUTPUT" in result["output"]:
|
||||
print(f" ✅ Task completion signal detected!")
|
||||
completed = True
|
||||
|
||||
# Add tool response
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"content": result_json,
|
||||
"tool_call_id": tc.id
|
||||
})
|
||||
|
||||
print(f" ✅ exit_code={result['exit_code']}, output={len(result['output'])} chars")
|
||||
|
||||
# If task completed, we can stop
|
||||
if completed:
|
||||
final_response = assistant_message.content
|
||||
break
|
||||
|
||||
else:
|
||||
# No tool calls - final response
|
||||
final_response = assistant_message.content or ""
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": final_response
|
||||
})
|
||||
completed = True
|
||||
print(f"🎉 Agent finished (no more tool calls)")
|
||||
break
|
||||
|
||||
if api_call_count >= self.max_iterations:
|
||||
print(f"⚠️ Reached max iterations ({self.max_iterations})")
|
||||
|
||||
finally:
|
||||
# Cleanup environment
|
||||
self._cleanup_env()
|
||||
|
||||
# Convert to Hermes trajectory format
|
||||
trajectory = self._convert_to_hermes_format(messages, task, completed)
|
||||
|
||||
return {
|
||||
"conversations": trajectory,
|
||||
"completed": completed,
|
||||
"api_calls": api_call_count,
|
||||
"metadata": {
|
||||
"model": self.model,
|
||||
"env_type": self.env_type,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
def run_batch(
|
||||
self,
|
||||
prompts: List[str],
|
||||
output_file: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Run multiple tasks and save trajectories to a JSONL file.
|
||||
|
||||
Args:
|
||||
prompts: List of task prompts
|
||||
output_file: Output JSONL file path
|
||||
|
||||
Returns:
|
||||
List of results
|
||||
"""
|
||||
results = []
|
||||
|
||||
print(f"\n📦 Running batch of {len(prompts)} tasks")
|
||||
print(f"📁 Output: {output_file}")
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i, prompt in enumerate(prompts, 1):
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📋 Task {i}/{len(prompts)}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
try:
|
||||
result = self.run_task(prompt)
|
||||
results.append(result)
|
||||
|
||||
# Write to file immediately
|
||||
f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||||
f.flush()
|
||||
|
||||
print(f"✅ Task {i} completed (api_calls={result['api_calls']})")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error on task {i}: {e}")
|
||||
error_result = {
|
||||
"conversations": [],
|
||||
"completed": False,
|
||||
"api_calls": 0,
|
||||
"error": str(e),
|
||||
"metadata": {"timestamp": datetime.now().isoformat()}
|
||||
}
|
||||
results.append(error_result)
|
||||
f.write(json.dumps(error_result, ensure_ascii=False) + "\n")
|
||||
f.flush()
|
||||
|
||||
print(f"\n✅ Batch complete! {len(results)} trajectories saved to {output_file}")
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Interface
|
||||
# ============================================================================
|
||||
|
||||
def main(
|
||||
task: str = None,
|
||||
prompts_file: str = None,
|
||||
output_file: str = "mini-swe-agent-test1.jsonl",
|
||||
model: str = "claude-sonnet-4-20250514",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
env: str = "local",
|
||||
image: str = "python:3.11-slim",
|
||||
cwd: str = "/tmp",
|
||||
max_iterations: int = 15,
|
||||
timeout: int = 60,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
Run mini-swe-agent tasks with Hermes trajectory format output.
|
||||
|
||||
Args:
|
||||
task: Single task to run (use this OR prompts_file)
|
||||
prompts_file: JSONL file with prompts (each line: {"prompt": "..."})
|
||||
output_file: Output JSONL file for trajectories
|
||||
model: Model name (default: claude-sonnet-4-20250514)
|
||||
base_url: API base URL (optional)
|
||||
api_key: API key (optional, uses env vars)
|
||||
env: Environment type - "local", "docker", or "modal"
|
||||
image: Docker/Modal image (default: python:3.11-slim)
|
||||
cwd: Working directory (default: /tmp)
|
||||
max_iterations: Maximum tool-calling iterations (default: 15)
|
||||
timeout: Command timeout in seconds (default: 60)
|
||||
verbose: Enable verbose logging
|
||||
|
||||
Examples:
|
||||
# Single task with local environment
|
||||
python mini_swe_runner.py --task "Create hello.py that prints Hello World"
|
||||
|
||||
# Single task with Docker
|
||||
python mini_swe_runner.py --task "List files" --env docker
|
||||
|
||||
# Batch from file
|
||||
python mini_swe_runner.py --prompts_file tasks.jsonl --output_file results.jsonl
|
||||
"""
|
||||
print("🚀 Mini-SWE Runner with Hermes Trajectory Format")
|
||||
print("=" * 60)
|
||||
|
||||
# Initialize runner
|
||||
runner = MiniSWERunner(
|
||||
model=model,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
env_type=env,
|
||||
image=image,
|
||||
cwd=cwd,
|
||||
max_iterations=max_iterations,
|
||||
command_timeout=timeout,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
if task:
|
||||
# Single task mode
|
||||
result = runner.run_task(task)
|
||||
|
||||
# Save to file
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(result, ensure_ascii=False) + "\n")
|
||||
|
||||
print(f"\n📁 Trajectory saved to: {output_file}")
|
||||
print(f"✅ Completed: {result['completed']}")
|
||||
print(f"📞 API calls: {result['api_calls']}")
|
||||
print(f"💬 Turns: {len(result['conversations'])}")
|
||||
|
||||
elif prompts_file:
|
||||
# Batch mode
|
||||
prompts = []
|
||||
with open(prompts_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
prompts.append(entry.get("prompt", entry.get("task", "")))
|
||||
except json.JSONDecodeError:
|
||||
prompts.append(line)
|
||||
|
||||
if not prompts:
|
||||
print(f"❌ No prompts found in {prompts_file}")
|
||||
return
|
||||
|
||||
runner.run_batch(prompts, output_file)
|
||||
|
||||
else:
|
||||
print("❌ Please provide either --task or --prompts_file")
|
||||
print(" Example: python mini_swe_runner.py --task 'Create a hello world script'")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
1226
model_tools.py
1226
model_tools.py
File diff suppressed because it is too large
Load Diff
77
package-lock.json
generated
Normal file
77
package-lock.json
generated
Normal file
@@ -0,0 +1,77 @@
|
||||
{
|
||||
"name": "hermes-agent",
|
||||
"version": "1.0.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "hermes-agent",
|
||||
"version": "1.0.0",
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"agent-browser": "^0.7.6"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/agent-browser": {
|
||||
"version": "0.7.6",
|
||||
"resolved": "https://registry.npmjs.org/agent-browser/-/agent-browser-0.7.6.tgz",
|
||||
"integrity": "sha512-BDmzFlTM0siqn5P8LSBxgOBUNGv02Vo7RYztvXXjNOwQ+8rFJILWfBPxmw+57l/PcMst61AscjIe8uZ5sWrRZQ==",
|
||||
"hasInstallScript": true,
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"playwright-core": "^1.57.0",
|
||||
"ws": "^8.19.0",
|
||||
"zod": "^3.22.4"
|
||||
},
|
||||
"bin": {
|
||||
"agent-browser": "bin/agent-browser"
|
||||
}
|
||||
},
|
||||
"node_modules/playwright-core": {
|
||||
"version": "1.58.0",
|
||||
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.58.0.tgz",
|
||||
"integrity": "sha512-aaoB1RWrdNi3//rOeKuMiS65UCcgOVljU46At6eFcOFPFHWtd2weHRRow6z/n+Lec0Lvu0k9ZPKJSjPugikirw==",
|
||||
"license": "Apache-2.0",
|
||||
"bin": {
|
||||
"playwright-core": "cli.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18"
|
||||
}
|
||||
},
|
||||
"node_modules/ws": {
|
||||
"version": "8.19.0",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.19.0.tgz",
|
||||
"integrity": "sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"bufferutil": "^4.0.1",
|
||||
"utf-8-validate": ">=5.0.2"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"bufferutil": {
|
||||
"optional": true
|
||||
},
|
||||
"utf-8-validate": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/zod": {
|
||||
"version": "3.25.76",
|
||||
"resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz",
|
||||
"integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/colinhacks"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
24
package.json
Normal file
24
package.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"name": "hermes-agent",
|
||||
"version": "1.0.0",
|
||||
"description": "An AI agent with advanced tool-calling capabilities, featuring a flexible toolsets system for organizing and managing tools.",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"postinstall": "echo '✅ Browser tools ready. Run: python run_agent.py --help'"
|
||||
},
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/NousResearch/Hermes-Agent.git"
|
||||
},
|
||||
"license": "MIT",
|
||||
"bugs": {
|
||||
"url": "https://github.com/NousResearch/Hermes-Agent/issues"
|
||||
},
|
||||
"homepage": "https://github.com/NousResearch/Hermes-Agent#readme",
|
||||
"dependencies": {
|
||||
"agent-browser": "^0.7.6"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
}
|
||||
}
|
||||
@@ -8,21 +8,43 @@ version = "0.1.0"
|
||||
description = "AI agent with advanced tool-calling and toolsets"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
authors = [{ name = "Hermes Agent" }]
|
||||
authors = [{ name = "Nous Research" }]
|
||||
license = { text = "MIT" }
|
||||
dependencies = [
|
||||
"firecrawl-py",
|
||||
# Core
|
||||
"openai",
|
||||
"fal-client",
|
||||
"python-dotenv",
|
||||
"fire"
|
||||
"fire",
|
||||
"httpx",
|
||||
"rich",
|
||||
"tenacity",
|
||||
"pyyaml",
|
||||
"requests",
|
||||
"jinja2",
|
||||
"pydantic>=2.0",
|
||||
# Tools
|
||||
"firecrawl-py",
|
||||
"fal-client",
|
||||
# mini-swe-agent deps (terminal tool)
|
||||
"litellm>=1.75.5",
|
||||
"typer",
|
||||
"platformdirs",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
modal = ["modal", "boto3"]
|
||||
dev = ["pytest", "pytest-asyncio"]
|
||||
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0"]
|
||||
cron = ["croniter"]
|
||||
cli = ["simple-term-menu"]
|
||||
all = ["croniter", "python-telegram-bot>=20.0", "discord.py>=2.0", "simple-term-menu"]
|
||||
|
||||
[project.scripts]
|
||||
hermes = "hermes_cli.main:main"
|
||||
hermes-agent = "run_agent:main"
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets"]
|
||||
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions", "cli"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
include = ["tools"]
|
||||
include = ["tools", "hermes_cli", "gateway", "cron"]
|
||||
|
||||
@@ -1,6 +1,44 @@
|
||||
firecrawl-py
|
||||
# Core dependencies
|
||||
openai
|
||||
fal-client
|
||||
python-dotenv
|
||||
fire
|
||||
httpx
|
||||
httpx
|
||||
rich
|
||||
tenacity
|
||||
prompt_toolkit
|
||||
|
||||
# Web tools
|
||||
firecrawl-py
|
||||
|
||||
# Image generation
|
||||
fal-client
|
||||
|
||||
# mini-swe-agent dependencies (for terminal tool)
|
||||
# Note: Install mini-swe-agent itself with: pip install -e ./mini-swe-agent
|
||||
pyyaml
|
||||
requests
|
||||
jinja2
|
||||
pydantic>=2.0
|
||||
litellm>=1.75.5
|
||||
typer
|
||||
platformdirs
|
||||
|
||||
# Optional: For Docker backend (recommended)
|
||||
# Requires Docker installed and user in 'docker' group
|
||||
|
||||
# Optional: For Modal backend (cloud execution)
|
||||
# modal
|
||||
# boto3
|
||||
|
||||
# Optional: For cron expression parsing (cronjob scheduling)
|
||||
croniter
|
||||
|
||||
# Optional: For messaging platform integrations (gateway)
|
||||
# Telegram: pip install python-telegram-bot
|
||||
python-telegram-bot>=20.0
|
||||
|
||||
# Discord: pip install discord.py
|
||||
discord.py>=2.0
|
||||
|
||||
# WhatsApp: Requires Node.js bridge (see docs/messaging.md)
|
||||
# aiohttp # For WhatsApp bridge communication
|
||||
448
rl_cli.py
Normal file
448
rl_cli.py
Normal file
@@ -0,0 +1,448 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
RL Training CLI Runner
|
||||
|
||||
Dedicated CLI runner for RL training workflows with:
|
||||
- Extended timeouts for long-running training
|
||||
- RL-focused system prompts
|
||||
- Full toolset including RL training tools
|
||||
- Special handling for 30-minute check intervals
|
||||
|
||||
Usage:
|
||||
python rl_cli.py "Train a model on GSM8k for math reasoning"
|
||||
python rl_cli.py --interactive
|
||||
python rl_cli.py --list-environments
|
||||
|
||||
Environment Variables:
|
||||
TINKER_API_KEY: API key for Tinker service (required)
|
||||
WANDB_API_KEY: API key for WandB metrics (required)
|
||||
OPENROUTER_API_KEY: API key for OpenRouter (required for agent)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import yaml
|
||||
|
||||
# Load environment variables from .env file
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load from ~/.hermes/.env first, then local .env
|
||||
hermes_env_path = Path.home() / '.hermes' / '.env'
|
||||
local_env_path = Path(__file__).parent / '.env'
|
||||
|
||||
if hermes_env_path.exists():
|
||||
load_dotenv(dotenv_path=hermes_env_path)
|
||||
print(f"✅ Loaded environment variables from {hermes_env_path}")
|
||||
elif local_env_path.exists():
|
||||
load_dotenv(dotenv_path=local_env_path)
|
||||
print(f"✅ Loaded environment variables from {local_env_path}")
|
||||
|
||||
# Set terminal working directory to tinker-atropos submodule
|
||||
# This ensures terminal commands run in the right context for RL work
|
||||
tinker_atropos_dir = Path(__file__).parent / 'tinker-atropos'
|
||||
if tinker_atropos_dir.exists():
|
||||
os.environ['TERMINAL_CWD'] = str(tinker_atropos_dir)
|
||||
os.environ['HERMES_QUIET'] = '1' # Disable temp subdirectory creation
|
||||
print(f"📂 Terminal working directory: {tinker_atropos_dir}")
|
||||
else:
|
||||
# Fall back to hermes-agent directory if submodule not found
|
||||
os.environ['TERMINAL_CWD'] = str(Path(__file__).parent)
|
||||
os.environ['HERMES_QUIET'] = '1'
|
||||
print(f"⚠️ tinker-atropos submodule not found, using: {Path(__file__).parent}")
|
||||
|
||||
# Import agent and tools
|
||||
from run_agent import AIAgent
|
||||
from model_tools import get_tool_definitions, check_toolset_requirements
|
||||
from tools.rl_training_tool import check_rl_api_keys, get_missing_keys
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config Loading
|
||||
# ============================================================================
|
||||
|
||||
DEFAULT_MODEL = "anthropic/claude-opus-4.5"
|
||||
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def load_hermes_config() -> dict:
|
||||
"""
|
||||
Load configuration from ~/.hermes/config.yaml.
|
||||
|
||||
Returns:
|
||||
dict: Configuration with model, base_url, etc.
|
||||
"""
|
||||
config_path = Path.home() / '.hermes' / 'config.yaml'
|
||||
|
||||
config = {
|
||||
"model": DEFAULT_MODEL,
|
||||
"base_url": DEFAULT_BASE_URL,
|
||||
}
|
||||
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
file_config = yaml.safe_load(f) or {}
|
||||
|
||||
# Get model from config
|
||||
if "model" in file_config:
|
||||
if isinstance(file_config["model"], str):
|
||||
config["model"] = file_config["model"]
|
||||
elif isinstance(file_config["model"], dict):
|
||||
config["model"] = file_config["model"].get("default", DEFAULT_MODEL)
|
||||
|
||||
# Get base_url if specified
|
||||
if "base_url" in file_config:
|
||||
config["base_url"] = file_config["base_url"]
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Failed to load config.yaml: {e}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RL-Specific Configuration
|
||||
# ============================================================================
|
||||
|
||||
# Extended timeouts for long-running RL operations
|
||||
RL_MAX_ITERATIONS = 200 # Allow many more iterations for long workflows
|
||||
|
||||
# RL-focused system prompt
|
||||
RL_SYSTEM_PROMPT = """You are an automated post-training engineer specializing in reinforcement learning for language models.
|
||||
|
||||
## Your Capabilities
|
||||
|
||||
You have access to RL training tools for running reinforcement learning on models through Tinker-Atropos:
|
||||
|
||||
1. **DISCOVER**: Use `rl_list_environments` to see available RL environments
|
||||
2. **INSPECT**: Read environment files to understand how they work (verifiers, data loading, rewards)
|
||||
3. **INSPECT DATA**: Use terminal to explore HuggingFace datasets and understand their format
|
||||
4. **CREATE**: Copy existing environments as templates, modify for your needs
|
||||
5. **CONFIGURE**: Use `rl_select_environment` and `rl_edit_config` to set up training
|
||||
6. **TEST**: Always use `rl_test_inference` before full training to validate your setup
|
||||
7. **TRAIN**: Use `rl_start_training` to begin, `rl_check_status` to monitor
|
||||
8. **EVALUATE**: Use `rl_get_results` and analyze WandB metrics to assess performance
|
||||
|
||||
## Environment Files
|
||||
|
||||
Environment files are located in: `tinker-atropos/tinker_atropos/environments/`
|
||||
|
||||
Study existing environments to learn patterns. Look for:
|
||||
- `load_dataset()` calls - how data is loaded
|
||||
- `score_answer()` / `score()` - verification logic
|
||||
- `get_next_item()` - prompt formatting
|
||||
- `system_prompt` - instruction format
|
||||
- `config_init()` - default configuration
|
||||
|
||||
## Creating New Environments
|
||||
|
||||
To create a new environment:
|
||||
1. Read an existing environment file (e.g., gsm8k_tinker.py)
|
||||
2. Use terminal to explore the target dataset format
|
||||
3. Copy the environment file as a template
|
||||
4. Modify the dataset loading, prompt formatting, and verifier logic
|
||||
5. Test with `rl_test_inference` before training
|
||||
|
||||
## Important Guidelines
|
||||
|
||||
- **Always test before training**: Training runs take hours - verify everything works first
|
||||
- **Monitor metrics**: Check WandB for reward/mean and percent_correct
|
||||
- **Status check intervals**: Wait at least 30 minutes between status checks
|
||||
- **Early stopping**: Stop training early if metrics look bad or stagnant
|
||||
- **Iterate quickly**: Start with small total_steps to validate, then scale up
|
||||
|
||||
## Available Toolsets
|
||||
|
||||
You have access to:
|
||||
- **RL tools**: Environment discovery, config management, training, testing
|
||||
- **Terminal**: Run commands, inspect files, explore datasets
|
||||
- **Web**: Search for information, documentation, papers
|
||||
- **File tools**: Read and modify code files
|
||||
|
||||
When asked to train a model, follow this workflow:
|
||||
1. List available environments
|
||||
2. Select and configure the appropriate environment
|
||||
3. Test with sample prompts
|
||||
4. Start training with conservative settings
|
||||
5. Monitor progress and adjust as needed
|
||||
"""
|
||||
|
||||
# Toolsets to enable for RL workflows
|
||||
RL_TOOLSETS = ["terminal", "web", "rl"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
def check_requirements():
|
||||
"""Check that all required environment variables and services are available."""
|
||||
errors = []
|
||||
|
||||
# Check API keys
|
||||
if not os.getenv("OPENROUTER_API_KEY"):
|
||||
errors.append("OPENROUTER_API_KEY not set - required for agent")
|
||||
|
||||
missing_rl_keys = get_missing_keys()
|
||||
if missing_rl_keys:
|
||||
errors.append(f"Missing RL API keys: {', '.join(missing_rl_keys)}")
|
||||
|
||||
if errors:
|
||||
print("❌ Missing requirements:")
|
||||
for error in errors:
|
||||
print(f" - {error}")
|
||||
print("\nPlease set these environment variables in your .env file or shell.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_tinker_atropos():
|
||||
"""Check if tinker-atropos submodule is properly set up."""
|
||||
tinker_path = Path(__file__).parent / "tinker-atropos"
|
||||
|
||||
if not tinker_path.exists():
|
||||
return False, "tinker-atropos submodule not found. Run: git submodule update --init"
|
||||
|
||||
envs_path = tinker_path / "tinker_atropos" / "environments"
|
||||
if not envs_path.exists():
|
||||
return False, f"environments directory not found at {envs_path}"
|
||||
|
||||
env_files = list(envs_path.glob("*.py"))
|
||||
env_files = [f for f in env_files if not f.name.startswith("_")]
|
||||
|
||||
return True, {"path": str(tinker_path), "environments_count": len(env_files)}
|
||||
|
||||
|
||||
def list_environments_sync():
|
||||
"""List available environments (synchronous wrapper)."""
|
||||
from tools.rl_training_tool import rl_list_environments
|
||||
import json
|
||||
|
||||
async def _list():
|
||||
result = await rl_list_environments()
|
||||
return json.loads(result)
|
||||
|
||||
return asyncio.run(_list())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main CLI
|
||||
# ============================================================================
|
||||
|
||||
def main(
|
||||
task: str = None,
|
||||
model: str = None,
|
||||
api_key: str = None,
|
||||
base_url: str = None,
|
||||
max_iterations: int = RL_MAX_ITERATIONS,
|
||||
interactive: bool = False,
|
||||
list_environments: bool = False,
|
||||
check_server: bool = False,
|
||||
verbose: bool = False,
|
||||
save_trajectories: bool = True,
|
||||
):
|
||||
"""
|
||||
RL Training CLI - Dedicated runner for RL training workflows.
|
||||
|
||||
Args:
|
||||
task: The training task/goal (e.g., "Train a model on GSM8k for math")
|
||||
model: Model to use for the agent (reads from ~/.hermes/config.yaml if not provided)
|
||||
api_key: OpenRouter API key (uses OPENROUTER_API_KEY env var if not provided)
|
||||
base_url: API base URL (reads from config or defaults to OpenRouter)
|
||||
max_iterations: Maximum agent iterations (default: 200 for long workflows)
|
||||
interactive: Run in interactive mode (multiple conversations)
|
||||
list_environments: Just list available RL environments and exit
|
||||
check_server: Check if RL API server is running and exit
|
||||
verbose: Enable verbose logging
|
||||
save_trajectories: Save conversation trajectories (default: True for RL)
|
||||
|
||||
Examples:
|
||||
# Train on a specific environment
|
||||
python rl_cli.py "Train a model on GSM8k math problems"
|
||||
|
||||
# Interactive mode
|
||||
python rl_cli.py --interactive
|
||||
|
||||
# List available environments
|
||||
python rl_cli.py --list-environments
|
||||
|
||||
# Check server status
|
||||
python rl_cli.py --check-server
|
||||
"""
|
||||
# Load config from ~/.hermes/config.yaml
|
||||
config = load_hermes_config()
|
||||
|
||||
# Use config values if not explicitly provided
|
||||
if model is None:
|
||||
model = config["model"]
|
||||
if base_url is None:
|
||||
base_url = config["base_url"]
|
||||
|
||||
print("🎯 RL Training Agent")
|
||||
print("=" * 60)
|
||||
|
||||
# Handle setup check
|
||||
if check_server:
|
||||
print("\n🔍 Checking tinker-atropos setup...")
|
||||
ok, result = check_tinker_atropos()
|
||||
if ok:
|
||||
print("✅ tinker-atropos submodule found")
|
||||
print(f" Path: {result.get('path')}")
|
||||
print(f" Environments found: {result.get('environments_count', 0)}")
|
||||
|
||||
# Also check API keys
|
||||
missing = get_missing_keys()
|
||||
if missing:
|
||||
print(f"\n⚠️ Missing API keys: {', '.join(missing)}")
|
||||
print(" Add them to ~/.hermes/.env")
|
||||
else:
|
||||
print("✅ API keys configured")
|
||||
else:
|
||||
print(f"❌ tinker-atropos not set up: {result}")
|
||||
print("\nTo set up:")
|
||||
print(" git submodule update --init")
|
||||
print(" pip install -e ./tinker-atropos")
|
||||
return
|
||||
|
||||
# Handle environment listing
|
||||
if list_environments:
|
||||
print("\n📋 Available RL Environments:")
|
||||
print("-" * 40)
|
||||
try:
|
||||
data = list_environments_sync()
|
||||
if "error" in data:
|
||||
print(f"❌ Error: {data['error']}")
|
||||
return
|
||||
|
||||
envs = data.get("environments", [])
|
||||
if not envs:
|
||||
print("No environments found.")
|
||||
print("\nMake sure tinker-atropos is set up:")
|
||||
print(" git submodule update --init")
|
||||
return
|
||||
|
||||
for env in envs:
|
||||
print(f"\n 📦 {env['name']}")
|
||||
print(f" Class: {env['class_name']}")
|
||||
print(f" Path: {env['file_path']}")
|
||||
if env.get('description'):
|
||||
desc = env['description'][:100] + "..." if len(env.get('description', '')) > 100 else env.get('description', '')
|
||||
print(f" Description: {desc}")
|
||||
|
||||
print(f"\n📊 Total: {len(envs)} environments")
|
||||
print("\nUse `rl_select_environment(name)` to select an environment for training.")
|
||||
except Exception as e:
|
||||
print(f"❌ Error listing environments: {e}")
|
||||
print("\nMake sure tinker-atropos is set up:")
|
||||
print(" git submodule update --init")
|
||||
print(" pip install -e ./tinker-atropos")
|
||||
return
|
||||
|
||||
# Check requirements
|
||||
if not check_requirements():
|
||||
sys.exit(1)
|
||||
|
||||
# Set default task if none provided
|
||||
if not task and not interactive:
|
||||
print("\n⚠️ No task provided. Use --interactive for interactive mode or provide a task.")
|
||||
print("\nExamples:")
|
||||
print(' python rl_cli.py "Train a model on GSM8k math problems"')
|
||||
print(' python rl_cli.py "Create an RL environment for code generation"')
|
||||
print(' python rl_cli.py --interactive')
|
||||
return
|
||||
|
||||
# Get API key
|
||||
api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
print("❌ No API key provided. Set OPENROUTER_API_KEY or pass --api-key")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"\n🤖 Model: {model}")
|
||||
print(f"🔧 Max iterations: {max_iterations}")
|
||||
print(f"📁 Toolsets: {', '.join(RL_TOOLSETS)}")
|
||||
print("=" * 60)
|
||||
|
||||
# Create agent with RL configuration
|
||||
agent = AIAgent(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
max_iterations=max_iterations,
|
||||
enabled_toolsets=RL_TOOLSETS,
|
||||
save_trajectories=save_trajectories,
|
||||
verbose_logging=verbose,
|
||||
quiet_mode=False,
|
||||
ephemeral_system_prompt=RL_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
if interactive:
|
||||
# Interactive mode - multiple conversations
|
||||
print("\n🔄 Interactive RL Training Mode")
|
||||
print("Type 'quit' or 'exit' to end the session.")
|
||||
print("Type 'status' to check active training runs.")
|
||||
print("-" * 40)
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n🎯 RL Task> ").strip()
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
if user_input.lower() in ('quit', 'exit', 'q'):
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
|
||||
if user_input.lower() == 'status':
|
||||
# Quick status check
|
||||
from tools.rl_training_tool import rl_list_runs
|
||||
import json
|
||||
result = asyncio.run(rl_list_runs())
|
||||
runs = json.loads(result)
|
||||
if isinstance(runs, list) and runs:
|
||||
print("\n📊 Active Runs:")
|
||||
for run in runs:
|
||||
print(f" - {run['run_id']}: {run['environment']} ({run['status']})")
|
||||
else:
|
||||
print("\nNo active runs.")
|
||||
continue
|
||||
|
||||
# Run the agent
|
||||
print("\n" + "=" * 60)
|
||||
response = agent.run_conversation(user_input)
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 Interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
if verbose:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
else:
|
||||
# Single task mode
|
||||
print(f"\n📝 Task: {task}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
response = agent.run_conversation(task)
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ Task completed")
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ Interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
if verbose:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
1781
run_agent.py
1781
run_agent.py
File diff suppressed because it is too large
Load Diff
@@ -1,12 +0,0 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="hermes-agent-imagen-data/hermes_agent_imagen_eval.jsonl" \
|
||||
--batch_size=10 \
|
||||
--run_name="imagen_eval_gpt5" \
|
||||
--distribution="image_gen" \
|
||||
--model="gpt-5" \
|
||||
--base_url="https://api.openai.com/v1" \
|
||||
--api_key="${OPENAI_API_KEY}" \
|
||||
--num_workers=4 \
|
||||
--max_turns=5 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="When generating an image for the user view the image by using the vision_analyze tool to ensure it is what the user wanted. If it isn't feel free to retry a few times. If none are perfect, choose the best option that is the closest match, and explain its imperfections. If the image generation tool fails, try again a few times. If the vision analyze tool fails, provide the image to the user and explain it is your best effort attempt."
|
||||
@@ -1,12 +0,0 @@
|
||||
python batch_runner.py \
|
||||
--dataset_file="hermes-agent-megascience-data/hermes_agent_megascience_eval.jsonl" \
|
||||
--batch_size=10 \
|
||||
--run_name="megascience_eval_glm4-6-fixedterminal-2" \
|
||||
--distribution="science" \
|
||||
--model="z-ai/glm-4.6" \
|
||||
--base_url="https://openrouter.ai/api/v1" \
|
||||
--api_key="${OPENROUTER_API_KEY}" \
|
||||
--num_workers=5 \
|
||||
--max_turns=30 \
|
||||
--verbose \
|
||||
--ephemeral_system_prompt="You have access to a variety of tools to help you solve scientific, math, and technology problems presented to you. You can use them in sequence and build off of the results of prior tools you've used results. Always use a tool if it can provide additional context, verify formulas, double check concepts and recent studies and understanding, doing all calculations, etc. You should only be confident in your own reasoning, knowledge, or calculations if you've exhaustively used all tools available to you to that can help you verify or validate your work. Always pip install any packages you need to use the python scripts you want to run."
|
||||
414
scripts/hermes-gateway
Executable file
414
scripts/hermes-gateway
Executable file
@@ -0,0 +1,414 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Hermes Gateway - Standalone messaging platform integration.
|
||||
|
||||
This is the proper entry point for running the gateway as a service.
|
||||
NOT tied to the CLI - runs independently.
|
||||
|
||||
Usage:
|
||||
# Run in foreground (for testing)
|
||||
./scripts/hermes-gateway
|
||||
|
||||
# Install as systemd service
|
||||
./scripts/hermes-gateway install
|
||||
|
||||
# Manage the service
|
||||
./scripts/hermes-gateway start
|
||||
./scripts/hermes-gateway stop
|
||||
./scripts/hermes-gateway restart
|
||||
./scripts/hermes-gateway status
|
||||
|
||||
# Uninstall
|
||||
./scripts/hermes-gateway uninstall
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path
|
||||
SCRIPT_DIR = Path(__file__).parent.resolve()
|
||||
PROJECT_DIR = SCRIPT_DIR.parent
|
||||
sys.path.insert(0, str(PROJECT_DIR))
|
||||
|
||||
# Load .env file
|
||||
from dotenv import load_dotenv
|
||||
env_path = PROJECT_DIR / '.env'
|
||||
if env_path.exists():
|
||||
load_dotenv(dotenv_path=env_path)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Service Configuration
|
||||
# =============================================================================
|
||||
|
||||
SERVICE_NAME = "hermes-gateway"
|
||||
SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
|
||||
|
||||
def get_systemd_unit_path() -> Path:
|
||||
"""Get the path for the systemd user service file."""
|
||||
return Path.home() / ".config" / "systemd" / "user" / f"{SERVICE_NAME}.service"
|
||||
|
||||
def get_launchd_plist_path() -> Path:
|
||||
"""Get the path for the launchd plist file (macOS)."""
|
||||
return Path.home() / "Library" / "LaunchAgents" / f"ai.hermes.gateway.plist"
|
||||
|
||||
def get_python_path() -> str:
|
||||
"""Get the path to the Python interpreter."""
|
||||
# Prefer the venv if it exists
|
||||
venv_python = PROJECT_DIR / "venv" / "bin" / "python"
|
||||
if venv_python.exists():
|
||||
return str(venv_python)
|
||||
return sys.executable
|
||||
|
||||
def get_gateway_script_path() -> str:
|
||||
"""Get the path to this script."""
|
||||
return str(Path(__file__).resolve())
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Systemd Service (Linux)
|
||||
# =============================================================================
|
||||
|
||||
def generate_systemd_unit() -> str:
|
||||
"""Generate the systemd unit file content."""
|
||||
python_path = get_python_path()
|
||||
script_path = get_gateway_script_path()
|
||||
working_dir = str(PROJECT_DIR)
|
||||
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={python_path} {script_path} run
|
||||
WorkingDirectory={working_dir}
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
# Environment (optional - can also use .env file)
|
||||
# Environment="TELEGRAM_BOT_TOKEN=your_token"
|
||||
# Environment="DISCORD_BOT_TOKEN=your_token"
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
"""
|
||||
|
||||
def install_systemd():
|
||||
"""Install the systemd user service."""
|
||||
unit_path = get_systemd_unit_path()
|
||||
unit_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Installing systemd service to: {unit_path}")
|
||||
unit_path.write_text(generate_systemd_unit())
|
||||
|
||||
# Reload systemd
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
|
||||
# Enable the service (start on boot)
|
||||
subprocess.run(["systemctl", "--user", "enable", SERVICE_NAME], check=True)
|
||||
|
||||
print(f"✓ Service installed and enabled")
|
||||
print(f"")
|
||||
print(f"To start the service:")
|
||||
print(f" systemctl --user start {SERVICE_NAME}")
|
||||
print(f"")
|
||||
print(f"To view logs:")
|
||||
print(f" journalctl --user -u {SERVICE_NAME} -f")
|
||||
print(f"")
|
||||
print(f"To enable lingering (keeps service running after logout):")
|
||||
print(f" sudo loginctl enable-linger $USER")
|
||||
|
||||
def uninstall_systemd():
|
||||
"""Uninstall the systemd user service."""
|
||||
unit_path = get_systemd_unit_path()
|
||||
|
||||
# Stop and disable first
|
||||
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=False)
|
||||
subprocess.run(["systemctl", "--user", "disable", SERVICE_NAME], check=False)
|
||||
|
||||
# Remove the unit file
|
||||
if unit_path.exists():
|
||||
unit_path.unlink()
|
||||
print(f"✓ Removed {unit_path}")
|
||||
|
||||
# Reload systemd
|
||||
subprocess.run(["systemctl", "--user", "daemon-reload"], check=True)
|
||||
print(f"✓ Service uninstalled")
|
||||
|
||||
def systemd_status():
|
||||
"""Show systemd service status."""
|
||||
subprocess.run(["systemctl", "--user", "status", SERVICE_NAME])
|
||||
|
||||
def systemd_start():
|
||||
"""Start the systemd service."""
|
||||
subprocess.run(["systemctl", "--user", "start", SERVICE_NAME], check=True)
|
||||
print(f"✓ Service started")
|
||||
|
||||
def systemd_stop():
|
||||
"""Stop the systemd service."""
|
||||
subprocess.run(["systemctl", "--user", "stop", SERVICE_NAME], check=True)
|
||||
print(f"✓ Service stopped")
|
||||
|
||||
def systemd_restart():
|
||||
"""Restart the systemd service."""
|
||||
subprocess.run(["systemctl", "--user", "restart", SERVICE_NAME], check=True)
|
||||
print(f"✓ Service restarted")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Launchd Service (macOS)
|
||||
# =============================================================================
|
||||
|
||||
def generate_launchd_plist() -> str:
|
||||
"""Generate the launchd plist file content."""
|
||||
python_path = get_python_path()
|
||||
script_path = get_gateway_script_path()
|
||||
working_dir = str(PROJECT_DIR)
|
||||
log_dir = Path.home() / ".hermes" / "logs"
|
||||
|
||||
return f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>ai.hermes.gateway</string>
|
||||
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>{python_path}</string>
|
||||
<string>{script_path}</string>
|
||||
<string>run</string>
|
||||
</array>
|
||||
|
||||
<key>WorkingDirectory</key>
|
||||
<string>{working_dir}</string>
|
||||
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
|
||||
<key>KeepAlive</key>
|
||||
<dict>
|
||||
<key>SuccessfulExit</key>
|
||||
<false/>
|
||||
</dict>
|
||||
|
||||
<key>StandardOutPath</key>
|
||||
<string>{log_dir}/gateway.log</string>
|
||||
|
||||
<key>StandardErrorPath</key>
|
||||
<string>{log_dir}/gateway.error.log</string>
|
||||
|
||||
<key>EnvironmentVariables</key>
|
||||
<dict>
|
||||
<key>PATH</key>
|
||||
<string>/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin</string>
|
||||
</dict>
|
||||
</dict>
|
||||
</plist>
|
||||
"""
|
||||
|
||||
def install_launchd():
|
||||
"""Install the launchd service (macOS)."""
|
||||
plist_path = get_launchd_plist_path()
|
||||
plist_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Ensure log directory exists
|
||||
log_dir = Path.home() / ".hermes" / "logs"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Installing launchd service to: {plist_path}")
|
||||
plist_path.write_text(generate_launchd_plist())
|
||||
|
||||
# Load the service
|
||||
subprocess.run(["launchctl", "load", str(plist_path)], check=True)
|
||||
|
||||
print(f"✓ Service installed and loaded")
|
||||
print(f"")
|
||||
print(f"To view logs:")
|
||||
print(f" tail -f ~/.hermes/logs/gateway.log")
|
||||
print(f"")
|
||||
print(f"To manage the service:")
|
||||
print(f" launchctl start ai.hermes.gateway")
|
||||
print(f" launchctl stop ai.hermes.gateway")
|
||||
|
||||
def uninstall_launchd():
|
||||
"""Uninstall the launchd service (macOS)."""
|
||||
plist_path = get_launchd_plist_path()
|
||||
|
||||
# Unload first
|
||||
subprocess.run(["launchctl", "unload", str(plist_path)], check=False)
|
||||
|
||||
# Remove the plist file
|
||||
if plist_path.exists():
|
||||
plist_path.unlink()
|
||||
print(f"✓ Removed {plist_path}")
|
||||
|
||||
print(f"✓ Service uninstalled")
|
||||
|
||||
def launchd_status():
|
||||
"""Show launchd service status."""
|
||||
subprocess.run(["launchctl", "list", "ai.hermes.gateway"])
|
||||
|
||||
def launchd_start():
|
||||
"""Start the launchd service."""
|
||||
subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True)
|
||||
print(f"✓ Service started")
|
||||
|
||||
def launchd_stop():
|
||||
"""Stop the launchd service."""
|
||||
subprocess.run(["launchctl", "stop", "ai.hermes.gateway"], check=True)
|
||||
print(f"✓ Service stopped")
|
||||
|
||||
def launchd_restart():
|
||||
"""Restart the launchd service."""
|
||||
launchd_stop()
|
||||
launchd_start()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Platform Detection
|
||||
# =============================================================================
|
||||
|
||||
def is_linux() -> bool:
|
||||
return sys.platform.startswith('linux')
|
||||
|
||||
def is_macos() -> bool:
|
||||
return sys.platform == 'darwin'
|
||||
|
||||
def is_windows() -> bool:
|
||||
return sys.platform == 'win32'
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Gateway Runner
|
||||
# =============================================================================
|
||||
|
||||
def run_gateway():
|
||||
"""Run the gateway in foreground."""
|
||||
from gateway.run import start_gateway
|
||||
print("Starting Hermes Gateway...")
|
||||
print("Press Ctrl+C to stop.")
|
||||
print()
|
||||
asyncio.run(start_gateway())
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main CLI
|
||||
# =============================================================================
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Hermes Gateway - Messaging Platform Integration",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Run in foreground (for testing)
|
||||
./scripts/hermes-gateway run
|
||||
|
||||
# Install as system service
|
||||
./scripts/hermes-gateway install
|
||||
|
||||
# Manage the service
|
||||
./scripts/hermes-gateway start
|
||||
./scripts/hermes-gateway stop
|
||||
./scripts/hermes-gateway restart
|
||||
./scripts/hermes-gateway status
|
||||
|
||||
# Uninstall
|
||||
./scripts/hermes-gateway uninstall
|
||||
|
||||
Configuration:
|
||||
Set environment variables in .env file or system environment:
|
||||
- TELEGRAM_BOT_TOKEN
|
||||
- DISCORD_BOT_TOKEN
|
||||
- WHATSAPP_ENABLED
|
||||
|
||||
Or create ~/.hermes/gateway.json for advanced configuration.
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"command",
|
||||
choices=["run", "install", "uninstall", "start", "stop", "restart", "status"],
|
||||
nargs="?",
|
||||
default="run",
|
||||
help="Command to execute (default: run)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose", "-v",
|
||||
action="store_true",
|
||||
help="Verbose output"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Detect platform and dispatch command
|
||||
if args.command == "run":
|
||||
run_gateway()
|
||||
|
||||
elif args.command == "install":
|
||||
if is_linux():
|
||||
install_systemd()
|
||||
elif is_macos():
|
||||
install_launchd()
|
||||
else:
|
||||
print("Service installation not supported on this platform.")
|
||||
print("Please run manually: ./scripts/hermes-gateway run")
|
||||
sys.exit(1)
|
||||
|
||||
elif args.command == "uninstall":
|
||||
if is_linux():
|
||||
uninstall_systemd()
|
||||
elif is_macos():
|
||||
uninstall_launchd()
|
||||
else:
|
||||
print("Service uninstallation not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
elif args.command == "start":
|
||||
if is_linux():
|
||||
systemd_start()
|
||||
elif is_macos():
|
||||
launchd_start()
|
||||
else:
|
||||
print("Not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
elif args.command == "stop":
|
||||
if is_linux():
|
||||
systemd_stop()
|
||||
elif is_macos():
|
||||
launchd_stop()
|
||||
else:
|
||||
print("Not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
elif args.command == "restart":
|
||||
if is_linux():
|
||||
systemd_restart()
|
||||
elif is_macos():
|
||||
launchd_restart()
|
||||
else:
|
||||
print("Not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
elif args.command == "status":
|
||||
if is_linux():
|
||||
systemd_status()
|
||||
elif is_macos():
|
||||
launchd_status()
|
||||
else:
|
||||
print("Not supported on this platform.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
519
scripts/install.ps1
Normal file
519
scripts/install.ps1
Normal file
@@ -0,0 +1,519 @@
|
||||
# ============================================================================
|
||||
# Hermes Agent Installer for Windows
|
||||
# ============================================================================
|
||||
# Installation script for Windows (PowerShell).
|
||||
#
|
||||
# Usage:
|
||||
# irm https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.ps1 | iex
|
||||
#
|
||||
# Or download and run with options:
|
||||
# .\install.ps1 -NoVenv -SkipSetup
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
param(
|
||||
[switch]$NoVenv,
|
||||
[switch]$SkipSetup,
|
||||
[string]$Branch = "main",
|
||||
[string]$HermesHome = "$env:USERPROFILE\.hermes",
|
||||
[string]$InstallDir = "$env:USERPROFILE\.hermes\hermes-agent"
|
||||
)
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
# ============================================================================
|
||||
# Configuration
|
||||
# ============================================================================
|
||||
|
||||
$RepoUrlSsh = "git@github.com:NousResearch/hermes-agent.git"
|
||||
$RepoUrlHttps = "https://github.com/NousResearch/hermes-agent.git"
|
||||
|
||||
# ============================================================================
|
||||
# Helper functions
|
||||
# ============================================================================
|
||||
|
||||
function Write-Banner {
|
||||
Write-Host ""
|
||||
Write-Host "┌─────────────────────────────────────────────────────────┐" -ForegroundColor Magenta
|
||||
Write-Host "│ 🦋 Hermes Agent Installer │" -ForegroundColor Magenta
|
||||
Write-Host "├─────────────────────────────────────────────────────────┤" -ForegroundColor Magenta
|
||||
Write-Host "│ I'm just a butterfly with a lot of tools. │" -ForegroundColor Magenta
|
||||
Write-Host "└─────────────────────────────────────────────────────────┘" -ForegroundColor Magenta
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
function Write-Info {
|
||||
param([string]$Message)
|
||||
Write-Host "→ $Message" -ForegroundColor Cyan
|
||||
}
|
||||
|
||||
function Write-Success {
|
||||
param([string]$Message)
|
||||
Write-Host "✓ $Message" -ForegroundColor Green
|
||||
}
|
||||
|
||||
function Write-Warning {
|
||||
param([string]$Message)
|
||||
Write-Host "⚠ $Message" -ForegroundColor Yellow
|
||||
}
|
||||
|
||||
function Write-Error {
|
||||
param([string]$Message)
|
||||
Write-Host "✗ $Message" -ForegroundColor Red
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Dependency checks
|
||||
# ============================================================================
|
||||
|
||||
function Test-Python {
|
||||
Write-Info "Checking Python..."
|
||||
|
||||
# Try different python commands
|
||||
$pythonCmds = @("python3", "python", "py -3")
|
||||
|
||||
foreach ($cmd in $pythonCmds) {
|
||||
try {
|
||||
$version = & $cmd.Split()[0] $cmd.Split()[1..99] -c "import sys; print(f'{sys.version_info.major}.{sys.version_info.minor}')" 2>$null
|
||||
if ($version) {
|
||||
$major, $minor = $version.Split('.')
|
||||
if ([int]$major -ge 3 -and [int]$minor -ge 10) {
|
||||
$script:PythonCmd = $cmd
|
||||
Write-Success "Python $version found"
|
||||
return $true
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
# Try next command
|
||||
}
|
||||
}
|
||||
|
||||
Write-Error "Python 3.10+ not found"
|
||||
Write-Info "Please install Python 3.10 or newer from:"
|
||||
Write-Info " https://www.python.org/downloads/"
|
||||
Write-Info ""
|
||||
Write-Info "Make sure to check 'Add Python to PATH' during installation"
|
||||
return $false
|
||||
}
|
||||
|
||||
function Test-Git {
|
||||
Write-Info "Checking Git..."
|
||||
|
||||
if (Get-Command git -ErrorAction SilentlyContinue) {
|
||||
$version = git --version
|
||||
Write-Success "Git found ($version)"
|
||||
return $true
|
||||
}
|
||||
|
||||
Write-Error "Git not found"
|
||||
Write-Info "Please install Git from:"
|
||||
Write-Info " https://git-scm.com/download/win"
|
||||
return $false
|
||||
}
|
||||
|
||||
function Test-Node {
|
||||
Write-Info "Checking Node.js (optional, for browser tools)..."
|
||||
|
||||
if (Get-Command node -ErrorAction SilentlyContinue) {
|
||||
$version = node --version
|
||||
Write-Success "Node.js $version found"
|
||||
$script:HasNode = $true
|
||||
return $true
|
||||
}
|
||||
|
||||
Write-Warning "Node.js not found (browser tools will be limited)"
|
||||
Write-Info "To install Node.js (optional):"
|
||||
Write-Info " https://nodejs.org/en/download/"
|
||||
$script:HasNode = $false
|
||||
return $true # Don't fail - Node is optional
|
||||
}
|
||||
|
||||
function Test-Ripgrep {
|
||||
Write-Info "Checking ripgrep (optional, for faster file search)..."
|
||||
|
||||
if (Get-Command rg -ErrorAction SilentlyContinue) {
|
||||
$version = rg --version | Select-Object -First 1
|
||||
Write-Success "$version found"
|
||||
$script:HasRipgrep = $true
|
||||
return $true
|
||||
}
|
||||
|
||||
Write-Warning "ripgrep not found (file search will use findstr fallback)"
|
||||
|
||||
# Check what package managers are available
|
||||
$hasWinget = Get-Command winget -ErrorAction SilentlyContinue
|
||||
$hasChoco = Get-Command choco -ErrorAction SilentlyContinue
|
||||
$hasScoop = Get-Command scoop -ErrorAction SilentlyContinue
|
||||
|
||||
# Offer to install
|
||||
Write-Host ""
|
||||
$response = Read-Host "Would you like to install ripgrep? (faster search, recommended) [Y/n]"
|
||||
|
||||
if ($response -eq "" -or $response -match "^[Yy]") {
|
||||
Write-Info "Installing ripgrep..."
|
||||
|
||||
if ($hasWinget) {
|
||||
try {
|
||||
winget install BurntSushi.ripgrep.MSVC --silent 2>&1 | Out-Null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "ripgrep installed via winget"
|
||||
$script:HasRipgrep = $true
|
||||
return $true
|
||||
}
|
||||
} catch { }
|
||||
}
|
||||
|
||||
if ($hasChoco) {
|
||||
try {
|
||||
choco install ripgrep -y 2>&1 | Out-Null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "ripgrep installed via chocolatey"
|
||||
$script:HasRipgrep = $true
|
||||
return $true
|
||||
}
|
||||
} catch { }
|
||||
}
|
||||
|
||||
if ($hasScoop) {
|
||||
try {
|
||||
scoop install ripgrep 2>&1 | Out-Null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "ripgrep installed via scoop"
|
||||
$script:HasRipgrep = $true
|
||||
return $true
|
||||
}
|
||||
} catch { }
|
||||
}
|
||||
|
||||
Write-Warning "Auto-install failed. You can install manually:"
|
||||
} else {
|
||||
Write-Info "Skipping ripgrep installation. To install manually:"
|
||||
}
|
||||
|
||||
# Show manual install instructions
|
||||
Write-Info " winget install BurntSushi.ripgrep.MSVC"
|
||||
Write-Info " Or: choco install ripgrep"
|
||||
Write-Info " Or: scoop install ripgrep"
|
||||
Write-Info " Or download from: https://github.com/BurntSushi/ripgrep/releases"
|
||||
|
||||
$script:HasRipgrep = $false
|
||||
return $true # Don't fail - ripgrep is optional
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Installation
|
||||
# ============================================================================
|
||||
|
||||
function Install-Repository {
|
||||
Write-Info "Installing to $InstallDir..."
|
||||
|
||||
if (Test-Path $InstallDir) {
|
||||
if (Test-Path "$InstallDir\.git") {
|
||||
Write-Info "Existing installation found, updating..."
|
||||
Push-Location $InstallDir
|
||||
git fetch origin
|
||||
git checkout $Branch
|
||||
git pull origin $Branch
|
||||
Pop-Location
|
||||
} else {
|
||||
Write-Error "Directory exists but is not a git repository: $InstallDir"
|
||||
Write-Info "Remove it or choose a different directory with -InstallDir"
|
||||
exit 1
|
||||
}
|
||||
} else {
|
||||
# Try SSH first (for private repo access), fall back to HTTPS
|
||||
# Use --recurse-submodules to also clone mini-swe-agent and tinker-atropos
|
||||
Write-Info "Trying SSH clone..."
|
||||
$sshResult = git clone --branch $Branch --recurse-submodules $RepoUrlSsh $InstallDir 2>&1
|
||||
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "Cloned via SSH"
|
||||
} else {
|
||||
Write-Info "SSH failed, trying HTTPS..."
|
||||
$httpsResult = git clone --branch $Branch --recurse-submodules $RepoUrlHttps $InstallDir 2>&1
|
||||
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "Cloned via HTTPS"
|
||||
} else {
|
||||
Write-Error "Failed to clone repository"
|
||||
Write-Info "For private repo access, ensure your SSH key is added to GitHub:"
|
||||
Write-Info " ssh-add ~/.ssh/id_rsa"
|
||||
Write-Info " ssh -T git@github.com # Test connection"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Ensure submodules are initialized and updated (for existing installs or if --recurse failed)
|
||||
Write-Info "Initializing submodules (mini-swe-agent, tinker-atropos)..."
|
||||
Push-Location $InstallDir
|
||||
git submodule update --init --recursive
|
||||
Pop-Location
|
||||
Write-Success "Submodules ready"
|
||||
|
||||
Write-Success "Repository ready"
|
||||
}
|
||||
|
||||
function Install-Venv {
|
||||
if ($NoVenv) {
|
||||
Write-Info "Skipping virtual environment (-NoVenv)"
|
||||
return
|
||||
}
|
||||
|
||||
Write-Info "Creating virtual environment..."
|
||||
|
||||
Push-Location $InstallDir
|
||||
|
||||
if (-not (Test-Path "venv")) {
|
||||
& $PythonCmd -m venv venv
|
||||
}
|
||||
|
||||
# Activate
|
||||
& .\venv\Scripts\Activate.ps1
|
||||
|
||||
# Upgrade pip
|
||||
pip install --upgrade pip wheel setuptools | Out-Null
|
||||
|
||||
Pop-Location
|
||||
|
||||
Write-Success "Virtual environment ready"
|
||||
}
|
||||
|
||||
function Install-Dependencies {
|
||||
Write-Info "Installing dependencies..."
|
||||
|
||||
Push-Location $InstallDir
|
||||
|
||||
if (-not $NoVenv) {
|
||||
& .\venv\Scripts\Activate.ps1
|
||||
}
|
||||
|
||||
# Install main package
|
||||
try {
|
||||
pip install -e ".[all]" 2>&1 | Out-Null
|
||||
} catch {
|
||||
pip install -e "." | Out-Null
|
||||
}
|
||||
|
||||
Write-Success "Main package installed"
|
||||
|
||||
# Install submodules
|
||||
Write-Info "Installing mini-swe-agent (terminal tool backend)..."
|
||||
if (Test-Path "mini-swe-agent\pyproject.toml") {
|
||||
try {
|
||||
pip install -e ".\mini-swe-agent" 2>&1 | Out-Null
|
||||
Write-Success "mini-swe-agent installed"
|
||||
} catch {
|
||||
Write-Warning "mini-swe-agent install failed (terminal tools may not work)"
|
||||
}
|
||||
} else {
|
||||
Write-Warning "mini-swe-agent not found (run: git submodule update --init)"
|
||||
}
|
||||
|
||||
Write-Info "Installing tinker-atropos (RL training backend)..."
|
||||
if (Test-Path "tinker-atropos\pyproject.toml") {
|
||||
try {
|
||||
pip install -e ".\tinker-atropos" 2>&1 | Out-Null
|
||||
Write-Success "tinker-atropos installed"
|
||||
} catch {
|
||||
Write-Warning "tinker-atropos install failed (RL tools may not work)"
|
||||
}
|
||||
} else {
|
||||
Write-Warning "tinker-atropos not found (run: git submodule update --init)"
|
||||
}
|
||||
|
||||
Pop-Location
|
||||
|
||||
Write-Success "All dependencies installed"
|
||||
}
|
||||
|
||||
function Set-PathVariable {
|
||||
Write-Info "Setting up PATH..."
|
||||
|
||||
if ($NoVenv) {
|
||||
$binDir = "$InstallDir"
|
||||
} else {
|
||||
$binDir = "$InstallDir\venv\Scripts"
|
||||
}
|
||||
|
||||
# Add to user PATH
|
||||
$currentPath = [Environment]::GetEnvironmentVariable("Path", "User")
|
||||
|
||||
if ($currentPath -notlike "*$binDir*") {
|
||||
[Environment]::SetEnvironmentVariable(
|
||||
"Path",
|
||||
"$binDir;$currentPath",
|
||||
"User"
|
||||
)
|
||||
Write-Success "Added to user PATH"
|
||||
} else {
|
||||
Write-Info "PATH already configured"
|
||||
}
|
||||
|
||||
# Update current session
|
||||
$env:Path = "$binDir;$env:Path"
|
||||
}
|
||||
|
||||
function Copy-ConfigTemplates {
|
||||
Write-Info "Setting up configuration files..."
|
||||
|
||||
# Create ~/.hermes directory structure (config at top level, code in subdir)
|
||||
New-Item -ItemType Directory -Force -Path "$HermesHome\cron" | Out-Null
|
||||
New-Item -ItemType Directory -Force -Path "$HermesHome\sessions" | Out-Null
|
||||
New-Item -ItemType Directory -Force -Path "$HermesHome\logs" | Out-Null
|
||||
|
||||
# Create .env at ~/.hermes/.env (top level, easy to find)
|
||||
$envPath = "$HermesHome\.env"
|
||||
if (-not (Test-Path $envPath)) {
|
||||
$examplePath = "$InstallDir\.env.example"
|
||||
if (Test-Path $examplePath) {
|
||||
Copy-Item $examplePath $envPath
|
||||
Write-Success "Created ~/.hermes/.env from template"
|
||||
} else {
|
||||
# Create empty .env if no example exists
|
||||
New-Item -ItemType File -Force -Path $envPath | Out-Null
|
||||
Write-Success "Created ~/.hermes/.env"
|
||||
}
|
||||
} else {
|
||||
Write-Info "~/.hermes/.env already exists, keeping it"
|
||||
}
|
||||
|
||||
# Create config.yaml at ~/.hermes/config.yaml (top level, easy to find)
|
||||
$configPath = "$HermesHome\config.yaml"
|
||||
if (-not (Test-Path $configPath)) {
|
||||
$examplePath = "$InstallDir\cli-config.yaml.example"
|
||||
if (Test-Path $examplePath) {
|
||||
Copy-Item $examplePath $configPath
|
||||
Write-Success "Created ~/.hermes/config.yaml from template"
|
||||
}
|
||||
} else {
|
||||
Write-Info "~/.hermes/config.yaml already exists, keeping it"
|
||||
}
|
||||
|
||||
Write-Success "Configuration directory ready: ~/.hermes/"
|
||||
}
|
||||
|
||||
function Install-NodeDeps {
|
||||
if (-not $HasNode) {
|
||||
Write-Info "Skipping Node.js dependencies (Node not installed)"
|
||||
return
|
||||
}
|
||||
|
||||
Push-Location $InstallDir
|
||||
|
||||
if (Test-Path "package.json") {
|
||||
Write-Info "Installing Node.js dependencies..."
|
||||
try {
|
||||
npm install --silent 2>&1 | Out-Null
|
||||
Write-Success "Node.js dependencies installed"
|
||||
} catch {
|
||||
Write-Warning "npm install failed (browser tools may not work)"
|
||||
}
|
||||
}
|
||||
|
||||
Pop-Location
|
||||
}
|
||||
|
||||
function Invoke-SetupWizard {
|
||||
if ($SkipSetup) {
|
||||
Write-Info "Skipping setup wizard (-SkipSetup)"
|
||||
return
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
Write-Info "Starting setup wizard..."
|
||||
Write-Host ""
|
||||
|
||||
Push-Location $InstallDir
|
||||
|
||||
if (-not $NoVenv) {
|
||||
& .\venv\Scripts\Activate.ps1
|
||||
}
|
||||
|
||||
python -m hermes_cli.main setup
|
||||
|
||||
Pop-Location
|
||||
}
|
||||
|
||||
function Write-Completion {
|
||||
Write-Host ""
|
||||
Write-Host "┌─────────────────────────────────────────────────────────┐" -ForegroundColor Green
|
||||
Write-Host "│ ✓ Installation Complete! │" -ForegroundColor Green
|
||||
Write-Host "└─────────────────────────────────────────────────────────┘" -ForegroundColor Green
|
||||
Write-Host ""
|
||||
|
||||
# Show file locations
|
||||
Write-Host "📁 Your files (all in ~/.hermes/):" -ForegroundColor Cyan
|
||||
Write-Host ""
|
||||
Write-Host " Config: " -NoNewline -ForegroundColor Yellow
|
||||
Write-Host "$HermesHome\config.yaml"
|
||||
Write-Host " API Keys: " -NoNewline -ForegroundColor Yellow
|
||||
Write-Host "$HermesHome\.env"
|
||||
Write-Host " Data: " -NoNewline -ForegroundColor Yellow
|
||||
Write-Host "$HermesHome\cron\, sessions\, logs\"
|
||||
Write-Host " Code: " -NoNewline -ForegroundColor Yellow
|
||||
Write-Host "$HermesHome\hermes-agent\"
|
||||
Write-Host ""
|
||||
|
||||
Write-Host "─────────────────────────────────────────────────────────" -ForegroundColor Cyan
|
||||
Write-Host ""
|
||||
Write-Host "🚀 Commands:" -ForegroundColor Cyan
|
||||
Write-Host ""
|
||||
Write-Host " hermes " -NoNewline -ForegroundColor Green
|
||||
Write-Host "Start chatting"
|
||||
Write-Host " hermes setup " -NoNewline -ForegroundColor Green
|
||||
Write-Host "Configure API keys & settings"
|
||||
Write-Host " hermes config " -NoNewline -ForegroundColor Green
|
||||
Write-Host "View/edit configuration"
|
||||
Write-Host " hermes config edit " -NoNewline -ForegroundColor Green
|
||||
Write-Host "Open config in editor"
|
||||
Write-Host " hermes gateway " -NoNewline -ForegroundColor Green
|
||||
Write-Host "Run messaging gateway"
|
||||
Write-Host " hermes update " -NoNewline -ForegroundColor Green
|
||||
Write-Host "Update to latest version"
|
||||
Write-Host ""
|
||||
|
||||
Write-Host "─────────────────────────────────────────────────────────" -ForegroundColor Cyan
|
||||
Write-Host ""
|
||||
Write-Host "⚡ Restart your terminal for PATH changes to take effect" -ForegroundColor Yellow
|
||||
Write-Host ""
|
||||
|
||||
# Show notes about optional tools
|
||||
if (-not $HasNode) {
|
||||
Write-Host "Note: Node.js was not found. Browser automation tools" -ForegroundColor Yellow
|
||||
Write-Host "will have limited functionality." -ForegroundColor Yellow
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
if (-not $HasRipgrep) {
|
||||
Write-Host "Note: ripgrep (rg) was not found. File search will use" -ForegroundColor Yellow
|
||||
Write-Host "findstr as a fallback. For faster search:" -ForegroundColor Yellow
|
||||
Write-Host " winget install BurntSushi.ripgrep.MSVC" -ForegroundColor Yellow
|
||||
Write-Host ""
|
||||
}
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Main
|
||||
# ============================================================================
|
||||
|
||||
function Main {
|
||||
Write-Banner
|
||||
|
||||
if (-not (Test-Python)) { exit 1 }
|
||||
if (-not (Test-Git)) { exit 1 }
|
||||
Test-Node # Optional, doesn't fail
|
||||
Test-Ripgrep # Optional, doesn't fail
|
||||
|
||||
Install-Repository
|
||||
Install-Venv
|
||||
Install-Dependencies
|
||||
Install-NodeDeps
|
||||
Set-PathVariable
|
||||
Copy-ConfigTemplates
|
||||
Invoke-SetupWizard
|
||||
|
||||
Write-Completion
|
||||
}
|
||||
|
||||
Main
|
||||
692
scripts/install.sh
Executable file
692
scripts/install.sh
Executable file
@@ -0,0 +1,692 @@
|
||||
#!/bin/bash
|
||||
# ============================================================================
|
||||
# Hermes Agent Installer
|
||||
# ============================================================================
|
||||
# Installation script for Linux and macOS.
|
||||
#
|
||||
# Usage:
|
||||
# curl -fsSL https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.sh | bash
|
||||
#
|
||||
# Or with options:
|
||||
# curl -fsSL ... | bash -s -- --no-venv --skip-setup
|
||||
#
|
||||
# ============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
# Colors
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[0;33m'
|
||||
BLUE='\033[0;34m'
|
||||
MAGENTA='\033[0;35m'
|
||||
CYAN='\033[0;36m'
|
||||
NC='\033[0m' # No Color
|
||||
BOLD='\033[1m'
|
||||
|
||||
# Configuration
|
||||
REPO_URL_SSH="git@github.com:NousResearch/hermes-agent.git"
|
||||
REPO_URL_HTTPS="https://github.com/NousResearch/hermes-agent.git"
|
||||
HERMES_HOME="$HOME/.hermes"
|
||||
INSTALL_DIR="${HERMES_INSTALL_DIR:-$HERMES_HOME/hermes-agent}"
|
||||
PYTHON_MIN_VERSION="3.10"
|
||||
|
||||
# Options
|
||||
USE_VENV=true
|
||||
RUN_SETUP=true
|
||||
BRANCH="main"
|
||||
|
||||
# Parse arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
--no-venv)
|
||||
USE_VENV=false
|
||||
shift
|
||||
;;
|
||||
--skip-setup)
|
||||
RUN_SETUP=false
|
||||
shift
|
||||
;;
|
||||
--branch)
|
||||
BRANCH="$2"
|
||||
shift 2
|
||||
;;
|
||||
--dir)
|
||||
INSTALL_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
echo "Hermes Agent Installer"
|
||||
echo ""
|
||||
echo "Usage: install.sh [OPTIONS]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " --no-venv Don't create virtual environment"
|
||||
echo " --skip-setup Skip interactive setup wizard"
|
||||
echo " --branch NAME Git branch to install (default: main)"
|
||||
echo " --dir PATH Installation directory (default: ~/.hermes-agent)"
|
||||
echo " -h, --help Show this help"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# ============================================================================
|
||||
# Helper functions
|
||||
# ============================================================================
|
||||
|
||||
print_banner() {
|
||||
echo ""
|
||||
echo -e "${MAGENTA}${BOLD}"
|
||||
echo "┌─────────────────────────────────────────────────────────┐"
|
||||
echo "│ 🦋 Hermes Agent Installer │"
|
||||
echo "├─────────────────────────────────────────────────────────┤"
|
||||
echo "│ I'm just a butterfly with a lot of tools. │"
|
||||
echo "└─────────────────────────────────────────────────────────┘"
|
||||
echo -e "${NC}"
|
||||
}
|
||||
|
||||
log_info() {
|
||||
echo -e "${CYAN}→${NC} $1"
|
||||
}
|
||||
|
||||
log_success() {
|
||||
echo -e "${GREEN}✓${NC} $1"
|
||||
}
|
||||
|
||||
log_warn() {
|
||||
echo -e "${YELLOW}⚠${NC} $1"
|
||||
}
|
||||
|
||||
log_error() {
|
||||
echo -e "${RED}✗${NC} $1"
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# System detection
|
||||
# ============================================================================
|
||||
|
||||
detect_os() {
|
||||
case "$(uname -s)" in
|
||||
Linux*)
|
||||
OS="linux"
|
||||
if [ -f /etc/os-release ]; then
|
||||
. /etc/os-release
|
||||
DISTRO="$ID"
|
||||
else
|
||||
DISTRO="unknown"
|
||||
fi
|
||||
;;
|
||||
Darwin*)
|
||||
OS="macos"
|
||||
DISTRO="macos"
|
||||
;;
|
||||
CYGWIN*|MINGW*|MSYS*)
|
||||
OS="windows"
|
||||
DISTRO="windows"
|
||||
log_error "Windows detected. Please use the PowerShell installer:"
|
||||
log_info " irm https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/install.ps1 | iex"
|
||||
exit 1
|
||||
;;
|
||||
*)
|
||||
OS="unknown"
|
||||
DISTRO="unknown"
|
||||
log_warn "Unknown operating system"
|
||||
;;
|
||||
esac
|
||||
|
||||
log_success "Detected: $OS ($DISTRO)"
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Dependency checks
|
||||
# ============================================================================
|
||||
|
||||
check_python() {
|
||||
log_info "Checking Python..."
|
||||
|
||||
# Try different python commands
|
||||
for cmd in python3.12 python3.11 python3.10 python3 python; do
|
||||
if command -v $cmd &> /dev/null; then
|
||||
PYTHON_CMD=$cmd
|
||||
PYTHON_VERSION=$($cmd -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
||||
|
||||
# Check version
|
||||
if python3 -c "import sys; exit(0 if sys.version_info >= (3, 10) else 1)" 2>/dev/null; then
|
||||
log_success "Python $PYTHON_VERSION found"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
log_error "Python 3.10+ not found"
|
||||
log_info "Please install Python 3.10 or newer:"
|
||||
|
||||
case "$OS" in
|
||||
linux)
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian)
|
||||
log_info " sudo apt update && sudo apt install python3.11 python3.11-venv"
|
||||
;;
|
||||
fedora)
|
||||
log_info " sudo dnf install python3.11"
|
||||
;;
|
||||
arch)
|
||||
log_info " sudo pacman -S python"
|
||||
;;
|
||||
*)
|
||||
log_info " Use your package manager to install Python 3.10+"
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
macos)
|
||||
log_info " brew install python@3.11"
|
||||
log_info " Or download from https://www.python.org/downloads/"
|
||||
;;
|
||||
esac
|
||||
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_git() {
|
||||
log_info "Checking Git..."
|
||||
|
||||
if command -v git &> /dev/null; then
|
||||
GIT_VERSION=$(git --version | awk '{print $3}')
|
||||
log_success "Git $GIT_VERSION found"
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_error "Git not found"
|
||||
log_info "Please install Git:"
|
||||
|
||||
case "$OS" in
|
||||
linux)
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian)
|
||||
log_info " sudo apt update && sudo apt install git"
|
||||
;;
|
||||
fedora)
|
||||
log_info " sudo dnf install git"
|
||||
;;
|
||||
arch)
|
||||
log_info " sudo pacman -S git"
|
||||
;;
|
||||
*)
|
||||
log_info " Use your package manager to install git"
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
macos)
|
||||
log_info " xcode-select --install"
|
||||
log_info " Or: brew install git"
|
||||
;;
|
||||
esac
|
||||
|
||||
exit 1
|
||||
}
|
||||
|
||||
check_node() {
|
||||
log_info "Checking Node.js (optional, for browser tools)..."
|
||||
|
||||
if command -v node &> /dev/null; then
|
||||
NODE_VERSION=$(node --version)
|
||||
log_success "Node.js $NODE_VERSION found"
|
||||
HAS_NODE=true
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_warn "Node.js not found (browser tools will be limited)"
|
||||
log_info "To install Node.js (optional):"
|
||||
|
||||
case "$OS" in
|
||||
linux)
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian)
|
||||
log_info " curl -fsSL https://deb.nodesource.com/setup_20.x | sudo -E bash -"
|
||||
log_info " sudo apt install -y nodejs"
|
||||
;;
|
||||
fedora)
|
||||
log_info " sudo dnf install nodejs"
|
||||
;;
|
||||
arch)
|
||||
log_info " sudo pacman -S nodejs npm"
|
||||
;;
|
||||
*)
|
||||
log_info " https://nodejs.org/en/download/"
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
macos)
|
||||
log_info " brew install node"
|
||||
log_info " Or: https://nodejs.org/en/download/"
|
||||
;;
|
||||
esac
|
||||
|
||||
HAS_NODE=false
|
||||
# Don't exit - Node is optional
|
||||
}
|
||||
|
||||
check_ripgrep() {
|
||||
log_info "Checking ripgrep (optional, for faster file search)..."
|
||||
|
||||
if command -v rg &> /dev/null; then
|
||||
RG_VERSION=$(rg --version | head -1)
|
||||
log_success "$RG_VERSION found"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_warn "ripgrep not found (file search will use grep fallback)"
|
||||
|
||||
# Offer to install
|
||||
echo ""
|
||||
read -p "Would you like to install ripgrep? (faster search, recommended) [Y/n] " -n 1 -r
|
||||
echo
|
||||
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
log_info "Installing ripgrep..."
|
||||
|
||||
# Check if we can use sudo
|
||||
CAN_SUDO=false
|
||||
if command -v sudo &> /dev/null; then
|
||||
# Check if user has sudo access (without actually running sudo)
|
||||
if sudo -n true 2>/dev/null || sudo -v 2>/dev/null; then
|
||||
CAN_SUDO=true
|
||||
fi
|
||||
fi
|
||||
|
||||
case "$OS" in
|
||||
linux)
|
||||
if [ "$CAN_SUDO" = true ]; then
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian)
|
||||
if sudo apt install -y ripgrep 2>/dev/null; then
|
||||
log_success "ripgrep installed"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
;;
|
||||
fedora)
|
||||
if sudo dnf install -y ripgrep 2>/dev/null; then
|
||||
log_success "ripgrep installed"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
;;
|
||||
arch)
|
||||
if sudo pacman -S --noconfirm ripgrep 2>/dev/null; then
|
||||
log_success "ripgrep installed"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
else
|
||||
log_warn "sudo not available - cannot auto-install system packages"
|
||||
# Try cargo as fallback if available
|
||||
if command -v cargo &> /dev/null; then
|
||||
log_info "Trying cargo install (no sudo required)..."
|
||||
if cargo install ripgrep 2>/dev/null; then
|
||||
log_success "ripgrep installed via cargo"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
macos)
|
||||
if command -v brew &> /dev/null; then
|
||||
if brew install ripgrep 2>/dev/null; then
|
||||
log_success "ripgrep installed"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
log_warn "Auto-install failed. You can install manually later:"
|
||||
else
|
||||
log_info "Skipping ripgrep installation. To install manually:"
|
||||
fi
|
||||
|
||||
# Show manual install instructions
|
||||
case "$OS" in
|
||||
linux)
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian)
|
||||
log_info " sudo apt install ripgrep"
|
||||
;;
|
||||
fedora)
|
||||
log_info " sudo dnf install ripgrep"
|
||||
;;
|
||||
arch)
|
||||
log_info " sudo pacman -S ripgrep"
|
||||
;;
|
||||
*)
|
||||
log_info " https://github.com/BurntSushi/ripgrep#installation"
|
||||
;;
|
||||
esac
|
||||
# Show cargo alternative for users without sudo
|
||||
if command -v cargo &> /dev/null; then
|
||||
log_info " Or without sudo: cargo install ripgrep"
|
||||
fi
|
||||
;;
|
||||
macos)
|
||||
log_info " brew install ripgrep"
|
||||
;;
|
||||
esac
|
||||
|
||||
HAS_RIPGREP=false
|
||||
# Don't exit - ripgrep is optional (grep fallback exists)
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Installation
|
||||
# ============================================================================
|
||||
|
||||
clone_repo() {
|
||||
log_info "Installing to $INSTALL_DIR..."
|
||||
|
||||
if [ -d "$INSTALL_DIR" ]; then
|
||||
if [ -d "$INSTALL_DIR/.git" ]; then
|
||||
log_info "Existing installation found, updating..."
|
||||
cd "$INSTALL_DIR"
|
||||
git fetch origin
|
||||
git checkout "$BRANCH"
|
||||
git pull origin "$BRANCH"
|
||||
else
|
||||
log_error "Directory exists but is not a git repository: $INSTALL_DIR"
|
||||
log_info "Remove it or choose a different directory with --dir"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
# Try SSH first (for private repo access), fall back to HTTPS
|
||||
# Use --recurse-submodules to also clone mini-swe-agent and tinker-atropos
|
||||
log_info "Trying SSH clone..."
|
||||
if git clone --branch "$BRANCH" --recurse-submodules "$REPO_URL_SSH" "$INSTALL_DIR" 2>/dev/null; then
|
||||
log_success "Cloned via SSH"
|
||||
else
|
||||
log_info "SSH failed, trying HTTPS..."
|
||||
if git clone --branch "$BRANCH" --recurse-submodules "$REPO_URL_HTTPS" "$INSTALL_DIR"; then
|
||||
log_success "Cloned via HTTPS"
|
||||
else
|
||||
log_error "Failed to clone repository"
|
||||
log_info "For private repo access, ensure your SSH key is added to GitHub:"
|
||||
log_info " ssh-add ~/.ssh/id_rsa"
|
||||
log_info " ssh -T git@github.com # Test connection"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
cd "$INSTALL_DIR"
|
||||
|
||||
# Ensure submodules are initialized and updated (for existing installs or if --recurse failed)
|
||||
log_info "Initializing submodules (mini-swe-agent, tinker-atropos)..."
|
||||
git submodule update --init --recursive
|
||||
log_success "Submodules ready"
|
||||
|
||||
log_success "Repository ready"
|
||||
}
|
||||
|
||||
setup_venv() {
|
||||
if [ "$USE_VENV" = false ]; then
|
||||
log_info "Skipping virtual environment (--no-venv)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_info "Creating virtual environment..."
|
||||
|
||||
if [ -d "venv" ]; then
|
||||
log_info "Virtual environment already exists"
|
||||
else
|
||||
$PYTHON_CMD -m venv venv
|
||||
fi
|
||||
|
||||
# Activate
|
||||
source venv/bin/activate
|
||||
|
||||
# Upgrade pip
|
||||
pip install --upgrade pip wheel setuptools > /dev/null
|
||||
|
||||
log_success "Virtual environment ready"
|
||||
}
|
||||
|
||||
install_deps() {
|
||||
log_info "Installing dependencies..."
|
||||
|
||||
if [ "$USE_VENV" = true ]; then
|
||||
source venv/bin/activate
|
||||
fi
|
||||
|
||||
# Install the main package in editable mode with all extras
|
||||
pip install -e ".[all]" > /dev/null 2>&1 || pip install -e "." > /dev/null
|
||||
|
||||
log_success "Main package installed"
|
||||
|
||||
# Install submodules
|
||||
log_info "Installing mini-swe-agent (terminal tool backend)..."
|
||||
if [ -d "mini-swe-agent" ] && [ -f "mini-swe-agent/pyproject.toml" ]; then
|
||||
pip install -e "./mini-swe-agent" > /dev/null 2>&1 || log_warn "mini-swe-agent install failed (terminal tools may not work)"
|
||||
log_success "mini-swe-agent installed"
|
||||
else
|
||||
log_warn "mini-swe-agent not found (run: git submodule update --init)"
|
||||
fi
|
||||
|
||||
log_info "Installing tinker-atropos (RL training backend)..."
|
||||
if [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then
|
||||
pip install -e "./tinker-atropos" > /dev/null 2>&1 || log_warn "tinker-atropos install failed (RL tools may not work)"
|
||||
log_success "tinker-atropos installed"
|
||||
else
|
||||
log_warn "tinker-atropos not found (run: git submodule update --init)"
|
||||
fi
|
||||
|
||||
log_success "All dependencies installed"
|
||||
}
|
||||
|
||||
setup_path() {
|
||||
log_info "Setting up PATH..."
|
||||
|
||||
# Determine the bin directory
|
||||
if [ "$USE_VENV" = true ]; then
|
||||
BIN_DIR="$INSTALL_DIR/venv/bin"
|
||||
else
|
||||
BIN_DIR="$HOME/.local/bin"
|
||||
mkdir -p "$BIN_DIR"
|
||||
|
||||
# Create a wrapper script
|
||||
cat > "$BIN_DIR/hermes" << EOF
|
||||
#!/bin/bash
|
||||
cd "$INSTALL_DIR"
|
||||
exec python -m hermes_cli.main "\$@"
|
||||
EOF
|
||||
chmod +x "$BIN_DIR/hermes"
|
||||
fi
|
||||
|
||||
# Add to PATH in shell config
|
||||
SHELL_CONFIG=""
|
||||
if [ -n "$BASH_VERSION" ]; then
|
||||
if [ -f "$HOME/.bashrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.bashrc"
|
||||
elif [ -f "$HOME/.bash_profile" ]; then
|
||||
SHELL_CONFIG="$HOME/.bash_profile"
|
||||
fi
|
||||
elif [ -n "$ZSH_VERSION" ] || [ -f "$HOME/.zshrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.zshrc"
|
||||
fi
|
||||
|
||||
PATH_LINE="export PATH=\"$BIN_DIR:\$PATH\""
|
||||
|
||||
if [ -n "$SHELL_CONFIG" ]; then
|
||||
if ! grep -q "hermes-agent" "$SHELL_CONFIG" 2>/dev/null; then
|
||||
echo "" >> "$SHELL_CONFIG"
|
||||
echo "# Hermes Agent" >> "$SHELL_CONFIG"
|
||||
echo "$PATH_LINE" >> "$SHELL_CONFIG"
|
||||
log_success "Added to $SHELL_CONFIG"
|
||||
else
|
||||
log_info "PATH already configured in $SHELL_CONFIG"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Also export for current session
|
||||
export PATH="$BIN_DIR:$PATH"
|
||||
|
||||
log_success "PATH configured"
|
||||
}
|
||||
|
||||
copy_config_templates() {
|
||||
log_info "Setting up configuration files..."
|
||||
|
||||
# Create ~/.hermes directory structure (config at top level, code in subdir)
|
||||
mkdir -p "$HERMES_HOME/cron"
|
||||
mkdir -p "$HERMES_HOME/sessions"
|
||||
mkdir -p "$HERMES_HOME/logs"
|
||||
|
||||
# Create .env at ~/.hermes/.env (top level, easy to find)
|
||||
if [ ! -f "$HERMES_HOME/.env" ]; then
|
||||
if [ -f "$INSTALL_DIR/.env.example" ]; then
|
||||
cp "$INSTALL_DIR/.env.example" "$HERMES_HOME/.env"
|
||||
log_success "Created ~/.hermes/.env from template"
|
||||
else
|
||||
# Create empty .env if no example exists
|
||||
touch "$HERMES_HOME/.env"
|
||||
log_success "Created ~/.hermes/.env"
|
||||
fi
|
||||
else
|
||||
log_info "~/.hermes/.env already exists, keeping it"
|
||||
fi
|
||||
|
||||
# Create config.yaml at ~/.hermes/config.yaml (top level, easy to find)
|
||||
if [ ! -f "$HERMES_HOME/config.yaml" ]; then
|
||||
if [ -f "$INSTALL_DIR/cli-config.yaml.example" ]; then
|
||||
cp "$INSTALL_DIR/cli-config.yaml.example" "$HERMES_HOME/config.yaml"
|
||||
log_success "Created ~/.hermes/config.yaml from template"
|
||||
fi
|
||||
else
|
||||
log_info "~/.hermes/config.yaml already exists, keeping it"
|
||||
fi
|
||||
|
||||
log_success "Configuration directory ready: ~/.hermes/"
|
||||
}
|
||||
|
||||
install_node_deps() {
|
||||
if [ "$HAS_NODE" = false ]; then
|
||||
log_info "Skipping Node.js dependencies (Node not installed)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
if [ -f "$INSTALL_DIR/package.json" ]; then
|
||||
log_info "Installing Node.js dependencies..."
|
||||
cd "$INSTALL_DIR"
|
||||
npm install --silent 2>/dev/null || {
|
||||
log_warn "npm install failed (browser tools may not work)"
|
||||
return 0
|
||||
}
|
||||
log_success "Node.js dependencies installed"
|
||||
fi
|
||||
}
|
||||
|
||||
run_setup_wizard() {
|
||||
if [ "$RUN_SETUP" = false ]; then
|
||||
log_info "Skipping setup wizard (--skip-setup)"
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo ""
|
||||
log_info "Starting setup wizard..."
|
||||
echo ""
|
||||
|
||||
if [ "$USE_VENV" = true ]; then
|
||||
source "$INSTALL_DIR/venv/bin/activate"
|
||||
fi
|
||||
|
||||
cd "$INSTALL_DIR"
|
||||
python -m hermes_cli.main setup
|
||||
}
|
||||
|
||||
print_success() {
|
||||
echo ""
|
||||
echo -e "${GREEN}${BOLD}"
|
||||
echo "┌─────────────────────────────────────────────────────────┐"
|
||||
echo "│ ✓ Installation Complete! │"
|
||||
echo "└─────────────────────────────────────────────────────────┘"
|
||||
echo -e "${NC}"
|
||||
echo ""
|
||||
|
||||
# Show file locations
|
||||
echo -e "${CYAN}${BOLD}📁 Your files (all in ~/.hermes/):${NC}"
|
||||
echo ""
|
||||
echo -e " ${YELLOW}Config:${NC} ~/.hermes/config.yaml"
|
||||
echo -e " ${YELLOW}API Keys:${NC} ~/.hermes/.env"
|
||||
echo -e " ${YELLOW}Data:${NC} ~/.hermes/cron/, sessions/, logs/"
|
||||
echo -e " ${YELLOW}Code:${NC} ~/.hermes/hermes-agent/"
|
||||
echo ""
|
||||
|
||||
echo -e "${CYAN}─────────────────────────────────────────────────────────${NC}"
|
||||
echo ""
|
||||
echo -e "${CYAN}${BOLD}🚀 Commands:${NC}"
|
||||
echo ""
|
||||
echo -e " ${GREEN}hermes${NC} Start chatting"
|
||||
echo -e " ${GREEN}hermes setup${NC} Configure API keys & settings"
|
||||
echo -e " ${GREEN}hermes config${NC} View/edit configuration"
|
||||
echo -e " ${GREEN}hermes config edit${NC} Open config in editor"
|
||||
echo -e " ${GREEN}hermes gateway${NC} Run messaging gateway"
|
||||
echo -e " ${GREEN}hermes update${NC} Update to latest version"
|
||||
echo ""
|
||||
|
||||
echo -e "${CYAN}─────────────────────────────────────────────────────────${NC}"
|
||||
echo ""
|
||||
echo -e "${YELLOW}⚡ Reload your shell to use 'hermes' command:${NC}"
|
||||
echo ""
|
||||
echo " source ~/.bashrc # or ~/.zshrc"
|
||||
echo ""
|
||||
|
||||
# Show Node.js warning if not installed
|
||||
if [ "$HAS_NODE" = false ]; then
|
||||
echo -e "${YELLOW}"
|
||||
echo "Note: Node.js was not found. Browser automation tools"
|
||||
echo "will have limited functionality. Install Node.js later"
|
||||
echo "if you need full browser support."
|
||||
echo -e "${NC}"
|
||||
fi
|
||||
|
||||
# Show ripgrep note if not installed
|
||||
if [ "$HAS_RIPGREP" = false ]; then
|
||||
echo -e "${YELLOW}"
|
||||
echo "Note: ripgrep (rg) was not found. File search will use"
|
||||
echo "grep as a fallback. For faster search in large codebases,"
|
||||
echo "install ripgrep: sudo apt install ripgrep (or brew install ripgrep)"
|
||||
echo -e "${NC}"
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Main
|
||||
# ============================================================================
|
||||
|
||||
main() {
|
||||
print_banner
|
||||
|
||||
detect_os
|
||||
check_python
|
||||
check_git
|
||||
check_node
|
||||
check_ripgrep
|
||||
|
||||
clone_repo
|
||||
setup_venv
|
||||
install_deps
|
||||
install_node_deps
|
||||
setup_path
|
||||
copy_config_templates
|
||||
run_setup_wizard
|
||||
|
||||
print_success
|
||||
}
|
||||
|
||||
main
|
||||
411
scripts/sample_and_compress.py
Normal file
411
scripts/sample_and_compress.py
Normal file
@@ -0,0 +1,411 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Sample and Compress HuggingFace Datasets
|
||||
|
||||
Downloads trajectories from multiple HuggingFace datasets, randomly samples them,
|
||||
and runs trajectory compression to fit within a target token budget.
|
||||
|
||||
Usage:
|
||||
python scripts/sample_and_compress.py
|
||||
|
||||
# Custom sample size
|
||||
python scripts/sample_and_compress.py --total_samples=5000
|
||||
|
||||
# Custom output name
|
||||
python scripts/sample_and_compress.py --output_name=compressed_16k
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Tuple
|
||||
import fire
|
||||
|
||||
# Load environment variables
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# Default datasets to sample from
|
||||
DEFAULT_DATASETS = [
|
||||
"NousResearch/swe-terminus-agent-glm-kimi-minimax",
|
||||
"NousResearch/hermes-agent-megascience-sft1",
|
||||
"NousResearch/Hermes-Agent-Thinking-GLM-4.7-SFT2",
|
||||
"NousResearch/Hermes-Agent-Thinking-GLM-4.7-SFT1",
|
||||
"NousResearch/terminal-tasks-glm-hermes-agent"
|
||||
]
|
||||
|
||||
|
||||
def load_dataset_from_hf(dataset_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load a dataset from HuggingFace.
|
||||
|
||||
Args:
|
||||
dataset_name: HuggingFace dataset name (e.g., "NousResearch/dataset-name")
|
||||
|
||||
Returns:
|
||||
List of trajectory entries
|
||||
"""
|
||||
from datasets import load_dataset
|
||||
|
||||
print(f" Loading {dataset_name}...")
|
||||
|
||||
try:
|
||||
# Try loading with default config
|
||||
ds = load_dataset(dataset_name, split="train")
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Error loading {dataset_name}: {e}")
|
||||
return []
|
||||
|
||||
# Convert to list of dicts
|
||||
entries = []
|
||||
for item in ds:
|
||||
# Handle different possible formats
|
||||
if "conversations" in item:
|
||||
entries.append({"conversations": item["conversations"]})
|
||||
elif "messages" in item:
|
||||
# Convert messages format to conversations format if needed
|
||||
entries.append({"conversations": item["messages"]})
|
||||
else:
|
||||
# Assume the whole item is the entry
|
||||
entries.append(dict(item))
|
||||
|
||||
print(f" ✅ Loaded {len(entries):,} entries from {dataset_name}")
|
||||
return entries
|
||||
|
||||
|
||||
# Global tokenizer for multiprocessing (set in worker init)
|
||||
_TOKENIZER = None
|
||||
|
||||
|
||||
def _init_tokenizer_worker(tokenizer_name: str):
|
||||
"""Initialize tokenizer in worker process."""
|
||||
global _TOKENIZER
|
||||
from transformers import AutoTokenizer
|
||||
_TOKENIZER = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
|
||||
|
||||
|
||||
def _count_tokens_for_entry(entry: Dict) -> Tuple[Dict, int]:
|
||||
"""
|
||||
Count tokens for a single entry (used in parallel processing).
|
||||
|
||||
Args:
|
||||
entry: Trajectory entry with 'conversations' field
|
||||
|
||||
Returns:
|
||||
Tuple of (entry, token_count)
|
||||
"""
|
||||
global _TOKENIZER
|
||||
|
||||
conversations = entry.get("conversations", [])
|
||||
if not conversations:
|
||||
return entry, 0
|
||||
|
||||
total = 0
|
||||
for turn in conversations:
|
||||
value = turn.get("value", "")
|
||||
if value:
|
||||
try:
|
||||
total += len(_TOKENIZER.encode(value))
|
||||
except:
|
||||
# Fallback to character estimate
|
||||
total += len(value) // 4
|
||||
|
||||
return entry, total
|
||||
|
||||
|
||||
def sample_from_datasets(
|
||||
datasets: List[str],
|
||||
total_samples: int,
|
||||
min_tokens: int = 16000,
|
||||
tokenizer_name: str = "moonshotai/Kimi-K2-Thinking",
|
||||
seed: int = 42,
|
||||
num_proc: int = 8
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load all datasets, filter by token count, then randomly sample from combined pool.
|
||||
|
||||
Args:
|
||||
datasets: List of HuggingFace dataset names
|
||||
total_samples: Total number of samples to collect
|
||||
min_tokens: Minimum token count to include (only sample trajectories >= this)
|
||||
tokenizer_name: HuggingFace tokenizer for counting tokens
|
||||
seed: Random seed for reproducibility
|
||||
num_proc: Number of parallel processes for tokenization
|
||||
|
||||
Returns:
|
||||
List of sampled trajectory entries
|
||||
"""
|
||||
from multiprocessing import Pool
|
||||
from functools import partial
|
||||
|
||||
random.seed(seed)
|
||||
|
||||
print(f"\n📥 Loading {len(datasets)} datasets...")
|
||||
print(f" Minimum tokens: {min_tokens:,} (filtering smaller trajectories)")
|
||||
print(f" Parallel workers: {num_proc}")
|
||||
print()
|
||||
|
||||
# Load ALL entries from all datasets into one pool
|
||||
all_entries = []
|
||||
|
||||
for dataset_name in datasets:
|
||||
entries = load_dataset_from_hf(dataset_name)
|
||||
|
||||
if not entries:
|
||||
print(f" ⚠️ Skipping {dataset_name} (no entries loaded)")
|
||||
continue
|
||||
|
||||
# Add source metadata to each entry
|
||||
for entry in entries:
|
||||
entry["_source_dataset"] = dataset_name
|
||||
|
||||
all_entries.extend(entries)
|
||||
|
||||
print(f"\n📊 Total entries loaded: {len(all_entries):,}")
|
||||
|
||||
# Filter by token count using parallel processing
|
||||
print(f"\n🔍 Filtering trajectories with >= {min_tokens:,} tokens (using {num_proc} workers)...")
|
||||
|
||||
filtered_entries = []
|
||||
token_counts = []
|
||||
|
||||
# Use multiprocessing for token counting
|
||||
with Pool(
|
||||
processes=num_proc,
|
||||
initializer=_init_tokenizer_worker,
|
||||
initargs=(tokenizer_name,)
|
||||
) as pool:
|
||||
# Process in chunks and show progress
|
||||
chunk_size = 1000
|
||||
processed = 0
|
||||
|
||||
for result in pool.imap_unordered(_count_tokens_for_entry, all_entries, chunksize=100):
|
||||
entry, token_count = result
|
||||
processed += 1
|
||||
|
||||
if processed % chunk_size == 0:
|
||||
print(f" Processed {processed:,}/{len(all_entries):,}...", end="\r")
|
||||
|
||||
if token_count >= min_tokens:
|
||||
entry["_original_tokens"] = token_count
|
||||
filtered_entries.append(entry)
|
||||
token_counts.append(token_count)
|
||||
|
||||
print(f"\n ✅ Found {len(filtered_entries):,} trajectories >= {min_tokens:,} tokens")
|
||||
|
||||
if token_counts:
|
||||
avg_tokens = sum(token_counts) / len(token_counts)
|
||||
print(f" 📈 Token stats: min={min(token_counts):,}, max={max(token_counts):,}, avg={avg_tokens:,.0f}")
|
||||
|
||||
# Random sample from the filtered pool
|
||||
if len(filtered_entries) <= total_samples:
|
||||
print(f"\n⚠️ Only {len(filtered_entries):,} trajectories available, using all of them")
|
||||
sampled = filtered_entries
|
||||
else:
|
||||
sampled = random.sample(filtered_entries, total_samples)
|
||||
print(f"\n✅ Randomly sampled {len(sampled):,} trajectories from pool of {len(filtered_entries):,}")
|
||||
|
||||
# Show source distribution
|
||||
source_counts = {}
|
||||
for entry in sampled:
|
||||
source = entry.get("_source_dataset", "unknown").split("/")[-1]
|
||||
source_counts[source] = source_counts.get(source, 0) + 1
|
||||
|
||||
print(f"\n📌 Sample distribution by source:")
|
||||
for source, count in sorted(source_counts.items()):
|
||||
print(f" {source}: {count:,}")
|
||||
|
||||
# Shuffle
|
||||
random.shuffle(sampled)
|
||||
|
||||
return sampled
|
||||
|
||||
|
||||
def save_samples_for_compression(
|
||||
samples: List[Dict[str, Any]],
|
||||
output_dir: Path,
|
||||
batch_size: int = 100
|
||||
):
|
||||
"""
|
||||
Save samples to JSONL files for trajectory compression.
|
||||
|
||||
Args:
|
||||
samples: List of trajectory entries
|
||||
output_dir: Directory to save JSONL files
|
||||
batch_size: Number of entries per file
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Split into batches
|
||||
num_batches = (len(samples) + batch_size - 1) // batch_size
|
||||
|
||||
print(f"\n💾 Saving {len(samples)} samples to {output_dir}")
|
||||
print(f" Batch size: {batch_size}, Total batches: {num_batches}")
|
||||
|
||||
for i in range(num_batches):
|
||||
start_idx = i * batch_size
|
||||
end_idx = min((i + 1) * batch_size, len(samples))
|
||||
batch = samples[start_idx:end_idx]
|
||||
|
||||
output_file = output_dir / f"batch_{i}.jsonl"
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for entry in batch:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f" ✅ Saved {num_batches} batch files")
|
||||
|
||||
|
||||
def run_compression(input_dir: Path, output_dir: Path, config_path: str):
|
||||
"""
|
||||
Run trajectory compression on the sampled data.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing JSONL files to compress
|
||||
output_dir: Directory for compressed output
|
||||
config_path: Path to compression config YAML
|
||||
"""
|
||||
# Import the compressor
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from trajectory_compressor import TrajectoryCompressor, CompressionConfig
|
||||
|
||||
print(f"\n🗜️ Running trajectory compression...")
|
||||
print(f" Input: {input_dir}")
|
||||
print(f" Output: {output_dir}")
|
||||
print(f" Config: {config_path}")
|
||||
|
||||
# Load config
|
||||
config = CompressionConfig.from_yaml(config_path)
|
||||
|
||||
# Initialize compressor
|
||||
compressor = TrajectoryCompressor(config)
|
||||
|
||||
# Run compression
|
||||
compressor.process_directory(input_dir, output_dir)
|
||||
|
||||
|
||||
def merge_output_to_single_jsonl(input_dir: Path, output_file: Path):
|
||||
"""
|
||||
Merge all JSONL files in a directory into a single JSONL file.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing JSONL files
|
||||
output_file: Output JSONL file path
|
||||
"""
|
||||
print(f"\n📦 Merging output files into {output_file.name}...")
|
||||
|
||||
all_entries = []
|
||||
for jsonl_file in sorted(input_dir.glob("*.jsonl")):
|
||||
if jsonl_file.name == output_file.name:
|
||||
continue
|
||||
with open(jsonl_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
all_entries.append(json.loads(line))
|
||||
|
||||
# Write merged file
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for entry in all_entries:
|
||||
f.write(json.dumps(entry, ensure_ascii=False) + '\n')
|
||||
|
||||
print(f" ✅ Merged {len(all_entries):,} entries into {output_file.name}")
|
||||
return output_file
|
||||
|
||||
|
||||
def main(
|
||||
total_samples: int = 2500,
|
||||
output_name: str = "compressed_agentic",
|
||||
datasets: str = None,
|
||||
config: str = "configs/trajectory_compression.yaml",
|
||||
seed: int = 42,
|
||||
batch_size: int = 100,
|
||||
min_tokens: int = 16000,
|
||||
num_proc: int = 8,
|
||||
skip_download: bool = False,
|
||||
):
|
||||
"""
|
||||
Sample trajectories from HuggingFace datasets and run compression.
|
||||
|
||||
Args:
|
||||
total_samples: Total number of samples to collect (default: 2500)
|
||||
output_name: Name for output directory/file (default: "compressed_agentic")
|
||||
datasets: Comma-separated list of dataset names (uses defaults if not provided)
|
||||
config: Path to compression config YAML
|
||||
seed: Random seed for reproducibility
|
||||
batch_size: Number of entries per JSONL file during processing
|
||||
min_tokens: Minimum token count to filter trajectories (default: 16000)
|
||||
num_proc: Number of parallel workers for tokenization (default: 8)
|
||||
skip_download: Skip download and use existing sampled data
|
||||
"""
|
||||
print("=" * 70)
|
||||
print("📊 TRAJECTORY SAMPLING AND COMPRESSION")
|
||||
print("=" * 70)
|
||||
|
||||
# Parse datasets
|
||||
if datasets:
|
||||
dataset_list = [d.strip() for d in datasets.split(",")]
|
||||
else:
|
||||
dataset_list = DEFAULT_DATASETS
|
||||
|
||||
print(f"\n📋 Configuration:")
|
||||
print(f" Total samples: {total_samples:,}")
|
||||
print(f" Min tokens filter: {min_tokens:,}")
|
||||
print(f" Parallel workers: {num_proc}")
|
||||
print(f" Datasets: {len(dataset_list)}")
|
||||
for ds in dataset_list:
|
||||
print(f" - {ds}")
|
||||
print(f" Output name: {output_name}")
|
||||
print(f" Config: {config}")
|
||||
print(f" Seed: {seed}")
|
||||
|
||||
# Setup paths
|
||||
base_dir = Path(__file__).parent.parent
|
||||
sampled_dir = base_dir / "data" / f"{output_name}_raw"
|
||||
compressed_dir = base_dir / "data" / f"{output_name}_batches"
|
||||
final_output = base_dir / "data" / f"{output_name}.jsonl"
|
||||
|
||||
if not skip_download:
|
||||
# Step 1: Download, filter by token count, and sample from combined pool
|
||||
samples = sample_from_datasets(
|
||||
dataset_list,
|
||||
total_samples,
|
||||
min_tokens=min_tokens,
|
||||
seed=seed,
|
||||
num_proc=num_proc
|
||||
)
|
||||
|
||||
if not samples:
|
||||
print("❌ No samples collected. Exiting.")
|
||||
return
|
||||
|
||||
# Step 2: Save to JSONL files
|
||||
save_samples_for_compression(samples, sampled_dir, batch_size)
|
||||
else:
|
||||
print(f"\n⏭️ Skipping download, using existing data in {sampled_dir}")
|
||||
|
||||
# Step 3: Run compression
|
||||
config_path = base_dir / config
|
||||
if not config_path.exists():
|
||||
print(f"❌ Config not found: {config_path}")
|
||||
return
|
||||
|
||||
run_compression(sampled_dir, compressed_dir, str(config_path))
|
||||
|
||||
# Step 4: Merge into single JSONL file
|
||||
merge_output_to_single_jsonl(compressed_dir, final_output)
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("✅ COMPLETE!")
|
||||
print("=" * 70)
|
||||
print(f"\n📁 Raw samples: {sampled_dir}")
|
||||
print(f"📁 Compressed batches: {compressed_dir}")
|
||||
print(f"📁 Final output: {final_output}")
|
||||
print(f"\nTo upload to HuggingFace:")
|
||||
print(f" huggingface-cli upload NousResearch/{output_name} {final_output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
203
setup-hermes.sh
Executable file
203
setup-hermes.sh
Executable file
@@ -0,0 +1,203 @@
|
||||
#!/bin/bash
|
||||
# ============================================================================
|
||||
# Hermes Agent Setup Script
|
||||
# ============================================================================
|
||||
# Quick setup for developers who cloned the repo manually.
|
||||
#
|
||||
# Usage:
|
||||
# ./setup-hermes.sh
|
||||
#
|
||||
# This script:
|
||||
# 1. Creates a virtual environment (if not exists)
|
||||
# 2. Installs dependencies
|
||||
# 3. Creates .env from template (if not exists)
|
||||
# 4. Installs the 'hermes' CLI command
|
||||
# 5. Runs the setup wizard (optional)
|
||||
# ============================================================================
|
||||
|
||||
set -e
|
||||
|
||||
# Colors
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[0;33m'
|
||||
CYAN='\033[0;36m'
|
||||
NC='\033[0m'
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
cd "$SCRIPT_DIR"
|
||||
|
||||
echo ""
|
||||
echo -e "${CYAN}🦋 Hermes Agent Setup${NC}"
|
||||
echo ""
|
||||
|
||||
# ============================================================================
|
||||
# Python check
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Checking Python..."
|
||||
|
||||
PYTHON_CMD=""
|
||||
for cmd in python3.12 python3.11 python3.10 python3 python; do
|
||||
if command -v $cmd &> /dev/null; then
|
||||
if $cmd -c "import sys; exit(0 if sys.version_info >= (3, 10) else 1)" 2>/dev/null; then
|
||||
PYTHON_CMD=$cmd
|
||||
break
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$PYTHON_CMD" ]; then
|
||||
echo -e "${YELLOW}✗${NC} Python 3.10+ required"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PYTHON_VERSION=$($PYTHON_CMD -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
||||
echo -e "${GREEN}✓${NC} Python $PYTHON_VERSION found"
|
||||
|
||||
# ============================================================================
|
||||
# Virtual environment
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Setting up virtual environment..."
|
||||
|
||||
if [ ! -d "venv" ]; then
|
||||
$PYTHON_CMD -m venv venv
|
||||
echo -e "${GREEN}✓${NC} Created venv"
|
||||
else
|
||||
echo -e "${GREEN}✓${NC} venv exists"
|
||||
fi
|
||||
|
||||
source venv/bin/activate
|
||||
pip install --upgrade pip wheel setuptools > /dev/null
|
||||
|
||||
# ============================================================================
|
||||
# Dependencies
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Installing dependencies..."
|
||||
|
||||
pip install -e ".[all]" > /dev/null 2>&1 || pip install -e "." > /dev/null
|
||||
|
||||
echo -e "${GREEN}✓${NC} Dependencies installed"
|
||||
|
||||
# ============================================================================
|
||||
# Optional: ripgrep (for faster file search)
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Checking ripgrep (optional, for faster search)..."
|
||||
|
||||
if command -v rg &> /dev/null; then
|
||||
echo -e "${GREEN}✓${NC} ripgrep found"
|
||||
else
|
||||
echo -e "${YELLOW}⚠${NC} ripgrep not found (file search will use grep fallback)"
|
||||
read -p "Install ripgrep for faster search? [Y/n] " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
INSTALLED=false
|
||||
|
||||
# Check if sudo is available
|
||||
if command -v sudo &> /dev/null && sudo -n true 2>/dev/null; then
|
||||
if command -v apt &> /dev/null; then
|
||||
sudo apt install -y ripgrep && INSTALLED=true
|
||||
elif command -v dnf &> /dev/null; then
|
||||
sudo dnf install -y ripgrep && INSTALLED=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Try brew (no sudo needed)
|
||||
if [ "$INSTALLED" = false ] && command -v brew &> /dev/null; then
|
||||
brew install ripgrep && INSTALLED=true
|
||||
fi
|
||||
|
||||
# Try cargo (no sudo needed)
|
||||
if [ "$INSTALLED" = false ] && command -v cargo &> /dev/null; then
|
||||
echo -e "${CYAN}→${NC} Trying cargo install (no sudo required)..."
|
||||
cargo install ripgrep && INSTALLED=true
|
||||
fi
|
||||
|
||||
if [ "$INSTALLED" = true ]; then
|
||||
echo -e "${GREEN}✓${NC} ripgrep installed"
|
||||
else
|
||||
echo -e "${YELLOW}⚠${NC} Auto-install failed. Install options:"
|
||||
echo " sudo apt install ripgrep # Debian/Ubuntu"
|
||||
echo " brew install ripgrep # macOS"
|
||||
echo " cargo install ripgrep # With Rust (no sudo)"
|
||||
echo " https://github.com/BurntSushi/ripgrep#installation"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
# Environment file
|
||||
# ============================================================================
|
||||
|
||||
if [ ! -f ".env" ]; then
|
||||
if [ -f ".env.example" ]; then
|
||||
cp .env.example .env
|
||||
echo -e "${GREEN}✓${NC} Created .env from template"
|
||||
fi
|
||||
else
|
||||
echo -e "${GREEN}✓${NC} .env exists"
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
# PATH setup
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Setting up hermes command..."
|
||||
|
||||
BIN_DIR="$SCRIPT_DIR/venv/bin"
|
||||
|
||||
# Add to shell config if not already there
|
||||
SHELL_CONFIG=""
|
||||
if [ -f "$HOME/.zshrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.zshrc"
|
||||
elif [ -f "$HOME/.bashrc" ]; then
|
||||
SHELL_CONFIG="$HOME/.bashrc"
|
||||
elif [ -f "$HOME/.bash_profile" ]; then
|
||||
SHELL_CONFIG="$HOME/.bash_profile"
|
||||
fi
|
||||
|
||||
if [ -n "$SHELL_CONFIG" ]; then
|
||||
if ! grep -q "hermes-agent" "$SHELL_CONFIG" 2>/dev/null; then
|
||||
echo "" >> "$SHELL_CONFIG"
|
||||
echo "# Hermes Agent" >> "$SHELL_CONFIG"
|
||||
echo "export PATH=\"$BIN_DIR:\$PATH\"" >> "$SHELL_CONFIG"
|
||||
echo -e "${GREEN}✓${NC} Added to $SHELL_CONFIG"
|
||||
else
|
||||
echo -e "${GREEN}✓${NC} PATH already in $SHELL_CONFIG"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
# Done
|
||||
# ============================================================================
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}✓ Setup complete!${NC}"
|
||||
echo ""
|
||||
echo "Next steps:"
|
||||
echo ""
|
||||
echo " 1. Reload your shell:"
|
||||
echo " source $SHELL_CONFIG"
|
||||
echo ""
|
||||
echo " 2. Run the setup wizard to configure API keys:"
|
||||
echo " hermes setup"
|
||||
echo ""
|
||||
echo " 3. Start chatting:"
|
||||
echo " hermes"
|
||||
echo ""
|
||||
echo "Other commands:"
|
||||
echo " hermes status # Check configuration"
|
||||
echo " hermes gateway # Start messaging gateway"
|
||||
echo " hermes cron daemon # Run cron daemon"
|
||||
echo " hermes doctor # Diagnose issues"
|
||||
echo ""
|
||||
|
||||
# Ask if they want to run setup wizard now
|
||||
read -p "Would you like to run the setup wizard now? [Y/n] " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
echo ""
|
||||
python -m hermes_cli.main setup
|
||||
fi
|
||||
3
skills/mlops/DESCRIPTION.md
Normal file
3
skills/mlops/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Knowledge and Tools for Machine Learning Operations - tools and frameworks for training, fine-tuning, deploying, and optimizing ML/AI models
|
||||
---
|
||||
332
skills/mlops/accelerate/SKILL.md
Normal file
332
skills/mlops/accelerate/SKILL.md
Normal file
@@ -0,0 +1,332 @@
|
||||
---
|
||||
name: huggingface-accelerate
|
||||
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
|
||||
dependencies: [accelerate, torch, transformers]
|
||||
---
|
||||
|
||||
# HuggingFace Accelerate - Unified Distributed Training
|
||||
|
||||
## Quick start
|
||||
|
||||
Accelerate simplifies distributed training to 4 lines of code.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
**Convert PyTorch script** (4 lines):
|
||||
```python
|
||||
import torch
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
+ accelerator = Accelerator()
|
||||
|
||||
model = torch.nn.Transformer()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset)
|
||||
|
||||
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
- loss.backward()
|
||||
+ accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Run** (single command):
|
||||
```bash
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: From single GPU to multi-GPU
|
||||
|
||||
**Original script**:
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
|
||||
model = torch.nn.Linear(10, 2).to('cuda')
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
batch = batch.to('cuda')
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**With Accelerate** (4 lines added):
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
from accelerate import Accelerator # +1
|
||||
|
||||
accelerator = Accelerator() # +2
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
# No .to('cuda') needed - automatic!
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
accelerator.backward(loss) # +4
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Configure** (interactive):
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
- Which machine? (single/multi GPU/TPU/CPU)
|
||||
- How many machines? (1)
|
||||
- Mixed precision? (no/fp16/bf16/fp8)
|
||||
- DeepSpeed? (no/yes)
|
||||
|
||||
**Launch** (works on any setup):
|
||||
```bash
|
||||
# Single GPU
|
||||
accelerate launch train.py
|
||||
|
||||
# Multi-GPU (8 GPUs)
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
train.py
|
||||
```
|
||||
|
||||
### Workflow 2: Mixed precision training
|
||||
|
||||
**Enable FP16/BF16**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# FP16 (with gradient scaling)
|
||||
accelerator = Accelerator(mixed_precision='fp16')
|
||||
|
||||
# BF16 (no scaling, more stable)
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# FP8 (H100+)
|
||||
accelerator = Accelerator(mixed_precision='fp8')
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
# Everything else is automatic!
|
||||
for batch in dataloader:
|
||||
with accelerator.autocast(): # Optional, done automatically
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
```
|
||||
|
||||
### Workflow 3: DeepSpeed ZeRO integration
|
||||
|
||||
**Enable DeepSpeed ZeRO-2**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
deepspeed_plugin={
|
||||
"zero_stage": 2, # ZeRO-2
|
||||
"offload_optimizer": False,
|
||||
"gradient_accumulation_steps": 4
|
||||
}
|
||||
)
|
||||
|
||||
# Same code as before!
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: DeepSpeed → ZeRO-2
|
||||
```
|
||||
|
||||
**deepspeed_config.json**:
|
||||
```json
|
||||
{
|
||||
"fp16": {"enabled": false},
|
||||
"bf16": {"enabled": true},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {"device": "cpu"},
|
||||
"allgather_bucket_size": 5e8,
|
||||
"reduce_bucket_size": 5e8
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
accelerate launch --config_file deepspeed_config.json train.py
|
||||
```
|
||||
|
||||
### Workflow 4: FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**Enable FSDP**:
|
||||
```python
|
||||
from accelerate import Accelerator, FullyShardedDataParallelPlugin
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
||||
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
|
||||
cpu_offload=False
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
fsdp_plugin=fsdp_plugin
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: FSDP → Full Shard → No CPU Offload
|
||||
```
|
||||
|
||||
### Workflow 5: Gradient accumulation
|
||||
|
||||
**Accumulate gradients**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=4)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Accelerate when**:
|
||||
- Want simplest distributed training
|
||||
- Need single script for any hardware
|
||||
- Use HuggingFace ecosystem
|
||||
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
|
||||
- Need quick prototyping
|
||||
|
||||
**Key advantages**:
|
||||
- **4 lines**: Minimal code changes
|
||||
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
|
||||
- **Automatic**: Device placement, mixed precision, sharding
|
||||
- **Interactive config**: No manual launcher setup
|
||||
- **Single launch**: Works everywhere
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **PyTorch Lightning**: Need callbacks, high-level abstractions
|
||||
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
|
||||
- **DeepSpeed**: Direct API control, advanced features
|
||||
- **Raw DDP**: Maximum control, minimal abstraction
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Wrong device placement**
|
||||
|
||||
Don't manually move to device:
|
||||
```python
|
||||
# WRONG
|
||||
batch = batch.to('cuda')
|
||||
|
||||
# CORRECT
|
||||
# Accelerate handles it automatically after prepare()
|
||||
```
|
||||
|
||||
**Issue: Gradient accumulation not working**
|
||||
|
||||
Use context manager:
|
||||
```python
|
||||
# CORRECT
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Issue: Checkpointing in distributed**
|
||||
|
||||
Use accelerator methods:
|
||||
```python
|
||||
# Save only on main process
|
||||
if accelerator.is_main_process:
|
||||
accelerator.save_state('checkpoint/')
|
||||
|
||||
# Load on all processes
|
||||
accelerator.load_state('checkpoint/')
|
||||
```
|
||||
|
||||
**Issue: Different results with FSDP**
|
||||
|
||||
Ensure same random seed:
|
||||
```python
|
||||
from accelerate.utils import set_seed
|
||||
set_seed(42)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
|
||||
|
||||
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
|
||||
|
||||
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **CPU**: Works (slow)
|
||||
- **Single GPU**: Works
|
||||
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
|
||||
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
|
||||
- **TPU**: Supported
|
||||
- **Apple MPS**: Supported
|
||||
|
||||
**Launcher requirements**:
|
||||
- **DDP**: `torch.distributed.run` (built-in)
|
||||
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
|
||||
- **FSDP**: PyTorch 1.12+ (built-in)
|
||||
- **Megatron**: Custom setup
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/accelerate
|
||||
- GitHub: https://github.com/huggingface/accelerate
|
||||
- Version: 1.11.0+
|
||||
- Tutorial: "Accelerate your scripts"
|
||||
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
|
||||
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries
|
||||
|
||||
|
||||
|
||||
453
skills/mlops/accelerate/references/custom-plugins.md
Normal file
453
skills/mlops/accelerate/references/custom-plugins.md
Normal file
@@ -0,0 +1,453 @@
|
||||
# Custom Plugins for Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
|
||||
|
||||
## Plugin Architecture
|
||||
|
||||
### Base Plugin Structure
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
"""Custom training plugin."""
|
||||
|
||||
# Plugin configuration
|
||||
param1: int = 1
|
||||
param2: str = "default"
|
||||
|
||||
def __post_init__(self):
|
||||
# Validation logic
|
||||
if self.param1 < 1:
|
||||
raise ValueError("param1 must be >= 1")
|
||||
```
|
||||
|
||||
### Using Custom Plugin
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# Create plugin
|
||||
custom_plugin = CustomPlugin(param1=4, param2="value")
|
||||
|
||||
# Pass to Accelerator
|
||||
accelerator = Accelerator(
|
||||
custom_plugin=custom_plugin # Not a real parameter, example only
|
||||
)
|
||||
```
|
||||
|
||||
## Built-In Plugin Examples
|
||||
|
||||
### 1. GradScalerKwargs (FP16 Configuration)
|
||||
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
# Configure gradient scaler for FP16
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16, # Initial loss scale
|
||||
growth_factor=2.0, # Scale growth rate
|
||||
backoff_factor=0.5, # Scale backoff rate
|
||||
growth_interval=2000, # Steps between scale increases
|
||||
enabled=True # Enable scaler
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Fine-tune FP16 gradient scaling behavior
|
||||
|
||||
### 2. DistributedDataParallelKwargs
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
# Configure DDP behavior
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Gradient bucketing size
|
||||
find_unused_parameters=False, # Find unused params (slower)
|
||||
check_reduction=False, # Check gradient reduction
|
||||
gradient_as_bucket_view=True, # Memory optimization
|
||||
static_graph=False # Static computation graph
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
kwargs_handlers=[ddp_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Optimize DDP performance for specific models
|
||||
|
||||
### 3. FP8RecipeKwargs (H100 FP8)
|
||||
|
||||
```python
|
||||
from accelerate.utils import FP8RecipeKwargs
|
||||
|
||||
# Configure FP8 training (H100)
|
||||
fp8_recipe = FP8RecipeKwargs(
|
||||
backend="te", # TransformerEngine backend
|
||||
margin=0, # Scaling margin
|
||||
interval=1, # Scaling interval
|
||||
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
|
||||
amax_history_len=1024, # AMAX history length
|
||||
amax_compute_algo="max" # AMAX computation algorithm
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp8',
|
||||
kwargs_handlers=[fp8_recipe]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Ultra-fast training on H100 GPUs
|
||||
|
||||
## Custom DeepSpeed Configuration
|
||||
|
||||
### ZeRO-3 with CPU Offload
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
# Custom DeepSpeed config
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3, # ZeRO-3
|
||||
offload_optimizer_device="cpu", # CPU offload optimizer
|
||||
offload_param_device="cpu", # CPU offload parameters
|
||||
zero3_init_flag=True, # ZeRO-3 initialization
|
||||
zero3_save_16bit_model=True, # Save FP16 weights
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### ZeRO-2 with NVMe Offload
|
||||
|
||||
```python
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="nvme", # NVMe offload
|
||||
offload_param_device="nvme",
|
||||
nvme_path="/local_nvme", # NVMe mount path
|
||||
)
|
||||
```
|
||||
|
||||
### Custom JSON Config
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Load custom DeepSpeed config
|
||||
with open('deepspeed_config.json', 'r') as f:
|
||||
ds_config = json.load(f)
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
|
||||
|
||||
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
|
||||
```
|
||||
|
||||
**Example config** (`deepspeed_config.json`):
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"steps_per_print": 100,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
```
|
||||
|
||||
## Custom FSDP Configuration
|
||||
|
||||
### FSDP with Custom Auto-Wrap Policy
|
||||
|
||||
```python
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin
|
||||
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||
import functools
|
||||
|
||||
# Custom wrap policy (size-based)
|
||||
wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy,
|
||||
min_num_params=1e6 # Wrap layers with 1M+ params
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
|
||||
mixed_precision_policy=None, # Use Accelerator's mixed precision
|
||||
auto_wrap_policy=wrap_policy, # Custom wrapping
|
||||
cpu_offload=False,
|
||||
ignored_modules=None, # Modules to not wrap
|
||||
state_dict_type="FULL_STATE_DICT", # Save format
|
||||
optim_state_dict_config=None,
|
||||
limit_all_gathers=False,
|
||||
use_orig_params=True, # Use original param shapes
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
fsdp_plugin=fsdp_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### FSDP with Transformer Auto-Wrap
|
||||
|
||||
```python
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
|
||||
# Wrap at transformer block level
|
||||
wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=wrap_policy
|
||||
)
|
||||
```
|
||||
|
||||
## Creating Custom Training Strategy
|
||||
|
||||
### Example: Custom Gradient Accumulation
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
class CustomGradientAccumulation:
|
||||
def __init__(self, steps=4, adaptive=False):
|
||||
self.steps = steps
|
||||
self.adaptive = adaptive
|
||||
self.current_step = 0
|
||||
|
||||
def should_sync(self, loss):
|
||||
"""Decide whether to sync gradients."""
|
||||
self.current_step += 1
|
||||
|
||||
# Adaptive: sync on high loss
|
||||
if self.adaptive and loss > threshold:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
# Regular: sync every N steps
|
||||
if self.current_step >= self.steps:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Usage
|
||||
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
|
||||
accelerator = Accelerator()
|
||||
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# Scale loss
|
||||
loss = loss / custom_accum.steps
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Conditional sync
|
||||
if custom_accum.should_sync(loss.item()):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### Example: Custom Mixed Precision
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class CustomMixedPrecision:
|
||||
"""Custom mixed precision with dynamic loss scaling."""
|
||||
|
||||
def __init__(self, init_scale=2**16, scale_window=2000):
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=init_scale,
|
||||
growth_interval=scale_window
|
||||
)
|
||||
self.scale_history = []
|
||||
|
||||
def scale_loss(self, loss):
|
||||
"""Scale loss for backward."""
|
||||
return self.scaler.scale(loss)
|
||||
|
||||
def unscale_and_clip(self, optimizer, max_norm=1.0):
|
||||
"""Unscale gradients and clip."""
|
||||
self.scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
optimizer.param_groups[0]['params'],
|
||||
max_norm
|
||||
)
|
||||
|
||||
def step(self, optimizer):
|
||||
"""Optimizer step with scaler update."""
|
||||
scale_before = self.scaler.get_scale()
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.update()
|
||||
scale_after = self.scaler.get_scale()
|
||||
|
||||
# Track scale changes
|
||||
if scale_before != scale_after:
|
||||
self.scale_history.append(scale_after)
|
||||
|
||||
# Usage
|
||||
custom_mp = CustomMixedPrecision()
|
||||
|
||||
for batch in dataloader:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
loss = model(**batch).loss
|
||||
|
||||
scaled_loss = custom_mp.scale_loss(loss)
|
||||
scaled_loss.backward()
|
||||
|
||||
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
|
||||
custom_mp.step(optimizer)
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Advanced: Custom Distributed Backend
|
||||
|
||||
### Custom AllReduce Strategy
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
|
||||
class CustomAllReduce:
|
||||
"""Custom all-reduce with compression."""
|
||||
|
||||
def __init__(self, compression_ratio=0.1):
|
||||
self.compression_ratio = compression_ratio
|
||||
|
||||
def compress_gradients(self, tensor):
|
||||
"""Top-k gradient compression."""
|
||||
k = int(tensor.numel() * self.compression_ratio)
|
||||
values, indices = torch.topk(tensor.abs().view(-1), k)
|
||||
return values, indices
|
||||
|
||||
def all_reduce_compressed(self, tensor):
|
||||
"""All-reduce with gradient compression."""
|
||||
# Compress
|
||||
values, indices = self.compress_gradients(tensor)
|
||||
|
||||
# All-reduce compressed gradients
|
||||
dist.all_reduce(values, op=dist.ReduceOp.SUM)
|
||||
|
||||
# Decompress
|
||||
tensor_compressed = torch.zeros_like(tensor).view(-1)
|
||||
tensor_compressed[indices] = values / dist.get_world_size()
|
||||
|
||||
return tensor_compressed.view_as(tensor)
|
||||
|
||||
# Usage in training loop
|
||||
custom_ar = CustomAllReduce(compression_ratio=0.1)
|
||||
|
||||
for batch in dataloader:
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
|
||||
# Custom all-reduce
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Plugin Best Practices
|
||||
|
||||
### 1. Validation in `__post_init__`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
learning_rate: float = 1e-3
|
||||
warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate parameters
|
||||
if self.learning_rate <= 0:
|
||||
raise ValueError("learning_rate must be positive")
|
||||
if self.warmup_steps < 0:
|
||||
raise ValueError("warmup_steps must be non-negative")
|
||||
|
||||
# Compute derived values
|
||||
self.min_lr = self.learning_rate * 0.1
|
||||
```
|
||||
|
||||
### 2. Compatibility Checks
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
feature_enabled: bool = True
|
||||
|
||||
def is_compatible(self, accelerator):
|
||||
"""Check if plugin is compatible with accelerator config."""
|
||||
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
|
||||
raise ValueError("Custom plugin not compatible with FP8")
|
||||
return True
|
||||
```
|
||||
|
||||
### 3. State Management
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
counter: int = 0
|
||||
history: list = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.history is None:
|
||||
self.history = []
|
||||
|
||||
def update_state(self, value):
|
||||
"""Update plugin state during training."""
|
||||
self.counter += 1
|
||||
self.history.append(value)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
|
||||
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
|
||||
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
|
||||
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu
|
||||
489
skills/mlops/accelerate/references/megatron-integration.md
Normal file
489
skills/mlops/accelerate/references/megatron-integration.md
Normal file
@@ -0,0 +1,489 @@
|
||||
# Megatron Integration with Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
|
||||
|
||||
**Megatron capabilities**:
|
||||
- **Tensor Parallelism (TP)**: Split layers across GPUs
|
||||
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
|
||||
- **Data Parallelism (DP)**: Replicate model across GPU groups
|
||||
- **Sequence Parallelism**: Split sequences for long contexts
|
||||
|
||||
## Setup
|
||||
|
||||
### Install Megatron-LM
|
||||
|
||||
```bash
|
||||
# Clone Megatron-LM repository
|
||||
git clone https://github.com/NVIDIA/Megatron-LM.git
|
||||
cd Megatron-LM
|
||||
pip install -e .
|
||||
|
||||
# Install Apex (NVIDIA optimizations)
|
||||
git clone https://github.com/NVIDIA/apex
|
||||
cd apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
|
||||
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
```
|
||||
|
||||
### Accelerate Configuration
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
```
|
||||
In which compute environment are you running?
|
||||
> This machine
|
||||
|
||||
Which type of machine are you using?
|
||||
> Multi-GPU
|
||||
|
||||
How many different machines will you use?
|
||||
> 1
|
||||
|
||||
Do you want to use DeepSpeed/FSDP?
|
||||
> No
|
||||
|
||||
Do you want to use Megatron-LM?
|
||||
> Yes
|
||||
|
||||
What is the Tensor Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Do you want to enable Sequence Parallelism?
|
||||
> No
|
||||
|
||||
What is the Pipeline Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
What is the Data Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
|
||||
> SELECTIVE
|
||||
|
||||
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
|
||||
> SEQUENTIAL
|
||||
```
|
||||
|
||||
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MEGATRON_LM
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
megatron_lm_config:
|
||||
megatron_lm_gradient_clipping: 1.0
|
||||
megatron_lm_learning_rate_decay_iters: 320000
|
||||
megatron_lm_num_micro_batches: 1
|
||||
megatron_lm_pp_degree: 2
|
||||
megatron_lm_recompute_activations: true
|
||||
megatron_lm_sequence_parallelism: false
|
||||
megatron_lm_tp_degree: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Tensor Parallelism (TP)
|
||||
|
||||
**Splits each transformer layer across GPUs**:
|
||||
|
||||
```python
|
||||
# Layer split across 2 GPUs
|
||||
# GPU 0: First half of attention heads
|
||||
# GPU 1: Second half of attention heads
|
||||
|
||||
# Each GPU computes partial outputs
|
||||
# All-reduce combines results
|
||||
```
|
||||
|
||||
**TP degree recommendations**:
|
||||
- **TP=1**: No tensor parallelism (single GPU per layer)
|
||||
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
|
||||
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
|
||||
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Reduces memory per GPU
|
||||
- All-reduce communication (fast)
|
||||
|
||||
**Drawbacks**:
|
||||
- Requires fast inter-GPU bandwidth (NVLink)
|
||||
- Communication overhead per layer
|
||||
|
||||
### Pipeline Parallelism (PP)
|
||||
|
||||
**Splits model depth across GPUs**:
|
||||
|
||||
```python
|
||||
# 12-layer model, PP=4
|
||||
# GPU 0: Layers 0-2
|
||||
# GPU 1: Layers 3-5
|
||||
# GPU 2: Layers 6-8
|
||||
# GPU 3: Layers 9-11
|
||||
```
|
||||
|
||||
**PP degree recommendations**:
|
||||
- **PP=1**: No pipeline parallelism
|
||||
- **PP=2**: 2 pipeline stages (good for 20-40B models)
|
||||
- **PP=4**: 4 pipeline stages (good for 70B+ models)
|
||||
- **PP=8**: 8 pipeline stages (good for 175B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Linear memory reduction (4× PP = 4× less memory)
|
||||
- Works across nodes (slower interconnect OK)
|
||||
|
||||
**Drawbacks**:
|
||||
- Pipeline bubbles (idle time)
|
||||
- Requires micro-batching
|
||||
|
||||
### Data Parallelism (DP)
|
||||
|
||||
**Replicates model across GPU groups**:
|
||||
|
||||
```python
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
# Group 0 (GPUs 0-3): Full model replica
|
||||
# Group 1 (GPUs 4-7): Full model replica
|
||||
```
|
||||
|
||||
**DP degree**:
|
||||
- `DP = total_gpus / (TP × PP)`
|
||||
- Example: 8 GPUs, TP=2, PP=2 → DP=2
|
||||
|
||||
**Benefits**:
|
||||
- Increases throughput
|
||||
- Scales batch size
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
**Splits long sequences across GPUs** (extends TP):
|
||||
|
||||
```python
|
||||
# 8K sequence, TP=2, Sequence Parallel=True
|
||||
# GPU 0: Tokens 0-4095
|
||||
# GPU 1: Tokens 4096-8191
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Enables very long sequences (100K+ tokens)
|
||||
- Reduces activation memory
|
||||
|
||||
**Requirements**:
|
||||
- Must use with TP > 1
|
||||
- RoPE/ALiBi position encodings work best
|
||||
|
||||
## Accelerate Code Example
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
|
||||
# Configure Megatron
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2, # Tensor parallelism degree
|
||||
pp_degree=2, # Pipeline parallelism degree
|
||||
num_micro_batches=4, # Micro-batches for pipeline
|
||||
gradient_clipping=1.0, # Gradient clipping value
|
||||
sequence_parallelism=False, # Enable sequence parallelism
|
||||
recompute_activations=True, # Activation checkpointing
|
||||
use_distributed_optimizer=True, # Distributed optimizer
|
||||
custom_prepare_model_function=None, # Custom model prep
|
||||
)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Prepare model and optimizer
|
||||
model, optimizer, train_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Training loop (same as DDP!)
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
### Full Training Script
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
def main():
|
||||
# Megatron configuration
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2,
|
||||
pp_degree=2,
|
||||
num_micro_batches=4,
|
||||
gradient_clipping=1.0,
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
gradient_accumulation_steps=8,
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Model
|
||||
config = GPT2Config(
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
n_embd=1024,
|
||||
)
|
||||
model = GPT2LMHeadModel(config)
|
||||
|
||||
# Optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
|
||||
|
||||
# Prepare
|
||||
model, optimizer, train_loader = accelerator.prepare(
|
||||
model, optimizer, train_loader
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
for batch in train_loader:
|
||||
with accelerator.accumulate(model):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Save checkpoint
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.save_state(f'checkpoint-epoch-{epoch}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
### Launch Command
|
||||
|
||||
```bash
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node (2 nodes, 8 GPUs each)
|
||||
# Node 0
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
|
||||
# Node 1
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 1 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
```
|
||||
|
||||
## Activation Checkpointing
|
||||
|
||||
**Reduces memory by recomputing activations**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
recompute_activations=True, # Enable checkpointing
|
||||
checkpoint_num_layers=1, # Checkpoint every N layers
|
||||
distribute_checkpointed_activations=True, # Distribute across TP
|
||||
partition_activations=True, # Partition in PP
|
||||
check_for_nan_in_loss_and_grad=True, # Stability check
|
||||
)
|
||||
```
|
||||
|
||||
**Strategies**:
|
||||
- `SELECTIVE`: Checkpoint transformer blocks only
|
||||
- `FULL`: Checkpoint all layers
|
||||
- `NONE`: No checkpointing
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
## Distributed Optimizer
|
||||
|
||||
**Shards optimizer state across DP ranks**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
use_distributed_optimizer=True, # Enable sharded optimizer
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Reduces optimizer memory by DP degree
|
||||
- Example: DP=4 → 4× less optimizer memory per GPU
|
||||
|
||||
**Compatible with**:
|
||||
- AdamW, Adam, SGD
|
||||
- Mixed precision training
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Micro-Batch Size
|
||||
|
||||
```python
|
||||
# Pipeline parallelism requires micro-batching
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # 16 micro-batches per pipeline
|
||||
)
|
||||
|
||||
# Effective batch = num_micro_batches × micro_batch_size × DP
|
||||
# Example: 16 × 2 × 4 = 128
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- More micro-batches → less pipeline bubble
|
||||
- Typical: 4-16 micro-batches
|
||||
|
||||
### Sequence Length
|
||||
|
||||
```python
|
||||
# For long sequences, enable sequence parallelism
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=4,
|
||||
sequence_parallelism=True, # Required: TP > 1
|
||||
)
|
||||
|
||||
# Enables sequences up to TP × normal limit
|
||||
# Example: TP=4, 8K normal → 32K with sequence parallel
|
||||
```
|
||||
|
||||
### GPU Topology
|
||||
|
||||
**NVLink required for TP**:
|
||||
```bash
|
||||
# Check NVLink topology
|
||||
nvidia-smi topo -m
|
||||
|
||||
# Good topology (NVLink between all GPUs)
|
||||
# GPU0 - GPU1: NV12 (fast)
|
||||
# GPU0 - GPU2: NV12 (fast)
|
||||
|
||||
# Bad topology (PCIe only)
|
||||
# GPU0 - GPU4: PHB (slow, avoid TP across these)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- **TP**: Within same node (NVLink)
|
||||
- **PP**: Across nodes (slower interconnect OK)
|
||||
- **DP**: Any topology
|
||||
|
||||
## Model Size Guidelines
|
||||
|
||||
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|
||||
|------------|------|----|----|----|--------------|
|
||||
| 7B | 8 | 1 | 1 | 8 | 1 |
|
||||
| 13B | 8 | 2 | 1 | 4 | 1 |
|
||||
| 20B | 16 | 4 | 1 | 4 | 1 |
|
||||
| 40B | 32 | 4 | 2 | 4 | 4 |
|
||||
| 70B | 64 | 8 | 2 | 4 | 8 |
|
||||
| 175B | 128 | 8 | 4 | 4 | 16 |
|
||||
|
||||
**Assumptions**: BF16, 2K sequence length, A100 80GB
|
||||
|
||||
## Checkpointing
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
# Save full model state
|
||||
accelerator.save_state('checkpoint-1000')
|
||||
|
||||
# Megatron saves separate files per rank
|
||||
# checkpoint-1000/
|
||||
# pytorch_model_tp_0_pp_0.bin
|
||||
# pytorch_model_tp_0_pp_1.bin
|
||||
# pytorch_model_tp_1_pp_0.bin
|
||||
# pytorch_model_tp_1_pp_1.bin
|
||||
# optimizer_tp_0_pp_0.bin
|
||||
# ...
|
||||
```
|
||||
|
||||
### Load Checkpoint
|
||||
|
||||
```python
|
||||
# Resume training
|
||||
accelerator.load_state('checkpoint-1000')
|
||||
|
||||
# Automatically loads correct shard per rank
|
||||
```
|
||||
|
||||
### Convert to Standard PyTorch
|
||||
|
||||
```bash
|
||||
# Merge Megatron checkpoint to single file
|
||||
python merge_megatron_checkpoint.py \
|
||||
--checkpoint-dir checkpoint-1000 \
|
||||
--output pytorch_model.bin
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: OOM with Pipeline Parallelism
|
||||
|
||||
**Solution**: Increase micro-batches
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # Increase from 4
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow Training
|
||||
|
||||
**Check 1**: Pipeline bubbles (PP too high)
|
||||
```python
|
||||
# Reduce PP, increase TP
|
||||
tp_degree=4 # Increase
|
||||
pp_degree=2 # Decrease
|
||||
```
|
||||
|
||||
**Check 2**: Micro-batch size too small
|
||||
```python
|
||||
num_micro_batches=8 # Increase
|
||||
```
|
||||
|
||||
### Issue: NVLink Not Detected
|
||||
|
||||
```bash
|
||||
# Verify NVLink
|
||||
nvidia-smi nvlink -s
|
||||
|
||||
# If no NVLink, avoid TP > 1
|
||||
# Use PP or DP instead
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
|
||||
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
|
||||
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
|
||||
- NVIDIA Apex: https://github.com/NVIDIA/apex
|
||||
525
skills/mlops/accelerate/references/performance.md
Normal file
525
skills/mlops/accelerate/references/performance.md
Normal file
@@ -0,0 +1,525 @@
|
||||
# Accelerate Performance Tuning
|
||||
|
||||
## Profiling
|
||||
|
||||
### Basic Profiling
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
import time
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
batch = next(iter(dataloader))
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Profile training loop
|
||||
start = time.time()
|
||||
total_batches = 100
|
||||
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= total_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone() # Sync all processes
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
batches_per_sec = total_batches / elapsed
|
||||
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
|
||||
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
|
||||
print(f"Batches/sec: {batches_per_sec:.2f}")
|
||||
```
|
||||
|
||||
### PyTorch Profiler Integration
|
||||
|
||||
```python
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True
|
||||
) as prof:
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10: # Profile first 10 batches
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Print profiling results
|
||||
print(prof.key_averages().table(
|
||||
sort_by="cuda_time_total", row_limit=20
|
||||
))
|
||||
|
||||
# Export to Chrome tracing
|
||||
prof.export_chrome_trace("trace.json")
|
||||
# View at chrome://tracing
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### 1. Gradient Accumulation
|
||||
|
||||
**Problem**: Large batch size causes OOM
|
||||
|
||||
**Solution**: Accumulate gradients across micro-batches
|
||||
|
||||
```python
|
||||
accelerator = Accelerator(gradient_accumulation_steps=8)
|
||||
|
||||
# Effective batch = batch_size × accumulation_steps × num_gpus
|
||||
# Example: 4 × 8 × 8 = 256
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation logic
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
|
||||
|
||||
### 2. Gradient Checkpointing
|
||||
|
||||
**Enable in model**:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
use_cache=False # Required for gradient checkpointing
|
||||
)
|
||||
|
||||
# Enable checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare with Accelerate
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
### 3. Mixed Precision
|
||||
|
||||
**BF16 (A100/H100)**:
|
||||
```python
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# Automatic mixed precision
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch) # Forward in BF16
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss) # Backward in FP32
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**FP16 (V100, older GPUs)**:
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16,
|
||||
growth_interval=2000
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% compared to FP32
|
||||
|
||||
### 4. CPU Offloading (DeepSpeed)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3,
|
||||
offload_optimizer_device="cpu", # Offload optimizer to CPU
|
||||
offload_param_device="cpu", # Offload parameters to CPU
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
|
||||
|
||||
**Trade-off**: 20-30% slower due to CPU-GPU transfers
|
||||
|
||||
### 5. Flash Attention
|
||||
|
||||
```python
|
||||
# Install flash-attn
|
||||
# pip install flash-attn
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
attn_implementation="flash_attention_2" # Enable Flash Attention 2
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% for attention, 2× faster
|
||||
|
||||
**Requirements**: A100/H100, sequence length must be multiple of 128
|
||||
|
||||
## Communication Optimization
|
||||
|
||||
### 1. Gradient Bucketing (DDP)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Bucket size for gradient reduction
|
||||
gradient_as_bucket_view=True, # Reduce memory copies
|
||||
static_graph=False # Set True if model doesn't change
|
||||
)
|
||||
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
```
|
||||
|
||||
**Recommended bucket sizes**:
|
||||
- Small models (<1B): 25 MB
|
||||
- Medium models (1-10B): 50-100 MB
|
||||
- Large models (>10B): 100-200 MB
|
||||
|
||||
### 2. Find Unused Parameters
|
||||
|
||||
```python
|
||||
# Only enable if model has unused parameters (slower!)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
find_unused_parameters=True
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Models with conditional branches (e.g., mixture of experts)
|
||||
|
||||
**Cost**: 10-20% slower
|
||||
|
||||
### 3. NCCL Tuning
|
||||
|
||||
```bash
|
||||
# Set environment variables before launch
|
||||
export NCCL_DEBUG=INFO # Debug info
|
||||
export NCCL_IB_DISABLE=0 # Enable InfiniBand
|
||||
export NCCL_SOCKET_IFNAME=eth0 # Network interface
|
||||
export NCCL_P2P_LEVEL=NVL # Use NVLink
|
||||
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
**NCCL_P2P_LEVEL options**:
|
||||
- `NVL`: NVLink (fastest, within node)
|
||||
- `PIX`: PCIe (fast, within node)
|
||||
- `PHB`: PCIe host bridge (slow, cross-node)
|
||||
|
||||
## Data Loading Optimization
|
||||
|
||||
### 1. DataLoader Workers
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Parallel data loading
|
||||
pin_memory=True, # Pin memory for faster GPU transfer
|
||||
prefetch_factor=2, # Prefetch batches per worker
|
||||
persistent_workers=True # Keep workers alive between epochs
|
||||
)
|
||||
|
||||
train_loader = accelerator.prepare(train_loader)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
|
||||
- `pin_memory`: Always True for GPU training
|
||||
- `prefetch_factor`: 2-4 (higher for slow data loading)
|
||||
|
||||
### 2. Data Preprocessing
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Bad: Preprocess during training (slow)
|
||||
dataset = load_dataset("openwebtext")
|
||||
|
||||
for batch in dataset:
|
||||
tokens = tokenizer(batch['text']) # Slow!
|
||||
...
|
||||
|
||||
# Good: Preprocess once, save
|
||||
dataset = load_dataset("openwebtext")
|
||||
tokenized = dataset.map(
|
||||
lambda x: tokenizer(x['text']),
|
||||
batched=True,
|
||||
num_proc=8, # Parallel preprocessing
|
||||
remove_columns=['text']
|
||||
)
|
||||
tokenized.save_to_disk("preprocessed_data")
|
||||
|
||||
# Load preprocessed
|
||||
dataset = load_from_disk("preprocessed_data")
|
||||
```
|
||||
|
||||
### 3. Faster Tokenization
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Enable Rust-based tokenizers (10× faster)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"gpt2",
|
||||
use_fast=True # Use fast Rust tokenizer
|
||||
)
|
||||
```
|
||||
|
||||
## Compilation (PyTorch 2.0+)
|
||||
|
||||
### Compile Model
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Compile model for faster execution
|
||||
model = torch.compile(
|
||||
model,
|
||||
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
|
||||
fullgraph=False, # Compile entire graph (stricter)
|
||||
dynamic=True # Support dynamic shapes
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Speedup**: 10-50% depending on model
|
||||
|
||||
**Compilation modes**:
|
||||
- `default`: Balanced (best for most cases)
|
||||
- `reduce-overhead`: Min overhead (best for small batches)
|
||||
- `max-autotune`: Max performance (slow compile, best for production)
|
||||
|
||||
### Compilation Best Practices
|
||||
|
||||
```python
|
||||
# Bad: Compile after prepare (won't work)
|
||||
model = accelerator.prepare(model)
|
||||
model = torch.compile(model) # Error!
|
||||
|
||||
# Good: Compile before prepare
|
||||
model = torch.compile(model)
|
||||
model = accelerator.prepare(model)
|
||||
|
||||
# Training loop
|
||||
for batch in dataloader:
|
||||
# First iteration: slow (compilation)
|
||||
# Subsequent iterations: fast (compiled)
|
||||
outputs = model(**batch)
|
||||
...
|
||||
```
|
||||
|
||||
## Benchmarking Different Strategies
|
||||
|
||||
### Script Template
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def benchmark_strategy(strategy_name, accelerator_kwargs):
|
||||
"""Benchmark a specific training strategy."""
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
# Setup
|
||||
model = create_model()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
dataloader = create_dataloader()
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(
|
||||
model, optimizer, dataloader
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10:
|
||||
break
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Benchmark
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
||||
num_batches = 100
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= num_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(f"\n{strategy_name}:")
|
||||
print(f" Throughput: {throughput:.2f} samples/sec")
|
||||
print(f" Memory: {memory_used:.2f} GB")
|
||||
print(f" Time: {elapsed:.2f} sec")
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Benchmark different strategies
|
||||
strategies = [
|
||||
("DDP + FP32", {}),
|
||||
("DDP + BF16", {"mixed_precision": "bf16"}),
|
||||
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
|
||||
("FSDP", {"fsdp_plugin": fsdp_plugin}),
|
||||
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
|
||||
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
|
||||
]
|
||||
|
||||
for name, kwargs in strategies:
|
||||
benchmark_strategy(name, kwargs)
|
||||
```
|
||||
|
||||
## Performance Checklist
|
||||
|
||||
**Before training**:
|
||||
- [ ] Use BF16/FP16 mixed precision
|
||||
- [ ] Enable gradient checkpointing (if OOM)
|
||||
- [ ] Set appropriate `num_workers` (2-4 per GPU)
|
||||
- [ ] Enable `pin_memory=True`
|
||||
- [ ] Preprocess data once, not during training
|
||||
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
|
||||
|
||||
**For large models**:
|
||||
- [ ] Use FSDP or DeepSpeed ZeRO-3
|
||||
- [ ] Enable CPU offloading (if still OOM)
|
||||
- [ ] Use Flash Attention
|
||||
- [ ] Increase gradient accumulation
|
||||
|
||||
**For multi-node**:
|
||||
- [ ] Check network topology (InfiniBand > Ethernet)
|
||||
- [ ] Tune NCCL settings
|
||||
- [ ] Use larger bucket sizes for DDP
|
||||
- [ ] Verify NVLink for tensor parallelism
|
||||
|
||||
**Profiling**:
|
||||
- [ ] Profile first 10-100 batches
|
||||
- [ ] Check GPU utilization (`nvidia-smi dmon`)
|
||||
- [ ] Check data loading time (should be <5% of iteration)
|
||||
- [ ] Identify communication bottlenecks
|
||||
|
||||
## Common Performance Issues
|
||||
|
||||
### Issue: Low GPU Utilization (<80%)
|
||||
|
||||
**Cause 1**: Data loading bottleneck
|
||||
```python
|
||||
# Solution: Increase workers and prefetch
|
||||
num_workers=8
|
||||
prefetch_factor=4
|
||||
```
|
||||
|
||||
**Cause 2**: Small batch size
|
||||
```python
|
||||
# Solution: Increase batch size or use gradient accumulation
|
||||
batch_size=32 # Increase
|
||||
gradient_accumulation_steps=4 # Or accumulate
|
||||
```
|
||||
|
||||
### Issue: High Memory Usage
|
||||
|
||||
**Solution 1**: Gradient checkpointing
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Solution 2**: Reduce batch size, increase accumulation
|
||||
```python
|
||||
batch_size=8 # Reduce from 32
|
||||
gradient_accumulation_steps=16 # Maintain effective batch
|
||||
```
|
||||
|
||||
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
|
||||
```python
|
||||
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
|
||||
```
|
||||
|
||||
### Issue: Slow Multi-GPU Training
|
||||
|
||||
**Cause**: Communication bottleneck
|
||||
|
||||
**Check 1**: Gradient bucket size
|
||||
```python
|
||||
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
|
||||
```
|
||||
|
||||
**Check 2**: NCCL settings
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
|
||||
```
|
||||
|
||||
**Check 3**: Network bandwidth
|
||||
```bash
|
||||
# Test inter-GPU bandwidth
|
||||
nvidia-smi nvlink -s
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
|
||||
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
|
||||
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
|
||||
- Flash Attention: https://github.com/Dao-AILab/flash-attention
|
||||
564
skills/mlops/audiocraft/SKILL.md
Normal file
564
skills/mlops/audiocraft/SKILL.md
Normal file
@@ -0,0 +1,564 @@
|
||||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
|
||||
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
|
||||
---
|
||||
|
||||
# AudioCraft: Audio Generation
|
||||
|
||||
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
|
||||
|
||||
## When to use AudioCraft
|
||||
|
||||
**Use AudioCraft when:**
|
||||
- Need to generate music from text descriptions
|
||||
- Creating sound effects and environmental audio
|
||||
- Building music generation applications
|
||||
- Need melody-conditioned music generation
|
||||
- Want stereo audio output
|
||||
- Require controllable music generation with style transfer
|
||||
|
||||
**Key features:**
|
||||
- **MusicGen**: Text-to-music generation with melody conditioning
|
||||
- **AudioGen**: Text-to-sound effects generation
|
||||
- **EnCodec**: High-fidelity neural audio codec
|
||||
- **Multiple model sizes**: Small (300M) to Large (3.3B)
|
||||
- **Stereo support**: Full stereo audio generation
|
||||
- **Style conditioning**: MusicGen-Style for reference-based generation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Stable Audio**: For longer commercial music generation
|
||||
- **Bark**: For text-to-speech with music/sound effects
|
||||
- **Riffusion**: For spectogram-based music generation
|
||||
- **OpenAI Jukebox**: For raw audio generation with lyrics
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# From GitHub (latest)
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Or use HuggingFace Transformers
|
||||
pip install transformers torch torchaudio
|
||||
```
|
||||
|
||||
### Basic text-to-music (AudioCraft)
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=8, # seconds
|
||||
top_k=250,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Generate from text
|
||||
descriptions = ["happy upbeat electronic dance music with synths"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save audio
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Using HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
import scipy
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
model.to("cuda")
|
||||
|
||||
# Generate music
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
guidance_scale=3,
|
||||
max_new_tokens=256
|
||||
)
|
||||
|
||||
# Save
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
|
||||
```
|
||||
|
||||
### Text-to-sound with AudioGen
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
# Load AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
model.set_generation_params(duration=5)
|
||||
|
||||
# Generate sound effects
|
||||
descriptions = ["dog barking in a park with birds chirping"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Architecture overview
|
||||
|
||||
```
|
||||
AudioCraft Architecture:
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Text Encoder (T5) │
|
||||
│ │ │
|
||||
│ Text Embeddings │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ Transformer Decoder (LM) │
|
||||
│ Auto-regressively generates audio tokens │
|
||||
│ Using efficient token interleaving patterns │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ EnCodec Audio Decoder │
|
||||
│ Converts tokens back to audio waveform │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Size | Description | Use Case |
|
||||
|-------|------|-------------|----------|
|
||||
| `musicgen-small` | 300M | Text-to-music | Quick generation |
|
||||
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
|
||||
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
|
||||
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
|
||||
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
|
||||
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
|
||||
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
|
||||
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `duration` | 8.0 | Length in seconds (1-120) |
|
||||
| `top_k` | 250 | Top-k sampling |
|
||||
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
|
||||
| `temperature` | 1.0 | Sampling temperature |
|
||||
| `cfg_coef` | 3.0 | Classifier-free guidance |
|
||||
|
||||
## MusicGen usage
|
||||
|
||||
### Text-to-music generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Configure generation
|
||||
model.set_generation_params(
|
||||
duration=30, # Up to 30 seconds
|
||||
top_k=250, # Sampling diversity
|
||||
top_p=0.0, # 0 = use top_k only
|
||||
temperature=1.0, # Creativity (higher = more varied)
|
||||
cfg_coef=3.0 # Text adherence (higher = stricter)
|
||||
)
|
||||
|
||||
# Generate multiple samples
|
||||
descriptions = [
|
||||
"epic orchestral soundtrack with strings and brass",
|
||||
"chill lo-fi hip hop beat with jazzy piano",
|
||||
"energetic rock song with electric guitar"
|
||||
]
|
||||
|
||||
# Generate (returns [batch, channels, samples])
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save each
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Melody-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
# Load melody model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
model.set_generation_params(duration=30)
|
||||
|
||||
# Load melody audio
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Generate with melody conditioning
|
||||
descriptions = ["acoustic guitar folk song"]
|
||||
wav = model.generate_with_chroma(descriptions, melody, sr)
|
||||
|
||||
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Stereo generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load stereo model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
model.set_generation_params(duration=15)
|
||||
|
||||
descriptions = ["ambient electronic music with wide stereo panning"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch, 2, samples] for stereo
|
||||
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
|
||||
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Audio continuation
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
# Load audio to continue
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load("intro.wav")
|
||||
|
||||
# Process with text and audio
|
||||
inputs = processor(
|
||||
audio=audio.squeeze().numpy(),
|
||||
sampling_rate=sr,
|
||||
text=["continue with a epic chorus"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate continuation
|
||||
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
|
||||
```
|
||||
|
||||
## MusicGen-Style usage
|
||||
|
||||
### Style-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load style model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-style')
|
||||
|
||||
# Configure generation with style
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=5.0 # Style influence
|
||||
)
|
||||
|
||||
# Configure style conditioner
|
||||
model.set_style_conditioner_params(
|
||||
eval_q=3, # RVQ quantizers (1-6)
|
||||
excerpt_length=3.0 # Style excerpt length
|
||||
)
|
||||
|
||||
# Load style reference
|
||||
style_audio, sr = torchaudio.load("reference_style.wav")
|
||||
|
||||
# Generate with text + style
|
||||
descriptions = ["upbeat dance track"]
|
||||
wav = model.generate_with_style(descriptions, style_audio, sr)
|
||||
```
|
||||
|
||||
### Style-only generation (no text)
|
||||
|
||||
```python
|
||||
# Generate matching style without text prompt
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=None # Disable double CFG for style-only
|
||||
)
|
||||
|
||||
wav = model.generate_with_style([None], style_audio, sr)
|
||||
```
|
||||
|
||||
## AudioGen usage
|
||||
|
||||
### Sound effect generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate various sounds
|
||||
descriptions = [
|
||||
"thunderstorm with heavy rain and lightning",
|
||||
"busy city traffic with car horns",
|
||||
"ocean waves crashing on rocks",
|
||||
"crackling campfire in forest"
|
||||
]
|
||||
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## EnCodec usage
|
||||
|
||||
### Audio compression
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# Load EnCodec
|
||||
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Load audio
|
||||
wav, sr = torchaudio.load("audio.wav")
|
||||
|
||||
# Ensure correct sample rate
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Encode to tokens
|
||||
with torch.no_grad():
|
||||
encoded = model.encode(wav.unsqueeze(0))
|
||||
codes = encoded[0] # Audio codes
|
||||
|
||||
# Decode back to audio
|
||||
with torch.no_grad():
|
||||
decoded = model.decode(codes)
|
||||
|
||||
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Music generation pipeline
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenerator:
|
||||
def __init__(self, model_name="facebook/musicgen-medium"):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.sample_rate = 32000
|
||||
|
||||
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
|
||||
self.model.set_generation_params(
|
||||
duration=duration,
|
||||
top_k=250,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([prompt])
|
||||
|
||||
return wav[0].cpu()
|
||||
|
||||
def generate_batch(self, prompts, duration=30):
|
||||
self.model.set_generation_params(duration=duration)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(prompts)
|
||||
|
||||
return wav.cpu()
|
||||
|
||||
def save(self, audio, path):
|
||||
torchaudio.save(path, audio, sample_rate=self.sample_rate)
|
||||
|
||||
# Usage
|
||||
generator = MusicGenerator()
|
||||
audio = generator.generate(
|
||||
"epic cinematic orchestral music",
|
||||
duration=30,
|
||||
temperature=1.0
|
||||
)
|
||||
generator.save(audio, "epic_music.wav")
|
||||
```
|
||||
|
||||
### Workflow 2: Sound design batch processing
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
def batch_generate_sounds(sound_specs, output_dir):
|
||||
"""
|
||||
Generate multiple sounds from specifications.
|
||||
|
||||
Args:
|
||||
sound_specs: list of {"name": str, "description": str, "duration": float}
|
||||
output_dir: output directory path
|
||||
"""
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
for spec in sound_specs:
|
||||
model.set_generation_params(duration=spec.get("duration", 5))
|
||||
|
||||
wav = model.generate([spec["description"]])
|
||||
|
||||
output_path = output_dir / f"{spec['name']}.wav"
|
||||
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
|
||||
|
||||
results.append({
|
||||
"name": spec["name"],
|
||||
"path": str(output_path),
|
||||
"description": spec["description"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
sounds = [
|
||||
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
|
||||
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
|
||||
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
|
||||
]
|
||||
|
||||
results = batch_generate_sounds(sounds, "sound_effects/")
|
||||
```
|
||||
|
||||
### Workflow 3: Gradio demo
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
def generate_music(prompt, duration, temperature, cfg_coef):
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save to temp file
|
||||
path = "temp_output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate_music,
|
||||
inputs=[
|
||||
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
|
||||
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
|
||||
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Demo"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Memory optimization
|
||||
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Clear cache between generations
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Generate shorter durations
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Use half precision
|
||||
model = model.half()
|
||||
```
|
||||
|
||||
### Batch processing efficiency
|
||||
|
||||
```python
|
||||
# Process multiple prompts at once (more efficient)
|
||||
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
|
||||
wav = model.generate(descriptions) # Single batch
|
||||
|
||||
# Instead of
|
||||
for desc in descriptions:
|
||||
wav = model.generate([desc]) # Multiple batches (slower)
|
||||
```
|
||||
|
||||
### GPU memory requirements
|
||||
|
||||
| Model | FP32 VRAM | FP16 VRAM |
|
||||
|-------|-----------|-----------|
|
||||
| musicgen-small | ~4GB | ~2GB |
|
||||
| musicgen-medium | ~8GB | ~4GB |
|
||||
| musicgen-large | ~16GB | ~8GB |
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| CUDA OOM | Use smaller model, reduce duration |
|
||||
| Poor quality | Increase cfg_coef, better prompts |
|
||||
| Generation too short | Check max duration setting |
|
||||
| Audio artifacts | Try different temperature |
|
||||
| Stereo not working | Use stereo model variant |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/audiocraft
|
||||
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
|
||||
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
|
||||
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
|
||||
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen
|
||||
666
skills/mlops/audiocraft/references/advanced-usage.md
Normal file
666
skills/mlops/audiocraft/references/advanced-usage.md
Normal file
@@ -0,0 +1,666 @@
|
||||
# AudioCraft Advanced Usage Guide
|
||||
|
||||
## Fine-tuning MusicGen
|
||||
|
||||
### Custom dataset preparation
|
||||
|
||||
```python
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torchaudio
|
||||
|
||||
def prepare_dataset(audio_dir, output_dir, metadata_file):
|
||||
"""
|
||||
Prepare dataset for MusicGen fine-tuning.
|
||||
|
||||
Directory structure:
|
||||
output_dir/
|
||||
├── audio/
|
||||
│ ├── 0001.wav
|
||||
│ ├── 0002.wav
|
||||
│ └── ...
|
||||
└── metadata.json
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
audio_output = output_dir / "audio"
|
||||
audio_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load metadata (format: {"path": "...", "description": "..."})
|
||||
with open(metadata_file) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
processed = []
|
||||
|
||||
for idx, item in enumerate(metadata):
|
||||
audio_path = Path(audio_dir) / item["path"]
|
||||
|
||||
# Load and resample to 32kHz
|
||||
wav, sr = torchaudio.load(str(audio_path))
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
# Save processed audio
|
||||
output_path = audio_output / f"{idx:04d}.wav"
|
||||
torchaudio.save(str(output_path), wav, sample_rate=32000)
|
||||
|
||||
processed.append({
|
||||
"path": str(output_path.relative_to(output_dir)),
|
||||
"description": item["description"],
|
||||
"duration": wav.shape[1] / 32000
|
||||
})
|
||||
|
||||
# Save processed metadata
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(processed, f, indent=2)
|
||||
|
||||
print(f"Processed {len(processed)} samples")
|
||||
return processed
|
||||
```
|
||||
|
||||
### Fine-tuning with dora
|
||||
|
||||
```bash
|
||||
# AudioCraft uses dora for experiment management
|
||||
# Install dora
|
||||
pip install dora-search
|
||||
|
||||
# Clone AudioCraft
|
||||
git clone https://github.com/facebookresearch/audiocraft.git
|
||||
cd audiocraft
|
||||
|
||||
# Create config for fine-tuning
|
||||
cat > config/solver/musicgen/finetune.yaml << 'EOF'
|
||||
defaults:
|
||||
- musicgen/musicgen_base
|
||||
- /model: lm/musicgen_lm
|
||||
- /conditioner: cond_base
|
||||
|
||||
solver: musicgen
|
||||
autocast: true
|
||||
autocast_dtype: float16
|
||||
|
||||
optim:
|
||||
epochs: 100
|
||||
batch_size: 4
|
||||
lr: 1e-4
|
||||
ema: 0.999
|
||||
optimizer: adamw
|
||||
|
||||
dataset:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
train:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
valid:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
|
||||
checkpoint:
|
||||
save_every: 10
|
||||
keep_every_states: null
|
||||
EOF
|
||||
|
||||
# Run fine-tuning
|
||||
dora run solver=musicgen/finetune
|
||||
```
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from audiocraft.models import MusicGen
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Get the language model component
|
||||
lm = model.lm
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
lm = get_peft_model(lm, lora_config)
|
||||
lm.print_trainable_parameters()
|
||||
```
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
### DataParallel
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Wrap LM with DataParallel
|
||||
if torch.cuda.device_count() > 1:
|
||||
model.lm = nn.DataParallel(model.lm)
|
||||
|
||||
model.to("cuda")
|
||||
```
|
||||
|
||||
### DistributedDataParallel
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup(rank, world_size):
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.lm = model.lm.to(rank)
|
||||
model.lm = DDP(model.lm, device_ids=[rank])
|
||||
|
||||
# Training loop
|
||||
# ...
|
||||
|
||||
dist.destroy_process_group()
|
||||
```
|
||||
|
||||
## Custom Conditioning
|
||||
|
||||
### Adding new conditioners
|
||||
|
||||
```python
|
||||
from audiocraft.modules.conditioners import BaseConditioner
|
||||
import torch
|
||||
|
||||
class CustomConditioner(BaseConditioner):
|
||||
"""Custom conditioner for additional control signals."""
|
||||
|
||||
def __init__(self, dim, output_dim):
|
||||
super().__init__(dim, output_dim)
|
||||
self.embed = torch.nn.Linear(dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
def tokenize(self, x):
|
||||
# Tokenize input for conditioning
|
||||
return x
|
||||
|
||||
# Use with MusicGen
|
||||
from audiocraft.models.builders import get_lm_model
|
||||
|
||||
# Modify model config to include custom conditioner
|
||||
# This requires editing the model configuration
|
||||
```
|
||||
|
||||
### Melody conditioning internals
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
|
||||
import torch
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Access chroma extractor
|
||||
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
|
||||
|
||||
# Manual chroma extraction
|
||||
def extract_chroma(audio, sr):
|
||||
"""Extract chroma features from audio."""
|
||||
import librosa
|
||||
|
||||
# Compute chroma
|
||||
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
|
||||
|
||||
return torch.from_numpy(chroma).float()
|
||||
|
||||
# Use extracted chroma for conditioning
|
||||
chroma = extract_chroma(melody_audio, sample_rate)
|
||||
```
|
||||
|
||||
## EnCodec Deep Dive
|
||||
|
||||
### Custom compression settings
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
|
||||
# Load EnCodec
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Access codec parameters
|
||||
print(f"Sample rate: {encodec.sample_rate}")
|
||||
print(f"Channels: {encodec.channels}")
|
||||
print(f"Cardinality: {encodec.cardinality}") # Codebook size
|
||||
print(f"Num codebooks: {encodec.num_codebooks}")
|
||||
print(f"Frame rate: {encodec.frame_rate}")
|
||||
|
||||
# Encode with specific bandwidth
|
||||
# Lower bandwidth = more compression, lower quality
|
||||
encodec.set_target_bandwidth(6.0) # 6 kbps
|
||||
|
||||
audio = torch.randn(1, 1, 32000) # 1 second
|
||||
encoded = encodec.encode(audio)
|
||||
decoded = encodec.decode(encoded[0])
|
||||
```
|
||||
|
||||
### Streaming encoding
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.models import CompressionModel
|
||||
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
def encode_streaming(audio_stream, chunk_size=32000):
|
||||
"""Encode audio in streaming fashion."""
|
||||
all_codes = []
|
||||
|
||||
for chunk in audio_stream:
|
||||
# Ensure chunk is right shape
|
||||
if chunk.dim() == 1:
|
||||
chunk = chunk.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
codes = encodec.encode(chunk)[0]
|
||||
all_codes.append(codes)
|
||||
|
||||
return torch.cat(all_codes, dim=-1)
|
||||
|
||||
def decode_streaming(codes_stream, output_stream):
|
||||
"""Decode codes in streaming fashion."""
|
||||
for codes in codes_stream:
|
||||
with torch.no_grad():
|
||||
audio = encodec.decode(codes)
|
||||
output_stream.write(audio.cpu().numpy())
|
||||
```
|
||||
|
||||
## MultiBand Diffusion
|
||||
|
||||
### Using MBD for enhanced quality
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen, MultiBandDiffusion
|
||||
|
||||
# Load MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Load MultiBand Diffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate with standard decoder
|
||||
descriptions = ["epic orchestral music"]
|
||||
wav_standard = model.generate(descriptions)
|
||||
|
||||
# Generate tokens and use MBD decoder
|
||||
with torch.no_grad():
|
||||
# Get tokens
|
||||
gen_tokens = model.generate_tokens(descriptions)
|
||||
|
||||
# Decode with MBD
|
||||
wav_mbd = mbd.tokens_to_wav(gen_tokens)
|
||||
|
||||
# Compare quality
|
||||
print(f"Standard shape: {wav_standard.shape}")
|
||||
print(f"MBD shape: {wav_mbd.shape}")
|
||||
```
|
||||
|
||||
## API Server Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import io
|
||||
import base64
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
cfg_coef: float = 3.0
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
audio_base64: str
|
||||
sample_rate: int
|
||||
duration: float
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
model.set_generation_params(
|
||||
duration=min(request.duration, 30),
|
||||
temperature=request.temperature,
|
||||
cfg_coef=request.cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([request.prompt])
|
||||
|
||||
# Convert to bytes
|
||||
buffer = io.BytesIO()
|
||||
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
|
||||
buffer.seek(0)
|
||||
|
||||
audio_base64 = base64.b64encode(buffer.read()).decode()
|
||||
|
||||
return GenerateResponse(
|
||||
audio_base64=audio_base64,
|
||||
sample_rate=32000,
|
||||
duration=wav.shape[-1] / 32000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
|
||||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Batch processing service
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenService:
|
||||
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def generate_async(self, prompt, duration=10):
|
||||
"""Async generation with thread pool."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _generate():
|
||||
with torch.no_grad():
|
||||
self.model.set_generation_params(duration=duration)
|
||||
return self.model.generate([prompt])
|
||||
|
||||
# Run in thread pool
|
||||
wav = await loop.run_in_executor(self.executor, _generate)
|
||||
return wav[0].cpu()
|
||||
|
||||
async def generate_batch_async(self, prompts, duration=10):
|
||||
"""Process multiple prompts concurrently."""
|
||||
tasks = [self.generate_async(p, duration) for p in prompts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Usage
|
||||
service = MusicGenService()
|
||||
|
||||
async def main():
|
||||
prompts = ["jazz piano", "rock guitar", "electronic beats"]
|
||||
results = await service.generate_batch_async(prompts)
|
||||
return results
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### LangChain tool
|
||||
|
||||
```python
|
||||
from langchain.tools import BaseTool
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import tempfile
|
||||
|
||||
class MusicGeneratorTool(BaseTool):
|
||||
name = "music_generator"
|
||||
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
self.model.set_generation_params(duration=15)
|
||||
|
||||
def _run(self, description: str) -> str:
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([description])
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
|
||||
return f"Generated music saved to: {f.name}"
|
||||
|
||||
async def _arun(self, description: str) -> str:
|
||||
return self._run(description)
|
||||
```
|
||||
|
||||
### Gradio with advanced controls
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
models = {}
|
||||
|
||||
def load_model(model_size):
|
||||
if model_size not in models:
|
||||
model_name = f"facebook/musicgen-{model_size}"
|
||||
models[model_size] = MusicGen.get_pretrained(model_name)
|
||||
return models[model_size]
|
||||
|
||||
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
|
||||
model = load_model(model_size)
|
||||
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save
|
||||
path = "output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
gr.Textbox(label="Prompt", lines=3),
|
||||
gr.Slider(1, 30, value=10, label="Duration (s)"),
|
||||
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
|
||||
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
|
||||
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Advanced",
|
||||
allow_flagging="never"
|
||||
)
|
||||
|
||||
demo.launch(share=True)
|
||||
```
|
||||
|
||||
## Audio Processing Pipeline
|
||||
|
||||
### Post-processing chain
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T
|
||||
import numpy as np
|
||||
|
||||
class AudioPostProcessor:
|
||||
def __init__(self, sample_rate=32000):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def normalize(self, audio, target_db=-14.0):
|
||||
"""Normalize audio to target loudness."""
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
def fade_in_out(self, audio, fade_duration=0.1):
|
||||
"""Apply fade in/out."""
|
||||
fade_samples = int(fade_duration * self.sample_rate)
|
||||
|
||||
# Create fade curves
|
||||
fade_in = torch.linspace(0, 1, fade_samples)
|
||||
fade_out = torch.linspace(1, 0, fade_samples)
|
||||
|
||||
# Apply fades
|
||||
audio[..., :fade_samples] *= fade_in
|
||||
audio[..., -fade_samples:] *= fade_out
|
||||
|
||||
return audio
|
||||
|
||||
def apply_reverb(self, audio, decay=0.5):
|
||||
"""Apply simple reverb effect."""
|
||||
impulse = torch.zeros(int(self.sample_rate * 0.5))
|
||||
impulse[0] = 1.0
|
||||
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
|
||||
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
|
||||
|
||||
# Convolve
|
||||
audio = torch.nn.functional.conv1d(
|
||||
audio.unsqueeze(0),
|
||||
impulse.unsqueeze(0).unsqueeze(0),
|
||||
padding=len(impulse) // 2
|
||||
).squeeze(0)
|
||||
|
||||
return audio
|
||||
|
||||
def process(self, audio):
|
||||
"""Full processing pipeline."""
|
||||
audio = self.normalize(audio)
|
||||
audio = self.fade_in_out(audio)
|
||||
return audio
|
||||
|
||||
# Usage with MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
wav = model.generate(["chill ambient music"])
|
||||
processor = AudioPostProcessor()
|
||||
wav_processed = processor.process(wav[0].cpu())
|
||||
|
||||
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Audio quality metrics
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.metrics import CLAPTextConsistencyMetric
|
||||
from audiocraft.data.audio import audio_read
|
||||
|
||||
def evaluate_generation(audio_path, text_prompt):
|
||||
"""Evaluate generated audio quality."""
|
||||
# Load audio
|
||||
wav, sr = audio_read(audio_path)
|
||||
|
||||
# CLAP consistency (text-audio alignment)
|
||||
clap_metric = CLAPTextConsistencyMetric()
|
||||
clap_score = clap_metric.compute(wav, [text_prompt])
|
||||
|
||||
return {
|
||||
"clap_score": clap_score,
|
||||
"duration": wav.shape[-1] / sr
|
||||
}
|
||||
|
||||
# Batch evaluation
|
||||
def evaluate_batch(generations):
|
||||
"""Evaluate multiple generations."""
|
||||
results = []
|
||||
for gen in generations:
|
||||
result = evaluate_generation(gen["path"], gen["prompt"])
|
||||
result["prompt"] = gen["prompt"]
|
||||
results.append(result)
|
||||
|
||||
# Aggregate
|
||||
avg_clap = sum(r["clap_score"] for r in results) / len(results)
|
||||
return {
|
||||
"individual": results,
|
||||
"average_clap": avg_clap
|
||||
}
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
### MusicGen variants benchmark
|
||||
|
||||
| Model | CLAP Score | Generation Time (10s) | VRAM |
|
||||
|-------|------------|----------------------|------|
|
||||
| musicgen-small | 0.35 | ~5s | 2GB |
|
||||
| musicgen-medium | 0.42 | ~15s | 4GB |
|
||||
| musicgen-large | 0.48 | ~30s | 8GB |
|
||||
| musicgen-melody | 0.45 | ~15s | 4GB |
|
||||
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
|
||||
|
||||
### Prompt engineering tips
|
||||
|
||||
```python
|
||||
# Good prompts - specific and descriptive
|
||||
good_prompts = [
|
||||
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
|
||||
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
|
||||
"funky disco groove with slap bass, brass section, and rhythmic guitar"
|
||||
]
|
||||
|
||||
# Bad prompts - too vague
|
||||
bad_prompts = [
|
||||
"nice music",
|
||||
"song",
|
||||
"good beat"
|
||||
]
|
||||
|
||||
# Structure: [mood] [genre] with [instruments] at [tempo/style]
|
||||
```
|
||||
504
skills/mlops/audiocraft/references/troubleshooting.md
Normal file
504
skills/mlops/audiocraft/references/troubleshooting.md
Normal file
@@ -0,0 +1,504 @@
|
||||
# AudioCraft Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# Or from GitHub
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Verify installation
|
||||
python -c "from audiocraft.models import MusicGen; print('OK')"
|
||||
```
|
||||
|
||||
### FFmpeg not found
|
||||
|
||||
**Error**: `RuntimeError: ffmpeg not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows (using conda)
|
||||
conda install -c conda-forge ffmpeg
|
||||
|
||||
# Verify
|
||||
ffmpeg -version
|
||||
```
|
||||
|
||||
### PyTorch CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
python -c "import torch; print(torch.version.cuda)"
|
||||
|
||||
# Install matching PyTorch
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For CUDA 11.8
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
### xformers issues
|
||||
|
||||
**Error**: `ImportError: xformers` related errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install xformers for memory efficiency
|
||||
pip install xformers
|
||||
|
||||
# Or disable xformers
|
||||
export AUDIOCRAFT_USE_XFORMERS=0
|
||||
|
||||
# In Python
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
|
||||
from audiocraft.models import MusicGen
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Force CPU loading first
|
||||
import torch
|
||||
device = "cpu"
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
|
||||
model = model.to("cuda")
|
||||
|
||||
# Use HuggingFace with device_map
|
||||
from transformers import MusicgenForConditionalGeneration
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||
"facebook/musicgen-small",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Download failures
|
||||
|
||||
**Error**: Connection errors or incomplete downloads
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Set cache directory
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
|
||||
|
||||
# Or for HuggingFace
|
||||
os.environ["HF_HOME"] = "/path/to/hf_cache"
|
||||
|
||||
# Resume download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("facebook/musicgen-small", resume_download=True)
|
||||
|
||||
# Use local files
|
||||
model = MusicGen.get_pretrained('/local/path/to/model')
|
||||
```
|
||||
|
||||
### Wrong model type
|
||||
|
||||
**Error**: Loading wrong model for task
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# For text-to-music: use MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# For text-to-sound: use AudioGen
|
||||
from audiocraft.models import AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
# For melody conditioning: use melody variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# For stereo: use stereo variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Empty or silent output
|
||||
|
||||
**Problem**: Generated audio is silent or very quiet
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check output
|
||||
wav = model.generate(["upbeat music"])
|
||||
print(f"Shape: {wav.shape}")
|
||||
print(f"Max amplitude: {wav.abs().max().item()}")
|
||||
print(f"Mean amplitude: {wav.abs().mean().item()}")
|
||||
|
||||
# If too quiet, normalize
|
||||
def normalize_audio(audio, target_db=-14.0):
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
wav_normalized = normalize_audio(wav)
|
||||
```
|
||||
|
||||
### Poor quality output
|
||||
|
||||
**Problem**: Generated music sounds bad or noisy
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use larger model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-large')
|
||||
|
||||
# Adjust generation parameters
|
||||
model.set_generation_params(
|
||||
duration=15,
|
||||
top_k=250, # Increase for more diversity
|
||||
temperature=0.8, # Lower for more focused output
|
||||
cfg_coef=4.0 # Increase for better text adherence
|
||||
)
|
||||
|
||||
# Use better prompts
|
||||
# Bad: "music"
|
||||
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
|
||||
|
||||
# Try MultiBand Diffusion
|
||||
from audiocraft.models import MultiBandDiffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
tokens = model.generate_tokens(["prompt"])
|
||||
wav = mbd.tokens_to_wav(tokens)
|
||||
```
|
||||
|
||||
### Generation too short
|
||||
|
||||
**Problem**: Audio shorter than expected
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check duration setting
|
||||
model.set_generation_params(duration=30) # Set before generate
|
||||
|
||||
# Verify in generation
|
||||
print(f"Duration setting: {model.generation_params}")
|
||||
|
||||
# Check output shape
|
||||
wav = model.generate(["prompt"])
|
||||
actual_duration = wav.shape[-1] / 32000
|
||||
print(f"Actual duration: {actual_duration}s")
|
||||
|
||||
# Note: max duration is typically 30s
|
||||
```
|
||||
|
||||
### Melody conditioning fails
|
||||
|
||||
**Error**: Issues with melody-conditioned generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load melody model (not base model)
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Load and prepare melody
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Resample to model sample rate if needed
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
melody = resampler(melody)
|
||||
|
||||
# Ensure correct shape [batch, channels, samples]
|
||||
if melody.dim() == 1:
|
||||
melody = melody.unsqueeze(0).unsqueeze(0)
|
||||
elif melody.dim() == 2:
|
||||
melody = melody.unsqueeze(0)
|
||||
|
||||
# Convert stereo to mono
|
||||
if melody.shape[1] > 1:
|
||||
melody = melody.mean(dim=1, keepdim=True)
|
||||
|
||||
# Generate with melody
|
||||
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
|
||||
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Clear cache before generation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Generate one at a time
|
||||
for prompt in prompts:
|
||||
wav = model.generate([prompt])
|
||||
save_audio(wav)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use CPU for very large generations
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
|
||||
```
|
||||
|
||||
### Memory leak during batch processing
|
||||
|
||||
**Problem**: Memory grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(model, prompts):
|
||||
results = []
|
||||
|
||||
for prompt in prompts:
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
results.append(wav.cpu())
|
||||
|
||||
# Cleanup
|
||||
del wav
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
# Use context manager
|
||||
with torch.inference_mode():
|
||||
wav = model.generate(["prompt"])
|
||||
```
|
||||
|
||||
## Audio Format Issues
|
||||
|
||||
### Wrong sample rate
|
||||
|
||||
**Problem**: Audio plays at wrong speed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
|
||||
# MusicGen outputs at 32kHz
|
||||
sample_rate = 32000
|
||||
|
||||
# AudioGen outputs at 16kHz
|
||||
sample_rate = 16000
|
||||
|
||||
# Always use correct rate when saving
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
|
||||
|
||||
# Resample if needed
|
||||
resampler = torchaudio.transforms.Resample(32000, 44100)
|
||||
wav_resampled = resampler(wav)
|
||||
```
|
||||
|
||||
### Stereo/mono mismatch
|
||||
|
||||
**Problem**: Wrong number of channels
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check model type
|
||||
print(f"Audio channels: {wav.shape}")
|
||||
# Mono: [batch, 1, samples]
|
||||
# Stereo: [batch, 2, samples]
|
||||
|
||||
# Convert mono to stereo
|
||||
if wav.shape[1] == 1:
|
||||
wav_stereo = wav.repeat(1, 2, 1)
|
||||
|
||||
# Convert stereo to mono
|
||||
if wav.shape[1] == 2:
|
||||
wav_mono = wav.mean(dim=1, keepdim=True)
|
||||
|
||||
# Use stereo model for stereo output
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
### Clipping and distortion
|
||||
|
||||
**Problem**: Audio has clipping or distortion
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check for clipping
|
||||
max_val = wav.abs().max().item()
|
||||
print(f"Max amplitude: {max_val}")
|
||||
|
||||
# Normalize to prevent clipping
|
||||
if max_val > 1.0:
|
||||
wav = wav / max_val
|
||||
|
||||
# Apply soft clipping
|
||||
def soft_clip(x, threshold=0.9):
|
||||
return torch.tanh(x / threshold) * threshold
|
||||
|
||||
wav_clipped = soft_clip(wav)
|
||||
|
||||
# Lower temperature during generation
|
||||
model.set_generation_params(temperature=0.7) # More controlled
|
||||
```
|
||||
|
||||
## HuggingFace Transformers Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with MusicgenProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
# Load matching processor and model
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
# Ensure inputs are on same device
|
||||
inputs = processor(
|
||||
text=["prompt"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
# Check processor configuration
|
||||
print(processor.tokenizer)
|
||||
print(processor.feature_extractor)
|
||||
```
|
||||
|
||||
### Generation parameter errors
|
||||
|
||||
**Error**: Invalid generation parameters
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# HuggingFace uses different parameter names
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True, # Enable sampling
|
||||
guidance_scale=3.0, # CFG (not cfg_coef)
|
||||
max_new_tokens=256, # Token limit (not duration)
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Calculate tokens from duration
|
||||
# ~50 tokens per second
|
||||
duration_seconds = 10
|
||||
max_tokens = duration_seconds * 50
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Use GPU
|
||||
model.to("cuda")
|
||||
|
||||
# Enable flash attention if available
|
||||
# (requires compatible hardware)
|
||||
|
||||
# Batch multiple prompts
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
wav = model.generate(prompts) # Single batch is faster than loop
|
||||
|
||||
# Use compile (PyTorch 2.0+)
|
||||
model.lm = torch.compile(model.lm)
|
||||
```
|
||||
|
||||
### CPU fallback
|
||||
|
||||
**Problem**: Generation running on CPU instead of GPU
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check CUDA availability
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Explicitly move to GPU
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.to("cuda")
|
||||
|
||||
# Verify model device
|
||||
print(f"Model device: {next(model.lm.parameters()).device}")
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
|
||||
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
|
||||
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
|
||||
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
|
||||
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
|
||||
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2306.05284
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version
|
||||
- CUDA version
|
||||
- AudioCraft version: `pip show audiocraft`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Hardware (GPU model, VRAM)
|
||||
158
skills/mlops/axolotl/SKILL.md
Normal file
158
skills/mlops/axolotl/SKILL.md
Normal file
@@ -0,0 +1,158 @@
|
||||
---
|
||||
name: axolotl
|
||||
description: Expert guidance for fine-tuning LLMs with Axolotl - YAML configs, 100+ models, LoRA/QLoRA, DPO/KTO/ORPO/GRPO, multimodal support
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Fine-Tuning, Axolotl, LLM, LoRA, QLoRA, DPO, KTO, ORPO, GRPO, YAML, HuggingFace, DeepSpeed, Multimodal]
|
||||
dependencies: [axolotl, torch, transformers, datasets, peft, accelerate, deepspeed]
|
||||
---
|
||||
|
||||
# Axolotl Skill
|
||||
|
||||
Comprehensive assistance with axolotl development, generated from official documentation.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
This skill should be triggered when:
|
||||
- Working with axolotl
|
||||
- Asking about axolotl features or APIs
|
||||
- Implementing axolotl solutions
|
||||
- Debugging axolotl code
|
||||
- Learning axolotl best practices
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Common Patterns
|
||||
|
||||
**Pattern 1:** To validate that acceptable data transfer speeds exist for your training job, running NCCL Tests can help pinpoint bottlenecks, for example:
|
||||
|
||||
```
|
||||
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||
```
|
||||
|
||||
**Pattern 2:** Configure your model to use FSDP in the Axolotl yaml. For example:
|
||||
|
||||
```
|
||||
fsdp_version: 2
|
||||
fsdp_config:
|
||||
offload_params: true
|
||||
state_dict_type: FULL_STATE_DICT
|
||||
auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
||||
reshard_after_forward: true
|
||||
```
|
||||
|
||||
**Pattern 3:** The context_parallel_size should be a divisor of the total number of GPUs. For example:
|
||||
|
||||
```
|
||||
context_parallel_size
|
||||
```
|
||||
|
||||
**Pattern 4:** For example: - With 8 GPUs and no sequence parallelism: 8 different batches processed per step - With 8 GPUs and context_parallel_size=4: Only 2 different batches processed per step (each split across 4 GPUs) - If your per-GPU micro_batch_size is 2, the global batch size decreases from 16 to 4
|
||||
|
||||
```
|
||||
context_parallel_size=4
|
||||
```
|
||||
|
||||
**Pattern 5:** Setting save_compressed: true in your configuration enables saving models in a compressed format, which: - Reduces disk space usage by approximately 40% - Maintains compatibility with vLLM for accelerated inference - Maintains compatibility with llmcompressor for further optimization (example: quantization)
|
||||
|
||||
```
|
||||
save_compressed: true
|
||||
```
|
||||
|
||||
**Pattern 6:** Note It is not necessary to place your integration in the integrations folder. It can be in any location, so long as it’s installed in a package in your python env. See this repo for an example: https://github.com/axolotl-ai-cloud/diff-transformer
|
||||
|
||||
```
|
||||
integrations
|
||||
```
|
||||
|
||||
**Pattern 7:** Handle both single-example and batched data. - single example: sample[‘input_ids’] is a list[int] - batched data: sample[‘input_ids’] is a list[list[int]]
|
||||
|
||||
```
|
||||
utils.trainer.drop_long_seq(sample, sequence_len=2048, min_sequence_len=2)
|
||||
```
|
||||
|
||||
### Example Code Patterns
|
||||
|
||||
**Example 1** (python):
|
||||
```python
|
||||
cli.cloud.modal_.ModalCloud(config, app=None)
|
||||
```
|
||||
|
||||
**Example 2** (python):
|
||||
```python
|
||||
cli.cloud.modal_.run_cmd(cmd, run_folder, volumes=None)
|
||||
```
|
||||
|
||||
**Example 3** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer(
|
||||
*_args,
|
||||
bench_data_collator=None,
|
||||
eval_data_collator=None,
|
||||
dataset_tags=None,
|
||||
**kwargs,
|
||||
)
|
||||
```
|
||||
|
||||
**Example 4** (python):
|
||||
```python
|
||||
core.trainers.base.AxolotlTrainer.log(logs, start_time=None)
|
||||
```
|
||||
|
||||
**Example 5** (python):
|
||||
```python
|
||||
prompt_strategies.input_output.RawInputOutputPrompter()
|
||||
```
|
||||
|
||||
## Reference Files
|
||||
|
||||
This skill includes comprehensive documentation in `references/`:
|
||||
|
||||
- **api.md** - Api documentation
|
||||
- **dataset-formats.md** - Dataset-Formats documentation
|
||||
- **other.md** - Other documentation
|
||||
|
||||
Use `view` to read specific reference files when detailed information is needed.
|
||||
|
||||
## Working with This Skill
|
||||
|
||||
### For Beginners
|
||||
Start with the getting_started or tutorials reference files for foundational concepts.
|
||||
|
||||
### For Specific Features
|
||||
Use the appropriate category reference file (api, guides, etc.) for detailed information.
|
||||
|
||||
### For Code Examples
|
||||
The quick reference section above contains common patterns extracted from the official docs.
|
||||
|
||||
## Resources
|
||||
|
||||
### references/
|
||||
Organized documentation extracted from official sources. These files contain:
|
||||
- Detailed explanations
|
||||
- Code examples with language annotations
|
||||
- Links to original documentation
|
||||
- Table of contents for quick navigation
|
||||
|
||||
### scripts/
|
||||
Add helper scripts here for common automation tasks.
|
||||
|
||||
### assets/
|
||||
Add templates, boilerplate, or example projects here.
|
||||
|
||||
## Notes
|
||||
|
||||
- This skill was automatically generated from official documentation
|
||||
- Reference files preserve the structure and examples from source docs
|
||||
- Code examples include language detection for better syntax highlighting
|
||||
- Quick reference patterns are extracted from common usage examples in the docs
|
||||
|
||||
## Updating
|
||||
|
||||
To refresh this skill with updated documentation:
|
||||
1. Re-run the scraper with the same configuration
|
||||
2. The skill will be rebuilt with the latest information
|
||||
|
||||
|
||||
5548
skills/mlops/axolotl/references/api.md
Normal file
5548
skills/mlops/axolotl/references/api.md
Normal file
File diff suppressed because it is too large
Load Diff
1029
skills/mlops/axolotl/references/dataset-formats.md
Normal file
1029
skills/mlops/axolotl/references/dataset-formats.md
Normal file
File diff suppressed because it is too large
Load Diff
15
skills/mlops/axolotl/references/index.md
Normal file
15
skills/mlops/axolotl/references/index.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# Axolotl Documentation Index
|
||||
|
||||
## Categories
|
||||
|
||||
### Api
|
||||
**File:** `api.md`
|
||||
**Pages:** 150
|
||||
|
||||
### Dataset-Formats
|
||||
**File:** `dataset-formats.md`
|
||||
**Pages:** 9
|
||||
|
||||
### Other
|
||||
**File:** `other.md`
|
||||
**Pages:** 26
|
||||
3563
skills/mlops/axolotl/references/other.md
Normal file
3563
skills/mlops/axolotl/references/other.md
Normal file
File diff suppressed because it is too large
Load Diff
406
skills/mlops/chroma/SKILL.md
Normal file
406
skills/mlops/chroma/SKILL.md
Normal file
@@ -0,0 +1,406 @@
|
||||
---
|
||||
name: chroma
|
||||
description: Open-source embedding database for AI applications. Store embeddings and metadata, perform vector and full-text search, filter by metadata. Simple 4-function API. Scales from notebooks to production clusters. Use for semantic search, RAG applications, or document retrieval. Best for local development and open-source projects.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [RAG, Chroma, Vector Database, Embeddings, Semantic Search, Open Source, Self-Hosted, Document Retrieval, Metadata Filtering]
|
||||
dependencies: [chromadb, sentence-transformers]
|
||||
---
|
||||
|
||||
# Chroma - Open-Source Embedding Database
|
||||
|
||||
The AI-native database for building LLM applications with memory.
|
||||
|
||||
## When to use Chroma
|
||||
|
||||
**Use Chroma when:**
|
||||
- Building RAG (retrieval-augmented generation) applications
|
||||
- Need local/self-hosted vector database
|
||||
- Want open-source solution (Apache 2.0)
|
||||
- Prototyping in notebooks
|
||||
- Semantic search over documents
|
||||
- Storing embeddings with metadata
|
||||
|
||||
**Metrics**:
|
||||
- **24,300+ GitHub stars**
|
||||
- **1,900+ forks**
|
||||
- **v1.3.3** (stable, weekly releases)
|
||||
- **Apache 2.0 license**
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Pinecone**: Managed cloud, auto-scaling
|
||||
- **FAISS**: Pure similarity search, no metadata
|
||||
- **Weaviate**: Production ML-native database
|
||||
- **Qdrant**: High performance, Rust-based
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Python
|
||||
pip install chromadb
|
||||
|
||||
# JavaScript/TypeScript
|
||||
npm install chromadb @chroma-core/default-embed
|
||||
```
|
||||
|
||||
### Basic usage (Python)
|
||||
|
||||
```python
|
||||
import chromadb
|
||||
|
||||
# Create client
|
||||
client = chromadb.Client()
|
||||
|
||||
# Create collection
|
||||
collection = client.create_collection(name="my_collection")
|
||||
|
||||
# Add documents
|
||||
collection.add(
|
||||
documents=["This is document 1", "This is document 2"],
|
||||
metadatas=[{"source": "doc1"}, {"source": "doc2"}],
|
||||
ids=["id1", "id2"]
|
||||
)
|
||||
|
||||
# Query
|
||||
results = collection.query(
|
||||
query_texts=["document about topic"],
|
||||
n_results=2
|
||||
)
|
||||
|
||||
print(results)
|
||||
```
|
||||
|
||||
## Core operations
|
||||
|
||||
### 1. Create collection
|
||||
|
||||
```python
|
||||
# Simple collection
|
||||
collection = client.create_collection("my_docs")
|
||||
|
||||
# With custom embedding function
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key="your-key",
|
||||
model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
collection = client.create_collection(
|
||||
name="my_docs",
|
||||
embedding_function=openai_ef
|
||||
)
|
||||
|
||||
# Get existing collection
|
||||
collection = client.get_collection("my_docs")
|
||||
|
||||
# Delete collection
|
||||
client.delete_collection("my_docs")
|
||||
```
|
||||
|
||||
### 2. Add documents
|
||||
|
||||
```python
|
||||
# Add with auto-generated IDs
|
||||
collection.add(
|
||||
documents=["Doc 1", "Doc 2", "Doc 3"],
|
||||
metadatas=[
|
||||
{"source": "web", "category": "tutorial"},
|
||||
{"source": "pdf", "page": 5},
|
||||
{"source": "api", "timestamp": "2025-01-01"}
|
||||
],
|
||||
ids=["id1", "id2", "id3"]
|
||||
)
|
||||
|
||||
# Add with custom embeddings
|
||||
collection.add(
|
||||
embeddings=[[0.1, 0.2, ...], [0.3, 0.4, ...]],
|
||||
documents=["Doc 1", "Doc 2"],
|
||||
ids=["id1", "id2"]
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Query (similarity search)
|
||||
|
||||
```python
|
||||
# Basic query
|
||||
results = collection.query(
|
||||
query_texts=["machine learning tutorial"],
|
||||
n_results=5
|
||||
)
|
||||
|
||||
# Query with filters
|
||||
results = collection.query(
|
||||
query_texts=["Python programming"],
|
||||
n_results=3,
|
||||
where={"source": "web"}
|
||||
)
|
||||
|
||||
# Query with metadata filters
|
||||
results = collection.query(
|
||||
query_texts=["advanced topics"],
|
||||
where={
|
||||
"$and": [
|
||||
{"category": "tutorial"},
|
||||
{"difficulty": {"$gte": 3}}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Access results
|
||||
print(results["documents"]) # List of matching documents
|
||||
print(results["metadatas"]) # Metadata for each doc
|
||||
print(results["distances"]) # Similarity scores
|
||||
print(results["ids"]) # Document IDs
|
||||
```
|
||||
|
||||
### 4. Get documents
|
||||
|
||||
```python
|
||||
# Get by IDs
|
||||
docs = collection.get(
|
||||
ids=["id1", "id2"]
|
||||
)
|
||||
|
||||
# Get with filters
|
||||
docs = collection.get(
|
||||
where={"category": "tutorial"},
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Get all documents
|
||||
docs = collection.get()
|
||||
```
|
||||
|
||||
### 5. Update documents
|
||||
|
||||
```python
|
||||
# Update document content
|
||||
collection.update(
|
||||
ids=["id1"],
|
||||
documents=["Updated content"],
|
||||
metadatas=[{"source": "updated"}]
|
||||
)
|
||||
```
|
||||
|
||||
### 6. Delete documents
|
||||
|
||||
```python
|
||||
# Delete by IDs
|
||||
collection.delete(ids=["id1", "id2"])
|
||||
|
||||
# Delete with filter
|
||||
collection.delete(
|
||||
where={"source": "outdated"}
|
||||
)
|
||||
```
|
||||
|
||||
## Persistent storage
|
||||
|
||||
```python
|
||||
# Persist to disk
|
||||
client = chromadb.PersistentClient(path="./chroma_db")
|
||||
|
||||
collection = client.create_collection("my_docs")
|
||||
collection.add(documents=["Doc 1"], ids=["id1"])
|
||||
|
||||
# Data persisted automatically
|
||||
# Reload later with same path
|
||||
client = chromadb.PersistentClient(path="./chroma_db")
|
||||
collection = client.get_collection("my_docs")
|
||||
```
|
||||
|
||||
## Embedding functions
|
||||
|
||||
### Default (Sentence Transformers)
|
||||
|
||||
```python
|
||||
# Uses sentence-transformers by default
|
||||
collection = client.create_collection("my_docs")
|
||||
# Default model: all-MiniLM-L6-v2
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
from chromadb.utils import embedding_functions
|
||||
|
||||
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key="your-key",
|
||||
model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
collection = client.create_collection(
|
||||
name="openai_docs",
|
||||
embedding_function=openai_ef
|
||||
)
|
||||
```
|
||||
|
||||
### HuggingFace
|
||||
|
||||
```python
|
||||
huggingface_ef = embedding_functions.HuggingFaceEmbeddingFunction(
|
||||
api_key="your-key",
|
||||
model_name="sentence-transformers/all-mpnet-base-v2"
|
||||
)
|
||||
|
||||
collection = client.create_collection(
|
||||
name="hf_docs",
|
||||
embedding_function=huggingface_ef
|
||||
)
|
||||
```
|
||||
|
||||
### Custom embedding function
|
||||
|
||||
```python
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
|
||||
class MyEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
# Your embedding logic
|
||||
return embeddings
|
||||
|
||||
my_ef = MyEmbeddingFunction()
|
||||
collection = client.create_collection(
|
||||
name="custom_docs",
|
||||
embedding_function=my_ef
|
||||
)
|
||||
```
|
||||
|
||||
## Metadata filtering
|
||||
|
||||
```python
|
||||
# Exact match
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={"category": "tutorial"}
|
||||
)
|
||||
|
||||
# Comparison operators
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={"page": {"$gt": 10}} # $gt, $gte, $lt, $lte, $ne
|
||||
)
|
||||
|
||||
# Logical operators
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={
|
||||
"$and": [
|
||||
{"category": "tutorial"},
|
||||
{"difficulty": {"$lte": 3}}
|
||||
]
|
||||
} # Also: $or
|
||||
)
|
||||
|
||||
# Contains
|
||||
results = collection.query(
|
||||
query_texts=["query"],
|
||||
where={"tags": {"$in": ["python", "ml"]}}
|
||||
)
|
||||
```
|
||||
|
||||
## LangChain integration
|
||||
|
||||
```python
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
# Split documents
|
||||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000)
|
||||
docs = text_splitter.split_documents(documents)
|
||||
|
||||
# Create Chroma vector store
|
||||
vectorstore = Chroma.from_documents(
|
||||
documents=docs,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
persist_directory="./chroma_db"
|
||||
)
|
||||
|
||||
# Query
|
||||
results = vectorstore.similarity_search("machine learning", k=3)
|
||||
|
||||
# As retriever
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
||||
```
|
||||
|
||||
## LlamaIndex integration
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.core import VectorStoreIndex, StorageContext
|
||||
import chromadb
|
||||
|
||||
# Initialize Chroma
|
||||
db = chromadb.PersistentClient(path="./chroma_db")
|
||||
collection = db.get_or_create_collection("my_collection")
|
||||
|
||||
# Create vector store
|
||||
vector_store = ChromaVectorStore(chroma_collection=collection)
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
|
||||
# Create index
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents,
|
||||
storage_context=storage_context
|
||||
)
|
||||
|
||||
# Query
|
||||
query_engine = index.as_query_engine()
|
||||
response = query_engine.query("What is machine learning?")
|
||||
```
|
||||
|
||||
## Server mode
|
||||
|
||||
```python
|
||||
# Run Chroma server
|
||||
# Terminal: chroma run --path ./chroma_db --port 8000
|
||||
|
||||
# Connect to server
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
|
||||
client = chromadb.HttpClient(
|
||||
host="localhost",
|
||||
port=8000,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
|
||||
# Use as normal
|
||||
collection = client.get_or_create_collection("my_docs")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use persistent client** - Don't lose data on restart
|
||||
2. **Add metadata** - Enables filtering and tracking
|
||||
3. **Batch operations** - Add multiple docs at once
|
||||
4. **Choose right embedding model** - Balance speed/quality
|
||||
5. **Use filters** - Narrow search space
|
||||
6. **Unique IDs** - Avoid collisions
|
||||
7. **Regular backups** - Copy chroma_db directory
|
||||
8. **Monitor collection size** - Scale up if needed
|
||||
9. **Test embedding functions** - Ensure quality
|
||||
10. **Use server mode for production** - Better for multi-user
|
||||
|
||||
## Performance
|
||||
|
||||
| Operation | Latency | Notes |
|
||||
|-----------|---------|-------|
|
||||
| Add 100 docs | ~1-3s | With embedding |
|
||||
| Query (top 10) | ~50-200ms | Depends on collection size |
|
||||
| Metadata filter | ~10-50ms | Fast with proper indexing |
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/chroma-core/chroma ⭐ 24,300+
|
||||
- **Docs**: https://docs.trychroma.com
|
||||
- **Discord**: https://discord.gg/MMeYNTmh3x
|
||||
- **Version**: 1.3.3+
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
38
skills/mlops/chroma/references/integration.md
Normal file
38
skills/mlops/chroma/references/integration.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Chroma Integration Guide
|
||||
|
||||
Integration with LangChain, LlamaIndex, and frameworks.
|
||||
|
||||
## LangChain
|
||||
|
||||
```python
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
vectorstore = Chroma.from_documents(
|
||||
documents=docs,
|
||||
embedding=OpenAIEmbeddings(),
|
||||
persist_directory="./chroma_db"
|
||||
)
|
||||
|
||||
# Query
|
||||
results = vectorstore.similarity_search("query", k=3)
|
||||
|
||||
# As retriever
|
||||
retriever = vectorstore.as_retriever()
|
||||
```
|
||||
|
||||
## LlamaIndex
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
import chromadb
|
||||
|
||||
db = chromadb.PersistentClient(path="./chroma_db")
|
||||
collection = db.get_or_create_collection("docs")
|
||||
|
||||
vector_store = ChromaVectorStore(chroma_collection=collection)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Docs**: https://docs.trychroma.com
|
||||
253
skills/mlops/clip/SKILL.md
Normal file
253
skills/mlops/clip/SKILL.md
Normal file
@@ -0,0 +1,253 @@
|
||||
---
|
||||
name: clip
|
||||
description: OpenAI's model connecting vision and language. Enables zero-shot image classification, image-text matching, and cross-modal retrieval. Trained on 400M image-text pairs. Use for image search, content moderation, or vision-language tasks without fine-tuning. Best for general-purpose image understanding.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Multimodal, CLIP, Vision-Language, Zero-Shot, Image Classification, OpenAI, Image Search, Cross-Modal Retrieval, Content Moderation]
|
||||
dependencies: [transformers, torch, pillow]
|
||||
---
|
||||
|
||||
# CLIP - Contrastive Language-Image Pre-Training
|
||||
|
||||
OpenAI's model that understands images from natural language.
|
||||
|
||||
## When to use CLIP
|
||||
|
||||
**Use when:**
|
||||
- Zero-shot image classification (no training data needed)
|
||||
- Image-text similarity/matching
|
||||
- Semantic image search
|
||||
- Content moderation (detect NSFW, violence)
|
||||
- Visual question answering
|
||||
- Cross-modal retrieval (image→text, text→image)
|
||||
|
||||
**Metrics**:
|
||||
- **25,300+ GitHub stars**
|
||||
- Trained on 400M image-text pairs
|
||||
- Matches ResNet-50 on ImageNet (zero-shot)
|
||||
- MIT License
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **BLIP-2**: Better captioning
|
||||
- **LLaVA**: Vision-language chat
|
||||
- **Segment Anything**: Image segmentation
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/openai/CLIP.git
|
||||
pip install torch torchvision ftfy regex tqdm
|
||||
```
|
||||
|
||||
### Zero-shot classification
|
||||
|
||||
```python
|
||||
import torch
|
||||
import clip
|
||||
from PIL import Image
|
||||
|
||||
# Load model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model, preprocess = clip.load("ViT-B/32", device=device)
|
||||
|
||||
# Load image
|
||||
image = preprocess(Image.open("photo.jpg")).unsqueeze(0).to(device)
|
||||
|
||||
# Define possible labels
|
||||
text = clip.tokenize(["a dog", "a cat", "a bird", "a car"]).to(device)
|
||||
|
||||
# Compute similarity
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image)
|
||||
text_features = model.encode_text(text)
|
||||
|
||||
# Cosine similarity
|
||||
logits_per_image, logits_per_text = model(image, text)
|
||||
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
|
||||
# Print results
|
||||
labels = ["a dog", "a cat", "a bird", "a car"]
|
||||
for label, prob in zip(labels, probs[0]):
|
||||
print(f"{label}: {prob:.2%}")
|
||||
```
|
||||
|
||||
## Available models
|
||||
|
||||
```python
|
||||
# Models (sorted by size)
|
||||
models = [
|
||||
"RN50", # ResNet-50
|
||||
"RN101", # ResNet-101
|
||||
"ViT-B/32", # Vision Transformer (recommended)
|
||||
"ViT-B/16", # Better quality, slower
|
||||
"ViT-L/14", # Best quality, slowest
|
||||
]
|
||||
|
||||
model, preprocess = clip.load("ViT-B/32")
|
||||
```
|
||||
|
||||
| Model | Parameters | Speed | Quality |
|
||||
|-------|------------|-------|---------|
|
||||
| RN50 | 102M | Fast | Good |
|
||||
| ViT-B/32 | 151M | Medium | Better |
|
||||
| ViT-L/14 | 428M | Slow | Best |
|
||||
|
||||
## Image-text similarity
|
||||
|
||||
```python
|
||||
# Compute embeddings
|
||||
image_features = model.encode_image(image)
|
||||
text_features = model.encode_text(text)
|
||||
|
||||
# Normalize
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Cosine similarity
|
||||
similarity = (image_features @ text_features.T).item()
|
||||
print(f"Similarity: {similarity:.4f}")
|
||||
```
|
||||
|
||||
## Semantic image search
|
||||
|
||||
```python
|
||||
# Index images
|
||||
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
|
||||
image_embeddings = []
|
||||
|
||||
for img_path in image_paths:
|
||||
image = preprocess(Image.open(img_path)).unsqueeze(0).to(device)
|
||||
with torch.no_grad():
|
||||
embedding = model.encode_image(image)
|
||||
embedding /= embedding.norm(dim=-1, keepdim=True)
|
||||
image_embeddings.append(embedding)
|
||||
|
||||
image_embeddings = torch.cat(image_embeddings)
|
||||
|
||||
# Search with text query
|
||||
query = "a sunset over the ocean"
|
||||
text_input = clip.tokenize([query]).to(device)
|
||||
with torch.no_grad():
|
||||
text_embedding = model.encode_text(text_input)
|
||||
text_embedding /= text_embedding.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Find most similar images
|
||||
similarities = (text_embedding @ image_embeddings.T).squeeze(0)
|
||||
top_k = similarities.topk(3)
|
||||
|
||||
for idx, score in zip(top_k.indices, top_k.values):
|
||||
print(f"{image_paths[idx]}: {score:.3f}")
|
||||
```
|
||||
|
||||
## Content moderation
|
||||
|
||||
```python
|
||||
# Define categories
|
||||
categories = [
|
||||
"safe for work",
|
||||
"not safe for work",
|
||||
"violent content",
|
||||
"graphic content"
|
||||
]
|
||||
|
||||
text = clip.tokenize(categories).to(device)
|
||||
|
||||
# Check image
|
||||
with torch.no_grad():
|
||||
logits_per_image, _ = model(image, text)
|
||||
probs = logits_per_image.softmax(dim=-1)
|
||||
|
||||
# Get classification
|
||||
max_idx = probs.argmax().item()
|
||||
max_prob = probs[0, max_idx].item()
|
||||
|
||||
print(f"Category: {categories[max_idx]} ({max_prob:.2%})")
|
||||
```
|
||||
|
||||
## Batch processing
|
||||
|
||||
```python
|
||||
# Process multiple images
|
||||
images = [preprocess(Image.open(f"img{i}.jpg")) for i in range(10)]
|
||||
images = torch.stack(images).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(images)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Batch text
|
||||
texts = ["a dog", "a cat", "a bird"]
|
||||
text_tokens = clip.tokenize(texts).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
text_features = model.encode_text(text_tokens)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Similarity matrix (10 images × 3 texts)
|
||||
similarities = image_features @ text_features.T
|
||||
print(similarities.shape) # (10, 3)
|
||||
```
|
||||
|
||||
## Integration with vector databases
|
||||
|
||||
```python
|
||||
# Store CLIP embeddings in Chroma/FAISS
|
||||
import chromadb
|
||||
|
||||
client = chromadb.Client()
|
||||
collection = client.create_collection("image_embeddings")
|
||||
|
||||
# Add image embeddings
|
||||
for img_path, embedding in zip(image_paths, image_embeddings):
|
||||
collection.add(
|
||||
embeddings=[embedding.cpu().numpy().tolist()],
|
||||
metadatas=[{"path": img_path}],
|
||||
ids=[img_path]
|
||||
)
|
||||
|
||||
# Query with text
|
||||
query = "a sunset"
|
||||
text_embedding = model.encode_text(clip.tokenize([query]))
|
||||
results = collection.query(
|
||||
query_embeddings=[text_embedding.cpu().numpy().tolist()],
|
||||
n_results=5
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use ViT-B/32 for most cases** - Good balance
|
||||
2. **Normalize embeddings** - Required for cosine similarity
|
||||
3. **Batch processing** - More efficient
|
||||
4. **Cache embeddings** - Expensive to recompute
|
||||
5. **Use descriptive labels** - Better zero-shot performance
|
||||
6. **GPU recommended** - 10-50× faster
|
||||
7. **Preprocess images** - Use provided preprocess function
|
||||
|
||||
## Performance
|
||||
|
||||
| Operation | CPU | GPU (V100) |
|
||||
|-----------|-----|------------|
|
||||
| Image encoding | ~200ms | ~20ms |
|
||||
| Text encoding | ~50ms | ~5ms |
|
||||
| Similarity compute | <1ms | <1ms |
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Not for fine-grained tasks** - Best for broad categories
|
||||
2. **Requires descriptive text** - Vague labels perform poorly
|
||||
3. **Biased on web data** - May have dataset biases
|
||||
4. **No bounding boxes** - Whole image only
|
||||
5. **Limited spatial understanding** - Position/counting weak
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/openai/CLIP ⭐ 25,300+
|
||||
- **Paper**: https://arxiv.org/abs/2103.00020
|
||||
- **Colab**: https://colab.research.google.com/github/openai/clip/
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
207
skills/mlops/clip/references/applications.md
Normal file
207
skills/mlops/clip/references/applications.md
Normal file
@@ -0,0 +1,207 @@
|
||||
# CLIP Applications Guide
|
||||
|
||||
Practical applications and use cases for CLIP.
|
||||
|
||||
## Zero-shot image classification
|
||||
|
||||
```python
|
||||
import torch
|
||||
import clip
|
||||
from PIL import Image
|
||||
|
||||
model, preprocess = clip.load("ViT-B/32")
|
||||
|
||||
# Define categories
|
||||
categories = [
|
||||
"a photo of a dog",
|
||||
"a photo of a cat",
|
||||
"a photo of a bird",
|
||||
"a photo of a car",
|
||||
"a photo of a person"
|
||||
]
|
||||
|
||||
# Prepare image
|
||||
image = preprocess(Image.open("photo.jpg")).unsqueeze(0)
|
||||
text = clip.tokenize(categories)
|
||||
|
||||
# Classify
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image)
|
||||
text_features = model.encode_text(text)
|
||||
|
||||
logits_per_image, _ = model(image, text)
|
||||
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
|
||||
|
||||
# Print results
|
||||
for category, prob in zip(categories, probs[0]):
|
||||
print(f"{category}: {prob:.2%}")
|
||||
```
|
||||
|
||||
## Semantic image search
|
||||
|
||||
```python
|
||||
# Index images
|
||||
image_database = []
|
||||
image_paths = ["img1.jpg", "img2.jpg", "img3.jpg"]
|
||||
|
||||
for img_path in image_paths:
|
||||
image = preprocess(Image.open(img_path)).unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
features = model.encode_image(image)
|
||||
features /= features.norm(dim=-1, keepdim=True)
|
||||
image_database.append((img_path, features))
|
||||
|
||||
# Search with text
|
||||
query = "a sunset over mountains"
|
||||
text_input = clip.tokenize([query])
|
||||
|
||||
with torch.no_grad():
|
||||
text_features = model.encode_text(text_input)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Find matches
|
||||
similarities = []
|
||||
for img_path, img_features in image_database:
|
||||
similarity = (text_features @ img_features.T).item()
|
||||
similarities.append((img_path, similarity))
|
||||
|
||||
# Sort by similarity
|
||||
similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
for img_path, score in similarities[:3]:
|
||||
print(f"{img_path}: {score:.3f}")
|
||||
```
|
||||
|
||||
## Content moderation
|
||||
|
||||
```python
|
||||
# Define safety categories
|
||||
categories = [
|
||||
"safe for work content",
|
||||
"not safe for work content",
|
||||
"violent or graphic content",
|
||||
"hate speech or offensive content",
|
||||
"spam or misleading content"
|
||||
]
|
||||
|
||||
text = clip.tokenize(categories)
|
||||
|
||||
# Check image
|
||||
with torch.no_grad():
|
||||
logits, _ = model(image, text)
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
# Get classification
|
||||
max_idx = probs.argmax().item()
|
||||
confidence = probs[0, max_idx].item()
|
||||
|
||||
if confidence > 0.7:
|
||||
print(f"Classified as: {categories[max_idx]} ({confidence:.2%})")
|
||||
else:
|
||||
print(f"Uncertain classification (confidence: {confidence:.2%})")
|
||||
```
|
||||
|
||||
## Image-to-text retrieval
|
||||
|
||||
```python
|
||||
# Text database
|
||||
captions = [
|
||||
"A beautiful sunset over the ocean",
|
||||
"A cute dog playing in the park",
|
||||
"A modern city skyline at night",
|
||||
"A delicious pizza with toppings"
|
||||
]
|
||||
|
||||
# Encode captions
|
||||
caption_features = []
|
||||
for caption in captions:
|
||||
text = clip.tokenize([caption])
|
||||
with torch.no_grad():
|
||||
features = model.encode_text(text)
|
||||
features /= features.norm(dim=-1, keepdim=True)
|
||||
caption_features.append(features)
|
||||
|
||||
caption_features = torch.cat(caption_features)
|
||||
|
||||
# Find matching captions for image
|
||||
with torch.no_grad():
|
||||
image_features = model.encode_image(image)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarities = (image_features @ caption_features.T).squeeze(0)
|
||||
top_k = similarities.topk(3)
|
||||
|
||||
for idx, score in zip(top_k.indices, top_k.values):
|
||||
print(f"{captions[idx]}: {score:.3f}")
|
||||
```
|
||||
|
||||
## Visual question answering
|
||||
|
||||
```python
|
||||
# Create yes/no questions
|
||||
image = preprocess(Image.open("photo.jpg")).unsqueeze(0)
|
||||
|
||||
questions = [
|
||||
"a photo showing people",
|
||||
"a photo showing animals",
|
||||
"a photo taken indoors",
|
||||
"a photo taken outdoors",
|
||||
"a photo taken during daytime",
|
||||
"a photo taken at night"
|
||||
]
|
||||
|
||||
text = clip.tokenize(questions)
|
||||
|
||||
with torch.no_grad():
|
||||
logits, _ = model(image, text)
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
# Answer questions
|
||||
for question, prob in zip(questions, probs[0]):
|
||||
answer = "Yes" if prob > 0.5 else "No"
|
||||
print(f"{question}: {answer} ({prob:.2%})")
|
||||
```
|
||||
|
||||
## Image deduplication
|
||||
|
||||
```python
|
||||
# Detect duplicate/similar images
|
||||
def compute_similarity(img1_path, img2_path):
|
||||
img1 = preprocess(Image.open(img1_path)).unsqueeze(0)
|
||||
img2 = preprocess(Image.open(img2_path)).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
feat1 = model.encode_image(img1)
|
||||
feat2 = model.encode_image(img2)
|
||||
|
||||
feat1 /= feat1.norm(dim=-1, keepdim=True)
|
||||
feat2 /= feat2.norm(dim=-1, keepdim=True)
|
||||
|
||||
similarity = (feat1 @ feat2.T).item()
|
||||
|
||||
return similarity
|
||||
|
||||
# Check for duplicates
|
||||
threshold = 0.95
|
||||
image_pairs = [("img1.jpg", "img2.jpg"), ("img1.jpg", "img3.jpg")]
|
||||
|
||||
for img1, img2 in image_pairs:
|
||||
sim = compute_similarity(img1, img2)
|
||||
if sim > threshold:
|
||||
print(f"{img1} and {img2} are duplicates (similarity: {sim:.3f})")
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use descriptive labels** - "a photo of X" works better than just "X"
|
||||
2. **Normalize embeddings** - Always normalize for cosine similarity
|
||||
3. **Batch processing** - Process multiple images/texts together
|
||||
4. **Cache embeddings** - Expensive to recompute
|
||||
5. **Set appropriate thresholds** - Test on validation data
|
||||
6. **Use GPU** - 10-50× faster than CPU
|
||||
7. **Consider model size** - ViT-B/32 good default, ViT-L/14 for best quality
|
||||
|
||||
## Resources
|
||||
|
||||
- **Paper**: https://arxiv.org/abs/2103.00020
|
||||
- **GitHub**: https://github.com/openai/CLIP
|
||||
- **Colab**: https://colab.research.google.com/github/openai/clip/
|
||||
81
skills/mlops/code-review/SKILL.md
Normal file
81
skills/mlops/code-review/SKILL.md
Normal file
@@ -0,0 +1,81 @@
|
||||
---
|
||||
name: code-review
|
||||
description: Guidelines for performing thorough code reviews with security and quality focus
|
||||
---
|
||||
|
||||
# Code Review Skill
|
||||
|
||||
Use this skill when reviewing code changes, pull requests, or auditing existing code.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
### 1. Security First
|
||||
- [ ] No hardcoded secrets, API keys, or credentials
|
||||
- [ ] Input validation on all user-provided data
|
||||
- [ ] SQL queries use parameterized statements (no string concatenation)
|
||||
- [ ] File operations validate paths (no path traversal)
|
||||
- [ ] Authentication/authorization checks present where needed
|
||||
|
||||
### 2. Error Handling
|
||||
- [ ] All external calls (API, DB, file) have try/catch
|
||||
- [ ] Errors are logged with context (but no sensitive data)
|
||||
- [ ] User-facing errors are helpful but don't leak internals
|
||||
- [ ] Resources are cleaned up in finally blocks or context managers
|
||||
|
||||
### 3. Code Quality
|
||||
- [ ] Functions do one thing and are reasonably sized (<50 lines ideal)
|
||||
- [ ] Variable names are descriptive (no single letters except loops)
|
||||
- [ ] No commented-out code left behind
|
||||
- [ ] Complex logic has explanatory comments
|
||||
- [ ] No duplicate code (DRY principle)
|
||||
|
||||
### 4. Testing Considerations
|
||||
- [ ] Edge cases handled (empty inputs, nulls, boundaries)
|
||||
- [ ] Happy path and error paths both work
|
||||
- [ ] New code has corresponding tests (if test suite exists)
|
||||
|
||||
## Review Response Format
|
||||
|
||||
When providing review feedback, structure it as:
|
||||
|
||||
```
|
||||
## Summary
|
||||
[1-2 sentence overall assessment]
|
||||
|
||||
## Critical Issues (Must Fix)
|
||||
- Issue 1: [description + suggested fix]
|
||||
- Issue 2: ...
|
||||
|
||||
## Suggestions (Nice to Have)
|
||||
- Suggestion 1: [description]
|
||||
|
||||
## Questions
|
||||
- [Any clarifying questions about intent]
|
||||
```
|
||||
|
||||
## Common Patterns to Flag
|
||||
|
||||
### Python
|
||||
```python
|
||||
# Bad: SQL injection risk
|
||||
cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
# Good: Parameterized query
|
||||
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
```
|
||||
|
||||
### JavaScript
|
||||
```javascript
|
||||
// Bad: XSS risk
|
||||
element.innerHTML = userInput;
|
||||
|
||||
// Good: Safe text content
|
||||
element.textContent = userInput;
|
||||
```
|
||||
|
||||
## Tone Guidelines
|
||||
|
||||
- Be constructive, not critical
|
||||
- Explain *why* something is an issue, not just *what*
|
||||
- Offer solutions, not just problems
|
||||
- Acknowledge good patterns you see
|
||||
590
skills/mlops/dspy/SKILL.md
Normal file
590
skills/mlops/dspy/SKILL.md
Normal file
@@ -0,0 +1,590 @@
|
||||
---
|
||||
name: dspy
|
||||
description: Build complex AI systems with declarative programming, optimize prompts automatically, create modular RAG systems and agents with DSPy - Stanford NLP's framework for systematic LM programming
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [Prompt Engineering, DSPy, Declarative Programming, RAG, Agents, Prompt Optimization, LM Programming, Stanford NLP, Automatic Optimization, Modular AI]
|
||||
dependencies: [dspy, openai, anthropic]
|
||||
---
|
||||
|
||||
# DSPy: Declarative Language Model Programming
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use DSPy when you need to:
|
||||
- **Build complex AI systems** with multiple components and workflows
|
||||
- **Program LMs declaratively** instead of manual prompt engineering
|
||||
- **Optimize prompts automatically** using data-driven methods
|
||||
- **Create modular AI pipelines** that are maintainable and portable
|
||||
- **Improve model outputs systematically** with optimizers
|
||||
- **Build RAG systems, agents, or classifiers** with better reliability
|
||||
|
||||
**GitHub Stars**: 22,000+ | **Created By**: Stanford NLP
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Stable release
|
||||
pip install dspy
|
||||
|
||||
# Latest development version
|
||||
pip install git+https://github.com/stanfordnlp/dspy.git
|
||||
|
||||
# With specific LM providers
|
||||
pip install dspy[openai] # OpenAI
|
||||
pip install dspy[anthropic] # Anthropic Claude
|
||||
pip install dspy[all] # All providers
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Question Answering
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
# Configure your language model
|
||||
lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
|
||||
dspy.settings.configure(lm=lm)
|
||||
|
||||
# Define a signature (input → output)
|
||||
class QA(dspy.Signature):
|
||||
"""Answer questions with short factual answers."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="often between 1 and 5 words")
|
||||
|
||||
# Create a module
|
||||
qa = dspy.Predict(QA)
|
||||
|
||||
# Use it
|
||||
response = qa(question="What is the capital of France?")
|
||||
print(response.answer) # "Paris"
|
||||
```
|
||||
|
||||
### Chain of Thought Reasoning
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
|
||||
dspy.settings.configure(lm=lm)
|
||||
|
||||
# Use ChainOfThought for better reasoning
|
||||
class MathProblem(dspy.Signature):
|
||||
"""Solve math word problems."""
|
||||
problem = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="numerical answer")
|
||||
|
||||
# ChainOfThought generates reasoning steps automatically
|
||||
cot = dspy.ChainOfThought(MathProblem)
|
||||
|
||||
response = cot(problem="If John has 5 apples and gives 2 to Mary, how many does he have?")
|
||||
print(response.rationale) # Shows reasoning steps
|
||||
print(response.answer) # "3"
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Signatures
|
||||
|
||||
Signatures define the structure of your AI task (inputs → outputs):
|
||||
|
||||
```python
|
||||
# Inline signature (simple)
|
||||
qa = dspy.Predict("question -> answer")
|
||||
|
||||
# Class signature (detailed)
|
||||
class Summarize(dspy.Signature):
|
||||
"""Summarize text into key points."""
|
||||
text = dspy.InputField()
|
||||
summary = dspy.OutputField(desc="bullet points, 3-5 items")
|
||||
|
||||
summarizer = dspy.ChainOfThought(Summarize)
|
||||
```
|
||||
|
||||
**When to use each:**
|
||||
- **Inline**: Quick prototyping, simple tasks
|
||||
- **Class**: Complex tasks, type hints, better documentation
|
||||
|
||||
### 2. Modules
|
||||
|
||||
Modules are reusable components that transform inputs to outputs:
|
||||
|
||||
#### dspy.Predict
|
||||
Basic prediction module:
|
||||
|
||||
```python
|
||||
predictor = dspy.Predict("context, question -> answer")
|
||||
result = predictor(context="Paris is the capital of France",
|
||||
question="What is the capital?")
|
||||
```
|
||||
|
||||
#### dspy.ChainOfThought
|
||||
Generates reasoning steps before answering:
|
||||
|
||||
```python
|
||||
cot = dspy.ChainOfThought("question -> answer")
|
||||
result = cot(question="Why is the sky blue?")
|
||||
print(result.rationale) # Reasoning steps
|
||||
print(result.answer) # Final answer
|
||||
```
|
||||
|
||||
#### dspy.ReAct
|
||||
Agent-like reasoning with tools:
|
||||
|
||||
```python
|
||||
from dspy.predict import ReAct
|
||||
|
||||
class SearchQA(dspy.Signature):
|
||||
"""Answer questions using search."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField()
|
||||
|
||||
def search_tool(query: str) -> str:
|
||||
"""Search Wikipedia."""
|
||||
# Your search implementation
|
||||
return results
|
||||
|
||||
react = ReAct(SearchQA, tools=[search_tool])
|
||||
result = react(question="When was Python created?")
|
||||
```
|
||||
|
||||
#### dspy.ProgramOfThought
|
||||
Generates and executes code for reasoning:
|
||||
|
||||
```python
|
||||
pot = dspy.ProgramOfThought("question -> answer")
|
||||
result = pot(question="What is 15% of 240?")
|
||||
# Generates: answer = 240 * 0.15
|
||||
```
|
||||
|
||||
### 3. Optimizers
|
||||
|
||||
Optimizers improve your modules automatically using training data:
|
||||
|
||||
#### BootstrapFewShot
|
||||
Learns from examples:
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(question="What is 2+2?", answer="4").with_inputs("question"),
|
||||
dspy.Example(question="What is 3+5?", answer="8").with_inputs("question"),
|
||||
]
|
||||
|
||||
# Define metric
|
||||
def validate_answer(example, pred, trace=None):
|
||||
return example.answer == pred.answer
|
||||
|
||||
# Optimize
|
||||
optimizer = BootstrapFewShot(metric=validate_answer, max_bootstrapped_demos=3)
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Now optimized_qa performs better!
|
||||
```
|
||||
|
||||
#### MIPRO (Most Important Prompt Optimization)
|
||||
Iteratively improves prompts:
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import MIPRO
|
||||
|
||||
optimizer = MIPRO(
|
||||
metric=validate_answer,
|
||||
num_candidates=10,
|
||||
init_temperature=1.0
|
||||
)
|
||||
|
||||
optimized_cot = optimizer.compile(
|
||||
cot,
|
||||
trainset=trainset,
|
||||
num_trials=100
|
||||
)
|
||||
```
|
||||
|
||||
#### BootstrapFinetune
|
||||
Creates datasets for model fine-tuning:
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFinetune
|
||||
|
||||
optimizer = BootstrapFinetune(metric=validate_answer)
|
||||
optimized_module = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Exports training data for fine-tuning
|
||||
```
|
||||
|
||||
### 4. Building Complex Systems
|
||||
|
||||
#### Multi-Stage Pipeline
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
class MultiHopQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
self.generate_query = dspy.ChainOfThought("question -> search_query")
|
||||
self.generate_answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Stage 1: Generate search query
|
||||
search_query = self.generate_query(question=question).search_query
|
||||
|
||||
# Stage 2: Retrieve context
|
||||
passages = self.retrieve(search_query).passages
|
||||
context = "\n".join(passages)
|
||||
|
||||
# Stage 3: Generate answer
|
||||
answer = self.generate_answer(context=context, question=question).answer
|
||||
return dspy.Prediction(answer=answer, context=context)
|
||||
|
||||
# Use the pipeline
|
||||
qa_system = MultiHopQA()
|
||||
result = qa_system(question="Who wrote the book that inspired the movie Blade Runner?")
|
||||
```
|
||||
|
||||
#### RAG System with Optimization
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.retrieve.chromadb_rm import ChromadbRM
|
||||
|
||||
# Configure retriever
|
||||
retriever = ChromadbRM(
|
||||
collection_name="documents",
|
||||
persist_directory="./chroma_db"
|
||||
)
|
||||
|
||||
class RAG(dspy.Module):
|
||||
def __init__(self, num_passages=3):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=num_passages)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
context = self.retrieve(question).passages
|
||||
return self.generate(context=context, question=question)
|
||||
|
||||
# Create and optimize
|
||||
rag = RAG()
|
||||
|
||||
# Optimize with training data
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
optimizer = BootstrapFewShot(metric=validate_answer)
|
||||
optimized_rag = optimizer.compile(rag, trainset=trainset)
|
||||
```
|
||||
|
||||
## LM Provider Configuration
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
lm = dspy.Claude(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key", # Or set ANTHROPIC_API_KEY env var
|
||||
max_tokens=1000,
|
||||
temperature=0.7
|
||||
)
|
||||
dspy.settings.configure(lm=lm)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
lm = dspy.OpenAI(
|
||||
model="gpt-4",
|
||||
api_key="your-api-key",
|
||||
max_tokens=1000
|
||||
)
|
||||
dspy.settings.configure(lm=lm)
|
||||
```
|
||||
|
||||
### Local Models (Ollama)
|
||||
|
||||
```python
|
||||
lm = dspy.OllamaLocal(
|
||||
model="llama3.1",
|
||||
base_url="http://localhost:11434"
|
||||
)
|
||||
dspy.settings.configure(lm=lm)
|
||||
```
|
||||
|
||||
### Multiple Models
|
||||
|
||||
```python
|
||||
# Different models for different tasks
|
||||
cheap_lm = dspy.OpenAI(model="gpt-3.5-turbo")
|
||||
strong_lm = dspy.Claude(model="claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use cheap model for retrieval, strong model for reasoning
|
||||
with dspy.settings.context(lm=cheap_lm):
|
||||
context = retriever(question)
|
||||
|
||||
with dspy.settings.context(lm=strong_lm):
|
||||
answer = generator(context=context, question=question)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: Structured Output
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(description="Age in years")
|
||||
occupation: str = Field(description="Current job")
|
||||
|
||||
class ExtractPerson(dspy.Signature):
|
||||
"""Extract person information from text."""
|
||||
text = dspy.InputField()
|
||||
person: PersonInfo = dspy.OutputField()
|
||||
|
||||
extractor = dspy.TypedPredictor(ExtractPerson)
|
||||
result = extractor(text="John Doe is a 35-year-old software engineer.")
|
||||
print(result.person.name) # "John Doe"
|
||||
print(result.person.age) # 35
|
||||
```
|
||||
|
||||
### Pattern 2: Assertion-Driven Optimization
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
|
||||
|
||||
class MathQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.solve = dspy.ChainOfThought("problem -> solution: float")
|
||||
|
||||
def forward(self, problem):
|
||||
solution = self.solve(problem=problem).solution
|
||||
|
||||
# Assert solution is numeric
|
||||
dspy.Assert(
|
||||
isinstance(float(solution), float),
|
||||
"Solution must be a number",
|
||||
backtrack=backtrack_handler
|
||||
)
|
||||
|
||||
return dspy.Prediction(solution=solution)
|
||||
```
|
||||
|
||||
### Pattern 3: Self-Consistency
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from collections import Counter
|
||||
|
||||
class ConsistentQA(dspy.Module):
|
||||
def __init__(self, num_samples=5):
|
||||
super().__init__()
|
||||
self.qa = dspy.ChainOfThought("question -> answer")
|
||||
self.num_samples = num_samples
|
||||
|
||||
def forward(self, question):
|
||||
# Generate multiple answers
|
||||
answers = []
|
||||
for _ in range(self.num_samples):
|
||||
result = self.qa(question=question)
|
||||
answers.append(result.answer)
|
||||
|
||||
# Return most common answer
|
||||
most_common = Counter(answers).most_common(1)[0][0]
|
||||
return dspy.Prediction(answer=most_common)
|
||||
```
|
||||
|
||||
### Pattern 4: Retrieval with Reranking
|
||||
|
||||
```python
|
||||
class RerankedRAG(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=10)
|
||||
self.rerank = dspy.Predict("question, passage -> relevance_score: float")
|
||||
self.answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Retrieve candidates
|
||||
passages = self.retrieve(question).passages
|
||||
|
||||
# Rerank passages
|
||||
scored = []
|
||||
for passage in passages:
|
||||
score = float(self.rerank(question=question, passage=passage).relevance_score)
|
||||
scored.append((score, passage))
|
||||
|
||||
# Take top 3
|
||||
top_passages = [p for _, p in sorted(scored, reverse=True)[:3]]
|
||||
context = "\n\n".join(top_passages)
|
||||
|
||||
# Generate answer
|
||||
return self.answer(context=context, question=question)
|
||||
```
|
||||
|
||||
## Evaluation and Metrics
|
||||
|
||||
### Custom Metrics
|
||||
|
||||
```python
|
||||
def exact_match(example, pred, trace=None):
|
||||
"""Exact match metric."""
|
||||
return example.answer.lower() == pred.answer.lower()
|
||||
|
||||
def f1_score(example, pred, trace=None):
|
||||
"""F1 score for text overlap."""
|
||||
pred_tokens = set(pred.answer.lower().split())
|
||||
gold_tokens = set(example.answer.lower().split())
|
||||
|
||||
if not pred_tokens:
|
||||
return 0.0
|
||||
|
||||
precision = len(pred_tokens & gold_tokens) / len(pred_tokens)
|
||||
recall = len(pred_tokens & gold_tokens) / len(gold_tokens)
|
||||
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
return 2 * (precision * recall) / (precision + recall)
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```python
|
||||
from dspy.evaluate import Evaluate
|
||||
|
||||
# Create evaluator
|
||||
evaluator = Evaluate(
|
||||
devset=testset,
|
||||
metric=exact_match,
|
||||
num_threads=4,
|
||||
display_progress=True
|
||||
)
|
||||
|
||||
# Evaluate model
|
||||
score = evaluator(qa_system)
|
||||
print(f"Accuracy: {score}")
|
||||
|
||||
# Compare optimized vs unoptimized
|
||||
score_before = evaluator(qa)
|
||||
score_after = evaluator(optimized_qa)
|
||||
print(f"Improvement: {score_after - score_before:.2%}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Start Simple, Iterate
|
||||
|
||||
```python
|
||||
# Start with Predict
|
||||
qa = dspy.Predict("question -> answer")
|
||||
|
||||
# Add reasoning if needed
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Add optimization when you have data
|
||||
optimized_qa = optimizer.compile(qa, trainset=data)
|
||||
```
|
||||
|
||||
### 2. Use Descriptive Signatures
|
||||
|
||||
```python
|
||||
# ❌ Bad: Vague
|
||||
class Task(dspy.Signature):
|
||||
input = dspy.InputField()
|
||||
output = dspy.OutputField()
|
||||
|
||||
# ✅ Good: Descriptive
|
||||
class SummarizeArticle(dspy.Signature):
|
||||
"""Summarize news articles into 3-5 key points."""
|
||||
article = dspy.InputField(desc="full article text")
|
||||
summary = dspy.OutputField(desc="bullet points, 3-5 items")
|
||||
```
|
||||
|
||||
### 3. Optimize with Representative Data
|
||||
|
||||
```python
|
||||
# Create diverse training examples
|
||||
trainset = [
|
||||
dspy.Example(question="factual", answer="...).with_inputs("question"),
|
||||
dspy.Example(question="reasoning", answer="...").with_inputs("question"),
|
||||
dspy.Example(question="calculation", answer="...").with_inputs("question"),
|
||||
]
|
||||
|
||||
# Use validation set for metric
|
||||
def metric(example, pred, trace=None):
|
||||
return example.answer in pred.answer
|
||||
```
|
||||
|
||||
### 4. Save and Load Optimized Models
|
||||
|
||||
```python
|
||||
# Save
|
||||
optimized_qa.save("models/qa_v1.json")
|
||||
|
||||
# Load
|
||||
loaded_qa = dspy.ChainOfThought("question -> answer")
|
||||
loaded_qa.load("models/qa_v1.json")
|
||||
```
|
||||
|
||||
### 5. Monitor and Debug
|
||||
|
||||
```python
|
||||
# Enable tracing
|
||||
dspy.settings.configure(lm=lm, trace=[])
|
||||
|
||||
# Run prediction
|
||||
result = qa(question="...")
|
||||
|
||||
# Inspect trace
|
||||
for call in dspy.settings.trace:
|
||||
print(f"Prompt: {call['prompt']}")
|
||||
print(f"Response: {call['response']}")
|
||||
```
|
||||
|
||||
## Comparison to Other Approaches
|
||||
|
||||
| Feature | Manual Prompting | LangChain | DSPy |
|
||||
|---------|-----------------|-----------|------|
|
||||
| Prompt Engineering | Manual | Manual | Automatic |
|
||||
| Optimization | Trial & error | None | Data-driven |
|
||||
| Modularity | Low | Medium | High |
|
||||
| Type Safety | No | Limited | Yes (Signatures) |
|
||||
| Portability | Low | Medium | High |
|
||||
| Learning Curve | Low | Medium | Medium-High |
|
||||
|
||||
**When to choose DSPy:**
|
||||
- You have training data or can generate it
|
||||
- You need systematic prompt improvement
|
||||
- You're building complex multi-stage systems
|
||||
- You want to optimize across different LMs
|
||||
|
||||
**When to choose alternatives:**
|
||||
- Quick prototypes (manual prompting)
|
||||
- Simple chains with existing tools (LangChain)
|
||||
- Custom optimization logic needed
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://dspy.ai
|
||||
- **GitHub**: https://github.com/stanfordnlp/dspy (22k+ stars)
|
||||
- **Discord**: https://discord.gg/XCGy2WDCQB
|
||||
- **Twitter**: @DSPyOSS
|
||||
- **Paper**: "DSPy: Compiling Declarative Language Model Calls into Self-Improving Pipelines"
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/modules.md` - Detailed module guide (Predict, ChainOfThought, ReAct, ProgramOfThought)
|
||||
- `references/optimizers.md` - Optimization algorithms (BootstrapFewShot, MIPRO, BootstrapFinetune)
|
||||
- `references/examples.md` - Real-world examples (RAG, agents, classifiers)
|
||||
|
||||
|
||||
663
skills/mlops/dspy/references/examples.md
Normal file
663
skills/mlops/dspy/references/examples.md
Normal file
@@ -0,0 +1,663 @@
|
||||
# DSPy Real-World Examples
|
||||
|
||||
Practical examples of building production systems with DSPy.
|
||||
|
||||
## Table of Contents
|
||||
- RAG Systems
|
||||
- Agent Systems
|
||||
- Classification
|
||||
- Data Processing
|
||||
- Multi-Stage Pipelines
|
||||
|
||||
## RAG Systems
|
||||
|
||||
### Basic RAG
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
class BasicRAG(dspy.Module):
|
||||
def __init__(self, num_passages=3):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=num_passages)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
passages = self.retrieve(question).passages
|
||||
context = "\n\n".join(passages)
|
||||
return self.generate(context=context, question=question)
|
||||
|
||||
# Configure retriever (example with Chroma)
|
||||
from dspy.retrieve.chromadb_rm import ChromadbRM
|
||||
|
||||
retriever = ChromadbRM(
|
||||
collection_name="my_docs",
|
||||
persist_directory="./chroma_db",
|
||||
k=3
|
||||
)
|
||||
dspy.settings.configure(rm=retriever)
|
||||
|
||||
# Use RAG
|
||||
rag = BasicRAG()
|
||||
result = rag(question="What is DSPy?")
|
||||
print(result.answer)
|
||||
```
|
||||
|
||||
### Optimized RAG
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Training data with question-answer pairs
|
||||
trainset = [
|
||||
dspy.Example(
|
||||
question="What is retrieval augmented generation?",
|
||||
answer="RAG combines retrieval of relevant documents with generation..."
|
||||
).with_inputs("question"),
|
||||
# ... more examples
|
||||
]
|
||||
|
||||
# Define metric
|
||||
def answer_correctness(example, pred, trace=None):
|
||||
# Check if answer contains key information
|
||||
return example.answer.lower() in pred.answer.lower()
|
||||
|
||||
# Optimize RAG
|
||||
optimizer = BootstrapFewShot(metric=answer_correctness)
|
||||
optimized_rag = optimizer.compile(rag, trainset=trainset)
|
||||
|
||||
# Optimized RAG performs better on similar questions
|
||||
result = optimized_rag(question="Explain RAG systems")
|
||||
```
|
||||
|
||||
### Multi-Hop RAG
|
||||
|
||||
```python
|
||||
class MultiHopRAG(dspy.Module):
|
||||
"""RAG that follows chains of reasoning across documents."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
self.generate_query = dspy.ChainOfThought("question -> search_query")
|
||||
self.generate_answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# First retrieval
|
||||
query1 = self.generate_query(question=question).search_query
|
||||
passages1 = self.retrieve(query1).passages
|
||||
|
||||
# Generate follow-up query based on first results
|
||||
context1 = "\n".join(passages1)
|
||||
query2 = self.generate_query(
|
||||
question=f"Based on: {context1}\nFollow-up: {question}"
|
||||
).search_query
|
||||
|
||||
# Second retrieval
|
||||
passages2 = self.retrieve(query2).passages
|
||||
|
||||
# Combine all context
|
||||
all_context = "\n\n".join(passages1 + passages2)
|
||||
|
||||
# Generate final answer
|
||||
return self.generate_answer(context=all_context, question=question)
|
||||
|
||||
# Use multi-hop RAG
|
||||
multi_rag = MultiHopRAG()
|
||||
result = multi_rag(question="Who wrote the book that inspired Blade Runner?")
|
||||
# Hop 1: Find "Blade Runner was based on..."
|
||||
# Hop 2: Find author of that book
|
||||
```
|
||||
|
||||
### RAG with Reranking
|
||||
|
||||
```python
|
||||
class RerankedRAG(dspy.Module):
|
||||
"""RAG with learned reranking of retrieved passages."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=10) # Get more candidates
|
||||
self.rerank = dspy.Predict("question, passage -> relevance_score: float")
|
||||
self.answer = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Retrieve candidates
|
||||
passages = self.retrieve(question).passages
|
||||
|
||||
# Rerank passages
|
||||
scored_passages = []
|
||||
for passage in passages:
|
||||
score = float(self.rerank(
|
||||
question=question,
|
||||
passage=passage
|
||||
).relevance_score)
|
||||
scored_passages.append((score, passage))
|
||||
|
||||
# Take top 3 after reranking
|
||||
top_passages = [p for _, p in sorted(scored_passages, reverse=True)[:3]]
|
||||
context = "\n\n".join(top_passages)
|
||||
|
||||
# Generate answer from reranked context
|
||||
return self.answer(context=context, question=question)
|
||||
```
|
||||
|
||||
## Agent Systems
|
||||
|
||||
### ReAct Agent
|
||||
|
||||
```python
|
||||
from dspy.predict import ReAct
|
||||
|
||||
# Define tools
|
||||
def search_wikipedia(query: str) -> str:
|
||||
"""Search Wikipedia for information."""
|
||||
import wikipedia
|
||||
try:
|
||||
return wikipedia.summary(query, sentences=3)
|
||||
except:
|
||||
return "No results found"
|
||||
|
||||
def calculate(expression: str) -> str:
|
||||
"""Evaluate mathematical expression safely."""
|
||||
try:
|
||||
# Use safe eval
|
||||
result = eval(expression, {"__builtins__": {}}, {})
|
||||
return str(result)
|
||||
except:
|
||||
return "Invalid expression"
|
||||
|
||||
def search_web(query: str) -> str:
|
||||
"""Search the web."""
|
||||
# Your web search implementation
|
||||
return results
|
||||
|
||||
# Create agent signature
|
||||
class ResearchAgent(dspy.Signature):
|
||||
"""Answer questions using available tools."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField()
|
||||
|
||||
# Create ReAct agent
|
||||
agent = ReAct(ResearchAgent, tools=[search_wikipedia, calculate, search_web])
|
||||
|
||||
# Agent decides which tools to use
|
||||
result = agent(question="What is the population of France divided by 10?")
|
||||
# Agent:
|
||||
# 1. Thinks: "Need population of France"
|
||||
# 2. Acts: search_wikipedia("France population")
|
||||
# 3. Thinks: "Got 67 million, need to divide"
|
||||
# 4. Acts: calculate("67000000 / 10")
|
||||
# 5. Returns: "6,700,000"
|
||||
```
|
||||
|
||||
### Multi-Agent System
|
||||
|
||||
```python
|
||||
class MultiAgentSystem(dspy.Module):
|
||||
"""System with specialized agents for different tasks."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Router agent
|
||||
self.router = dspy.Predict("question -> agent_type: str")
|
||||
|
||||
# Specialized agents
|
||||
self.research_agent = ReAct(
|
||||
ResearchAgent,
|
||||
tools=[search_wikipedia, search_web]
|
||||
)
|
||||
self.math_agent = dspy.ProgramOfThought("problem -> answer")
|
||||
self.reasoning_agent = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Route to appropriate agent
|
||||
agent_type = self.router(question=question).agent_type
|
||||
|
||||
if agent_type == "research":
|
||||
return self.research_agent(question=question)
|
||||
elif agent_type == "math":
|
||||
return self.math_agent(problem=question)
|
||||
else:
|
||||
return self.reasoning_agent(question=question)
|
||||
|
||||
# Use multi-agent system
|
||||
mas = MultiAgentSystem()
|
||||
result = mas(question="What is 15% of the GDP of France?")
|
||||
# Routes to research_agent for GDP, then to math_agent for calculation
|
||||
```
|
||||
|
||||
## Classification
|
||||
|
||||
### Binary Classifier
|
||||
|
||||
```python
|
||||
class SentimentClassifier(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.classify = dspy.Predict("text -> sentiment: str")
|
||||
|
||||
def forward(self, text):
|
||||
return self.classify(text=text)
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(text="I love this!", sentiment="positive").with_inputs("text"),
|
||||
dspy.Example(text="Terrible experience", sentiment="negative").with_inputs("text"),
|
||||
# ... more examples
|
||||
]
|
||||
|
||||
# Optimize
|
||||
def accuracy(example, pred, trace=None):
|
||||
return example.sentiment == pred.sentiment
|
||||
|
||||
optimizer = BootstrapFewShot(metric=accuracy, max_bootstrapped_demos=5)
|
||||
classifier = SentimentClassifier()
|
||||
optimized_classifier = optimizer.compile(classifier, trainset=trainset)
|
||||
|
||||
# Use classifier
|
||||
result = optimized_classifier(text="This product is amazing!")
|
||||
print(result.sentiment) # "positive"
|
||||
```
|
||||
|
||||
### Multi-Class Classifier
|
||||
|
||||
```python
|
||||
class TopicClassifier(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.classify = dspy.ChainOfThought(
|
||||
"text -> category: str, confidence: float"
|
||||
)
|
||||
|
||||
def forward(self, text):
|
||||
result = self.classify(text=text)
|
||||
return dspy.Prediction(
|
||||
category=result.category,
|
||||
confidence=float(result.confidence)
|
||||
)
|
||||
|
||||
# Define categories in signature
|
||||
class TopicSignature(dspy.Signature):
|
||||
"""Classify text into one of: technology, sports, politics, entertainment."""
|
||||
text = dspy.InputField()
|
||||
category = dspy.OutputField(desc="one of: technology, sports, politics, entertainment")
|
||||
confidence = dspy.OutputField(desc="0.0 to 1.0")
|
||||
|
||||
classifier = dspy.ChainOfThought(TopicSignature)
|
||||
result = classifier(text="The Lakers won the championship")
|
||||
print(result.category) # "sports"
|
||||
print(result.confidence) # 0.95
|
||||
```
|
||||
|
||||
### Hierarchical Classifier
|
||||
|
||||
```python
|
||||
class HierarchicalClassifier(dspy.Module):
|
||||
"""Two-stage classification: coarse then fine-grained."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.coarse = dspy.Predict("text -> broad_category: str")
|
||||
self.fine_tech = dspy.Predict("text -> tech_subcategory: str")
|
||||
self.fine_sports = dspy.Predict("text -> sports_subcategory: str")
|
||||
|
||||
def forward(self, text):
|
||||
# Stage 1: Broad category
|
||||
broad = self.coarse(text=text).broad_category
|
||||
|
||||
# Stage 2: Fine-grained based on broad
|
||||
if broad == "technology":
|
||||
fine = self.fine_tech(text=text).tech_subcategory
|
||||
elif broad == "sports":
|
||||
fine = self.fine_sports(text=text).sports_subcategory
|
||||
else:
|
||||
fine = "other"
|
||||
|
||||
return dspy.Prediction(broad_category=broad, fine_category=fine)
|
||||
```
|
||||
|
||||
## Data Processing
|
||||
|
||||
### Text Summarization
|
||||
|
||||
```python
|
||||
class AdaptiveSummarizer(dspy.Module):
|
||||
"""Summarizes text to target length."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.summarize = dspy.ChainOfThought("text, target_length -> summary")
|
||||
|
||||
def forward(self, text, target_length="3 sentences"):
|
||||
return self.summarize(text=text, target_length=target_length)
|
||||
|
||||
# Use summarizer
|
||||
summarizer = AdaptiveSummarizer()
|
||||
long_text = "..." # Long article
|
||||
|
||||
short_summary = summarizer(long_text, target_length="1 sentence")
|
||||
medium_summary = summarizer(long_text, target_length="3 sentences")
|
||||
detailed_summary = summarizer(long_text, target_length="1 paragraph")
|
||||
```
|
||||
|
||||
### Information Extraction
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(description="Age in years")
|
||||
occupation: str = Field(description="Job title")
|
||||
location: str = Field(description="City and country")
|
||||
|
||||
class ExtractPerson(dspy.Signature):
|
||||
"""Extract person information from text."""
|
||||
text = dspy.InputField()
|
||||
person: PersonInfo = dspy.OutputField()
|
||||
|
||||
extractor = dspy.TypedPredictor(ExtractPerson)
|
||||
|
||||
text = "Dr. Jane Smith, 42, is a neuroscientist at Stanford University in Palo Alto, California."
|
||||
result = extractor(text=text)
|
||||
|
||||
print(result.person.name) # "Dr. Jane Smith"
|
||||
print(result.person.age) # 42
|
||||
print(result.person.occupation) # "neuroscientist"
|
||||
print(result.person.location) # "Palo Alto, California"
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
class BatchProcessor(dspy.Module):
|
||||
"""Process large datasets efficiently."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.process = dspy.Predict("text -> processed_text")
|
||||
|
||||
def forward(self, texts):
|
||||
# Batch processing for efficiency
|
||||
return self.process.batch([{"text": t} for t in texts])
|
||||
|
||||
# Process 1000 documents
|
||||
processor = BatchProcessor()
|
||||
results = processor(texts=large_dataset)
|
||||
|
||||
# Results are returned in order
|
||||
for original, result in zip(large_dataset, results):
|
||||
print(f"{original} -> {result.processed_text}")
|
||||
```
|
||||
|
||||
## Multi-Stage Pipelines
|
||||
|
||||
### Document Processing Pipeline
|
||||
|
||||
```python
|
||||
class DocumentPipeline(dspy.Module):
|
||||
"""Multi-stage document processing."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.extract = dspy.Predict("document -> key_points")
|
||||
self.classify = dspy.Predict("key_points -> category")
|
||||
self.summarize = dspy.ChainOfThought("key_points, category -> summary")
|
||||
self.tag = dspy.Predict("summary -> tags")
|
||||
|
||||
def forward(self, document):
|
||||
# Stage 1: Extract key points
|
||||
key_points = self.extract(document=document).key_points
|
||||
|
||||
# Stage 2: Classify
|
||||
category = self.classify(key_points=key_points).category
|
||||
|
||||
# Stage 3: Summarize
|
||||
summary = self.summarize(
|
||||
key_points=key_points,
|
||||
category=category
|
||||
).summary
|
||||
|
||||
# Stage 4: Generate tags
|
||||
tags = self.tag(summary=summary).tags
|
||||
|
||||
return dspy.Prediction(
|
||||
key_points=key_points,
|
||||
category=category,
|
||||
summary=summary,
|
||||
tags=tags
|
||||
)
|
||||
```
|
||||
|
||||
### Quality Control Pipeline
|
||||
|
||||
```python
|
||||
class QualityControlPipeline(dspy.Module):
|
||||
"""Generate output and verify quality."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.generate = dspy.ChainOfThought("prompt -> output")
|
||||
self.verify = dspy.Predict("output -> is_valid: bool, issues: str")
|
||||
self.improve = dspy.ChainOfThought("output, issues -> improved_output")
|
||||
|
||||
def forward(self, prompt, max_iterations=3):
|
||||
output = self.generate(prompt=prompt).output
|
||||
|
||||
for _ in range(max_iterations):
|
||||
# Verify output
|
||||
verification = self.verify(output=output)
|
||||
|
||||
if verification.is_valid:
|
||||
return dspy.Prediction(output=output, iterations=_ + 1)
|
||||
|
||||
# Improve based on issues
|
||||
output = self.improve(
|
||||
output=output,
|
||||
issues=verification.issues
|
||||
).improved_output
|
||||
|
||||
return dspy.Prediction(output=output, iterations=max_iterations)
|
||||
```
|
||||
|
||||
## Production Tips
|
||||
|
||||
### 1. Caching for Performance
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
class CachedRAG(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
@lru_cache(maxsize=1000)
|
||||
def forward(self, question):
|
||||
passages = self.retrieve(question).passages
|
||||
context = "\n".join(passages)
|
||||
return self.generate(context=context, question=question).answer
|
||||
```
|
||||
|
||||
### 2. Error Handling
|
||||
|
||||
```python
|
||||
class RobustModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.process = dspy.ChainOfThought("input -> output")
|
||||
|
||||
def forward(self, input):
|
||||
try:
|
||||
result = self.process(input=input)
|
||||
return result
|
||||
except Exception as e:
|
||||
# Log error
|
||||
print(f"Error processing {input}: {e}")
|
||||
# Return fallback
|
||||
return dspy.Prediction(output="Error: could not process input")
|
||||
```
|
||||
|
||||
### 3. Monitoring
|
||||
|
||||
```python
|
||||
class MonitoredModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.process = dspy.ChainOfThought("input -> output")
|
||||
self.call_count = 0
|
||||
self.errors = 0
|
||||
|
||||
def forward(self, input):
|
||||
self.call_count += 1
|
||||
|
||||
try:
|
||||
result = self.process(input=input)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.errors += 1
|
||||
raise
|
||||
|
||||
def get_stats(self):
|
||||
return {
|
||||
"calls": self.call_count,
|
||||
"errors": self.errors,
|
||||
"error_rate": self.errors / max(self.call_count, 1)
|
||||
}
|
||||
```
|
||||
|
||||
### 4. A/B Testing
|
||||
|
||||
```python
|
||||
class ABTestModule(dspy.Module):
|
||||
"""Run two variants and compare."""
|
||||
|
||||
def __init__(self, variant_a, variant_b):
|
||||
super().__init__()
|
||||
self.variant_a = variant_a
|
||||
self.variant_b = variant_b
|
||||
self.a_calls = 0
|
||||
self.b_calls = 0
|
||||
|
||||
def forward(self, input, variant="a"):
|
||||
if variant == "a":
|
||||
self.a_calls += 1
|
||||
return self.variant_a(input=input)
|
||||
else:
|
||||
self.b_calls += 1
|
||||
return self.variant_b(input=input)
|
||||
|
||||
# Compare two optimizers
|
||||
baseline = dspy.ChainOfThought("question -> answer")
|
||||
optimized = BootstrapFewShot(...).compile(baseline, trainset=trainset)
|
||||
|
||||
ab_test = ABTestModule(variant_a=baseline, variant_b=optimized)
|
||||
|
||||
# Route 50% to each
|
||||
import random
|
||||
variant = "a" if random.random() < 0.5 else "b"
|
||||
result = ab_test(input=question, variant=variant)
|
||||
```
|
||||
|
||||
## Complete Example: Customer Support Bot
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
class CustomerSupportBot(dspy.Module):
|
||||
"""Complete customer support system."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# Classify intent
|
||||
self.classify_intent = dspy.Predict("message -> intent: str")
|
||||
|
||||
# Specialized handlers
|
||||
self.technical_handler = dspy.ChainOfThought("message, history -> response")
|
||||
self.billing_handler = dspy.ChainOfThought("message, history -> response")
|
||||
self.general_handler = dspy.Predict("message, history -> response")
|
||||
|
||||
# Retrieve relevant docs
|
||||
self.retrieve = dspy.Retrieve(k=3)
|
||||
|
||||
# Conversation history
|
||||
self.history = []
|
||||
|
||||
def forward(self, message):
|
||||
# Classify intent
|
||||
intent = self.classify_intent(message=message).intent
|
||||
|
||||
# Retrieve relevant documentation
|
||||
docs = self.retrieve(message).passages
|
||||
context = "\n".join(docs)
|
||||
|
||||
# Add context to history
|
||||
history_str = "\n".join(self.history)
|
||||
full_message = f"Context: {context}\n\nMessage: {message}"
|
||||
|
||||
# Route to appropriate handler
|
||||
if intent == "technical":
|
||||
response = self.technical_handler(
|
||||
message=full_message,
|
||||
history=history_str
|
||||
).response
|
||||
elif intent == "billing":
|
||||
response = self.billing_handler(
|
||||
message=full_message,
|
||||
history=history_str
|
||||
).response
|
||||
else:
|
||||
response = self.general_handler(
|
||||
message=full_message,
|
||||
history=history_str
|
||||
).response
|
||||
|
||||
# Update history
|
||||
self.history.append(f"User: {message}")
|
||||
self.history.append(f"Bot: {response}")
|
||||
|
||||
return dspy.Prediction(response=response, intent=intent)
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(
|
||||
message="My account isn't working",
|
||||
intent="technical",
|
||||
response="I'd be happy to help. What error are you seeing?"
|
||||
).with_inputs("message"),
|
||||
# ... more examples
|
||||
]
|
||||
|
||||
# Define metric
|
||||
def response_quality(example, pred, trace=None):
|
||||
# Check if response is helpful
|
||||
if len(pred.response) < 20:
|
||||
return 0.0
|
||||
if example.intent != pred.intent:
|
||||
return 0.3
|
||||
return 1.0
|
||||
|
||||
# Optimize
|
||||
optimizer = BootstrapFewShot(metric=response_quality)
|
||||
bot = CustomerSupportBot()
|
||||
optimized_bot = optimizer.compile(bot, trainset=trainset)
|
||||
|
||||
# Use in production
|
||||
optimized_bot.save("models/support_bot_v1.json")
|
||||
|
||||
# Later, load and use
|
||||
loaded_bot = CustomerSupportBot()
|
||||
loaded_bot.load("models/support_bot_v1.json")
|
||||
response = loaded_bot(message="I can't log in")
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://dspy.ai
|
||||
- **Examples Repo**: https://github.com/stanfordnlp/dspy/tree/main/examples
|
||||
- **Discord**: https://discord.gg/XCGy2WDCQB
|
||||
475
skills/mlops/dspy/references/modules.md
Normal file
475
skills/mlops/dspy/references/modules.md
Normal file
@@ -0,0 +1,475 @@
|
||||
# DSPy Modules
|
||||
|
||||
Complete guide to DSPy's built-in modules for language model programming.
|
||||
|
||||
## Module Basics
|
||||
|
||||
DSPy modules are composable building blocks inspired by PyTorch's NN modules:
|
||||
- Have learnable parameters (prompts, few-shot examples)
|
||||
- Can be composed using Python control flow
|
||||
- Generalized to handle any signature
|
||||
- Optimizable with DSPy optimizers
|
||||
|
||||
### Base Module Pattern
|
||||
|
||||
```python
|
||||
import dspy
|
||||
|
||||
class CustomModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# Initialize sub-modules
|
||||
self.predictor = dspy.Predict("input -> output")
|
||||
|
||||
def forward(self, input):
|
||||
# Module logic
|
||||
result = self.predictor(input=input)
|
||||
return result
|
||||
```
|
||||
|
||||
## Core Modules
|
||||
|
||||
### dspy.Predict
|
||||
|
||||
**Basic prediction module** - Makes LM calls without reasoning steps.
|
||||
|
||||
```python
|
||||
# Inline signature
|
||||
qa = dspy.Predict("question -> answer")
|
||||
result = qa(question="What is 2+2?")
|
||||
|
||||
# Class signature
|
||||
class QA(dspy.Signature):
|
||||
"""Answer questions concisely."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField(desc="short, factual answer")
|
||||
|
||||
qa = dspy.Predict(QA)
|
||||
result = qa(question="What is the capital of France?")
|
||||
print(result.answer) # "Paris"
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Simple, direct predictions
|
||||
- No reasoning steps needed
|
||||
- Fast responses required
|
||||
|
||||
### dspy.ChainOfThought
|
||||
|
||||
**Step-by-step reasoning** - Generates rationale before answer.
|
||||
|
||||
**Parameters:**
|
||||
- `signature`: Task signature
|
||||
- `rationale_field`: Custom reasoning field (optional)
|
||||
- `rationale_field_type`: Type for rationale (default: `str`)
|
||||
|
||||
```python
|
||||
# Basic usage
|
||||
cot = dspy.ChainOfThought("question -> answer")
|
||||
result = cot(question="If I have 5 apples and give away 2, how many remain?")
|
||||
print(result.rationale) # "Let's think step by step..."
|
||||
print(result.answer) # "3"
|
||||
|
||||
# Custom rationale field
|
||||
cot = dspy.ChainOfThought(
|
||||
signature="problem -> solution",
|
||||
rationale_field=dspy.OutputField(
|
||||
prefix="Reasoning: Let's break this down step by step to"
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Complex reasoning tasks
|
||||
- Math word problems
|
||||
- Logical deduction
|
||||
- Quality > speed
|
||||
|
||||
**Performance:**
|
||||
- ~2x slower than Predict
|
||||
- Significantly better accuracy on reasoning tasks
|
||||
|
||||
### dspy.ProgramOfThought
|
||||
|
||||
**Code-based reasoning** - Generates and executes Python code.
|
||||
|
||||
```python
|
||||
pot = dspy.ProgramOfThought("question -> answer")
|
||||
|
||||
result = pot(question="What is 15% of 240?")
|
||||
# Internally generates: answer = 240 * 0.15
|
||||
# Executes code and returns result
|
||||
print(result.answer) # 36.0
|
||||
|
||||
result = pot(question="If a train travels 60 mph for 2.5 hours, how far does it go?")
|
||||
# Generates: distance = 60 * 2.5
|
||||
print(result.answer) # 150.0
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Arithmetic calculations
|
||||
- Symbolic math
|
||||
- Data transformations
|
||||
- Deterministic computations
|
||||
|
||||
**Benefits:**
|
||||
- More reliable than text-based math
|
||||
- Handles complex calculations
|
||||
- Transparent (shows generated code)
|
||||
|
||||
### dspy.ReAct
|
||||
|
||||
**Reasoning + Acting** - Agent that uses tools iteratively.
|
||||
|
||||
```python
|
||||
from dspy.predict import ReAct
|
||||
|
||||
# Define tools
|
||||
def search_wikipedia(query: str) -> str:
|
||||
"""Search Wikipedia for information."""
|
||||
# Your search implementation
|
||||
return search_results
|
||||
|
||||
def calculate(expression: str) -> float:
|
||||
"""Evaluate a mathematical expression."""
|
||||
return eval(expression)
|
||||
|
||||
# Create ReAct agent
|
||||
class ResearchQA(dspy.Signature):
|
||||
"""Answer questions using available tools."""
|
||||
question = dspy.InputField()
|
||||
answer = dspy.OutputField()
|
||||
|
||||
react = ReAct(ResearchQA, tools=[search_wikipedia, calculate])
|
||||
|
||||
# Agent decides which tools to use
|
||||
result = react(question="How old was Einstein when he published special relativity?")
|
||||
# Internally:
|
||||
# 1. Thinks: "Need birth year and publication year"
|
||||
# 2. Acts: search_wikipedia("Albert Einstein")
|
||||
# 3. Acts: search_wikipedia("Special relativity 1905")
|
||||
# 4. Acts: calculate("1905 - 1879")
|
||||
# 5. Returns: "26 years old"
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Multi-step research tasks
|
||||
- Tool-using agents
|
||||
- Complex information retrieval
|
||||
- Tasks requiring multiple API calls
|
||||
|
||||
**Best practices:**
|
||||
- Keep tool descriptions clear and specific
|
||||
- Limit to 5-7 tools (too many = confusion)
|
||||
- Provide tool usage examples in docstrings
|
||||
|
||||
### dspy.MultiChainComparison
|
||||
|
||||
**Generate multiple outputs and compare** - Self-consistency pattern.
|
||||
|
||||
```python
|
||||
mcc = dspy.MultiChainComparison("question -> answer", M=5)
|
||||
|
||||
result = mcc(question="What is the capital of France?")
|
||||
# Generates 5 candidate answers
|
||||
# Compares and selects most consistent
|
||||
print(result.answer) # "Paris"
|
||||
print(result.candidates) # All 5 generated answers
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `M`: Number of candidates to generate (default: 5)
|
||||
- `temperature`: Sampling temperature for diversity
|
||||
|
||||
**When to use:**
|
||||
- High-stakes decisions
|
||||
- Ambiguous questions
|
||||
- When single answer may be unreliable
|
||||
|
||||
**Tradeoff:**
|
||||
- M times slower (M parallel calls)
|
||||
- Higher accuracy on ambiguous tasks
|
||||
|
||||
### dspy.majority
|
||||
|
||||
**Majority voting over multiple predictions.**
|
||||
|
||||
```python
|
||||
from dspy.primitives import majority
|
||||
|
||||
# Generate multiple predictions
|
||||
predictor = dspy.Predict("question -> answer")
|
||||
predictions = [predictor(question="What is 2+2?") for _ in range(5)]
|
||||
|
||||
# Take majority vote
|
||||
answer = majority([p.answer for p in predictions])
|
||||
print(answer) # "4"
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Combining multiple model outputs
|
||||
- Reducing variance in predictions
|
||||
- Ensemble approaches
|
||||
|
||||
## Advanced Modules
|
||||
|
||||
### dspy.TypedPredictor
|
||||
|
||||
**Structured output with Pydantic models.**
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class PersonInfo(BaseModel):
|
||||
name: str = Field(description="Full name")
|
||||
age: int = Field(description="Age in years")
|
||||
occupation: str = Field(description="Current job")
|
||||
|
||||
class ExtractPerson(dspy.Signature):
|
||||
"""Extract person information from text."""
|
||||
text = dspy.InputField()
|
||||
person: PersonInfo = dspy.OutputField()
|
||||
|
||||
extractor = dspy.TypedPredictor(ExtractPerson)
|
||||
result = extractor(text="John Doe is a 35-year-old software engineer.")
|
||||
|
||||
print(result.person.name) # "John Doe"
|
||||
print(result.person.age) # 35
|
||||
print(result.person.occupation) # "software engineer"
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Type safety
|
||||
- Automatic validation
|
||||
- JSON schema generation
|
||||
- IDE autocomplete
|
||||
|
||||
### dspy.Retry
|
||||
|
||||
**Automatic retry with validation.**
|
||||
|
||||
```python
|
||||
from dspy.primitives import Retry
|
||||
|
||||
def validate_number(example, pred, trace=None):
|
||||
"""Validate output is a number."""
|
||||
try:
|
||||
float(pred.answer)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
# Retry up to 3 times if validation fails
|
||||
qa = Retry(
|
||||
dspy.ChainOfThought("question -> answer"),
|
||||
validate=validate_number,
|
||||
max_retries=3
|
||||
)
|
||||
|
||||
result = qa(question="What is 15% of 80?")
|
||||
# If first attempt returns non-numeric, retries automatically
|
||||
```
|
||||
|
||||
### dspy.Assert
|
||||
|
||||
**Assertion-driven optimization.**
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.primitives.assertions import assert_transform_module, backtrack_handler
|
||||
|
||||
class ValidatedQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.qa = dspy.ChainOfThought("question -> answer: float")
|
||||
|
||||
def forward(self, question):
|
||||
answer = self.qa(question=question).answer
|
||||
|
||||
# Assert answer is numeric
|
||||
dspy.Assert(
|
||||
isinstance(float(answer), float),
|
||||
"Answer must be a number",
|
||||
backtrack=backtrack_handler
|
||||
)
|
||||
|
||||
return dspy.Prediction(answer=answer)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Catches errors during optimization
|
||||
- Guides LM toward valid outputs
|
||||
- Better than post-hoc filtering
|
||||
|
||||
## Module Composition
|
||||
|
||||
### Sequential Pipeline
|
||||
|
||||
```python
|
||||
class Pipeline(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.stage1 = dspy.Predict("input -> intermediate")
|
||||
self.stage2 = dspy.ChainOfThought("intermediate -> output")
|
||||
|
||||
def forward(self, input):
|
||||
intermediate = self.stage1(input=input).intermediate
|
||||
output = self.stage2(intermediate=intermediate).output
|
||||
return dspy.Prediction(output=output)
|
||||
```
|
||||
|
||||
### Conditional Logic
|
||||
|
||||
```python
|
||||
class ConditionalModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.router = dspy.Predict("question -> category: str")
|
||||
self.simple_qa = dspy.Predict("question -> answer")
|
||||
self.complex_qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
category = self.router(question=question).category
|
||||
|
||||
if category == "simple":
|
||||
return self.simple_qa(question=question)
|
||||
else:
|
||||
return self.complex_qa(question=question)
|
||||
```
|
||||
|
||||
### Parallel Execution
|
||||
|
||||
```python
|
||||
class ParallelModule(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.approach1 = dspy.ChainOfThought("question -> answer")
|
||||
self.approach2 = dspy.ProgramOfThought("question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
# Run both approaches
|
||||
answer1 = self.approach1(question=question).answer
|
||||
answer2 = self.approach2(question=question).answer
|
||||
|
||||
# Compare or combine results
|
||||
if answer1 == answer2:
|
||||
return dspy.Prediction(answer=answer1, confidence="high")
|
||||
else:
|
||||
return dspy.Prediction(answer=answer1, confidence="low")
|
||||
```
|
||||
|
||||
## Batch Processing
|
||||
|
||||
All modules support batch processing for efficiency:
|
||||
|
||||
```python
|
||||
cot = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
questions = [
|
||||
"What is 2+2?",
|
||||
"What is 3+3?",
|
||||
"What is 4+4?"
|
||||
]
|
||||
|
||||
# Process all at once
|
||||
results = cot.batch([{"question": q} for q in questions])
|
||||
|
||||
for result in results:
|
||||
print(result.answer)
|
||||
```
|
||||
|
||||
## Saving and Loading
|
||||
|
||||
```python
|
||||
# Save module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
qa.save("models/qa_v1.json")
|
||||
|
||||
# Load module
|
||||
loaded_qa = dspy.ChainOfThought("question -> answer")
|
||||
loaded_qa.load("models/qa_v1.json")
|
||||
```
|
||||
|
||||
**What gets saved:**
|
||||
- Few-shot examples
|
||||
- Prompt instructions
|
||||
- Module configuration
|
||||
|
||||
**What doesn't get saved:**
|
||||
- Model weights (DSPy doesn't fine-tune by default)
|
||||
- LM provider configuration
|
||||
|
||||
## Module Selection Guide
|
||||
|
||||
| Task | Module | Reason |
|
||||
|------|--------|--------|
|
||||
| Simple classification | Predict | Fast, direct |
|
||||
| Math word problems | ProgramOfThought | Reliable calculations |
|
||||
| Logical reasoning | ChainOfThought | Better with steps |
|
||||
| Multi-step research | ReAct | Tool usage |
|
||||
| High-stakes decisions | MultiChainComparison | Self-consistency |
|
||||
| Structured extraction | TypedPredictor | Type safety |
|
||||
| Ambiguous questions | MultiChainComparison | Multiple perspectives |
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Start with Predict**, add reasoning only if needed
|
||||
2. **Use batch processing** for multiple inputs
|
||||
3. **Cache predictions** for repeated queries
|
||||
4. **Profile token usage** with `track_usage=True`
|
||||
5. **Optimize after prototyping** with teleprompters
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern: Retrieval + Generation
|
||||
|
||||
```python
|
||||
class RAG(dspy.Module):
|
||||
def __init__(self, k=3):
|
||||
super().__init__()
|
||||
self.retrieve = dspy.Retrieve(k=k)
|
||||
self.generate = dspy.ChainOfThought("context, question -> answer")
|
||||
|
||||
def forward(self, question):
|
||||
context = self.retrieve(question).passages
|
||||
return self.generate(context=context, question=question)
|
||||
```
|
||||
|
||||
### Pattern: Verification Loop
|
||||
|
||||
```python
|
||||
class VerifiedQA(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.answer = dspy.ChainOfThought("question -> answer")
|
||||
self.verify = dspy.Predict("question, answer -> is_correct: bool")
|
||||
|
||||
def forward(self, question, max_attempts=3):
|
||||
for _ in range(max_attempts):
|
||||
answer = self.answer(question=question).answer
|
||||
is_correct = self.verify(question=question, answer=answer).is_correct
|
||||
|
||||
if is_correct:
|
||||
return dspy.Prediction(answer=answer)
|
||||
|
||||
return dspy.Prediction(answer="Unable to verify answer")
|
||||
```
|
||||
|
||||
### Pattern: Multi-Turn Dialog
|
||||
|
||||
```python
|
||||
class DialogAgent(dspy.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.respond = dspy.Predict("history, user_message -> assistant_message")
|
||||
self.history = []
|
||||
|
||||
def forward(self, user_message):
|
||||
history_str = "\n".join(self.history)
|
||||
response = self.respond(history=history_str, user_message=user_message)
|
||||
|
||||
self.history.append(f"User: {user_message}")
|
||||
self.history.append(f"Assistant: {response.assistant_message}")
|
||||
|
||||
return response
|
||||
```
|
||||
566
skills/mlops/dspy/references/optimizers.md
Normal file
566
skills/mlops/dspy/references/optimizers.md
Normal file
@@ -0,0 +1,566 @@
|
||||
# DSPy Optimizers (Teleprompters)
|
||||
|
||||
Complete guide to DSPy's optimization algorithms for improving prompts and model weights.
|
||||
|
||||
## What are Optimizers?
|
||||
|
||||
DSPy optimizers (called "teleprompters") automatically improve your modules by:
|
||||
- **Synthesizing few-shot examples** from training data
|
||||
- **Proposing better instructions** through search
|
||||
- **Fine-tuning model weights** (optional)
|
||||
|
||||
**Key idea**: Instead of manually tuning prompts, define a metric and let DSPy optimize.
|
||||
|
||||
## Optimizer Selection Guide
|
||||
|
||||
| Optimizer | Best For | Speed | Quality | Data Needed |
|
||||
|-----------|----------|-------|---------|-------------|
|
||||
| BootstrapFewShot | General purpose | Fast | Good | 10-50 examples |
|
||||
| MIPRO | Instruction tuning | Medium | Excellent | 50-200 examples |
|
||||
| BootstrapFinetune | Fine-tuning | Slow | Excellent | 100+ examples |
|
||||
| COPRO | Prompt optimization | Medium | Good | 20-100 examples |
|
||||
| KNNFewShot | Quick baseline | Very fast | Fair | 10+ examples |
|
||||
|
||||
## Core Optimizers
|
||||
|
||||
### BootstrapFewShot
|
||||
|
||||
**Most popular optimizer** - Generates few-shot demonstrations from training data.
|
||||
|
||||
**How it works:**
|
||||
1. Takes your training examples
|
||||
2. Uses your module to generate predictions
|
||||
3. Selects high-quality predictions (based on metric)
|
||||
4. Uses these as few-shot examples in future prompts
|
||||
|
||||
**Parameters:**
|
||||
- `metric`: Function that scores predictions (required)
|
||||
- `max_bootstrapped_demos`: Max demonstrations to generate (default: 4)
|
||||
- `max_labeled_demos`: Max labeled examples to use (default: 16)
|
||||
- `max_rounds`: Optimization iterations (default: 1)
|
||||
- `metric_threshold`: Minimum score to accept (optional)
|
||||
|
||||
```python
|
||||
import dspy
|
||||
from dspy.teleprompt import BootstrapFewShot
|
||||
|
||||
# Define metric
|
||||
def validate_answer(example, pred, trace=None):
|
||||
"""Return True if prediction matches gold answer."""
|
||||
return example.answer.lower() == pred.answer.lower()
|
||||
|
||||
# Training data
|
||||
trainset = [
|
||||
dspy.Example(question="What is 2+2?", answer="4").with_inputs("question"),
|
||||
dspy.Example(question="What is 3+5?", answer="8").with_inputs("question"),
|
||||
dspy.Example(question="What is 10-3?", answer="7").with_inputs("question"),
|
||||
]
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Optimize
|
||||
optimizer = BootstrapFewShot(
|
||||
metric=validate_answer,
|
||||
max_bootstrapped_demos=3,
|
||||
max_rounds=2
|
||||
)
|
||||
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Now optimized_qa has learned few-shot examples!
|
||||
result = optimized_qa(question="What is 5+7?")
|
||||
```
|
||||
|
||||
**Best practices:**
|
||||
- Start with 10-50 training examples
|
||||
- Use diverse examples covering edge cases
|
||||
- Set `max_bootstrapped_demos=3-5` for most tasks
|
||||
- Increase `max_rounds=2-3` for better quality
|
||||
|
||||
**When to use:**
|
||||
- First optimizer to try
|
||||
- You have 10+ labeled examples
|
||||
- Want quick improvements
|
||||
- General-purpose tasks
|
||||
|
||||
### MIPRO (Most Important Prompt Optimization)
|
||||
|
||||
**State-of-the-art optimizer** - Iteratively searches for better instructions.
|
||||
|
||||
**How it works:**
|
||||
1. Generates candidate instructions
|
||||
2. Tests each on validation set
|
||||
3. Selects best-performing instructions
|
||||
4. Iterates to refine further
|
||||
|
||||
**Parameters:**
|
||||
- `metric`: Evaluation metric (required)
|
||||
- `num_candidates`: Instructions to try per iteration (default: 10)
|
||||
- `init_temperature`: Sampling temperature (default: 1.0)
|
||||
- `verbose`: Show progress (default: False)
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import MIPRO
|
||||
|
||||
# Define metric with more nuance
|
||||
def answer_quality(example, pred, trace=None):
|
||||
"""Score answer quality 0-1."""
|
||||
if example.answer.lower() in pred.answer.lower():
|
||||
return 1.0
|
||||
# Partial credit for similar answers
|
||||
return 0.5 if len(set(example.answer.split()) & set(pred.answer.split())) > 0 else 0.0
|
||||
|
||||
# Larger training set (MIPRO benefits from more data)
|
||||
trainset = [...] # 50-200 examples
|
||||
valset = [...] # 20-50 examples
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Optimize with MIPRO
|
||||
optimizer = MIPRO(
|
||||
metric=answer_quality,
|
||||
num_candidates=10,
|
||||
init_temperature=1.0,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
optimized_qa = optimizer.compile(
|
||||
student=qa,
|
||||
trainset=trainset,
|
||||
valset=valset, # MIPRO uses separate validation set
|
||||
num_trials=100 # More trials = better quality
|
||||
)
|
||||
```
|
||||
|
||||
**Best practices:**
|
||||
- Use 50-200 training examples
|
||||
- Separate validation set (20-50 examples)
|
||||
- Run 100-200 trials for best results
|
||||
- Takes 10-30 minutes typically
|
||||
|
||||
**When to use:**
|
||||
- You have 50+ labeled examples
|
||||
- Want state-of-the-art performance
|
||||
- Willing to wait for optimization
|
||||
- Complex reasoning tasks
|
||||
|
||||
### BootstrapFinetune
|
||||
|
||||
**Fine-tune model weights** - Creates training dataset for fine-tuning.
|
||||
|
||||
**How it works:**
|
||||
1. Generates synthetic training data
|
||||
2. Exports data in fine-tuning format
|
||||
3. You fine-tune model separately
|
||||
4. Load fine-tuned model back
|
||||
|
||||
**Parameters:**
|
||||
- `metric`: Evaluation metric (required)
|
||||
- `max_bootstrapped_demos`: Demonstrations to generate (default: 4)
|
||||
- `max_rounds`: Data generation rounds (default: 1)
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import BootstrapFinetune
|
||||
|
||||
# Training data
|
||||
trainset = [...] # 100+ examples recommended
|
||||
|
||||
# Define metric
|
||||
def validate(example, pred, trace=None):
|
||||
return example.answer == pred.answer
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Generate fine-tuning data
|
||||
optimizer = BootstrapFinetune(metric=validate)
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# Exports training data to file
|
||||
# You then fine-tune using your LM provider's API
|
||||
|
||||
# After fine-tuning, load your model:
|
||||
finetuned_lm = dspy.OpenAI(model="ft:gpt-3.5-turbo:your-model-id")
|
||||
dspy.settings.configure(lm=finetuned_lm)
|
||||
```
|
||||
|
||||
**Best practices:**
|
||||
- Use 100+ training examples
|
||||
- Validate on held-out test set
|
||||
- Monitor for overfitting
|
||||
- Compare with prompt-based methods first
|
||||
|
||||
**When to use:**
|
||||
- You have 100+ examples
|
||||
- Latency is critical (fine-tuned models faster)
|
||||
- Task is narrow and well-defined
|
||||
- Prompt optimization isn't enough
|
||||
|
||||
### COPRO (Coordinate Prompt Optimization)
|
||||
|
||||
**Optimize prompts via gradient-free search.**
|
||||
|
||||
**How it works:**
|
||||
1. Generates prompt variants
|
||||
2. Evaluates each variant
|
||||
3. Selects best prompts
|
||||
4. Iterates to refine
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import COPRO
|
||||
|
||||
# Training data
|
||||
trainset = [...]
|
||||
|
||||
# Define metric
|
||||
def metric(example, pred, trace=None):
|
||||
return example.answer == pred.answer
|
||||
|
||||
# Create module
|
||||
qa = dspy.ChainOfThought("question -> answer")
|
||||
|
||||
# Optimize with COPRO
|
||||
optimizer = COPRO(
|
||||
metric=metric,
|
||||
breadth=10, # Candidates per iteration
|
||||
depth=3 # Optimization rounds
|
||||
)
|
||||
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Want prompt optimization
|
||||
- Have 20-100 examples
|
||||
- MIPRO too slow
|
||||
|
||||
### KNNFewShot
|
||||
|
||||
**Simple k-nearest neighbors** - Selects similar examples for each query.
|
||||
|
||||
**How it works:**
|
||||
1. Embeds all training examples
|
||||
2. For each query, finds k most similar examples
|
||||
3. Uses these as few-shot demonstrations
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import KNNFewShot
|
||||
|
||||
trainset = [...]
|
||||
|
||||
# No metric needed - just selects similar examples
|
||||
optimizer = KNNFewShot(k=3)
|
||||
optimized_qa = optimizer.compile(qa, trainset=trainset)
|
||||
|
||||
# For each query, uses 3 most similar examples from trainset
|
||||
```
|
||||
|
||||
**When to use:**
|
||||
- Quick baseline
|
||||
- Have diverse training examples
|
||||
- Similarity is good proxy for helpfulness
|
||||
|
||||
## Writing Metrics
|
||||
|
||||
Metrics are functions that score predictions. They're critical for optimization.
|
||||
|
||||
### Binary Metrics
|
||||
|
||||
```python
|
||||
def exact_match(example, pred, trace=None):
|
||||
"""Return True if prediction exactly matches gold."""
|
||||
return example.answer == pred.answer
|
||||
|
||||
def contains_answer(example, pred, trace=None):
|
||||
"""Return True if prediction contains gold answer."""
|
||||
return example.answer.lower() in pred.answer.lower()
|
||||
```
|
||||
|
||||
### Continuous Metrics
|
||||
|
||||
```python
|
||||
def f1_score(example, pred, trace=None):
|
||||
"""F1 score between prediction and gold."""
|
||||
pred_tokens = set(pred.answer.lower().split())
|
||||
gold_tokens = set(example.answer.lower().split())
|
||||
|
||||
if not pred_tokens:
|
||||
return 0.0
|
||||
|
||||
precision = len(pred_tokens & gold_tokens) / len(pred_tokens)
|
||||
recall = len(pred_tokens & gold_tokens) / len(gold_tokens)
|
||||
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
return 2 * (precision * recall) / (precision + recall)
|
||||
|
||||
def semantic_similarity(example, pred, trace=None):
|
||||
"""Embedding similarity between prediction and gold."""
|
||||
from sentence_transformers import SentenceTransformer
|
||||
model = SentenceTransformer('all-MiniLM-L6-v2')
|
||||
|
||||
emb1 = model.encode(example.answer)
|
||||
emb2 = model.encode(pred.answer)
|
||||
|
||||
similarity = cosine_similarity(emb1, emb2)
|
||||
return similarity
|
||||
```
|
||||
|
||||
### Multi-Factor Metrics
|
||||
|
||||
```python
|
||||
def comprehensive_metric(example, pred, trace=None):
|
||||
"""Combine multiple factors."""
|
||||
score = 0.0
|
||||
|
||||
# Correctness (50%)
|
||||
if example.answer.lower() in pred.answer.lower():
|
||||
score += 0.5
|
||||
|
||||
# Conciseness (25%)
|
||||
if len(pred.answer.split()) <= 20:
|
||||
score += 0.25
|
||||
|
||||
# Citation (25%)
|
||||
if "source:" in pred.answer.lower():
|
||||
score += 0.25
|
||||
|
||||
return score
|
||||
```
|
||||
|
||||
### Using Trace for Debugging
|
||||
|
||||
```python
|
||||
def metric_with_trace(example, pred, trace=None):
|
||||
"""Metric that uses trace for debugging."""
|
||||
is_correct = example.answer == pred.answer
|
||||
|
||||
if trace is not None and not is_correct:
|
||||
# Log failures for analysis
|
||||
print(f"Failed on: {example.question}")
|
||||
print(f"Expected: {example.answer}")
|
||||
print(f"Got: {pred.answer}")
|
||||
|
||||
return is_correct
|
||||
```
|
||||
|
||||
## Evaluation Best Practices
|
||||
|
||||
### Train/Val/Test Split
|
||||
|
||||
```python
|
||||
# Split data
|
||||
trainset = data[:100] # 70%
|
||||
valset = data[100:120] # 15%
|
||||
testset = data[120:] # 15%
|
||||
|
||||
# Optimize on train
|
||||
optimized = optimizer.compile(module, trainset=trainset)
|
||||
|
||||
# Validate during optimization (for MIPRO)
|
||||
optimized = optimizer.compile(module, trainset=trainset, valset=valset)
|
||||
|
||||
# Evaluate on test
|
||||
from dspy.evaluate import Evaluate
|
||||
evaluator = Evaluate(devset=testset, metric=metric)
|
||||
score = evaluator(optimized)
|
||||
```
|
||||
|
||||
### Cross-Validation
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
kfold = KFold(n_splits=5)
|
||||
scores = []
|
||||
|
||||
for train_idx, val_idx in kfold.split(data):
|
||||
trainset = [data[i] for i in train_idx]
|
||||
valset = [data[i] for i in val_idx]
|
||||
|
||||
optimized = optimizer.compile(module, trainset=trainset)
|
||||
score = evaluator(optimized, devset=valset)
|
||||
scores.append(score)
|
||||
|
||||
print(f"Average score: {sum(scores) / len(scores):.2f}")
|
||||
```
|
||||
|
||||
### Comparing Optimizers
|
||||
|
||||
```python
|
||||
results = {}
|
||||
|
||||
for opt_name, optimizer in [
|
||||
("baseline", None),
|
||||
("fewshot", BootstrapFewShot(metric=metric)),
|
||||
("mipro", MIPRO(metric=metric)),
|
||||
]:
|
||||
if optimizer is None:
|
||||
module_opt = module
|
||||
else:
|
||||
module_opt = optimizer.compile(module, trainset=trainset)
|
||||
|
||||
score = evaluator(module_opt, devset=testset)
|
||||
results[opt_name] = score
|
||||
|
||||
print(results)
|
||||
# {'baseline': 0.65, 'fewshot': 0.78, 'mipro': 0.85}
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Custom Optimizer
|
||||
|
||||
```python
|
||||
from dspy.teleprompt import Teleprompter
|
||||
|
||||
class CustomOptimizer(Teleprompter):
|
||||
def __init__(self, metric):
|
||||
self.metric = metric
|
||||
|
||||
def compile(self, student, trainset, **kwargs):
|
||||
# Your optimization logic here
|
||||
# Return optimized student module
|
||||
return student
|
||||
```
|
||||
|
||||
### Multi-Stage Optimization
|
||||
|
||||
```python
|
||||
# Stage 1: Bootstrap few-shot
|
||||
stage1 = BootstrapFewShot(metric=metric, max_bootstrapped_demos=3)
|
||||
optimized1 = stage1.compile(module, trainset=trainset)
|
||||
|
||||
# Stage 2: Instruction tuning
|
||||
stage2 = MIPRO(metric=metric, num_candidates=10)
|
||||
optimized2 = stage2.compile(optimized1, trainset=trainset, valset=valset)
|
||||
|
||||
# Final optimized module
|
||||
final_module = optimized2
|
||||
```
|
||||
|
||||
### Ensemble Optimization
|
||||
|
||||
```python
|
||||
class EnsembleModule(dspy.Module):
|
||||
def __init__(self, modules):
|
||||
super().__init__()
|
||||
self.modules = modules
|
||||
|
||||
def forward(self, question):
|
||||
predictions = [m(question=question).answer for m in self.modules]
|
||||
# Vote or average
|
||||
return dspy.Prediction(answer=max(set(predictions), key=predictions.count))
|
||||
|
||||
# Optimize multiple modules
|
||||
opt1 = BootstrapFewShot(metric=metric).compile(module, trainset=trainset)
|
||||
opt2 = MIPRO(metric=metric).compile(module, trainset=trainset)
|
||||
opt3 = COPRO(metric=metric).compile(module, trainset=trainset)
|
||||
|
||||
# Ensemble
|
||||
ensemble = EnsembleModule([opt1, opt2, opt3])
|
||||
```
|
||||
|
||||
## Optimization Workflow
|
||||
|
||||
### 1. Start with Baseline
|
||||
|
||||
```python
|
||||
# No optimization
|
||||
baseline = dspy.ChainOfThought("question -> answer")
|
||||
baseline_score = evaluator(baseline, devset=testset)
|
||||
print(f"Baseline: {baseline_score}")
|
||||
```
|
||||
|
||||
### 2. Try BootstrapFewShot
|
||||
|
||||
```python
|
||||
# Quick optimization
|
||||
fewshot = BootstrapFewShot(metric=metric, max_bootstrapped_demos=3)
|
||||
optimized = fewshot.compile(baseline, trainset=trainset)
|
||||
fewshot_score = evaluator(optimized, devset=testset)
|
||||
print(f"Few-shot: {fewshot_score} (+{fewshot_score - baseline_score:.2f})")
|
||||
```
|
||||
|
||||
### 3. If More Data Available, Try MIPRO
|
||||
|
||||
```python
|
||||
# State-of-the-art optimization
|
||||
mipro = MIPRO(metric=metric, num_candidates=10)
|
||||
optimized_mipro = mipro.compile(baseline, trainset=trainset, valset=valset)
|
||||
mipro_score = evaluator(optimized_mipro, devset=testset)
|
||||
print(f"MIPRO: {mipro_score} (+{mipro_score - baseline_score:.2f})")
|
||||
```
|
||||
|
||||
### 4. Save Best Model
|
||||
|
||||
```python
|
||||
if mipro_score > fewshot_score:
|
||||
optimized_mipro.save("models/best_model.json")
|
||||
else:
|
||||
optimized.save("models/best_model.json")
|
||||
```
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
### 1. Overfitting to Training Data
|
||||
|
||||
```python
|
||||
# ❌ Bad: Too many demos
|
||||
optimizer = BootstrapFewShot(max_bootstrapped_demos=20) # Overfits!
|
||||
|
||||
# ✅ Good: Moderate demos
|
||||
optimizer = BootstrapFewShot(max_bootstrapped_demos=3-5)
|
||||
```
|
||||
|
||||
### 2. Metric Doesn't Match Task
|
||||
|
||||
```python
|
||||
# ❌ Bad: Binary metric for nuanced task
|
||||
def bad_metric(example, pred, trace=None):
|
||||
return example.answer == pred.answer # Too strict!
|
||||
|
||||
# ✅ Good: Graded metric
|
||||
def good_metric(example, pred, trace=None):
|
||||
return f1_score(example.answer, pred.answer) # Allows partial credit
|
||||
```
|
||||
|
||||
### 3. Insufficient Training Data
|
||||
|
||||
```python
|
||||
# ❌ Bad: Too little data
|
||||
trainset = data[:5] # Not enough!
|
||||
|
||||
# ✅ Good: Sufficient data
|
||||
trainset = data[:50] # Better
|
||||
```
|
||||
|
||||
### 4. No Validation Set
|
||||
|
||||
```python
|
||||
# ❌ Bad: Optimizing on test set
|
||||
optimizer.compile(module, trainset=testset) # Cheating!
|
||||
|
||||
# ✅ Good: Proper splits
|
||||
optimizer.compile(module, trainset=trainset, valset=valset)
|
||||
evaluator(optimized, devset=testset)
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Start simple**: BootstrapFewShot first
|
||||
2. **Use representative data**: Cover edge cases
|
||||
3. **Monitor overfitting**: Validate on held-out set
|
||||
4. **Iterate metrics**: Refine based on failures
|
||||
5. **Save checkpoints**: Don't lose progress
|
||||
6. **Compare to baseline**: Measure improvement
|
||||
7. **Test multiple optimizers**: Find best fit
|
||||
|
||||
## Resources
|
||||
|
||||
- **Paper**: "DSPy: Compiling Declarative Language Model Calls into Self-Improving Pipelines"
|
||||
- **GitHub**: https://github.com/stanfordnlp/dspy
|
||||
- **Discord**: https://discord.gg/XCGy2WDCQB
|
||||
221
skills/mlops/faiss/SKILL.md
Normal file
221
skills/mlops/faiss/SKILL.md
Normal file
@@ -0,0 +1,221 @@
|
||||
---
|
||||
name: faiss
|
||||
description: Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or when you need pure similarity search without metadata. Best for high-performance applications.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
tags: [RAG, FAISS, Similarity Search, Vector Search, Facebook AI, GPU Acceleration, Billion-Scale, K-NN, HNSW, High Performance, Large Scale]
|
||||
dependencies: [faiss-cpu, faiss-gpu, numpy]
|
||||
---
|
||||
|
||||
# FAISS - Efficient Similarity Search
|
||||
|
||||
Facebook AI's library for billion-scale vector similarity search.
|
||||
|
||||
## When to use FAISS
|
||||
|
||||
**Use FAISS when:**
|
||||
- Need fast similarity search on large vector datasets (millions/billions)
|
||||
- GPU acceleration required
|
||||
- Pure vector similarity (no metadata filtering needed)
|
||||
- High throughput, low latency critical
|
||||
- Offline/batch processing of embeddings
|
||||
|
||||
**Metrics**:
|
||||
- **31,700+ GitHub stars**
|
||||
- Meta/Facebook AI Research
|
||||
- **Handles billions of vectors**
|
||||
- **C++** with Python bindings
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Chroma/Pinecone**: Need metadata filtering
|
||||
- **Weaviate**: Need full database features
|
||||
- **Annoy**: Simpler, fewer features
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# CPU only
|
||||
pip install faiss-cpu
|
||||
|
||||
# GPU support
|
||||
pip install faiss-gpu
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
# Create sample data (1000 vectors, 128 dimensions)
|
||||
d = 128
|
||||
nb = 1000
|
||||
vectors = np.random.random((nb, d)).astype('float32')
|
||||
|
||||
# Create index
|
||||
index = faiss.IndexFlatL2(d) # L2 distance
|
||||
index.add(vectors) # Add vectors
|
||||
|
||||
# Search
|
||||
k = 5 # Find 5 nearest neighbors
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
|
||||
print(f"Nearest neighbors: {indices}")
|
||||
print(f"Distances: {distances}")
|
||||
```
|
||||
|
||||
## Index types
|
||||
|
||||
### 1. Flat (exact search)
|
||||
|
||||
```python
|
||||
# L2 (Euclidean) distance
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Inner product (cosine similarity if normalized)
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Slowest, most accurate
|
||||
```
|
||||
|
||||
### 2. IVF (inverted file) - Fast approximate
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# IVF index with 100 clusters
|
||||
nlist = 100
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 3. HNSW (Hierarchical NSW) - Best quality/speed
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# No training needed
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 4. Product Quantization - Memory efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers
|
||||
nbits = 8
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train and add
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
## Save and load
|
||||
|
||||
```python
|
||||
# Save index
|
||||
faiss.write_index(index, "large.index")
|
||||
|
||||
# Load index
|
||||
index = faiss.read_index("large.index")
|
||||
|
||||
# Continue using
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
```python
|
||||
# Single GPU
|
||||
res = faiss.StandardGpuResources()
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Multi-GPU
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# 10-100× faster than CPU
|
||||
```
|
||||
|
||||
## LangChain integration
|
||||
|
||||
```python
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
# Create FAISS vector store
|
||||
vectorstore = FAISS.from_documents(docs, OpenAIEmbeddings())
|
||||
|
||||
# Save
|
||||
vectorstore.save_local("faiss_index")
|
||||
|
||||
# Load
|
||||
vectorstore = FAISS.load_local(
|
||||
"faiss_index",
|
||||
OpenAIEmbeddings(),
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
|
||||
# Search
|
||||
results = vectorstore.similarity_search("query", k=5)
|
||||
```
|
||||
|
||||
## LlamaIndex integration
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
import faiss
|
||||
|
||||
# Create FAISS index
|
||||
d = 1536
|
||||
faiss_index = faiss.IndexFlatL2(d)
|
||||
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Choose right index type** - Flat for <10K, IVF for 10K-1M, HNSW for quality
|
||||
2. **Normalize for cosine** - Use IndexFlatIP with normalized vectors
|
||||
3. **Use GPU for large datasets** - 10-100× faster
|
||||
4. **Save trained indices** - Training is expensive
|
||||
5. **Tune nprobe/ef_search** - Balance speed/accuracy
|
||||
6. **Monitor memory** - PQ for large datasets
|
||||
7. **Batch queries** - Better GPU utilization
|
||||
|
||||
## Performance
|
||||
|
||||
| Index Type | Build Time | Search Time | Memory | Accuracy |
|
||||
|------------|------------|-------------|--------|----------|
|
||||
| Flat | Fast | Slow | High | 100% |
|
||||
| IVF | Medium | Fast | Medium | 95-99% |
|
||||
| HNSW | Slow | Fastest | High | 99% |
|
||||
| PQ | Medium | Fast | Low | 90-95% |
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/faiss ⭐ 31,700+
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
280
skills/mlops/faiss/references/index_types.md
Normal file
280
skills/mlops/faiss/references/index_types.md
Normal file
@@ -0,0 +1,280 @@
|
||||
# FAISS Index Types Guide
|
||||
|
||||
Complete guide to choosing and using FAISS index types.
|
||||
|
||||
## Index selection guide
|
||||
|
||||
| Dataset Size | Index Type | Training | Accuracy | Speed |
|
||||
|--------------|------------|----------|----------|-------|
|
||||
| < 10K | Flat | No | 100% | Slow |
|
||||
| 10K-1M | IVF | Yes | 95-99% | Fast |
|
||||
| 1M-10M | HNSW | No | 99% | Fastest |
|
||||
| > 10M | IVF+PQ | Yes | 90-95% | Fast, low memory |
|
||||
|
||||
## Flat indices (exact search)
|
||||
|
||||
### IndexFlatL2 - L2 (Euclidean) distance
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
d = 128 # Dimension
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Add vectors
|
||||
vectors = np.random.random((1000, d)).astype('float32')
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
k = 5
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset < 10,000 vectors
|
||||
- Need 100% accuracy
|
||||
- Serving as baseline
|
||||
|
||||
### IndexFlatIP - Inner product (cosine similarity)
|
||||
|
||||
```python
|
||||
# For cosine similarity, normalize vectors first
|
||||
import faiss
|
||||
|
||||
d = 128
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Normalize vectors (required for cosine similarity)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
faiss.normalize_L2(query)
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Need cosine similarity
|
||||
- Recommendation systems
|
||||
- Text embeddings
|
||||
|
||||
## IVF indices (inverted file)
|
||||
|
||||
### IndexIVFFlat - Cluster-based search
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# Create IVF index with 100 clusters
|
||||
nlist = 100 # Number of clusters
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10 # Search 10 closest clusters
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `nlist`: Number of clusters (√N to 4√N recommended)
|
||||
- `nprobe`: Clusters to search (1-nlist, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Dataset 10K-1M vectors
|
||||
- Need fast approximate search
|
||||
- Can afford training time
|
||||
|
||||
### Tuning nprobe
|
||||
|
||||
```python
|
||||
# Test different nprobe values
|
||||
for nprobe in [1, 5, 10, 20, 50]:
|
||||
index.nprobe = nprobe
|
||||
distances, indices = index.search(query, k)
|
||||
# Measure recall/speed trade-off
|
||||
```
|
||||
|
||||
**Guidelines:**
|
||||
- `nprobe=1`: Fastest, ~50% recall
|
||||
- `nprobe=10`: Good balance, ~95% recall
|
||||
- `nprobe=nlist`: Exact search (same as Flat)
|
||||
|
||||
## HNSW indices (graph-based)
|
||||
|
||||
### IndexHNSWFlat - Hierarchical NSW
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer (16-64)
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# Optional: Set ef_construction (build time parameter)
|
||||
index.hnsw.efConstruction = 40 # Higher = better quality, slower build
|
||||
|
||||
# Add vectors (no training needed!)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.hnsw.efSearch = 16 # Search time parameter
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `M`: Connections per layer (16-64, default 32)
|
||||
- `efConstruction`: Build quality (40-200, higher = better)
|
||||
- `efSearch`: Search quality (16-512, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Need best quality approximate search
|
||||
- Can afford higher memory (more connections)
|
||||
- Dataset 1M-10M vectors
|
||||
|
||||
## PQ indices (product quantization)
|
||||
|
||||
### IndexPQ - Memory-efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers (divides d)
|
||||
nbits = 8 # Bits per subquantizer
|
||||
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `m`: Subquantizers (d must be divisible by m)
|
||||
- `nbits`: Bits per code (8 or 16)
|
||||
|
||||
**Memory savings:**
|
||||
- Original: d × 4 bytes (float32)
|
||||
- PQ: m bytes
|
||||
- Compression ratio: 4d/m
|
||||
|
||||
**Use when:**
|
||||
- Limited memory
|
||||
- Large datasets (> 10M vectors)
|
||||
- Can accept ~90-95% accuracy
|
||||
|
||||
### IndexIVFPQ - IVF + PQ combined
|
||||
|
||||
```python
|
||||
# Best for very large datasets
|
||||
nlist = 4096
|
||||
m = 8
|
||||
nbits = 8
|
||||
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)
|
||||
|
||||
# Train
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.nprobe = 32
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset > 10M vectors
|
||||
- Need fast search + low memory
|
||||
- Can accept 90-95% accuracy
|
||||
|
||||
## GPU indices
|
||||
|
||||
### Single GPU
|
||||
|
||||
```python
|
||||
import faiss
|
||||
|
||||
# Create CPU index
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
|
||||
# Move to GPU
|
||||
res = faiss.StandardGpuResources() # GPU resources
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Use normally
|
||||
index_gpu.add(vectors)
|
||||
distances, indices = index_gpu.search(query, k)
|
||||
```
|
||||
|
||||
### Multi-GPU
|
||||
|
||||
```python
|
||||
# Use all available GPUs
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# Or specific GPUs
|
||||
gpus = [0, 1, 2, 3] # Use GPUs 0-3
|
||||
index_gpu = faiss.index_cpu_to_gpus_list(index_cpu, gpus)
|
||||
```
|
||||
|
||||
**Speedup:**
|
||||
- Single GPU: 10-50× faster than CPU
|
||||
- Multi-GPU: Near-linear scaling
|
||||
|
||||
## Index factory
|
||||
|
||||
```python
|
||||
# Easy index creation with string descriptors
|
||||
index = faiss.index_factory(d, "IVF100,Flat")
|
||||
index = faiss.index_factory(d, "HNSW32")
|
||||
index = faiss.index_factory(d, "IVF4096,PQ8")
|
||||
|
||||
# Train and use
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
**Common descriptors:**
|
||||
- `"Flat"`: Exact search
|
||||
- `"IVF100,Flat"`: IVF with 100 clusters
|
||||
- `"HNSW32"`: HNSW with M=32
|
||||
- `"IVF4096,PQ8"`: IVF + PQ compression
|
||||
|
||||
## Performance comparison
|
||||
|
||||
### Search speed (1M vectors, k=10)
|
||||
|
||||
| Index | Build Time | Search Time | Memory | Recall |
|
||||
|-------|------------|-------------|--------|--------|
|
||||
| Flat | 0s | 50ms | 512 MB | 100% |
|
||||
| IVF100 | 5s | 2ms | 512 MB | 95% |
|
||||
| HNSW32 | 60s | 1ms | 1GB | 99% |
|
||||
| IVF4096+PQ8 | 30s | 3ms | 32 MB | 90% |
|
||||
|
||||
*CPU (16 cores), 128-dim vectors*
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with Flat** - Baseline for comparison
|
||||
2. **Use IVF for medium datasets** - Good balance
|
||||
3. **Use HNSW for best quality** - If memory allows
|
||||
4. **Add PQ for memory savings** - Large datasets
|
||||
5. **GPU for > 100K vectors** - 10-50× speedup
|
||||
6. **Tune nprobe/efSearch** - Trade-off speed/accuracy
|
||||
7. **Train on representative data** - Better clustering
|
||||
8. **Save trained indices** - Avoid retraining
|
||||
|
||||
## Resources
|
||||
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **Paper**: https://arxiv.org/abs/1702.08734
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user