Compare commits
210 Commits
feat/insig
...
feat/head-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67cf37fc26 | ||
|
|
a2d0d07109 | ||
|
|
aedb773f0d | ||
|
|
aaf8f2d2d2 | ||
|
|
12f4800631 | ||
|
|
57b48a81ca | ||
|
|
7af33accf1 | ||
|
|
3214c05e82 | ||
|
|
4608a7fe4e | ||
|
|
af67ea8800 | ||
|
|
37c3dcf551 | ||
|
|
6a49fbb7da | ||
|
|
eb0b01de7b | ||
|
|
5b1528519c | ||
|
|
52f92eb689 | ||
|
|
7f9dd60c15 | ||
|
|
77da3bbc95 | ||
|
|
bb489a3903 | ||
|
|
167eb824cb | ||
|
|
efb64aee5a | ||
|
|
3045e29232 | ||
|
|
5d7d76025a | ||
|
|
e6c829384e | ||
|
|
5c658a416c | ||
|
|
a130aa8165 | ||
|
|
35d57ed752 | ||
|
|
5785bd3272 | ||
|
|
cf9482984e | ||
|
|
67275641f8 | ||
|
|
3ffaac00dd | ||
|
|
816a3ef6f1 | ||
|
|
a8bf414f4a | ||
|
|
3b312d45c5 | ||
|
|
fcd899f888 | ||
|
|
315f3ea429 | ||
|
|
b7d6eae64c | ||
|
|
b3765c28d0 | ||
|
|
4cfb66bac2 | ||
|
|
0c4cff352a | ||
|
|
503269b85a | ||
|
|
161436cfdd | ||
|
|
24f549a692 | ||
|
|
7a8778ac73 | ||
|
|
4d7d9d9715 | ||
|
|
a9c35f9175 | ||
|
|
31b84213e4 | ||
|
|
2036c22f88 | ||
|
|
7185a66b96 | ||
|
|
2394e18729 | ||
|
|
99f7582175 | ||
|
|
93c5997290 | ||
|
|
2d1a1c1c47 | ||
|
|
71e81728ac | ||
|
|
ebe60646db | ||
|
|
f996d7950b | ||
|
|
ae4a674c84 | ||
|
|
169615abc8 | ||
|
|
7c30ac2141 | ||
|
|
192501528f | ||
|
|
5ae0b731d0 | ||
|
|
d9f373654b | ||
|
|
0efbb137e8 | ||
|
|
cf63b2471f | ||
|
|
f88343a6da | ||
|
|
491605cfea | ||
|
|
3aded1d4e5 | ||
|
|
4f0402ed3a | ||
|
|
ecac6321c4 | ||
|
|
20c6573e0a | ||
|
|
97b1c76b14 | ||
|
|
24a37032fa | ||
|
|
c0520223fd | ||
|
|
1f1caa836a | ||
|
|
b3ea7714f5 | ||
|
|
a7f9721785 | ||
|
|
a5461e07bf | ||
|
|
2e73a9e893 | ||
|
|
26bb56b775 | ||
|
|
95b1130485 | ||
|
|
3fb8938cd3 | ||
|
|
c5e8166c8b | ||
|
|
2b88568653 | ||
|
|
34b4fe495e | ||
|
|
4fdd6c0dac | ||
|
|
60b6abefd9 | ||
|
|
4d53b7ccaa | ||
|
|
081079da62 | ||
|
|
333e4abe30 | ||
|
|
cd77c7100c | ||
|
|
cf810c2950 | ||
|
|
a23bcb81ce | ||
|
|
d07d867718 | ||
|
|
666f2dd486 | ||
|
|
34792dd907 | ||
|
|
7ad6fc8a40 | ||
|
|
f824c10429 | ||
|
|
132e5ec179 | ||
|
|
66d3e6a0c2 | ||
|
|
4a09ae2985 | ||
|
|
8c734f2f27 | ||
|
|
245d174359 | ||
|
|
77f47768dd | ||
|
|
90fa9e54ca | ||
|
|
9d3a44e0e8 | ||
|
|
932d596466 | ||
|
|
d518f40e8b | ||
|
|
f016cfca46 | ||
|
|
b8120df860 | ||
|
|
0df7df52f3 | ||
|
|
bfa27d0a68 | ||
|
|
5a20c486e3 | ||
|
|
78e19ebc95 | ||
|
|
b383cafc44 | ||
|
|
b10ff83566 | ||
|
|
daa1f542f9 | ||
|
|
d507f593d0 | ||
|
|
f210510276 | ||
|
|
19b6f81ee7 | ||
|
|
76545ab365 | ||
|
|
b8c3bc7841 | ||
|
|
a680367568 | ||
|
|
dfd37a4b31 | ||
|
|
5ee9b67d9b | ||
|
|
542faf225f | ||
|
|
5684c68121 | ||
|
|
4be783446a | ||
|
|
8d719b180a | ||
|
|
bf048c8aec | ||
|
|
c5a9d1ef9d | ||
|
|
c7b6f423c7 | ||
|
|
6d34207167 | ||
|
|
fcde9be10d | ||
|
|
3830bbda41 | ||
|
|
4447e7d71a | ||
|
|
7bccd904c7 | ||
|
|
313d522b61 | ||
|
|
9ee4fe41fe | ||
|
|
39ee3512cb | ||
|
|
42673556af | ||
|
|
faab73ad58 | ||
|
|
7e36468511 | ||
|
|
9ba5d399e5 | ||
|
|
306d92a9d7 | ||
|
|
5baae0df88 | ||
|
|
24f6a193e7 | ||
|
|
8c0f8baf32 | ||
|
|
d80c30cc92 | ||
|
|
e64d646bad | ||
|
|
b84f9e410c | ||
|
|
ee5daba061 | ||
|
|
23e84de830 | ||
|
|
48e0dc8791 | ||
|
|
fb0f579b16 | ||
|
|
5a711f32b1 | ||
|
|
4d34427cc7 | ||
|
|
41877183bc | ||
|
|
451a007fb1 | ||
|
|
0a82396718 | ||
|
|
5da55ea1e3 | ||
|
|
064c009deb | ||
|
|
caab1cf453 | ||
|
|
55c70f3508 | ||
|
|
d29249b8fa | ||
|
|
f668e9fc75 | ||
|
|
74fe1e2254 | ||
|
|
348936752a | ||
|
|
69a36a3361 | ||
|
|
8712dd6d1c | ||
|
|
55a21fe37b | ||
|
|
f55f625277 | ||
|
|
9dac85b069 | ||
|
|
99bd69baa8 | ||
|
|
a62a137a4f | ||
|
|
82b18e8ac2 | ||
|
|
0111c9848d | ||
|
|
ab9cadfeee | ||
|
|
8bf28e1441 | ||
|
|
ce28f847ce | ||
|
|
5609117882 | ||
|
|
b4fbb6fe10 | ||
|
|
82d7e9429e | ||
|
|
e2821effb5 | ||
|
|
9742f11fda | ||
|
|
388dd4789c | ||
|
|
fdebca4573 | ||
|
|
479dfc096a | ||
|
|
3c6c11b7c9 | ||
|
|
bc091eb7ef | ||
|
|
f75b1d21b4 | ||
|
|
94053d75a6 | ||
|
|
2a68099675 | ||
|
|
6cd3bc6640 | ||
|
|
211b55815e | ||
|
|
8ae4a6f824 | ||
|
|
b98301677a | ||
|
|
f2fdde5ba4 | ||
|
|
4f56e31dc7 | ||
|
|
6d3804770c | ||
|
|
ab0f4126cf | ||
|
|
68fbae5692 | ||
|
|
32636ecf8a | ||
|
|
3221818b6e | ||
|
|
f90a627f9a | ||
|
|
6a51fd23df | ||
|
|
a1c25046a9 | ||
|
|
d10108f8ca | ||
|
|
8b520f9848 | ||
|
|
a718aed1be | ||
|
|
5f29e7b63c | ||
|
|
ec97f9ad1a |
36
.env.example
36
.env.example
@@ -13,6 +13,38 @@ OPENROUTER_API_KEY=
|
||||
# Examples: anthropic/claude-opus-4.6, openai/gpt-4o, google/gemini-3-flash-preview, zhipuai/glm-4-plus
|
||||
LLM_MODEL=anthropic/claude-opus-4.6
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (z.ai / GLM)
|
||||
# =============================================================================
|
||||
# z.ai provides access to ZhipuAI GLM models (GLM-4-Plus, etc.)
|
||||
# Get your key at: https://z.ai or https://open.bigmodel.cn
|
||||
GLM_API_KEY=
|
||||
# GLM_BASE_URL=https://api.z.ai/api/paas/v4 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (Kimi / Moonshot)
|
||||
# =============================================================================
|
||||
# Kimi Code provides access to Moonshot AI coding models (kimi-k2.5, etc.)
|
||||
# Get your key at: https://platform.kimi.ai (Kimi Code console)
|
||||
# Keys prefixed sk-kimi- use the Kimi Code API (api.kimi.com) by default.
|
||||
# Legacy keys from platform.moonshot.ai need KIMI_BASE_URL override below.
|
||||
KIMI_API_KEY=
|
||||
# KIMI_BASE_URL=https://api.kimi.com/coding/v1 # Default for sk-kimi- keys
|
||||
# KIMI_BASE_URL=https://api.moonshot.ai/v1 # For legacy Moonshot keys
|
||||
# KIMI_BASE_URL=https://api.moonshot.cn/v1 # For Moonshot China keys
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (MiniMax)
|
||||
# =============================================================================
|
||||
# MiniMax provides access to MiniMax models (global endpoint)
|
||||
# Get your key at: https://www.minimax.io
|
||||
MINIMAX_API_KEY=
|
||||
# MINIMAX_BASE_URL=https://api.minimax.io/v1 # Override default base URL
|
||||
|
||||
# MiniMax China endpoint (for users in mainland China)
|
||||
MINIMAX_CN_API_KEY=
|
||||
# MINIMAX_CN_BASE_URL=https://api.minimaxi.com/v1 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
@@ -21,10 +53,6 @@ LLM_MODEL=anthropic/claude-opus-4.6
|
||||
# Get at: https://firecrawl.dev/
|
||||
FIRECRAWL_API_KEY=
|
||||
|
||||
# Nous Research API Key - Vision analysis and multi-model reasoning
|
||||
# Get at: https://inference-api.nousresearch.com/
|
||||
NOUS_API_KEY=
|
||||
|
||||
# FAL.ai API Key - Image generation
|
||||
# Get at: https://fal.ai/
|
||||
FAL_KEY=
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -47,4 +47,5 @@ cli-config.yaml
|
||||
|
||||
# Skills Hub state (lives in ~/.hermes/skills/.hub/ at runtime, but just in case)
|
||||
skills/.hub/
|
||||
ignored/
|
||||
ignored/
|
||||
.worktrees/
|
||||
|
||||
731
AGENTS.md
731
AGENTS.md
@@ -1,78 +1,60 @@
|
||||
# Hermes Agent - Development Guide
|
||||
|
||||
Instructions for AI coding assistants (GitHub Copilot, Cursor, etc.) and human developers.
|
||||
|
||||
Hermes Agent is an AI agent harness with tool-calling capabilities, interactive CLI, messaging integrations, and scheduled tasks.
|
||||
Instructions for AI coding assistants and developers working on the hermes-agent codebase.
|
||||
|
||||
## Development Environment
|
||||
|
||||
**IMPORTANT**: Always use the virtual environment if it exists:
|
||||
```bash
|
||||
source venv/bin/activate # Before running any Python commands
|
||||
source .venv/bin/activate # ALWAYS activate before running Python
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
hermes-agent/
|
||||
├── agent/ # Agent internals (extracted from run_agent.py)
|
||||
│ ├── model_metadata.py # Model context lengths, token estimation
|
||||
├── run_agent.py # AIAgent class — core conversation loop
|
||||
├── model_tools.py # Tool orchestration, _discover_tools(), handle_function_call()
|
||||
├── toolsets.py # Toolset definitions, _HERMES_CORE_TOOLS list
|
||||
├── cli.py # HermesCLI class — interactive CLI orchestrator
|
||||
├── hermes_state.py # SessionDB — SQLite session store (FTS5 search)
|
||||
├── agent/ # Agent internals
|
||||
│ ├── prompt_builder.py # System prompt assembly
|
||||
│ ├── context_compressor.py # Auto context compression
|
||||
│ ├── prompt_caching.py # Anthropic prompt caching
|
||||
│ ├── prompt_builder.py # System prompt assembly (identity, skills index, context files)
|
||||
│ ├── auxiliary_client.py # Auxiliary LLM client (vision, summarization)
|
||||
│ ├── model_metadata.py # Model context lengths, token estimation
|
||||
│ ├── display.py # KawaiiSpinner, tool preview formatting
|
||||
│ ├── skill_commands.py # Skill slash commands (shared CLI/gateway)
|
||||
│ └── trajectory.py # Trajectory saving helpers
|
||||
├── hermes_cli/ # CLI implementation
|
||||
│ ├── main.py # Entry point, command dispatcher
|
||||
│ ├── banner.py # Welcome banner, ASCII art, skills summary
|
||||
│ ├── commands.py # Slash command definitions + autocomplete
|
||||
│ ├── callbacks.py # Interactive prompt callbacks (clarify, sudo, approval)
|
||||
│ ├── setup.py # Interactive setup wizard
|
||||
│ ├── config.py # Config management & migration
|
||||
│ ├── status.py # Status display
|
||||
│ ├── doctor.py # Diagnostics
|
||||
│ ├── gateway.py # Gateway management
|
||||
│ ├── uninstall.py # Uninstaller
|
||||
│ ├── cron.py # Cron job management
|
||||
│ └── skills_hub.py # Skills Hub CLI + /skills slash command
|
||||
├── tools/ # Tool implementations
|
||||
│ ├── registry.py # Central tool registry (schemas, handlers, dispatch)
|
||||
│ ├── approval.py # Dangerous command detection + per-session approval
|
||||
│ ├── environments/ # Terminal execution backends
|
||||
│ │ ├── base.py # BaseEnvironment ABC
|
||||
│ │ ├── local.py # Local execution with interrupt support
|
||||
│ │ ├── docker.py # Docker container execution
|
||||
│ │ ├── ssh.py # SSH remote execution
|
||||
│ │ ├── singularity.py # Singularity/Apptainer + SIF management
|
||||
│ │ ├── modal.py # Modal cloud execution
|
||||
│ │ └── daytona.py # Daytona cloud sandboxes
|
||||
│ ├── terminal_tool.py # Terminal orchestration (sudo, lifecycle, factory)
|
||||
│ ├── todo_tool.py # Planning & task management
|
||||
│ ├── process_registry.py # Background process management
|
||||
│ └── ... # Other tool files
|
||||
├── gateway/ # Messaging platform adapters
|
||||
│ ├── platforms/ # Platform-specific adapters (telegram, discord, slack, whatsapp)
|
||||
│ └── ...
|
||||
├── cron/ # Scheduler implementation
|
||||
├── environments/ # RL training environments (Atropos integration)
|
||||
├── skills/ # Bundled skill sources
|
||||
├── optional-skills/ # Official optional skills (not activated by default)
|
||||
├── cli.py # Interactive CLI orchestrator (HermesCLI class)
|
||||
├── run_agent.py # AIAgent class (core conversation loop)
|
||||
├── model_tools.py # Tool orchestration (thin layer over tools/registry.py)
|
||||
├── toolsets.py # Tool groupings
|
||||
├── toolset_distributions.py # Probability-based tool selection
|
||||
├── hermes_cli/ # CLI subcommands and setup
|
||||
│ ├── main.py # Entry point — all `hermes` subcommands
|
||||
│ ├── config.py # DEFAULT_CONFIG, OPTIONAL_ENV_VARS, migration
|
||||
│ ├── commands.py # Slash command definitions + SlashCommandCompleter
|
||||
│ ├── callbacks.py # Terminal callbacks (clarify, sudo, approval)
|
||||
│ └── setup.py # Interactive setup wizard
|
||||
├── tools/ # Tool implementations (one file per tool)
|
||||
│ ├── registry.py # Central tool registry (schemas, handlers, dispatch)
|
||||
│ ├── approval.py # Dangerous command detection
|
||||
│ ├── terminal_tool.py # Terminal orchestration
|
||||
│ ├── process_registry.py # Background process management
|
||||
│ ├── file_tools.py # File read/write/search/patch
|
||||
│ ├── web_tools.py # Firecrawl search/extract
|
||||
│ ├── browser_tool.py # Browserbase browser automation
|
||||
│ ├── code_execution_tool.py # execute_code sandbox
|
||||
│ ├── delegate_tool.py # Subagent delegation
|
||||
│ ├── mcp_tool.py # MCP client (~1050 lines)
|
||||
│ └── environments/ # Terminal backends (local, docker, ssh, modal, daytona, singularity)
|
||||
├── gateway/ # Messaging platform gateway
|
||||
│ ├── run.py # Main loop, slash commands, message dispatch
|
||||
│ ├── session.py # SessionStore — conversation persistence
|
||||
│ └── platforms/ # Adapters: telegram, discord, slack, whatsapp, homeassistant, signal
|
||||
├── cron/ # Scheduler (jobs.py, scheduler.py)
|
||||
├── environments/ # RL training environments (Atropos)
|
||||
├── tests/ # Pytest suite (~2500+ tests)
|
||||
└── batch_runner.py # Parallel batch processing
|
||||
```
|
||||
|
||||
**User Configuration** (stored in `~/.hermes/`):
|
||||
- `~/.hermes/config.yaml` - Settings (model, terminal, toolsets, etc.)
|
||||
- `~/.hermes/.env` - API keys and secrets
|
||||
- `~/.hermes/pairing/` - DM pairing data
|
||||
- `~/.hermes/hooks/` - Custom event hooks
|
||||
- `~/.hermes/image_cache/` - Cached user images
|
||||
- `~/.hermes/audio_cache/` - Cached user voice messages
|
||||
- `~/.hermes/sticker_cache.json` - Telegram sticker descriptions
|
||||
**User config:** `~/.hermes/config.yaml` (settings), `~/.hermes/.env` (API keys)
|
||||
|
||||
## File Dependency Chain
|
||||
|
||||
@@ -86,600 +68,175 @@ model_tools.py (imports tools/registry + triggers tool discovery)
|
||||
run_agent.py, cli.py, batch_runner.py, environments/
|
||||
```
|
||||
|
||||
Each tool file co-locates its schema, handler, and registration. `model_tools.py` is a thin orchestration layer.
|
||||
|
||||
---
|
||||
|
||||
## AIAgent Class
|
||||
|
||||
The main agent is implemented in `run_agent.py`:
|
||||
## AIAgent Class (run_agent.py)
|
||||
|
||||
```python
|
||||
class AIAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_iterations: int = 60, # Max tool-calling loops
|
||||
def __init__(self,
|
||||
model: str = "anthropic/claude-opus-4.6",
|
||||
max_iterations: int = 90,
|
||||
enabled_toolsets: list = None,
|
||||
disabled_toolsets: list = None,
|
||||
verbose_logging: bool = False,
|
||||
quiet_mode: bool = False, # Suppress progress output
|
||||
tool_progress_callback: callable = None, # Called on each tool use
|
||||
):
|
||||
# Initialize OpenAI client, load tools based on toolsets
|
||||
...
|
||||
|
||||
def chat(self, user_message: str, task_id: str = None) -> str:
|
||||
# Main entry point - runs the agent loop
|
||||
...
|
||||
quiet_mode: bool = False,
|
||||
save_trajectories: bool = False,
|
||||
platform: str = None, # "cli", "telegram", etc.
|
||||
session_id: str = None,
|
||||
skip_context_files: bool = False,
|
||||
skip_memory: bool = False,
|
||||
# ... plus provider, api_mode, callbacks, routing params
|
||||
): ...
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
"""Simple interface — returns final response string."""
|
||||
|
||||
def run_conversation(self, user_message: str, system_message: str = None,
|
||||
conversation_history: list = None, task_id: str = None) -> dict:
|
||||
"""Full interface — returns dict with final_response + messages."""
|
||||
```
|
||||
|
||||
### Agent Loop
|
||||
|
||||
The core loop in `_run_agent_loop()`:
|
||||
|
||||
```
|
||||
1. Add user message to conversation
|
||||
2. Call LLM with tools
|
||||
3. If LLM returns tool calls:
|
||||
- Execute each tool
|
||||
- Add tool results to conversation
|
||||
- Go to step 2
|
||||
4. If LLM returns text response:
|
||||
- Return response to user
|
||||
```
|
||||
The core loop is inside `run_conversation()` — entirely synchronous:
|
||||
|
||||
```python
|
||||
while turns < max_turns:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tool_schemas,
|
||||
)
|
||||
|
||||
while api_call_count < self.max_iterations and self.iteration_budget.remaining > 0:
|
||||
response = client.chat.completions.create(model=model, messages=messages, tools=tool_schemas)
|
||||
if response.tool_calls:
|
||||
for tool_call in response.tool_calls:
|
||||
result = await execute_tool(tool_call)
|
||||
result = handle_function_call(tool_call.name, tool_call.args, task_id)
|
||||
messages.append(tool_result_message(result))
|
||||
turns += 1
|
||||
api_call_count += 1
|
||||
else:
|
||||
return response.content
|
||||
```
|
||||
|
||||
### Conversation Management
|
||||
|
||||
Messages are stored as a list of dicts following OpenAI format:
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant..."},
|
||||
{"role": "user", "content": "Search for Python tutorials"},
|
||||
{"role": "assistant", "content": None, "tool_calls": [...]},
|
||||
{"role": "tool", "tool_call_id": "...", "content": "..."},
|
||||
{"role": "assistant", "content": "Here's what I found..."},
|
||||
]
|
||||
```
|
||||
|
||||
### Reasoning Model Support
|
||||
|
||||
For models that support chain-of-thought reasoning:
|
||||
- Extract `reasoning_content` from API responses
|
||||
- Store in `assistant_msg["reasoning"]` for trajectory export
|
||||
- Pass back via `reasoning_content` field on subsequent turns
|
||||
Messages follow OpenAI format: `{"role": "system/user/assistant/tool", ...}`. Reasoning content is stored in `assistant_msg["reasoning"]`.
|
||||
|
||||
---
|
||||
|
||||
## CLI Architecture (cli.py)
|
||||
|
||||
The interactive CLI uses:
|
||||
- **Rich** - For the welcome banner and styled panels
|
||||
- **prompt_toolkit** - For fixed input area with history, `patch_stdout`, slash command autocomplete, and floating completion menus
|
||||
- **KawaiiSpinner** (in run_agent.py) - Animated kawaii faces during API calls; clean `┊` activity feed for tool execution results
|
||||
|
||||
Key components:
|
||||
- `HermesCLI` class - Main CLI controller with commands and conversation loop
|
||||
- `SlashCommandCompleter` - Autocomplete dropdown for `/commands` (type `/` to see all)
|
||||
- `agent/skill_commands.py` - Scans skills and builds invocation messages (shared with gateway)
|
||||
- `load_cli_config()` - Loads config, sets environment variables for terminal
|
||||
- `build_welcome_banner()` - Displays ASCII art logo, tools, and skills summary
|
||||
|
||||
CLI UX notes:
|
||||
- Thinking spinner (during LLM API call) shows animated kawaii face + verb (`(⌐■_■) deliberating...`)
|
||||
- When LLM returns tool calls, the spinner clears silently (no "got it!" noise)
|
||||
- Tool execution results appear as a clean activity feed: `┊ {emoji} {verb} {detail} {duration}`
|
||||
- "got it!" only appears when the LLM returns a final text response (`⚕ ready`)
|
||||
- The prompt shows `⚕ ❯` when the agent is working, `❯` when idle
|
||||
- Pasting 5+ lines auto-saves to `~/.hermes/pastes/` and collapses to a reference
|
||||
- Multi-line input via Alt+Enter or Ctrl+J
|
||||
- `/commands` - Process user commands like `/help`, `/clear`, `/personality`, etc.
|
||||
- `/skill-name` - Invoke installed skills directly (e.g., `/axolotl`, `/gif-search`)
|
||||
|
||||
CLI uses `quiet_mode=True` when creating AIAgent to suppress verbose logging.
|
||||
|
||||
### Skill Slash Commands
|
||||
|
||||
Every installed skill in `~/.hermes/skills/` is automatically registered as a slash command.
|
||||
The skill name (from frontmatter or folder name) becomes the command: `axolotl` → `/axolotl`.
|
||||
|
||||
Implementation (`agent/skill_commands.py`, shared between CLI and gateway):
|
||||
1. `scan_skill_commands()` scans all SKILL.md files at startup
|
||||
2. `build_skill_invocation_message()` loads the SKILL.md content and builds a user-turn message
|
||||
3. The message includes the full skill content, a list of supporting files (not loaded), and the user's instruction
|
||||
4. Supporting files can be loaded on demand via the `skill_view` tool
|
||||
5. Injected as a **user message** (not system prompt) to preserve prompt caching
|
||||
- **Rich** for banner/panels, **prompt_toolkit** for input with autocomplete
|
||||
- **KawaiiSpinner** (`agent/display.py`) — animated faces during API calls, `┊` activity feed for tool results
|
||||
- `load_cli_config()` in cli.py merges hardcoded defaults + user config YAML
|
||||
- `process_command()` is a method on `HermesCLI` (not in commands.py)
|
||||
- Skill slash commands: `agent/skill_commands.py` scans `~/.hermes/skills/`, injects as **user message** (not system prompt) to preserve prompt caching
|
||||
|
||||
### Adding CLI Commands
|
||||
|
||||
1. Add to `COMMANDS` dict with description
|
||||
2. Add handler in `process_command()` method
|
||||
3. For persistent settings, use `save_config_value()` to update config
|
||||
|
||||
---
|
||||
|
||||
## Hermes CLI Commands
|
||||
|
||||
The unified `hermes` command provides all functionality:
|
||||
|
||||
| Command | Description |
|
||||
|---------|-------------|
|
||||
| `hermes` | Interactive chat (default) |
|
||||
| `hermes chat -q "..."` | Single query mode |
|
||||
| `hermes setup` | Configure API keys and settings |
|
||||
| `hermes config` | View current configuration |
|
||||
| `hermes config edit` | Open config in editor |
|
||||
| `hermes config set KEY VAL` | Set a specific value |
|
||||
| `hermes config check` | Check for missing config |
|
||||
| `hermes config migrate` | Prompt for missing config interactively |
|
||||
| `hermes status` | Show configuration status |
|
||||
| `hermes doctor` | Diagnose issues |
|
||||
| `hermes update` | Update to latest (checks for new config) |
|
||||
| `hermes uninstall` | Uninstall (can keep configs for reinstall) |
|
||||
| `hermes gateway` | Start gateway (messaging + cron scheduler) |
|
||||
| `hermes gateway setup` | Configure messaging platforms interactively |
|
||||
| `hermes gateway install` | Install gateway as system service |
|
||||
| `hermes cron list` | View scheduled jobs |
|
||||
| `hermes cron status` | Check if cron scheduler is running |
|
||||
| `hermes version` | Show version info |
|
||||
| `hermes pairing list/approve/revoke` | Manage DM pairing codes |
|
||||
|
||||
---
|
||||
|
||||
## Messaging Gateway
|
||||
|
||||
The gateway connects Hermes to Telegram, Discord, Slack, and WhatsApp.
|
||||
|
||||
### Setup
|
||||
|
||||
The interactive setup wizard handles platform configuration:
|
||||
|
||||
```bash
|
||||
hermes gateway setup # Arrow-key menu of all platforms, configure tokens/allowlists/home channels
|
||||
```
|
||||
|
||||
This is the recommended way to configure messaging. It shows which platforms are already set up, walks through each one interactively, and offers to start/restart the gateway service at the end.
|
||||
|
||||
Platforms can also be configured manually in `~/.hermes/.env`:
|
||||
|
||||
### Configuration (in `~/.hermes/.env`):
|
||||
|
||||
```bash
|
||||
# Telegram
|
||||
TELEGRAM_BOT_TOKEN=123456:ABC-DEF... # From @BotFather
|
||||
TELEGRAM_ALLOWED_USERS=123456789,987654 # Comma-separated user IDs (from @userinfobot)
|
||||
|
||||
# Discord
|
||||
DISCORD_BOT_TOKEN=MTIz... # From Developer Portal
|
||||
DISCORD_ALLOWED_USERS=123456789012345678 # Comma-separated user IDs
|
||||
|
||||
# Agent Behavior
|
||||
HERMES_MAX_ITERATIONS=60 # Max tool-calling iterations
|
||||
MESSAGING_CWD=/home/myuser # Terminal working directory for messaging
|
||||
|
||||
# Tool progress is configured in config.yaml (display.tool_progress: off|new|all|verbose)
|
||||
```
|
||||
|
||||
### Working Directory Behavior
|
||||
|
||||
- **CLI (`hermes` command)**: Uses current directory (`.` → `os.getcwd()`)
|
||||
- **Messaging (Telegram/Discord)**: Uses `MESSAGING_CWD` (default: home directory)
|
||||
|
||||
This is intentional: CLI users are in a terminal and expect the agent to work in their current directory, while messaging users need a consistent starting location.
|
||||
|
||||
### Security (User Allowlists):
|
||||
|
||||
**IMPORTANT**: By default, the gateway denies all users who are not in an allowlist or paired via DM.
|
||||
|
||||
The gateway checks `{PLATFORM}_ALLOWED_USERS` environment variables:
|
||||
- If set: Only listed user IDs can interact with the bot
|
||||
- If unset: All users are denied unless `GATEWAY_ALLOW_ALL_USERS=true` is set
|
||||
|
||||
Users can find their IDs:
|
||||
- **Telegram**: Message [@userinfobot](https://t.me/userinfobot)
|
||||
- **Discord**: Enable Developer Mode, right-click name → Copy ID
|
||||
|
||||
### DM Pairing System
|
||||
|
||||
Instead of static allowlists, users can pair via one-time codes:
|
||||
1. Unknown user DMs the bot → receives pairing code
|
||||
2. Owner runs `hermes pairing approve <platform> <code>`
|
||||
3. User is permanently authorized
|
||||
|
||||
Security: 8-char codes, 1-hour expiry, rate-limited (1/10min/user), max 3 pending per platform, lockout after 5 failed attempts, `chmod 0600` on data files.
|
||||
|
||||
Files: `gateway/pairing.py`, `hermes_cli/pairing.py`
|
||||
|
||||
### Event Hooks
|
||||
|
||||
Hooks fire at lifecycle points. Place hook directories in `~/.hermes/hooks/`:
|
||||
|
||||
```
|
||||
~/.hermes/hooks/my-hook/
|
||||
├── HOOK.yaml # name, description, events list
|
||||
└── handler.py # async def handle(event_type, context): ...
|
||||
```
|
||||
|
||||
Events: `gateway:startup`, `session:start`, `session:reset`, `agent:start`, `agent:step`, `agent:end`, `command:*`
|
||||
|
||||
The `agent:step` event fires each iteration of the tool-calling loop with tool names and results.
|
||||
|
||||
Files: `gateway/hooks.py`
|
||||
|
||||
### Tool Progress Notifications
|
||||
|
||||
When `tool_progress` is enabled in `config.yaml`, the bot sends status messages as it works:
|
||||
- `💻 \`ls -la\`...` (terminal commands show the actual command)
|
||||
- `🔍 web_search...`
|
||||
- `📄 web_extract...`
|
||||
- `🐍 execute_code...` (programmatic tool calling sandbox)
|
||||
- `🔀 delegate_task...` (subagent delegation)
|
||||
- `❓ clarify...` (user question, CLI-only)
|
||||
|
||||
Modes:
|
||||
- `new`: Only when switching to a different tool (less spam)
|
||||
- `all`: Every single tool call
|
||||
|
||||
### Typing Indicator
|
||||
|
||||
The gateway keeps the "typing..." indicator active throughout processing, refreshing every 4 seconds. This lets users know the bot is working even during long tool-calling sequences.
|
||||
|
||||
### Platform Toolsets:
|
||||
|
||||
Each platform has a dedicated toolset in `toolsets.py`:
|
||||
- `hermes-telegram`: Full tools including terminal (with safety checks)
|
||||
- `hermes-discord`: Full tools including terminal
|
||||
- `hermes-whatsapp`: Full tools including terminal
|
||||
|
||||
---
|
||||
|
||||
## Configuration System
|
||||
|
||||
Configuration files are stored in `~/.hermes/` for easy user access:
|
||||
- `~/.hermes/config.yaml` - All settings (model, terminal, compression, etc.)
|
||||
- `~/.hermes/.env` - API keys and secrets
|
||||
|
||||
### Adding New Configuration Options
|
||||
|
||||
When adding new configuration variables, you MUST follow this process:
|
||||
|
||||
#### For config.yaml options:
|
||||
|
||||
1. Add to `DEFAULT_CONFIG` in `hermes_cli/config.py`
|
||||
2. **CRITICAL**: Bump `_config_version` in `DEFAULT_CONFIG` when adding required fields
|
||||
3. This triggers migration prompts for existing users on next `hermes update` or `hermes setup`
|
||||
|
||||
Example:
|
||||
```python
|
||||
DEFAULT_CONFIG = {
|
||||
# ... existing config ...
|
||||
|
||||
"new_feature": {
|
||||
"enabled": True,
|
||||
"option": "default_value",
|
||||
},
|
||||
|
||||
# BUMP THIS when adding required fields
|
||||
"_config_version": 2, # Was 1, now 2
|
||||
}
|
||||
```
|
||||
|
||||
#### For .env variables (API keys/secrets):
|
||||
|
||||
1. Add to `REQUIRED_ENV_VARS` or `OPTIONAL_ENV_VARS` in `hermes_cli/config.py`
|
||||
2. Include metadata for the migration system:
|
||||
|
||||
```python
|
||||
OPTIONAL_ENV_VARS = {
|
||||
# ... existing vars ...
|
||||
"NEW_API_KEY": {
|
||||
"description": "What this key is for",
|
||||
"prompt": "Display name in prompts",
|
||||
"url": "https://where-to-get-it.com/",
|
||||
"tools": ["tools_it_enables"], # What tools need this
|
||||
"password": True, # Mask input
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
#### Update related files:
|
||||
|
||||
- `hermes_cli/setup.py` - Add prompts in the setup wizard
|
||||
- `cli-config.yaml.example` - Add example with comments
|
||||
- Update README.md if user-facing
|
||||
|
||||
### Config Version Migration
|
||||
|
||||
The system uses `_config_version` to detect outdated configs:
|
||||
|
||||
1. `check_for_missing_config()` compares user config to `DEFAULT_CONFIG`
|
||||
2. `migrate_config()` interactively prompts for missing values
|
||||
3. Called automatically by `hermes update` and optionally by `hermes setup`
|
||||
|
||||
---
|
||||
|
||||
## Environment Variables
|
||||
|
||||
API keys are loaded from `~/.hermes/.env`:
|
||||
- `OPENROUTER_API_KEY` - Main LLM API access (primary provider)
|
||||
- `FIRECRAWL_API_KEY` - Web search/extract tools
|
||||
- `FIRECRAWL_API_URL` - Self-hosted Firecrawl endpoint (optional)
|
||||
- `BROWSERBASE_API_KEY` / `BROWSERBASE_PROJECT_ID` - Browser automation
|
||||
- `FAL_KEY` - Image generation (FLUX model)
|
||||
- `NOUS_API_KEY` - Vision and Mixture-of-Agents tools
|
||||
|
||||
Terminal tool configuration (in `~/.hermes/config.yaml`):
|
||||
- `terminal.backend` - Backend: local, docker, singularity, modal, daytona, or ssh
|
||||
- `terminal.cwd` - Working directory ("." = host CWD for local only; for remote backends set an absolute path inside the target, or omit to use the backend's default)
|
||||
- `terminal.docker_image` - Image for Docker backend
|
||||
- `terminal.singularity_image` - Image for Singularity backend
|
||||
- `terminal.modal_image` - Image for Modal backend
|
||||
- `terminal.daytona_image` - Image for Daytona backend
|
||||
- `DAYTONA_API_KEY` - API key for Daytona backend (in .env)
|
||||
- SSH: `TERMINAL_SSH_HOST`, `TERMINAL_SSH_USER`, `TERMINAL_SSH_KEY` in .env
|
||||
|
||||
Agent behavior (in `~/.hermes/.env`):
|
||||
- `HERMES_MAX_ITERATIONS` - Max tool-calling iterations (default: 60)
|
||||
- `MESSAGING_CWD` - Working directory for messaging platforms (default: ~)
|
||||
- `display.tool_progress` in config.yaml - Tool progress: `off`, `new`, `all`, `verbose`
|
||||
- `OPENAI_API_KEY` - Voice transcription (Whisper STT)
|
||||
- `SLACK_BOT_TOKEN` / `SLACK_APP_TOKEN` - Slack integration (Socket Mode)
|
||||
- `SLACK_ALLOWED_USERS` - Comma-separated Slack user IDs
|
||||
- `HERMES_HUMAN_DELAY_MODE` - Response pacing: off/natural/custom
|
||||
- `HERMES_HUMAN_DELAY_MIN_MS` / `HERMES_HUMAN_DELAY_MAX_MS` - Custom delay range
|
||||
|
||||
### Dangerous Command Approval
|
||||
|
||||
The terminal tool includes safety checks for potentially destructive commands (e.g., `rm -rf`, `DROP TABLE`, `chmod 777`, etc.):
|
||||
|
||||
**Behavior by Backend:**
|
||||
- **Docker/Singularity/Modal**: Commands run unrestricted (isolated containers)
|
||||
- **Local/SSH**: Dangerous commands trigger approval flow
|
||||
|
||||
**Approval Flow (CLI):**
|
||||
```
|
||||
⚠️ Potentially dangerous command detected: recursive delete
|
||||
rm -rf /tmp/test
|
||||
|
||||
[o]nce | [s]ession | [a]lways | [d]eny
|
||||
Choice [o/s/a/D]:
|
||||
```
|
||||
|
||||
**Approval Flow (Messaging):**
|
||||
- Command is blocked with explanation
|
||||
- Agent explains the command was blocked for safety
|
||||
- User must add the pattern to their allowlist via `hermes config edit` or run the command directly on their machine
|
||||
|
||||
**Configuration:**
|
||||
- `command_allowlist` in `~/.hermes/config.yaml` stores permanently allowed patterns
|
||||
- Add patterns via "always" approval or edit directly
|
||||
|
||||
**Sudo Handling (Messaging):**
|
||||
- If sudo fails over messaging, output includes tip to add `SUDO_PASSWORD` to `~/.hermes/.env`
|
||||
|
||||
---
|
||||
|
||||
## Background Process Management
|
||||
|
||||
The `process` tool works alongside `terminal` for managing long-running background processes:
|
||||
|
||||
**Starting a background process:**
|
||||
```python
|
||||
terminal(command="pytest -v tests/", background=true)
|
||||
# Returns: {"session_id": "proc_abc123", "pid": 12345, ...}
|
||||
```
|
||||
|
||||
**Managing it with the process tool:**
|
||||
- `process(action="list")` -- show all running/recent processes
|
||||
- `process(action="poll", session_id="proc_abc123")` -- check status + new output
|
||||
- `process(action="log", session_id="proc_abc123")` -- full output with pagination
|
||||
- `process(action="wait", session_id="proc_abc123", timeout=600)` -- block until done
|
||||
- `process(action="kill", session_id="proc_abc123")` -- terminate
|
||||
- `process(action="write", session_id="proc_abc123", data="y")` -- send stdin
|
||||
- `process(action="submit", session_id="proc_abc123", data="yes")` -- send + Enter
|
||||
|
||||
**Key behaviors:**
|
||||
- Background processes execute through the configured terminal backend (local/Docker/Modal/Daytona/SSH/Singularity) -- never directly on the host unless `TERMINAL_ENV=local`
|
||||
- The `wait` action blocks the tool call until the process finishes, times out, or is interrupted by a new user message
|
||||
- PTY mode (`pty=true` on terminal) enables interactive CLI tools (Codex, Claude Code)
|
||||
- In RL training, background processes are auto-killed when the episode ends (`tool_context.cleanup()`)
|
||||
- In the gateway, sessions with active background processes are exempt from idle reset
|
||||
- The process registry checkpoints to `~/.hermes/processes.json` for crash recovery
|
||||
|
||||
Files: `tools/process_registry.py` (registry + handler), `tools/terminal_tool.py` (spawn integration)
|
||||
1. Add to `COMMANDS` dict in `hermes_cli/commands.py`
|
||||
2. Add handler in `HermesCLI.process_command()` in `cli.py`
|
||||
3. For persistent settings, use `save_config_value()` in `cli.py`
|
||||
|
||||
---
|
||||
|
||||
## Adding New Tools
|
||||
|
||||
Adding a tool requires changes in **2 files** (the tool file and `toolsets.py`):
|
||||
|
||||
1. **Create `tools/your_tool.py`** with handler, schema, check function, and registry call:
|
||||
Requires changes in **3 files**:
|
||||
|
||||
**1. Create `tools/your_tool.py`:**
|
||||
```python
|
||||
# tools/example_tool.py
|
||||
import json
|
||||
import os
|
||||
import json, os
|
||||
from tools.registry import registry
|
||||
|
||||
def check_example_requirements() -> bool:
|
||||
"""Check if required API keys/dependencies are available."""
|
||||
def check_requirements() -> bool:
|
||||
return bool(os.getenv("EXAMPLE_API_KEY"))
|
||||
|
||||
def example_tool(param: str, task_id: str = None) -> str:
|
||||
"""Execute the tool and return JSON string result."""
|
||||
try:
|
||||
result = {"success": True, "data": "..."}
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
EXAMPLE_SCHEMA = {
|
||||
"name": "example_tool",
|
||||
"description": "Does something useful.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param": {"type": "string", "description": "The parameter"}
|
||||
},
|
||||
"required": ["param"]
|
||||
}
|
||||
}
|
||||
return json.dumps({"success": True, "data": "..."})
|
||||
|
||||
registry.register(
|
||||
name="example_tool",
|
||||
toolset="example",
|
||||
schema=EXAMPLE_SCHEMA,
|
||||
handler=lambda args, **kw: example_tool(
|
||||
param=args.get("param", ""), task_id=kw.get("task_id")),
|
||||
check_fn=check_example_requirements,
|
||||
schema={"name": "example_tool", "description": "...", "parameters": {...}},
|
||||
handler=lambda args, **kw: example_tool(param=args.get("param", ""), task_id=kw.get("task_id")),
|
||||
check_fn=check_requirements,
|
||||
requires_env=["EXAMPLE_API_KEY"],
|
||||
)
|
||||
```
|
||||
|
||||
2. **Add to `toolsets.py`**: Add `"example_tool"` to `_HERMES_CORE_TOOLS` if it should be in all platform toolsets, or create a new toolset entry.
|
||||
**2. Add import** in `model_tools.py` `_discover_tools()` list.
|
||||
|
||||
3. **Add discovery import** in `model_tools.py`'s `_discover_tools()` list: `"tools.example_tool"`.
|
||||
**3. Add to `toolsets.py`** — either `_HERMES_CORE_TOOLS` (all platforms) or a new toolset.
|
||||
|
||||
That's it. The registry handles schema collection, dispatch, availability checking, and error wrapping automatically. No edits to `TOOLSET_REQUIREMENTS`, `handle_function_call()`, `get_all_tool_names()`, or any other data structure.
|
||||
The registry handles schema collection, dispatch, availability checking, and error wrapping. All handlers MUST return a JSON string.
|
||||
|
||||
**Optional:** Add to `OPTIONAL_ENV_VARS` in `hermes_cli/config.py` for the setup wizard, and to `toolset_distributions.py` for batch processing.
|
||||
|
||||
**Special case: tools that need agent-level state** (like `todo`, `memory`):
|
||||
These are intercepted by `run_agent.py`'s tool dispatch loop *before* `handle_function_call()`. The registry still holds their schemas, but dispatch returns a stub error as a safety fallback. See `todo_tool.py` for the pattern.
|
||||
|
||||
All tool handlers MUST return a JSON string. The registry's `dispatch()` wraps all exceptions in `{"error": "..."}` automatically.
|
||||
|
||||
### Dynamic Tool Availability
|
||||
|
||||
Tools declare their requirements at registration time via `check_fn` and `requires_env`. The registry checks `check_fn()` when building tool definitions -- tools whose check fails are silently excluded.
|
||||
|
||||
### Stateful Tools
|
||||
|
||||
Tools that maintain state (terminal, browser) require:
|
||||
- `task_id` parameter for session isolation between concurrent tasks
|
||||
- `cleanup_*()` function to release resources
|
||||
- Cleanup is called automatically in run_agent.py after conversation completes
|
||||
**Agent-level tools** (todo, memory): intercepted by `run_agent.py` before `handle_function_call()`. See `todo_tool.py` for the pattern.
|
||||
|
||||
---
|
||||
|
||||
## Trajectory Format
|
||||
## Adding Configuration
|
||||
|
||||
Conversations are saved in ShareGPT format for training:
|
||||
```json
|
||||
{"from": "system", "value": "System prompt with <tools>...</tools>"}
|
||||
{"from": "human", "value": "User message"}
|
||||
{"from": "gpt", "value": "<think>reasoning</think>\n<tool_call>{...}</tool_call>"}
|
||||
{"from": "tool", "value": "<tool_response>{...}</tool_response>"}
|
||||
{"from": "gpt", "value": "Final response"}
|
||||
```
|
||||
|
||||
Tool calls use `<tool_call>` XML tags, responses use `<tool_response>` tags, reasoning uses `<think>` tags.
|
||||
|
||||
### Trajectory Export
|
||||
### config.yaml options:
|
||||
1. Add to `DEFAULT_CONFIG` in `hermes_cli/config.py`
|
||||
2. Bump `_config_version` (currently 5) to trigger migration for existing users
|
||||
|
||||
### .env variables:
|
||||
1. Add to `OPTIONAL_ENV_VARS` in `hermes_cli/config.py` with metadata:
|
||||
```python
|
||||
agent = AIAgent(save_trajectories=True)
|
||||
agent.chat("Do something")
|
||||
# Saves to trajectories/*.jsonl in ShareGPT format
|
||||
"NEW_API_KEY": {
|
||||
"description": "What it's for",
|
||||
"prompt": "Display name",
|
||||
"url": "https://...",
|
||||
"password": True,
|
||||
"category": "tool", # provider, tool, messaging, setting
|
||||
},
|
||||
```
|
||||
|
||||
### Config loaders (two separate systems):
|
||||
|
||||
| Loader | Used by | Location |
|
||||
|--------|---------|----------|
|
||||
| `load_cli_config()` | CLI mode | `cli.py` |
|
||||
| `load_config()` | `hermes tools`, `hermes setup` | `hermes_cli/config.py` |
|
||||
| Direct YAML load | Gateway | `gateway/run.py` |
|
||||
|
||||
---
|
||||
|
||||
## Batch Processing (batch_runner.py)
|
||||
## Important Policies
|
||||
|
||||
For processing multiple prompts:
|
||||
- Parallel execution with multiprocessing
|
||||
- Content-based resume for fault tolerance (matches on prompt text, not indices)
|
||||
- Toolset distributions control probabilistic tool availability per prompt
|
||||
- Output: `data/<run_name>/trajectories.jsonl` (combined) + individual batch files
|
||||
### Prompt Caching Must Not Break
|
||||
|
||||
Hermes-Agent ensures caching remains valid throughout a conversation. **Do NOT implement changes that would:**
|
||||
- Alter past context mid-conversation
|
||||
- Change toolsets mid-conversation
|
||||
- Reload memories or rebuild system prompts mid-conversation
|
||||
|
||||
Cache-breaking forces dramatically higher costs. The ONLY time we alter context is during context compression.
|
||||
|
||||
### Working Directory Behavior
|
||||
- **CLI**: Uses current directory (`.` → `os.getcwd()`)
|
||||
- **Messaging**: Uses `MESSAGING_CWD` env var (default: home directory)
|
||||
|
||||
---
|
||||
|
||||
## Known Pitfalls
|
||||
|
||||
### DO NOT use `simple_term_menu` for interactive menus
|
||||
Rendering bugs in tmux/iTerm2 — ghosting on scroll. Use `curses` (stdlib) instead. See `hermes_cli/tools_config.py` for the pattern.
|
||||
|
||||
### DO NOT use `\033[K` (ANSI erase-to-EOL) in spinner/display code
|
||||
Leaks as literal `?[K` text under `prompt_toolkit`'s `patch_stdout`. Use space-padding: `f"\r{line}{' ' * pad}"`.
|
||||
|
||||
### `_last_resolved_tool_names` is a process-global in `model_tools.py`
|
||||
When subagents overwrite this global, `execute_code` calls after delegation may fail with missing tool imports. Known bug.
|
||||
|
||||
### Tests must not write to `~/.hermes/`
|
||||
The `_isolate_hermes_home` autouse fixture in `tests/conftest.py` redirects `HERMES_HOME` to a temp dir. Never hardcode `~/.hermes/` paths in tests.
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
python batch_runner.py \
|
||||
--dataset_file=prompts.jsonl \
|
||||
--batch_size=20 \
|
||||
--num_workers=4 \
|
||||
--run_name=my_run
|
||||
source .venv/bin/activate
|
||||
python -m pytest tests/ -q # Full suite (~2500 tests, ~2 min)
|
||||
python -m pytest tests/test_model_tools.py -q # Toolset resolution
|
||||
python -m pytest tests/test_cli_init.py -q # CLI config loading
|
||||
python -m pytest tests/gateway/ -q # Gateway tests
|
||||
python -m pytest tests/tools/ -q # Tool-level tests
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Skills System
|
||||
|
||||
Skills are on-demand knowledge documents the agent can load. Compatible with the [agentskills.io](https://agentskills.io/specification) open standard.
|
||||
|
||||
```
|
||||
skills/
|
||||
├── mlops/ # Category folder
|
||||
│ ├── axolotl/ # Skill folder
|
||||
│ │ ├── SKILL.md # Main instructions (required)
|
||||
│ │ ├── references/ # Additional docs, API specs
|
||||
│ │ ├── templates/ # Output formats, configs
|
||||
│ │ └── assets/ # Supplementary files (agentskills.io)
|
||||
│ └── vllm/
|
||||
│ └── SKILL.md
|
||||
├── .hub/ # Skills Hub state (gitignored)
|
||||
│ ├── lock.json # Installed skill provenance
|
||||
│ ├── quarantine/ # Pending security review
|
||||
│ ├── audit.log # Security scan history
|
||||
│ ├── taps.json # Custom source repos
|
||||
│ └── index-cache/ # Cached remote indexes
|
||||
```
|
||||
|
||||
**Progressive disclosure** (token-efficient):
|
||||
1. `skills_categories()` - List category names (~50 tokens)
|
||||
2. `skills_list(category)` - Name + description per skill (~3k tokens)
|
||||
3. `skill_view(name)` - Full content + tags + linked files
|
||||
|
||||
SKILL.md files use YAML frontmatter (agentskills.io format):
|
||||
```yaml
|
||||
---
|
||||
name: skill-name
|
||||
description: Brief description for listing
|
||||
version: 1.0.0
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [tag1, tag2]
|
||||
related_skills: [other-skill]
|
||||
---
|
||||
# Skill Content...
|
||||
```
|
||||
|
||||
**Skills Hub** — user-driven skill search/install from online registries and official optional skills. Sources: official optional skills (shipped with repo, labeled "official"), GitHub (openai/skills, anthropics/skills, custom taps), ClawHub, Claude marketplace, LobeHub. Not exposed as an agent tool — the model cannot search for or install skills. Users manage skills via `hermes skills browse/search/install` CLI commands or the `/skills` slash command in chat.
|
||||
|
||||
Key files:
|
||||
- `tools/skills_tool.py` — Agent-facing skill list/view (progressive disclosure)
|
||||
- `tools/skills_guard.py` — Security scanner (regex + LLM audit, trust-aware install policy)
|
||||
- `tools/skills_hub.py` — Source adapters (OptionalSkillSource, GitHub, ClawHub, Claude marketplace, LobeHub), lock file, auth
|
||||
- `hermes_cli/skills_hub.py` — CLI subcommands + `/skills` slash command handler
|
||||
|
||||
---
|
||||
|
||||
## Testing Changes
|
||||
|
||||
After making changes:
|
||||
|
||||
1. Run `hermes doctor` to check setup
|
||||
2. Run `hermes config check` to verify config
|
||||
3. Test with `hermes chat -q "test message"`
|
||||
4. For new config options, test fresh install: `rm -rf ~/.hermes && hermes setup`
|
||||
Always run the full suite before pushing changes.
|
||||
|
||||
@@ -118,7 +118,7 @@ hermes-agent/
|
||||
├── cli.py # HermesCLI class — interactive TUI, prompt_toolkit integration
|
||||
├── model_tools.py # Tool orchestration (thin layer over tools/registry.py)
|
||||
├── toolsets.py # Tool groupings and presets (hermes-cli, hermes-telegram, etc.)
|
||||
├── hermes_state.py # SQLite session database with FTS5 full-text search
|
||||
├── hermes_state.py # SQLite session database with FTS5 full-text search, session titles
|
||||
├── batch_runner.py # Parallel batch processing for trajectory generation
|
||||
│
|
||||
├── agent/ # Agent internals (extracted modules)
|
||||
@@ -218,7 +218,7 @@ User message → AIAgent._run_agent_loop()
|
||||
|
||||
- **Self-registering tools**: Each tool file calls `registry.register()` at import time. `model_tools.py` triggers discovery by importing all tool modules.
|
||||
- **Toolset grouping**: Tools are grouped into toolsets (`web`, `terminal`, `file`, `browser`, etc.) that can be enabled/disabled per platform.
|
||||
- **Session persistence**: All conversations are stored in SQLite (`hermes_state.py`) with full-text search. JSON logs go to `~/.hermes/sessions/`.
|
||||
- **Session persistence**: All conversations are stored in SQLite (`hermes_state.py`) with full-text search and unique session titles. JSON logs go to `~/.hermes/sessions/`.
|
||||
- **Ephemeral injection**: System prompts and prefill messages are injected at API call time, never persisted to the database or logs.
|
||||
- **Provider abstraction**: The agent works with any OpenAI-compatible API. Provider resolution happens at init time (Nous Portal OAuth, OpenRouter API key, or custom endpoint).
|
||||
- **Provider routing**: When using OpenRouter, `provider_routing` in config.yaml controls provider selection (sort by throughput/latency/price, allow/ignore specific providers, data retention policies). These are injected as `extra_body.provider` in API requests.
|
||||
@@ -325,6 +325,9 @@ description: Brief description (shown in skill search results)
|
||||
version: 1.0.0
|
||||
author: Your Name
|
||||
license: MIT
|
||||
platforms: [macos, linux] # Optional — restrict to specific OS platforms
|
||||
# Valid: macos, linux, windows
|
||||
# Omit to load on all platforms (default)
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Category, Subcategory, Keywords]
|
||||
@@ -351,6 +354,18 @@ Known failure modes and how to handle them.
|
||||
How the agent confirms it worked.
|
||||
```
|
||||
|
||||
### Platform-specific skills
|
||||
|
||||
Skills can declare which OS platforms they support via the `platforms` frontmatter field. Skills with this field are automatically hidden from the system prompt, `skills_list()`, and slash commands on incompatible platforms.
|
||||
|
||||
```yaml
|
||||
platforms: [macos] # macOS only (e.g., iMessage, Apple Reminders)
|
||||
platforms: [macos, linux] # macOS and Linux
|
||||
platforms: [windows] # Windows only
|
||||
```
|
||||
|
||||
If the field is omitted or empty, the skill loads on all platforms (backward compatible). See `skills/apple/` for examples of macOS-only skills.
|
||||
|
||||
### Skill guidelines
|
||||
|
||||
- **No external dependencies unless absolutely necessary.** Prefer stdlib Python, curl, and existing Hermes tools (`web_extract`, `terminal`, `read_file`).
|
||||
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 Nous Research
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -13,11 +13,11 @@
|
||||
|
||||
**The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM.
|
||||
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
|
||||
|
||||
<table>
|
||||
<tr><td><b>A real terminal interface</b></td><td>Full TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.</td></tr>
|
||||
<tr><td><b>Lives where you do</b></td><td>Telegram, Discord, Slack, WhatsApp, and CLI — all from a single gateway process. Voice memo transcription, cross-platform conversation continuity.</td></tr>
|
||||
<tr><td><b>Lives where you do</b></td><td>Telegram, Discord, Slack, WhatsApp, Signal, and CLI — all from a single gateway process. Voice memo transcription, cross-platform conversation continuity.</td></tr>
|
||||
<tr><td><b>A closed learning loop</b></td><td>Agent-curated memory with periodic nudges. Autonomous skill creation after complex tasks. Skills self-improve during use. FTS5 session search with LLM summarization for cross-session recall. <a href="https://github.com/plastic-labs/honcho">Honcho</a> dialectic user modeling. Compatible with the <a href="https://agentskills.io">agentskills.io</a> open standard.</td></tr>
|
||||
<tr><td><b>Scheduled automations</b></td><td>Built-in cron scheduler with delivery to any platform. Daily reports, nightly backups, weekly audits — all in natural language, running unattended.</td></tr>
|
||||
<tr><td><b>Delegates and parallelizes</b></td><td>Spawn isolated subagents for parallel workstreams. Write Python scripts that call tools via RPC, collapsing multi-step pipelines into zero-context-cost turns.</td></tr>
|
||||
@@ -71,7 +71,7 @@ All documentation lives at **[hermes-agent.nousresearch.com/docs](https://hermes
|
||||
| [Quickstart](https://hermes-agent.nousresearch.com/docs/getting-started/quickstart) | Install → setup → first conversation in 2 minutes |
|
||||
| [CLI Usage](https://hermes-agent.nousresearch.com/docs/user-guide/cli) | Commands, keybindings, personalities, sessions |
|
||||
| [Configuration](https://hermes-agent.nousresearch.com/docs/user-guide/configuration) | Config file, providers, models, all options |
|
||||
| [Messaging Gateway](https://hermes-agent.nousresearch.com/docs/user-guide/messaging) | Telegram, Discord, Slack, WhatsApp, Home Assistant |
|
||||
| [Messaging Gateway](https://hermes-agent.nousresearch.com/docs/user-guide/messaging) | Telegram, Discord, Slack, WhatsApp, Signal, Home Assistant |
|
||||
| [Security](https://hermes-agent.nousresearch.com/docs/user-guide/security) | Command approval, DM pairing, container isolation |
|
||||
| [Tools & Toolsets](https://hermes-agent.nousresearch.com/docs/user-guide/features/tools) | 40+ tools, toolset system, terminal backends |
|
||||
| [Skills System](https://hermes-agent.nousresearch.com/docs/user-guide/features/skills) | Procedural memory, Skills Hub, creating skills |
|
||||
|
||||
129
TODO.md
129
TODO.md
@@ -1,129 +0,0 @@
|
||||
# Hermes Agent - Future Improvements
|
||||
|
||||
---
|
||||
|
||||
|
||||
|
||||
## 3. Local Browser Control via CDP 🌐
|
||||
|
||||
**Status:** Not started (currently Browserbase cloud only)
|
||||
**Priority:** Medium
|
||||
|
||||
Support local Chrome/Chromium via Chrome DevTools Protocol alongside existing Browserbase cloud backend.
|
||||
|
||||
**What other agents do:**
|
||||
- **OpenClaw**: Full CDP-based Chrome control with snapshots, actions, uploads, profiles, file chooser, PDF save, console messages, tab management. Uses local Chrome for persistent login sessions.
|
||||
- **Cline**: Headless browser with Computer Use (click, type, scroll, screenshot, console logs)
|
||||
|
||||
**Our approach:**
|
||||
- Add a `local` backend option to `browser_tool.py` using Playwright or raw CDP
|
||||
- Config toggle: `browser.backend: local | browserbase | auto`
|
||||
- `auto` mode: try local first, fall back to Browserbase
|
||||
- Local advantages: free, persistent login sessions, no API key needed
|
||||
- Local disadvantages: no CAPTCHA solving, no stealth mode, requires Chrome installed
|
||||
- Reuse the same 10-tool interface -- just swap the backend
|
||||
- Later: Chrome profile management for persistent sessions across restarts
|
||||
|
||||
---
|
||||
|
||||
## 4. Signal Integration 📡
|
||||
|
||||
**Status:** Not started
|
||||
**Priority:** Low
|
||||
|
||||
New platform adapter using signal-cli daemon (JSON-RPC HTTP + SSE). Requires Java runtime and phone number registration.
|
||||
|
||||
**Reference:** OpenClaw has Signal support via signal-cli.
|
||||
|
||||
---
|
||||
|
||||
## 5. Plugin/Extension System 🔌
|
||||
|
||||
**Status:** Partially implemented (event hooks exist in `gateway/hooks.py`)
|
||||
**Priority:** Medium
|
||||
|
||||
Full Python plugin interface that goes beyond the current hook system.
|
||||
|
||||
**What other agents do:**
|
||||
- **OpenClaw**: Plugin SDK with tool-send capabilities, lifecycle phase hooks (before-agent-start, after-tool-call, model-override), plugin registry with install/uninstall.
|
||||
- **Pi**: Extensions are TypeScript modules that can register tools, commands, keyboard shortcuts, custom UI widgets, overlays, status lines, dialogs, compaction hooks, raw terminal input listeners. Extremely comprehensive.
|
||||
- **OpenCode**: MCP client support (stdio, SSE, StreamableHTTP), OAuth auth for MCP servers. Also has Copilot/Codex plugins.
|
||||
- **Codex**: Full MCP integration with skill dependencies.
|
||||
- **Cline**: MCP integration + lifecycle hooks with cancellation support.
|
||||
|
||||
**Our approach (phased):**
|
||||
|
||||
### Phase 1: Enhanced hooks
|
||||
- Expand the existing `gateway/hooks.py` to support more events: `before-tool-call`, `after-tool-call`, `before-response`, `context-compress`, `session-end`
|
||||
- Allow hooks to modify tool results (e.g., filter sensitive output)
|
||||
|
||||
### Phase 2: Plugin interface
|
||||
- `~/.hermes/plugins/<name>/plugin.yaml` + `handler.py`
|
||||
- Plugins can: register new tools, add CLI commands, subscribe to events, inject system prompt sections
|
||||
- `hermes plugin list|install|uninstall|create` CLI commands
|
||||
- Plugin discovery and validation on startup
|
||||
|
||||
### Phase 3: MCP support (industry standard) ✅ DONE
|
||||
- ✅ MCP client that connects to external MCP servers (stdio + HTTP/StreamableHTTP)
|
||||
- ✅ Config: `mcp_servers` in config.yaml with connection details
|
||||
- ✅ Each MCP server's tools auto-registered as a dynamic toolset
|
||||
- Future: Resources, Prompts, Progress notifications, `hermes mcp` CLI command
|
||||
|
||||
---
|
||||
|
||||
## 6. MCP (Model Context Protocol) Support 🔗 ✅ DONE
|
||||
|
||||
**Status:** Implemented (PR #301)
|
||||
**Priority:** Complete
|
||||
|
||||
Native MCP client support with stdio and HTTP/StreamableHTTP transports, auto-discovery, reconnection with exponential backoff, env var filtering, and credential stripping. See `docs/mcp.md` for full documentation.
|
||||
|
||||
**Still TODO:**
|
||||
- `hermes mcp` CLI subcommand (list/test/status)
|
||||
- `hermes tools` UI integration for MCP toolsets
|
||||
- MCP Resources and Prompts support
|
||||
- OAuth authentication for remote servers
|
||||
- Progress notifications for long-running tools
|
||||
|
||||
---
|
||||
|
||||
## 8. Filesystem Checkpointing / Rollback 🔄
|
||||
|
||||
**Status:** Not started
|
||||
**Priority:** Low-Medium
|
||||
|
||||
Automatic filesystem snapshots after each agent loop iteration so the user can roll back destructive changes to their project.
|
||||
|
||||
**What other agents do:**
|
||||
- **Cline**: Workspace checkpoints at each step with Compare/Restore UI
|
||||
- **OpenCode**: Git-backed workspace snapshots per step, with weekly gc
|
||||
- **Codex**: Sandboxed execution with commit-per-step, rollback on failure
|
||||
|
||||
**Our approach:**
|
||||
- After each tool call (or batch of tool calls in a single turn) that modifies files, create a lightweight checkpoint of the affected files
|
||||
- Git-based when the project is a repo: auto-commit to a detached/temporary branch (`hermes/checkpoints/<session>`) after each agent turn, squash or discard on session end
|
||||
- Non-git fallback: tar snapshots of changed files in `~/.hermes/checkpoints/<session_id>/`
|
||||
- `hermes rollback` CLI command to restore to a previous checkpoint
|
||||
- Agent-accessible via a `checkpoint` tool: `list` (show available restore points), `restore` (roll back to a named point), `diff` (show what changed since a checkpoint)
|
||||
- Configurable: off by default (opt-in via `config.yaml`), since auto-committing can be surprising
|
||||
- Cleanup: checkpoints expire after session ends (or configurable retention period)
|
||||
- Integration with the terminal backend: works with local, SSH, and Docker backends (snapshots happen on the execution host)
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priority Order
|
||||
|
||||
### Tier 1: Next Up
|
||||
|
||||
1. ~~MCP Support -- #6~~ ✅ Done (PR #301)
|
||||
|
||||
### Tier 2: Quality of Life
|
||||
|
||||
3. Local Browser Control via CDP -- #3
|
||||
4. Plugin/Extension System -- #5
|
||||
|
||||
### Tier 3: Nice to Have
|
||||
|
||||
5. Session Branching / Checkpoints -- #7
|
||||
6. Filesystem Checkpointing / Rollback -- #8
|
||||
7. Signal Integration -- #4
|
||||
@@ -4,18 +4,29 @@ Provides a single resolution chain so every consumer (context compression,
|
||||
session search, web extraction, vision analysis, browser vision) picks up
|
||||
the best available backend without duplicating fallback logic.
|
||||
|
||||
Resolution order for text tasks:
|
||||
Resolution order for text tasks (auto mode):
|
||||
1. OpenRouter (OPENROUTER_API_KEY)
|
||||
2. Nous Portal (~/.hermes/auth.json active provider)
|
||||
3. Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY)
|
||||
4. Codex OAuth (Responses API via chatgpt.com with gpt-5.3-codex,
|
||||
wrapped to look like a chat.completions client)
|
||||
5. None
|
||||
5. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, MiniMax-CN)
|
||||
— checked via PROVIDER_REGISTRY entries with auth_type='api_key'
|
||||
6. None
|
||||
|
||||
Resolution order for vision/multimodal tasks:
|
||||
Resolution order for vision/multimodal tasks (auto mode):
|
||||
1. OpenRouter
|
||||
2. Nous Portal
|
||||
3. None (custom endpoints can't substitute for Gemini multimodal)
|
||||
3. None (steps 3-5 are skipped — they may not support multimodal)
|
||||
|
||||
Per-task provider overrides (e.g. AUXILIARY_VISION_PROVIDER,
|
||||
CONTEXT_COMPRESSION_PROVIDER) can force a specific provider for each task:
|
||||
"openrouter", "nous", "codex", or "main" (= steps 3-5).
|
||||
Default "auto" follows the chains above.
|
||||
|
||||
Per-task model overrides (e.g. AUXILIARY_VISION_MODEL,
|
||||
AUXILIARY_WEB_EXTRACT_MODEL) let callers use a different model slug
|
||||
than the provider's default.
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -31,6 +42,14 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"zai": "glm-4.5-flash",
|
||||
"kimi-coding": "kimi-k2-turbo-preview",
|
||||
"minimax": "MiniMax-M2.5-highspeed",
|
||||
"minimax-cn": "MiniMax-M2.5-highspeed",
|
||||
}
|
||||
|
||||
# OpenRouter app attribution headers
|
||||
_OR_HEADERS = {
|
||||
"HTTP-Referer": "https://github.com/NousResearch/hermes-agent",
|
||||
@@ -63,6 +82,55 @@ _CODEX_AUX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
# read response.choices[0].message.content. This adapter translates those
|
||||
# calls to the Codex Responses API so callers don't need any changes.
|
||||
|
||||
|
||||
def _convert_content_for_responses(content: Any) -> Any:
|
||||
"""Convert chat.completions content to Responses API format.
|
||||
|
||||
chat.completions uses:
|
||||
{"type": "text", "text": "..."}
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
|
||||
|
||||
Responses API uses:
|
||||
{"type": "input_text", "text": "..."}
|
||||
{"type": "input_image", "image_url": "data:image/png;base64,..."}
|
||||
|
||||
If content is a plain string, it's returned as-is (the Responses API
|
||||
accepts strings directly for text-only messages).
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if not isinstance(content, list):
|
||||
return str(content) if content else ""
|
||||
|
||||
converted: List[Dict[str, Any]] = []
|
||||
for part in content:
|
||||
if not isinstance(part, dict):
|
||||
continue
|
||||
ptype = part.get("type", "")
|
||||
if ptype == "text":
|
||||
converted.append({"type": "input_text", "text": part.get("text", "")})
|
||||
elif ptype == "image_url":
|
||||
# chat.completions nests the URL: {"image_url": {"url": "..."}}
|
||||
image_data = part.get("image_url", {})
|
||||
url = image_data.get("url", "") if isinstance(image_data, dict) else str(image_data)
|
||||
entry: Dict[str, Any] = {"type": "input_image", "image_url": url}
|
||||
# Preserve detail if specified
|
||||
detail = image_data.get("detail") if isinstance(image_data, dict) else None
|
||||
if detail:
|
||||
entry["detail"] = detail
|
||||
converted.append(entry)
|
||||
elif ptype in ("input_text", "input_image"):
|
||||
# Already in Responses format — pass through
|
||||
converted.append(part)
|
||||
else:
|
||||
# Unknown content type — try to preserve as text
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
converted.append({"type": "input_text", "text": text})
|
||||
|
||||
return converted or ""
|
||||
|
||||
|
||||
class _CodexCompletionsAdapter:
|
||||
"""Drop-in shim that accepts chat.completions.create() kwargs and
|
||||
routes them through the Codex Responses streaming API."""
|
||||
@@ -76,30 +144,31 @@ class _CodexCompletionsAdapter:
|
||||
model = kwargs.get("model", self._model)
|
||||
temperature = kwargs.get("temperature")
|
||||
|
||||
# Separate system/instructions from conversation messages
|
||||
# Separate system/instructions from conversation messages.
|
||||
# Convert chat.completions multimodal content blocks to Responses
|
||||
# API format (input_text / input_image instead of text / image_url).
|
||||
instructions = "You are a helpful assistant."
|
||||
input_msgs: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content") or ""
|
||||
if role == "system":
|
||||
instructions = content
|
||||
instructions = content if isinstance(content, str) else str(content)
|
||||
else:
|
||||
input_msgs.append({"role": role, "content": content})
|
||||
input_msgs.append({
|
||||
"role": role,
|
||||
"content": _convert_content_for_responses(content),
|
||||
})
|
||||
|
||||
resp_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"instructions": instructions,
|
||||
"input": input_msgs or [{"role": "user", "content": ""}],
|
||||
"stream": True,
|
||||
"store": False,
|
||||
}
|
||||
|
||||
max_tokens = kwargs.get("max_output_tokens") or kwargs.get("max_completion_tokens") or kwargs.get("max_tokens")
|
||||
if max_tokens is not None:
|
||||
resp_kwargs["max_output_tokens"] = int(max_tokens)
|
||||
if temperature is not None:
|
||||
resp_kwargs["temperature"] = temperature
|
||||
# Note: the Codex endpoint (chatgpt.com/backend-api/codex) does NOT
|
||||
# support max_output_tokens or temperature — omit to avoid 400 errors.
|
||||
|
||||
# Tools support for flush_memories and similar callers
|
||||
tools = kwargs.get("tools")
|
||||
@@ -282,53 +351,173 @@ def _read_codex_access_token() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────────────
|
||||
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Try each API-key provider in PROVIDER_REGISTRY order.
|
||||
|
||||
def get_text_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, model_slug) for text-only auxiliary tasks.
|
||||
|
||||
Falls through OpenRouter -> Nous Portal -> custom endpoint -> Codex OAuth -> (None, None).
|
||||
Returns (client, model) for the first provider whose env var is set,
|
||||
or (None, None) if none are configured.
|
||||
"""
|
||||
# 1. OpenRouter
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if or_key:
|
||||
logger.debug("Auxiliary text client: OpenRouter")
|
||||
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL,
|
||||
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
except ImportError:
|
||||
logger.debug("Could not import PROVIDER_REGISTRY for API-key fallback")
|
||||
return None, None
|
||||
|
||||
# 2. Nous Portal
|
||||
nous = _read_nous_auth()
|
||||
if nous:
|
||||
global auxiliary_is_nous
|
||||
auxiliary_is_nous = True
|
||||
logger.debug("Auxiliary text client: Nous Portal")
|
||||
return (
|
||||
OpenAI(api_key=_nous_api_key(nous), base_url=_nous_base_url()),
|
||||
_NOUS_MODEL,
|
||||
)
|
||||
for provider_id, pconfig in PROVIDER_REGISTRY.items():
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
# Check if any of the provider's env vars are set
|
||||
api_key = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
break
|
||||
if not api_key:
|
||||
continue
|
||||
# Resolve base URL (with optional env-var override)
|
||||
# Kimi Code keys (sk-kimi-) need api.kimi.com/coding/v1
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
if env_url:
|
||||
base_url = env_url.rstrip("/")
|
||||
elif provider_id == "kimi-coding" and api_key.startswith("sk-kimi-"):
|
||||
base_url = "https://api.kimi.com/coding/v1"
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
|
||||
|
||||
# 3. Custom endpoint (both base URL and key must be set)
|
||||
custom_base = os.getenv("OPENAI_BASE_URL")
|
||||
custom_key = os.getenv("OPENAI_API_KEY")
|
||||
if custom_base and custom_key:
|
||||
model = os.getenv("OPENAI_MODEL") or os.getenv("LLM_MODEL") or "gpt-4o-mini"
|
||||
logger.debug("Auxiliary text client: custom endpoint (%s)", model)
|
||||
return OpenAI(api_key=custom_key, base_url=custom_base), model
|
||||
|
||||
# 4. Codex OAuth -- uses the Responses API (only endpoint the token
|
||||
# can access), wrapped to look like a chat.completions client.
|
||||
codex_token = _read_codex_access_token()
|
||||
if codex_token:
|
||||
logger.debug("Auxiliary text client: Codex OAuth (%s via Responses API)", _CODEX_AUX_MODEL)
|
||||
real_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
|
||||
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
|
||||
|
||||
# 5. Nothing available
|
||||
logger.debug("Auxiliary text client: none available")
|
||||
return None, None
|
||||
|
||||
|
||||
def get_async_text_auxiliary_client():
|
||||
# ── Provider resolution helpers ─────────────────────────────────────────────
|
||||
|
||||
def _get_auxiliary_provider(task: str = "") -> str:
|
||||
"""Read the provider override for a specific auxiliary task.
|
||||
|
||||
Checks AUXILIARY_{TASK}_PROVIDER first (e.g. AUXILIARY_VISION_PROVIDER),
|
||||
then CONTEXT_{TASK}_PROVIDER (for the compression section's summary_provider),
|
||||
then falls back to "auto". Returns one of: "auto", "openrouter", "nous", "main".
|
||||
"""
|
||||
if task:
|
||||
for prefix in ("AUXILIARY_", "CONTEXT_"):
|
||||
val = os.getenv(f"{prefix}{task.upper()}_PROVIDER", "").strip().lower()
|
||||
if val and val != "auto":
|
||||
return val
|
||||
return "auto"
|
||||
|
||||
|
||||
def _try_openrouter() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not or_key:
|
||||
return None, None
|
||||
logger.debug("Auxiliary client: OpenRouter")
|
||||
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL,
|
||||
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
|
||||
|
||||
|
||||
def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
nous = _read_nous_auth()
|
||||
if not nous:
|
||||
return None, None
|
||||
global auxiliary_is_nous
|
||||
auxiliary_is_nous = True
|
||||
logger.debug("Auxiliary client: Nous Portal")
|
||||
return (
|
||||
OpenAI(api_key=_nous_api_key(nous), base_url=_nous_base_url()),
|
||||
_NOUS_MODEL,
|
||||
)
|
||||
|
||||
|
||||
def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
custom_base = os.getenv("OPENAI_BASE_URL")
|
||||
custom_key = os.getenv("OPENAI_API_KEY")
|
||||
if not custom_base or not custom_key:
|
||||
return None, None
|
||||
model = os.getenv("OPENAI_MODEL") or os.getenv("LLM_MODEL") or "gpt-4o-mini"
|
||||
logger.debug("Auxiliary client: custom endpoint (%s)", model)
|
||||
return OpenAI(api_key=custom_key, base_url=custom_base), model
|
||||
|
||||
|
||||
def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
|
||||
codex_token = _read_codex_access_token()
|
||||
if not codex_token:
|
||||
return None, None
|
||||
logger.debug("Auxiliary client: Codex OAuth (%s via Responses API)", _CODEX_AUX_MODEL)
|
||||
real_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
|
||||
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
|
||||
if forced == "openrouter":
|
||||
client, model = _try_openrouter()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=openrouter but OPENROUTER_API_KEY not set")
|
||||
return client, model
|
||||
|
||||
if forced == "nous":
|
||||
client, model = _try_nous()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=nous but Nous Portal not configured (run: hermes login)")
|
||||
return client, model
|
||||
|
||||
if forced == "codex":
|
||||
client, model = _try_codex()
|
||||
if client is None:
|
||||
logger.warning("auxiliary.provider=codex but no Codex OAuth token found (run: hermes model)")
|
||||
return client, model
|
||||
|
||||
if forced == "main":
|
||||
# "main" = skip OpenRouter/Nous, use the main chat model's credentials.
|
||||
for try_fn in (_try_custom_endpoint, _try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.warning("auxiliary.provider=main but no main endpoint credentials found")
|
||||
return None, None
|
||||
|
||||
# Unknown provider name — fall through to auto
|
||||
logger.warning("Unknown auxiliary.provider=%r, falling back to auto", forced)
|
||||
return None, None
|
||||
|
||||
|
||||
def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None."""
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint,
|
||||
_try_codex, _resolve_api_key_provider):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.debug("Auxiliary client: none available")
|
||||
return None, None
|
||||
|
||||
|
||||
# ── Public API ──────────────────────────────────────────────────────────────
|
||||
|
||||
def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, default_model_slug) for text-only auxiliary tasks.
|
||||
|
||||
Args:
|
||||
task: Optional task name ("compression", "web_extract") to check
|
||||
for a task-specific provider override.
|
||||
|
||||
Callers may override the returned model with a per-task env var
|
||||
(e.g. CONTEXT_COMPRESSION_MODEL, AUXILIARY_WEB_EXTRACT_MODEL).
|
||||
"""
|
||||
forced = _get_auxiliary_provider(task)
|
||||
if forced != "auto":
|
||||
return _resolve_forced_provider(forced)
|
||||
return _resolve_auto()
|
||||
|
||||
|
||||
def get_async_text_auxiliary_client(task: str = ""):
|
||||
"""Return (async_client, model_slug) for async consumers.
|
||||
|
||||
For standard providers returns (AsyncOpenAI, model). For Codex returns
|
||||
@@ -337,7 +526,7 @@ def get_async_text_auxiliary_client():
|
||||
"""
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
sync_client, model = get_text_auxiliary_client()
|
||||
sync_client, model = get_text_auxiliary_client(task)
|
||||
if sync_client is None:
|
||||
return None, None
|
||||
|
||||
@@ -350,33 +539,33 @@ def get_async_text_auxiliary_client():
|
||||
}
|
||||
if "openrouter" in str(sync_client.base_url).lower():
|
||||
async_kwargs["default_headers"] = dict(_OR_HEADERS)
|
||||
elif "api.kimi.com" in str(sync_client.base_url).lower():
|
||||
async_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
return AsyncOpenAI(**async_kwargs), model
|
||||
|
||||
|
||||
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Return (client, model_slug) for vision/multimodal auxiliary tasks.
|
||||
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks.
|
||||
|
||||
Only OpenRouter and Nous Portal qualify — custom endpoints cannot
|
||||
substitute for Gemini multimodal.
|
||||
Checks AUXILIARY_VISION_PROVIDER for a forced provider, otherwise
|
||||
auto-detects. Callers may override the returned model with
|
||||
AUXILIARY_VISION_MODEL.
|
||||
|
||||
In auto mode, only providers known to support multimodal are tried:
|
||||
OpenRouter, Nous Portal, and Codex OAuth (gpt-5.3-codex supports
|
||||
vision via the Responses API). Custom endpoints and API-key
|
||||
providers are skipped — they may not handle vision input. To use
|
||||
them, set AUXILIARY_VISION_PROVIDER explicitly.
|
||||
"""
|
||||
# 1. OpenRouter
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if or_key:
|
||||
logger.debug("Auxiliary vision client: OpenRouter")
|
||||
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL,
|
||||
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
|
||||
|
||||
# 2. Nous Portal
|
||||
nous = _read_nous_auth()
|
||||
if nous:
|
||||
logger.debug("Auxiliary vision client: Nous Portal")
|
||||
return (
|
||||
OpenAI(api_key=_nous_api_key(nous), base_url=_nous_base_url()),
|
||||
_NOUS_MODEL,
|
||||
)
|
||||
|
||||
# 3. Nothing suitable
|
||||
logger.debug("Auxiliary vision client: none available")
|
||||
forced = _get_auxiliary_provider("vision")
|
||||
if forced != "auto":
|
||||
return _resolve_forced_provider(forced)
|
||||
# Auto: only multimodal-capable providers
|
||||
for try_fn in (_try_openrouter, _try_nous, _try_codex):
|
||||
client, model = try_fn()
|
||||
if client is not None:
|
||||
return client, model
|
||||
logger.debug("Auxiliary vision client: none available (auto only tries OpenRouter/Nous/Codex)")
|
||||
return None, None
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ protecting head and tail context.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.auxiliary_client import get_text_auxiliary_client
|
||||
from agent.model_metadata import (
|
||||
@@ -53,7 +53,7 @@ class ContextCompressor:
|
||||
self.last_completion_tokens = 0
|
||||
self.last_total_tokens = 0
|
||||
|
||||
self.client, default_model = get_text_auxiliary_client()
|
||||
self.client, default_model = get_text_auxiliary_client("compression")
|
||||
self.summary_model = summary_model_override or default_model
|
||||
|
||||
def update_from_response(self, usage: Dict[str, Any]):
|
||||
@@ -82,11 +82,14 @@ class ContextCompressor:
|
||||
"compression_count": self.compression_count,
|
||||
}
|
||||
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> str:
|
||||
"""Generate a concise summary of conversation turns using a fast model."""
|
||||
if not self.client:
|
||||
return "[CONTEXT SUMMARY]: Previous conversation turns have been compressed to save space. The assistant performed various actions and received responses."
|
||||
def _generate_summary(self, turns_to_summarize: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Generate a concise summary of conversation turns.
|
||||
|
||||
Tries the auxiliary model first, then falls back to the user's main
|
||||
model. Returns None if all attempts fail — the caller should drop
|
||||
the middle turns without a summary rather than inject a useless
|
||||
placeholder.
|
||||
"""
|
||||
parts = []
|
||||
for msg in turns_to_summarize:
|
||||
role = msg.get("role", "unknown")
|
||||
@@ -117,28 +120,28 @@ TURNS TO SUMMARIZE:
|
||||
|
||||
Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
|
||||
try:
|
||||
return self._call_summary_model(self.client, self.summary_model, prompt)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to generate context summary with auxiliary model: {e}")
|
||||
# 1. Try the auxiliary model (cheap/fast)
|
||||
if self.client:
|
||||
try:
|
||||
return self._call_summary_model(self.client, self.summary_model, prompt)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to generate context summary with auxiliary model: {e}")
|
||||
|
||||
# Fallback: try the main model's endpoint. This handles the common
|
||||
# case where the user switched providers (e.g. OpenRouter → local LLM)
|
||||
# but a stale API key causes the auxiliary client to pick the old
|
||||
# provider which then fails (402, auth error, etc.).
|
||||
fallback_client, fallback_model = self._get_fallback_client()
|
||||
if fallback_client is not None:
|
||||
try:
|
||||
logger.info("Retrying context summary with fallback client (%s)", fallback_model)
|
||||
summary = self._call_summary_model(fallback_client, fallback_model, prompt)
|
||||
# Success — swap in the working client for future compressions
|
||||
self.client = fallback_client
|
||||
self.summary_model = fallback_model
|
||||
return summary
|
||||
except Exception as fallback_err:
|
||||
logging.warning(f"Fallback summary model also failed: {fallback_err}")
|
||||
# 2. Fallback: try the user's main model endpoint
|
||||
fallback_client, fallback_model = self._get_fallback_client()
|
||||
if fallback_client is not None:
|
||||
try:
|
||||
logger.info("Retrying context summary with main model (%s)", fallback_model)
|
||||
summary = self._call_summary_model(fallback_client, fallback_model, prompt)
|
||||
self.client = fallback_client
|
||||
self.summary_model = fallback_model
|
||||
return summary
|
||||
except Exception as fallback_err:
|
||||
logging.warning(f"Main model summary also failed: {fallback_err}")
|
||||
|
||||
return "[CONTEXT SUMMARY]: Previous conversation turns have been compressed. The assistant performed tool calls and received responses."
|
||||
# 3. All models failed — return None so the caller drops turns without a summary
|
||||
logging.warning("Context compression: no model available for summary. Middle turns will be dropped without summary.")
|
||||
return None
|
||||
|
||||
def _call_summary_model(self, client, model: str, prompt: str) -> str:
|
||||
"""Make the actual LLM call to generate a summary. Raises on failure."""
|
||||
@@ -196,10 +199,111 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
logger.debug("Could not build fallback auxiliary client: %s", exc)
|
||||
return None, None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool-call / tool-result pair integrity helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_call_id(tc) -> str:
|
||||
"""Extract the call ID from a tool_call entry (dict or SimpleNamespace)."""
|
||||
if isinstance(tc, dict):
|
||||
return tc.get("id", "")
|
||||
return getattr(tc, "id", "") or ""
|
||||
|
||||
def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Fix orphaned tool_call / tool_result pairs after compression.
|
||||
|
||||
Two failure modes:
|
||||
1. A tool *result* references a call_id whose assistant tool_call was
|
||||
removed (summarized/truncated). The API rejects this with
|
||||
"No tool call found for function call output with call_id ...".
|
||||
2. An assistant message has tool_calls whose results were dropped.
|
||||
The API rejects this because every tool_call must be followed by
|
||||
a tool result with the matching call_id.
|
||||
|
||||
This method removes orphaned results and inserts stub results for
|
||||
orphaned calls so the message list is always well-formed.
|
||||
"""
|
||||
surviving_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = self._get_tool_call_id(tc)
|
||||
if cid:
|
||||
surviving_call_ids.add(cid)
|
||||
|
||||
result_call_ids: set = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "tool":
|
||||
cid = msg.get("tool_call_id")
|
||||
if cid:
|
||||
result_call_ids.add(cid)
|
||||
|
||||
# 1. Remove tool results whose call_id has no matching assistant tool_call
|
||||
orphaned_results = result_call_ids - surviving_call_ids
|
||||
if orphaned_results:
|
||||
messages = [
|
||||
m for m in messages
|
||||
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
|
||||
]
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results))
|
||||
|
||||
# 2. Add stub results for assistant tool_calls whose results were dropped
|
||||
missing_results = surviving_call_ids - result_call_ids
|
||||
if missing_results:
|
||||
patched: List[Dict[str, Any]] = []
|
||||
for msg in messages:
|
||||
patched.append(msg)
|
||||
if msg.get("role") == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
cid = self._get_tool_call_id(tc)
|
||||
if cid in missing_results:
|
||||
patched.append({
|
||||
"role": "tool",
|
||||
"content": "[Result from earlier conversation — see context summary above]",
|
||||
"tool_call_id": cid,
|
||||
})
|
||||
messages = patched
|
||||
if not self.quiet_mode:
|
||||
logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results))
|
||||
|
||||
return messages
|
||||
|
||||
def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
"""Push a compress-start boundary forward past any orphan tool results.
|
||||
|
||||
If ``messages[idx]`` is a tool result, slide forward until we hit a
|
||||
non-tool message so we don't start the summarised region mid-group.
|
||||
"""
|
||||
while idx < len(messages) and messages[idx].get("role") == "tool":
|
||||
idx += 1
|
||||
return idx
|
||||
|
||||
def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int:
|
||||
"""Pull a compress-end boundary backward to avoid splitting a
|
||||
tool_call / result group.
|
||||
|
||||
If the message just before ``idx`` is an assistant message with
|
||||
tool_calls, those tool results will start at ``idx`` and would be
|
||||
separated from their parent. Move backwards to include the whole
|
||||
group in the summarised region.
|
||||
"""
|
||||
if idx <= 0 or idx >= len(messages):
|
||||
return idx
|
||||
prev = messages[idx - 1]
|
||||
if prev.get("role") == "assistant" and prev.get("tool_calls"):
|
||||
# The results for this assistant turn sit at idx..idx+k.
|
||||
# Include the assistant message in the summarised region too.
|
||||
idx -= 1
|
||||
return idx
|
||||
|
||||
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
|
||||
"""Compress conversation messages by summarizing middle turns.
|
||||
|
||||
Keeps first N + last N turns, summarizes everything in between.
|
||||
After compression, orphaned tool_call / tool_result pairs are cleaned
|
||||
up so the API never receives mismatched IDs.
|
||||
"""
|
||||
n_messages = len(messages)
|
||||
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
|
||||
@@ -212,6 +316,12 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
# Adjust boundaries to avoid splitting tool_call/result groups.
|
||||
compress_start = self._align_boundary_forward(messages, compress_start)
|
||||
compress_end = self._align_boundary_backward(messages, compress_end)
|
||||
if compress_start >= compress_end:
|
||||
return messages
|
||||
|
||||
turns_to_summarize = messages[compress_start:compress_end]
|
||||
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
|
||||
|
||||
@@ -219,24 +329,6 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
print(f"\n📦 Context compression triggered ({display_tokens:,} tokens ≥ {self.threshold_tokens:,} threshold)")
|
||||
print(f" 📊 Model context limit: {self.context_length:,} tokens ({self.threshold_percent*100:.0f}% = {self.threshold_tokens:,})")
|
||||
|
||||
# Truncation fallback when no auxiliary model is available
|
||||
if self.client is None:
|
||||
print("⚠️ Context compression: no auxiliary model available. Falling back to message truncation.")
|
||||
# Keep system message(s) at the front and the protected tail;
|
||||
# simply drop the oldest non-system messages until under threshold.
|
||||
kept = []
|
||||
for msg in messages:
|
||||
if msg.get("role") == "system":
|
||||
kept.append(msg.copy())
|
||||
else:
|
||||
break
|
||||
tail = messages[-self.protect_last_n:]
|
||||
kept.extend(m.copy() for m in tail)
|
||||
self.compression_count += 1
|
||||
if not self.quiet_mode:
|
||||
print(f" ✂️ Truncated: {len(messages)} → {len(kept)} messages (dropped middle turns)")
|
||||
return kept
|
||||
|
||||
if not self.quiet_mode:
|
||||
print(f" 🗜️ Summarizing turns {compress_start+1}-{compress_end} ({len(turns_to_summarize)} turns)")
|
||||
|
||||
@@ -249,13 +341,21 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||
msg["content"] = (msg.get("content") or "") + "\n\n[Note: Some earlier conversation turns may be summarized to preserve context space.]"
|
||||
compressed.append(msg)
|
||||
|
||||
compressed.append({"role": "user", "content": summary})
|
||||
if summary:
|
||||
last_head_role = messages[compress_start - 1].get("role", "user") if compress_start > 0 else "user"
|
||||
summary_role = "user" if last_head_role in ("assistant", "tool") else "assistant"
|
||||
compressed.append({"role": summary_role, "content": summary})
|
||||
else:
|
||||
if not self.quiet_mode:
|
||||
print(" ⚠️ No summary model available — middle turns dropped without summary")
|
||||
|
||||
for i in range(compress_end, n_messages):
|
||||
compressed.append(messages[i].copy())
|
||||
|
||||
self.compression_count += 1
|
||||
|
||||
compressed = self._sanitize_tool_pairs(compressed)
|
||||
|
||||
if not self.quiet_mode:
|
||||
new_estimate = estimate_messages_tokens_rough(compressed)
|
||||
saved_estimate = display_tokens - new_estimate
|
||||
|
||||
@@ -55,6 +55,20 @@ MODEL_PRICING = {
|
||||
# Meta (via providers)
|
||||
"llama-4-maverick": {"input": 0.50, "output": 0.70},
|
||||
"llama-4-scout": {"input": 0.20, "output": 0.30},
|
||||
# Z.AI / GLM (direct provider — pricing not published externally, treat as local)
|
||||
"glm-5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.7": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
|
||||
# Kimi / Moonshot (direct provider — pricing not published externally, treat as local)
|
||||
"kimi-k2.5": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
|
||||
# MiniMax (direct provider — pricing not published externally, treat as local)
|
||||
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
|
||||
}
|
||||
|
||||
# Fallback: unknown/custom models get zero cost (we can't assume pricing
|
||||
|
||||
@@ -49,6 +49,17 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"meta-llama/llama-3.3-70b-instruct": 131072,
|
||||
"deepseek/deepseek-chat-v3": 65536,
|
||||
"qwen/qwen-2.5-72b-instruct": 32768,
|
||||
"glm-4.7": 202752,
|
||||
"glm-5": 202752,
|
||||
"glm-4.5": 131072,
|
||||
"glm-4.5-flash": 131072,
|
||||
"kimi-k2.5": 262144,
|
||||
"kimi-k2-thinking": 262144,
|
||||
"kimi-k2-turbo-preview": 262144,
|
||||
"kimi-k2-0905-preview": 131072,
|
||||
"MiniMax-M2.5": 204800,
|
||||
"MiniMax-M2.5-highspeed": 204800,
|
||||
"MiniMax-M2.1": 204800,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,8 @@ DEFAULT_AGENT_IDENTITY = (
|
||||
"range of tasks including answering questions, writing and editing code, "
|
||||
"analyzing information, creative work, and executing actions via your tools. "
|
||||
"You communicate clearly, admit uncertainty when appropriate, and prioritize "
|
||||
"being genuinely useful over being verbose unless otherwise directed below."
|
||||
"being genuinely useful over being verbose unless otherwise directed below. "
|
||||
"Be targeted and efficient in your exploration and investigations."
|
||||
)
|
||||
|
||||
MEMORY_GUIDANCE = (
|
||||
@@ -102,12 +103,33 @@ PLATFORM_HINTS = {
|
||||
"You are on a text messaging communication platform, Telegram. "
|
||||
"Please do not use markdown as it does not render. "
|
||||
"You can send media files natively: to deliver a file to the user, "
|
||||
"include MEDIA:/absolute/path/to/file in your response. Audio "
|
||||
"(.ogg) sends as voice bubbles. You can also include image URLs "
|
||||
"in markdown format  and they will be sent as native photos."
|
||||
"include MEDIA:/absolute/path/to/file in your response. Images "
|
||||
"(.png, .jpg, .webp) appear as photos, audio (.ogg) sends as voice "
|
||||
"bubbles, and videos (.mp4) play inline. You can also include image "
|
||||
"URLs in markdown format  and they will be sent as native photos."
|
||||
),
|
||||
"discord": (
|
||||
"You are in a Discord server or group chat communicating with your user."
|
||||
"You are in a Discord server or group chat communicating with your user. "
|
||||
"You can send media files natively: include MEDIA:/absolute/path/to/file "
|
||||
"in your response. Images (.png, .jpg, .webp) are sent as photo "
|
||||
"attachments, audio as file attachments. You can also include image URLs "
|
||||
"in markdown format  and they will be sent as attachments."
|
||||
),
|
||||
"slack": (
|
||||
"You are in a Slack workspace communicating with your user. "
|
||||
"You can send media files natively: include MEDIA:/absolute/path/to/file "
|
||||
"in your response. Images (.png, .jpg, .webp) are uploaded as photo "
|
||||
"attachments, audio as file attachments. You can also include image URLs "
|
||||
"in markdown format  and they will be uploaded as attachments."
|
||||
),
|
||||
"signal": (
|
||||
"You are on a text messaging communication platform, Signal. "
|
||||
"Please do not use markdown as it does not render. "
|
||||
"You can send media files natively: to deliver a file to the user, "
|
||||
"include MEDIA:/absolute/path/to/file in your response. Images "
|
||||
"(.png, .jpg, .webp) appear as photos, audio as attachments, and other "
|
||||
"files arrive as downloadable documents. You can also include image "
|
||||
"URLs in markdown format  and they will be sent as photos."
|
||||
),
|
||||
"cli": (
|
||||
"You are a CLI AI Agent. Try not to use markdown but simple text "
|
||||
@@ -142,12 +164,28 @@ def _read_skill_description(skill_file: Path, max_chars: int = 60) -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def _skill_is_platform_compatible(skill_file: Path) -> bool:
|
||||
"""Quick check if a SKILL.md is compatible with the current OS platform.
|
||||
|
||||
Reads just enough to parse the ``platforms`` frontmatter field.
|
||||
Skills without the field (the vast majority) are always compatible.
|
||||
"""
|
||||
try:
|
||||
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = _parse_frontmatter(raw)
|
||||
return skill_matches_platform(frontmatter)
|
||||
except Exception:
|
||||
return True # Err on the side of showing the skill
|
||||
|
||||
|
||||
def build_skills_system_prompt() -> str:
|
||||
"""Build a compact skill index for the system prompt.
|
||||
|
||||
Scans ~/.hermes/skills/ for SKILL.md files grouped by category.
|
||||
Includes per-skill descriptions from frontmatter so the model can
|
||||
match skills by meaning, not just name.
|
||||
Filters out skills incompatible with the current OS platform.
|
||||
"""
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
skills_dir = hermes_home / "skills"
|
||||
@@ -159,6 +197,9 @@ def build_skills_system_prompt() -> str:
|
||||
# Each entry: (skill_name, description)
|
||||
skills_by_category: dict[str, list[tuple[str, str]]] = {}
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not _skill_is_platform_compatible(skill_file):
|
||||
continue
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
|
||||
@@ -8,6 +8,7 @@ the first 6 and last 4 characters for debuggability.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
@@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Known API key prefixes -- match the prefix + contiguous token chars
|
||||
_PREFIX_PATTERNS = [
|
||||
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter
|
||||
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
|
||||
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
|
||||
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
|
||||
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
|
||||
@@ -25,6 +26,18 @@ _PREFIX_PATTERNS = [
|
||||
r"fc-[A-Za-z0-9]{10,}", # Firecrawl
|
||||
r"bb_live_[A-Za-z0-9_-]{10,}", # BrowserBase
|
||||
r"gAAAA[A-Za-z0-9_=-]{20,}", # Codex encrypted tokens
|
||||
r"AKIA[A-Z0-9]{16}", # AWS Access Key ID
|
||||
r"sk_live_[A-Za-z0-9]{10,}", # Stripe secret key (live)
|
||||
r"sk_test_[A-Za-z0-9]{10,}", # Stripe secret key (test)
|
||||
r"rk_live_[A-Za-z0-9]{10,}", # Stripe restricted key
|
||||
r"SG\.[A-Za-z0-9_-]{10,}", # SendGrid API key
|
||||
r"hf_[A-Za-z0-9]{10,}", # HuggingFace token
|
||||
r"r8_[A-Za-z0-9]{10,}", # Replicate API token
|
||||
r"npm_[A-Za-z0-9]{10,}", # npm access token
|
||||
r"pypi-[A-Za-z0-9_-]{10,}", # PyPI API token
|
||||
r"dop_v1_[A-Za-z0-9]{10,}", # DigitalOcean PAT
|
||||
r"doo_v1_[A-Za-z0-9]{10,}", # DigitalOcean OAuth
|
||||
r"am_[A-Za-z0-9_-]{10,}", # AgentMail API key
|
||||
]
|
||||
|
||||
# ENV assignment patterns: KEY=value where KEY contains a secret-like name
|
||||
@@ -52,6 +65,22 @@ _TELEGRAM_RE = re.compile(
|
||||
r"(bot)?(\d{8,}):([-A-Za-z0-9_]{30,})",
|
||||
)
|
||||
|
||||
# Private key blocks: -----BEGIN RSA PRIVATE KEY----- ... -----END RSA PRIVATE KEY-----
|
||||
_PRIVATE_KEY_RE = re.compile(
|
||||
r"-----BEGIN[A-Z ]*PRIVATE KEY-----[\s\S]*?-----END[A-Z ]*PRIVATE KEY-----"
|
||||
)
|
||||
|
||||
# Database connection strings: protocol://user:PASSWORD@host
|
||||
# Catches postgres, mysql, mongodb, redis, amqp URLs and redacts the password
|
||||
_DB_CONNSTR_RE = re.compile(
|
||||
r"((?:postgres(?:ql)?|mysql|mongodb(?:\+srv)?|redis|amqp)://[^:]+:)([^@]+)(@)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# E.164 phone numbers: +<country><number>, 7-15 digits
|
||||
# Negative lookahead prevents matching hex strings or identifiers
|
||||
_SIGNAL_PHONE_RE = re.compile(r"(\+[1-9]\d{6,14})(?![A-Za-z0-9])")
|
||||
|
||||
# Compile known prefix patterns into one alternation
|
||||
_PREFIX_RE = re.compile(
|
||||
r"(?<![A-Za-z0-9_-])(" + "|".join(_PREFIX_PATTERNS) + r")(?![A-Za-z0-9_-])"
|
||||
@@ -69,9 +98,12 @@ def redact_sensitive_text(text: str) -> str:
|
||||
"""Apply all redaction patterns to a block of text.
|
||||
|
||||
Safe to call on any string -- non-matching text passes through unchanged.
|
||||
Disabled when security.redact_secrets is false in config.yaml.
|
||||
"""
|
||||
if not text:
|
||||
return text
|
||||
if os.getenv("HERMES_REDACT_SECRETS", "").lower() in ("0", "false", "no", "off"):
|
||||
return text
|
||||
|
||||
# Known prefixes (sk-, ghp_, etc.)
|
||||
text = _PREFIX_RE.sub(lambda m: _mask_token(m.group(1)), text)
|
||||
@@ -101,6 +133,20 @@ def redact_sensitive_text(text: str) -> str:
|
||||
return f"{prefix}{digits}:***"
|
||||
text = _TELEGRAM_RE.sub(_redact_telegram, text)
|
||||
|
||||
# Private key blocks
|
||||
text = _PRIVATE_KEY_RE.sub("[REDACTED PRIVATE KEY]", text)
|
||||
|
||||
# Database connection string passwords
|
||||
text = _DB_CONNSTR_RE.sub(lambda m: f"{m.group(1)}***{m.group(3)}", text)
|
||||
|
||||
# E.164 phone numbers (Signal, WhatsApp)
|
||||
def _redact_phone(m):
|
||||
phone = m.group(1)
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:]
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
text = _SIGNAL_PHONE_RE.sub(_redact_phone, text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
global _skill_commands
|
||||
_skill_commands = {}
|
||||
try:
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter
|
||||
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform
|
||||
if not SKILLS_DIR.exists():
|
||||
return _skill_commands
|
||||
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
|
||||
@@ -31,6 +31,9 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
|
||||
try:
|
||||
content = skill_md.read_text(encoding='utf-8')
|
||||
frontmatter, body = _parse_frontmatter(content)
|
||||
# Skip skills incompatible with the current OS platform
|
||||
if not skill_matches_platform(frontmatter):
|
||||
continue
|
||||
name = frontmatter.get('name', skill_md.parent.name)
|
||||
description = frontmatter.get('description', '')
|
||||
if not description:
|
||||
|
||||
@@ -1112,7 +1112,7 @@ def main(
|
||||
batch_size: int = None,
|
||||
run_name: str = None,
|
||||
distribution: str = "default",
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
model: str = "anthropic/claude-sonnet-4.6",
|
||||
api_key: str = None,
|
||||
base_url: str = "https://openrouter.ai/api/v1",
|
||||
max_turns: int = 10,
|
||||
@@ -1155,7 +1155,7 @@ def main(
|
||||
providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google")
|
||||
provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only)
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "xhigh")
|
||||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "medium")
|
||||
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
|
||||
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
|
||||
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
|
||||
@@ -1216,7 +1216,7 @@ def main(
|
||||
providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None
|
||||
|
||||
# Build reasoning_config from CLI flags
|
||||
# --reasoning_disabled takes priority, then --reasoning_effort, then default (xhigh)
|
||||
# --reasoning_disabled takes priority, then --reasoning_effort, then default (medium)
|
||||
reasoning_config = None
|
||||
if reasoning_disabled:
|
||||
# Completely disable reasoning/thinking tokens
|
||||
|
||||
@@ -13,6 +13,10 @@ model:
|
||||
# "auto" - Use Nous Portal if logged in, otherwise OpenRouter/env vars (default)
|
||||
# "openrouter" - Always use OpenRouter API key from OPENROUTER_API_KEY
|
||||
# "nous" - Always use Nous Portal (requires: hermes login)
|
||||
# "zai" - Use z.ai / ZhipuAI GLM models (requires: GLM_API_KEY)
|
||||
# "kimi-coding"- Use Kimi / Moonshot AI models (requires: KIMI_API_KEY)
|
||||
# "minimax" - Use MiniMax global endpoint (requires: MINIMAX_API_KEY)
|
||||
# "minimax-cn" - Use MiniMax China endpoint (requires: MINIMAX_CN_API_KEY)
|
||||
# Can also be overridden with --provider flag or HERMES_INFERENCE_PROVIDER env var.
|
||||
provider: "auto"
|
||||
|
||||
@@ -46,6 +50,16 @@ model:
|
||||
# # Data policy: "allow" (default) or "deny" to exclude providers that may store data
|
||||
# # data_collection: "deny"
|
||||
|
||||
# =============================================================================
|
||||
# Git Worktree Isolation
|
||||
# =============================================================================
|
||||
# When enabled, each CLI session creates an isolated git worktree so multiple
|
||||
# agents can work on the same repo concurrently without file collisions.
|
||||
# Equivalent to always passing --worktree / -w on the command line.
|
||||
#
|
||||
# worktree: true # Always create a worktree when in a git repo
|
||||
# worktree: false # Default — only create when -w flag is passed
|
||||
|
||||
# =============================================================================
|
||||
# Terminal Tool Configuration
|
||||
# =============================================================================
|
||||
@@ -195,8 +209,58 @@ compression:
|
||||
threshold: 0.85
|
||||
|
||||
# Model to use for generating summaries (fast/cheap recommended)
|
||||
# This model compresses the middle turns into a concise summary
|
||||
# This model compresses the middle turns into a concise summary.
|
||||
# IMPORTANT: it receives the full middle section of the conversation, so it
|
||||
# MUST support a context length at least as large as your main model's.
|
||||
summary_model: "google/gemini-3-flash-preview"
|
||||
|
||||
# Provider for the summary model (default: "auto")
|
||||
# Options: "auto", "openrouter", "nous", "main"
|
||||
# summary_provider: "auto"
|
||||
|
||||
# =============================================================================
|
||||
# Auxiliary Models (Advanced — Experimental)
|
||||
# =============================================================================
|
||||
# Hermes uses lightweight "auxiliary" models for side tasks: image analysis,
|
||||
# browser screenshot analysis, web page summarization, and context compression.
|
||||
#
|
||||
# By default these use Gemini Flash via OpenRouter or Nous Portal and are
|
||||
# auto-detected from your credentials. You do NOT need to change anything
|
||||
# here for normal usage.
|
||||
#
|
||||
# WARNING: Overriding these with providers other than OpenRouter or Nous Portal
|
||||
# is EXPERIMENTAL and may not work. Not all models/providers support vision,
|
||||
# produce usable summaries, or accept the same API format. Change at your own
|
||||
# risk — if things break, reset to "auto" / empty values.
|
||||
#
|
||||
# Each task has its own provider + model pair so you can mix providers.
|
||||
# For example: OpenRouter for vision (needs multimodal), but your main
|
||||
# local endpoint for compression (just needs text).
|
||||
#
|
||||
# Provider options:
|
||||
# "auto" - Best available: OpenRouter → Nous Portal → main endpoint (default)
|
||||
# "openrouter" - Force OpenRouter (requires OPENROUTER_API_KEY)
|
||||
# "nous" - Force Nous Portal (requires: hermes login)
|
||||
# "codex" - Force Codex OAuth (requires: hermes model → Codex).
|
||||
# Uses gpt-5.3-codex which supports vision.
|
||||
# "main" - Use your custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY).
|
||||
# Works with OpenAI API, local models, or any OpenAI-compatible
|
||||
# endpoint. Also falls back to Codex OAuth and API-key providers.
|
||||
#
|
||||
# Model: leave empty to use the provider's default. When empty, OpenRouter
|
||||
# uses "google/gemini-3-flash-preview" and Nous uses "gemini-3-flash".
|
||||
# Other providers pick a sensible default automatically.
|
||||
#
|
||||
# auxiliary:
|
||||
# # Image analysis: vision_analyze tool + browser screenshots
|
||||
# vision:
|
||||
# provider: "auto"
|
||||
# model: "" # e.g. "google/gemini-2.5-flash", "openai/gpt-4o"
|
||||
#
|
||||
# # Web page scraping / summarization + browser page text extraction
|
||||
# web_extract:
|
||||
# provider: "auto"
|
||||
# model: ""
|
||||
|
||||
# =============================================================================
|
||||
# Persistent Memory
|
||||
@@ -281,7 +345,7 @@ agent:
|
||||
# Reasoning effort level (OpenRouter and Nous Portal)
|
||||
# Controls how much "thinking" the model does before responding.
|
||||
# Options: "xhigh" (max), "high", "medium", "low", "minimal", "none" (disable)
|
||||
reasoning_effort: "xhigh"
|
||||
reasoning_effort: "medium"
|
||||
|
||||
# Predefined personalities (use with /personality command)
|
||||
personalities:
|
||||
@@ -571,3 +635,8 @@ display:
|
||||
# verbose: Full args, results, and debug logs (same as /verbose)
|
||||
# Toggle at runtime with /verbose in the CLI
|
||||
tool_progress: all
|
||||
|
||||
# Play terminal bell when agent finishes a response.
|
||||
# Useful for long-running tasks — your terminal will ding when the agent is done.
|
||||
# Works over SSH. Most terminals can be configured to flash the taskbar or play a sound.
|
||||
bell_on_complete: false
|
||||
|
||||
45
cron/jobs.py
45
cron/jobs.py
@@ -14,6 +14,8 @@ from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, List, Any
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
try:
|
||||
from croniter import croniter
|
||||
HAS_CRONITER = True
|
||||
@@ -128,7 +130,7 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
# Duration like "30m", "2h", "1d" → one-shot from now
|
||||
try:
|
||||
minutes = parse_duration(schedule)
|
||||
run_at = datetime.now() + timedelta(minutes=minutes)
|
||||
run_at = _hermes_now() + timedelta(minutes=minutes)
|
||||
return {
|
||||
"kind": "once",
|
||||
"run_at": run_at.isoformat(),
|
||||
@@ -146,37 +148,50 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
|
||||
def _ensure_aware(dt: datetime) -> datetime:
|
||||
"""Make a naive datetime tz-aware using the configured timezone.
|
||||
|
||||
Handles backward compatibility: timestamps stored before timezone support
|
||||
are naive (server-local). We assume they were in the same timezone as
|
||||
the current configuration so comparisons work without crashing.
|
||||
"""
|
||||
if dt.tzinfo is None:
|
||||
tz = _hermes_now().tzinfo
|
||||
return dt.replace(tzinfo=tz)
|
||||
return dt
|
||||
|
||||
|
||||
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Compute the next run time for a schedule.
|
||||
|
||||
|
||||
Returns ISO timestamp string, or None if no more runs.
|
||||
"""
|
||||
now = datetime.now()
|
||||
|
||||
now = _hermes_now()
|
||||
|
||||
if schedule["kind"] == "once":
|
||||
run_at = datetime.fromisoformat(schedule["run_at"])
|
||||
run_at = _ensure_aware(datetime.fromisoformat(schedule["run_at"]))
|
||||
# If in the future, return it; if in the past, no more runs
|
||||
return schedule["run_at"] if run_at > now else None
|
||||
|
||||
|
||||
elif schedule["kind"] == "interval":
|
||||
minutes = schedule["minutes"]
|
||||
if last_run_at:
|
||||
# Next run is last_run + interval
|
||||
last = datetime.fromisoformat(last_run_at)
|
||||
last = _ensure_aware(datetime.fromisoformat(last_run_at))
|
||||
next_run = last + timedelta(minutes=minutes)
|
||||
else:
|
||||
# First run is now + interval
|
||||
next_run = now + timedelta(minutes=minutes)
|
||||
return next_run.isoformat()
|
||||
|
||||
|
||||
elif schedule["kind"] == "cron":
|
||||
if not HAS_CRONITER:
|
||||
return None
|
||||
cron = croniter(schedule["expr"], now)
|
||||
next_run = cron.get_next(datetime)
|
||||
return next_run.isoformat()
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -204,7 +219,7 @@ def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix='.tmp', prefix='.jobs_')
|
||||
try:
|
||||
with os.fdopen(fd, 'w', encoding='utf-8') as f:
|
||||
json.dump({"jobs": jobs, "updated_at": datetime.now().isoformat()}, f, indent=2)
|
||||
json.dump({"jobs": jobs, "updated_at": _hermes_now().isoformat()}, f, indent=2)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, JOBS_FILE)
|
||||
@@ -249,7 +264,7 @@ def create_job(
|
||||
deliver = "origin" if origin else "local"
|
||||
|
||||
job_id = uuid.uuid4().hex[:12]
|
||||
now = datetime.now().isoformat()
|
||||
now = _hermes_now().isoformat()
|
||||
|
||||
job = {
|
||||
"id": job_id,
|
||||
@@ -328,7 +343,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
if job["id"] == job_id:
|
||||
now = datetime.now().isoformat()
|
||||
now = _hermes_now().isoformat()
|
||||
job["last_run_at"] = now
|
||||
job["last_status"] = "ok" if success else "error"
|
||||
job["last_error"] = error if not success else None
|
||||
@@ -361,7 +376,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now."""
|
||||
now = datetime.now()
|
||||
now = _hermes_now()
|
||||
jobs = load_jobs()
|
||||
due = []
|
||||
|
||||
@@ -373,7 +388,7 @@ def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
if not next_run:
|
||||
continue
|
||||
|
||||
next_run_dt = datetime.fromisoformat(next_run)
|
||||
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
|
||||
if next_run_dt <= now:
|
||||
due.append(job)
|
||||
|
||||
@@ -386,7 +401,7 @@ def save_job_output(job_id: str, output: str):
|
||||
job_output_dir = OUTPUT_DIR / job_id
|
||||
job_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
timestamp = _hermes_now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
output_file = job_output_dir / f"{timestamp}.md"
|
||||
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
|
||||
@@ -27,6 +27,8 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from hermes_time import now as _hermes_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add parent directory to path for imports
|
||||
@@ -96,6 +98,7 @@ def _deliver_result(job: dict, content: str) -> None:
|
||||
"discord": Platform.DISCORD,
|
||||
"slack": Platform.SLACK,
|
||||
"whatsapp": Platform.WHATSAPP,
|
||||
"signal": Platform.SIGNAL,
|
||||
}
|
||||
platform = platform_map.get(platform_name.lower())
|
||||
if not platform:
|
||||
@@ -174,6 +177,8 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
model = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6"
|
||||
|
||||
# Load config.yaml for model, reasoning, prefill, toolsets, provider routing
|
||||
_cfg = {}
|
||||
try:
|
||||
import yaml
|
||||
_cfg_path = str(_hermes_home / "config.yaml")
|
||||
@@ -188,6 +193,41 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Reasoning config from env or config.yaml
|
||||
reasoning_config = None
|
||||
effort = os.getenv("HERMES_REASONING_EFFORT", "")
|
||||
if not effort:
|
||||
effort = str(_cfg.get("agent", {}).get("reasoning_effort", "")).strip()
|
||||
if effort and effort.lower() != "none":
|
||||
valid = ("xhigh", "high", "medium", "low", "minimal")
|
||||
if effort.lower() in valid:
|
||||
reasoning_config = {"enabled": True, "effort": effort.lower()}
|
||||
elif effort.lower() == "none":
|
||||
reasoning_config = {"enabled": False}
|
||||
|
||||
# Prefill messages from env or config.yaml
|
||||
prefill_messages = None
|
||||
prefill_file = os.getenv("HERMES_PREFILL_MESSAGES_FILE", "") or _cfg.get("prefill_messages_file", "")
|
||||
if prefill_file:
|
||||
import json as _json
|
||||
pfpath = Path(prefill_file).expanduser()
|
||||
if not pfpath.is_absolute():
|
||||
pfpath = _hermes_home / pfpath
|
||||
if pfpath.exists():
|
||||
try:
|
||||
with open(pfpath, "r", encoding="utf-8") as _pf:
|
||||
prefill_messages = _json.load(_pf)
|
||||
if not isinstance(prefill_messages, list):
|
||||
prefill_messages = None
|
||||
except Exception:
|
||||
prefill_messages = None
|
||||
|
||||
# Max iterations
|
||||
max_iterations = _cfg.get("agent", {}).get("max_turns") or _cfg.get("max_turns") or 90
|
||||
|
||||
# Provider routing
|
||||
pr = _cfg.get("provider_routing", {})
|
||||
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
format_runtime_provider_error,
|
||||
@@ -206,8 +246,15 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
base_url=runtime.get("base_url"),
|
||||
provider=runtime.get("provider"),
|
||||
api_mode=runtime.get("api_mode"),
|
||||
max_iterations=max_iterations,
|
||||
reasoning_config=reasoning_config,
|
||||
prefill_messages=prefill_messages,
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
quiet_mode=True,
|
||||
session_id=f"cron_{job_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
|
||||
)
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
@@ -219,7 +266,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
output = f"""# Cron Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
|
||||
## Prompt
|
||||
@@ -241,7 +288,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
output = f"""# Cron Job: {job_name} (FAILED)
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Schedule:** {job.get('schedule_display', 'N/A')}
|
||||
|
||||
## Prompt
|
||||
@@ -297,11 +344,11 @@ def tick(verbose: bool = True) -> int:
|
||||
due_jobs = get_due_jobs()
|
||||
|
||||
if verbose and not due_jobs:
|
||||
logger.info("%s - No jobs due", datetime.now().strftime('%H:%M:%S'))
|
||||
logger.info("%s - No jobs due", _hermes_now().strftime('%H:%M:%S'))
|
||||
return 0
|
||||
|
||||
if verbose:
|
||||
logger.info("%s - %s job(s) due", datetime.now().strftime('%H:%M:%S'), len(due_jobs))
|
||||
logger.info("%s - %s job(s) due", _hermes_now().strftime('%H:%M:%S'), len(due_jobs))
|
||||
|
||||
executed = 0
|
||||
for job in due_jobs:
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
# Documentation
|
||||
|
||||
All documentation has moved to the website:
|
||||
|
||||
**📖 [hermes-agent.nousresearch.com/docs](https://hermes-agent.nousresearch.com/docs/)**
|
||||
|
||||
The documentation source files live in [`website/docs/`](../website/docs/).
|
||||
@@ -1,344 +0,0 @@
|
||||
# send_file Integration Map — Hermes Agent Codebase Deep Dive
|
||||
|
||||
## 1. environments/tool_context.py — Base64 File Transfer Implementation
|
||||
|
||||
### upload_file() (lines 153-205)
|
||||
- Reads local file as raw bytes, base64-encodes to ASCII string
|
||||
- Creates parent dirs in sandbox via `self.terminal(f"mkdir -p {parent}")`
|
||||
- **Chunk size:** 60,000 chars (~60KB per shell command)
|
||||
- **Small files (<=60KB b64):** Single `printf '%s' '{b64}' | base64 -d > {remote_path}`
|
||||
- **Large files:** Writes chunks to `/tmp/_hermes_upload.b64` via `printf >> append`, then `base64 -d` to target
|
||||
- **Error handling:** Checks local file exists; returns `{exit_code, output}`
|
||||
- **Size limits:** No explicit limit, but shell arg limit ~2MB means chunking is necessary for files >~45KB raw
|
||||
- **No theoretical max** — but very large files would be slow (many terminal round trips)
|
||||
|
||||
### download_file() (lines 234-278)
|
||||
- Runs `base64 {remote_path}` inside sandbox, captures stdout
|
||||
- Strips output, base64-decodes to raw bytes
|
||||
- Writes to host filesystem with parent dir creation
|
||||
- **Error handling:** Checks exit code, empty output, decode errors
|
||||
- Returns `{success: bool, bytes: int}` or `{success: false, error: str}`
|
||||
- **Size limit:** Bounded by terminal output buffer (practical limit ~few MB via base64 terminal output)
|
||||
|
||||
### Promotion potential:
|
||||
- These methods work via `self.terminal()` — they're environment-agnostic
|
||||
- Could be directly lifted into a new tool that operates on the agent's current sandbox
|
||||
- For send_file, this `download_file()` pattern is the key: it extracts files from sandbox → host
|
||||
|
||||
## 2. tools/environments/base.py — BaseEnvironment Interface
|
||||
|
||||
### Current methods:
|
||||
- `execute(command, cwd, timeout, stdin_data)` → `{output, returncode}`
|
||||
- `cleanup()` — release resources
|
||||
- `stop()` — alias for cleanup
|
||||
- `_prepare_command()` — sudo transformation
|
||||
- `_build_run_kwargs()` — subprocess kwargs
|
||||
- `_timeout_result()` — standard timeout dict
|
||||
|
||||
### What would need to be added for file transfer:
|
||||
- **Nothing required at this level.** File transfer can be implemented via `execute()` (base64 over terminal, like ToolContext does) or via environment-specific methods.
|
||||
- Optional: `upload_file(local_path, remote_path)` and `download_file(remote_path, local_path)` methods could be added to BaseEnvironment for optimized per-backend transfers, but the base64-over-terminal approach already works universally.
|
||||
|
||||
## 3. tools/environments/docker.py — Docker Container Details
|
||||
|
||||
### Container ID tracking:
|
||||
- `self._container_id` stored at init from `self._inner.container_id`
|
||||
- Inner is `minisweagent.environments.docker.DockerEnvironment`
|
||||
- Container ID is a standard Docker container hash
|
||||
|
||||
### docker cp feasibility:
|
||||
- **YES**, `docker cp` could be used for optimized file transfer:
|
||||
- `docker cp {container_id}:{remote_path} {local_path}` (download)
|
||||
- `docker cp {local_path} {container_id}:{remote_path}` (upload)
|
||||
- Much faster than base64-over-terminal for large files
|
||||
- Container ID is directly accessible via `env._container_id` or `env._inner.container_id`
|
||||
|
||||
### Volumes mounted:
|
||||
- **Persistent mode:** Bind mounts at `~/.hermes/sandboxes/docker/{task_id}/workspace` → `/workspace` and `.../home` → `/root`
|
||||
- **Ephemeral mode:** tmpfs at `/workspace` (10GB), `/home` (1GB), `/root` (1GB)
|
||||
- **User volumes:** From `config.yaml docker_volumes` (arbitrary `-v` mounts)
|
||||
- **Security tmpfs:** `/tmp` (512MB), `/var/tmp` (256MB), `/run` (64MB)
|
||||
|
||||
### Direct host access for persistent mode:
|
||||
- If persistent, files at `/workspace/foo.txt` are just `~/.hermes/sandboxes/docker/{task_id}/workspace/foo.txt` on host — no transfer needed!
|
||||
|
||||
## 4. tools/environments/ssh.py — SSH Connection Management
|
||||
|
||||
### Connection management:
|
||||
- Uses SSH ControlMaster for persistent connection
|
||||
- Control socket at `/tmp/hermes-ssh/{user}@{host}:{port}.sock`
|
||||
- ControlPersist=300 (5 min keepalive)
|
||||
- BatchMode=yes (non-interactive)
|
||||
- Stores: `self.host`, `self.user`, `self.port`, `self.key_path`
|
||||
|
||||
### SCP/SFTP feasibility:
|
||||
- **YES**, SCP can piggyback on the ControlMaster socket:
|
||||
- `scp -o ControlPath={socket} {user}@{host}:{remote} {local}` (download)
|
||||
- `scp -o ControlPath={socket} {local} {user}@{host}:{remote}` (upload)
|
||||
- Same SSH key and connection reuse — zero additional auth
|
||||
- Would be much faster than base64-over-terminal for large files
|
||||
|
||||
## 5. tools/environments/modal.py — Modal Sandbox Filesystem
|
||||
|
||||
### Filesystem API exposure:
|
||||
- **Not directly.** The inner `SwerexModalEnvironment` wraps Modal's sandbox
|
||||
- The sandbox object is accessible at: `env._inner.deployment._sandbox`
|
||||
- Modal's Python SDK exposes `sandbox.open()` for file I/O — but only via async API
|
||||
- Currently only used for `snapshot_filesystem()` during cleanup
|
||||
- **Could use:** `sandbox.open(path, "rb")` to read files or `sandbox.open(path, "wb")` to write
|
||||
- **Alternative:** Base64-over-terminal already works via `execute()` — simpler, no SDK dependency
|
||||
|
||||
## 6. gateway/platforms/base.py — MEDIA: Tag Flow (Complete)
|
||||
|
||||
### extract_media() (lines 587-620):
|
||||
- **Pattern:** `MEDIA:\S+` — extracts file paths after MEDIA: prefix
|
||||
- **Voice flag:** `[[audio_as_voice]]` global directive sets `is_voice=True` for all media in message
|
||||
- Returns `List[Tuple[str, bool]]` (path, is_voice) and cleaned content
|
||||
|
||||
### _process_message_background() media routing (lines 752-786):
|
||||
- After extracting MEDIA tags, routes by file extension:
|
||||
- `.ogg .opus .mp3 .wav .m4a` → `send_voice()`
|
||||
- `.mp4 .mov .avi .mkv .3gp` → `send_video()`
|
||||
- `.jpg .jpeg .png .webp .gif` → `send_image_file()`
|
||||
- **Everything else** → `send_document()`
|
||||
- This routing already supports arbitrary files!
|
||||
|
||||
### send_* method inventory (base class):
|
||||
- `send(chat_id, content, reply_to, metadata)` — ABSTRACT, text
|
||||
- `send_image(chat_id, image_url, caption, reply_to)` — URL-based images
|
||||
- `send_animation(chat_id, animation_url, caption, reply_to)` — GIF animations
|
||||
- `send_voice(chat_id, audio_path, caption, reply_to)` — voice messages
|
||||
- `send_video(chat_id, video_path, caption, reply_to)` — video files
|
||||
- `send_document(chat_id, file_path, caption, file_name, reply_to)` — generic files
|
||||
- `send_image_file(chat_id, image_path, caption, reply_to)` — local image files
|
||||
- `send_typing(chat_id)` — typing indicator
|
||||
- `edit_message(chat_id, message_id, content)` — edit sent messages
|
||||
|
||||
### What's missing:
|
||||
- **Telegram:** No override for `send_document` or `send_image_file` — falls back to text!
|
||||
- **Discord:** No override for `send_document` — falls back to text!
|
||||
- **WhatsApp:** Has `send_document` and `send_image_file` via bridge — COMPLETE.
|
||||
- The base class defaults just send "📎 File: /path" as text — useless for actual file delivery.
|
||||
|
||||
## 7. gateway/platforms/telegram.py — Send Method Analysis
|
||||
|
||||
### Implemented send methods:
|
||||
- `send()` — MarkdownV2 text with fallback to plain
|
||||
- `send_voice()` — `.ogg`/`.opus` as `send_voice()`, others as `send_audio()`
|
||||
- `send_image()` — URL-based via `send_photo()`
|
||||
- `send_animation()` — GIF via `send_animation()`
|
||||
- `send_typing()` — "typing" chat action
|
||||
- `edit_message()` — edit text messages
|
||||
|
||||
### MISSING:
|
||||
- **`send_document()` NOT overridden** — Need to add `self._bot.send_document(chat_id, document=open(file_path, 'rb'), ...)`
|
||||
- **`send_image_file()` NOT overridden** — Need to add `self._bot.send_photo(chat_id, photo=open(path, 'rb'), ...)`
|
||||
- **`send_video()` NOT overridden** — Need to add `self._bot.send_video(...)`
|
||||
|
||||
## 8. gateway/platforms/discord.py — Send Method Analysis
|
||||
|
||||
### Implemented send methods:
|
||||
- `send()` — text messages with chunking
|
||||
- `send_voice()` — discord.File attachment
|
||||
- `send_image()` — downloads URL, creates discord.File attachment
|
||||
- `send_typing()` — channel.typing()
|
||||
- `edit_message()` — edit text messages
|
||||
|
||||
### MISSING:
|
||||
- **`send_document()` NOT overridden** — Need to add discord.File attachment
|
||||
- **`send_image_file()` NOT overridden** — Need to add discord.File from local path
|
||||
- **`send_video()` NOT overridden** — Need to add discord.File attachment
|
||||
|
||||
## 9. gateway/run.py — User File Attachment Handling
|
||||
|
||||
### Current attachment flow:
|
||||
1. **Telegram photos** (line 509-529): Download via `photo.get_file()` → `cache_image_from_bytes()` → vision auto-analysis
|
||||
2. **Telegram voice** (line 532-541): Download → `cache_audio_from_bytes()` → STT transcription
|
||||
3. **Telegram audio** (line 542-551): Same pattern
|
||||
4. **Telegram documents** (line 553-617): Extension validation against `SUPPORTED_DOCUMENT_TYPES`, 20MB limit, content injection for text files
|
||||
5. **Discord attachments** (line 717-751): Content-type detection, image/audio caching, URL fallback for other types
|
||||
6. **Gateway run.py** (lines 818-883): Auto-analyzes images with vision, transcribes audio, enriches document messages with context notes
|
||||
|
||||
### Key insight: Files are always cached to host filesystem first, then processed. The agent sees local file paths.
|
||||
|
||||
## 10. tools/terminal_tool.py — Terminal Tool & Environment Interaction
|
||||
|
||||
### How it manages environments:
|
||||
- Global dict `_active_environments: Dict[str, Any]` keyed by task_id
|
||||
- Per-task creation locks prevent duplicate sandbox creation
|
||||
- Auto-cleanup thread kills idle environments after `TERMINAL_LIFETIME_SECONDS`
|
||||
- `_get_env_config()` reads all TERMINAL_* env vars for backend selection
|
||||
- `_create_environment()` factory creates the right backend type
|
||||
|
||||
### Could send_file piggyback?
|
||||
- **YES.** send_file needs access to the same environment to extract files from sandboxes.
|
||||
- It can reuse `_active_environments[task_id]` to get the environment, then:
|
||||
- Docker: Use `docker cp` via `env._container_id`
|
||||
- SSH: Use `scp` via `env.control_socket`
|
||||
- Local: Just read the file directly
|
||||
- Modal: Use base64-over-terminal via `env.execute()`
|
||||
- The file_tools.py module already does this with `ShellFileOperations` — read_file/write_file/search/patch all share the same env instance.
|
||||
|
||||
## 11. tools/tts_tool.py — Working Example of File Delivery
|
||||
|
||||
### Flow:
|
||||
1. Generate audio file to `~/.hermes/audio_cache/tts_TIMESTAMP.{ogg,mp3}`
|
||||
2. Return JSON with `media_tag: "MEDIA:/path/to/file"`
|
||||
3. For Telegram voice: prepend `[[audio_as_voice]]` directive
|
||||
4. The LLM includes the MEDIA tag in its response text
|
||||
5. `BasePlatformAdapter._process_message_background()` calls `extract_media()` to find the tag
|
||||
6. Routes by extension → `send_voice()` for audio files
|
||||
7. Platform adapter sends the file natively
|
||||
|
||||
### Key pattern: Tool saves file to host → returns MEDIA: path → LLM echoes it → gateway extracts → platform delivers
|
||||
|
||||
## 12. tools/image_generation_tool.py — Working Example of Image Delivery
|
||||
|
||||
### Flow:
|
||||
1. Call FAL.ai API → get image URL
|
||||
2. Return JSON with `image: "https://fal.media/..."` URL
|
||||
3. The LLM includes the URL in markdown: ``
|
||||
4. `BasePlatformAdapter.extract_images()` finds `` patterns
|
||||
5. Routes through `send_image()` (URL) or `send_animation()` (GIF)
|
||||
6. Platform downloads and sends natively
|
||||
|
||||
### Key difference from TTS: Images are URL-based, not local files. The gateway downloads at send time.
|
||||
|
||||
---
|
||||
|
||||
# INTEGRATION MAP: Where send_file Hooks In
|
||||
|
||||
## Architecture Decision: MEDIA: Tag Protocol vs. New Tool
|
||||
|
||||
The MEDIA: tag protocol is already the established pattern for file delivery. Two options:
|
||||
|
||||
### Option A: Pure MEDIA: Tag (Minimal Change)
|
||||
- No new tool needed
|
||||
- Agent downloads file from sandbox to host using terminal (base64)
|
||||
- Saves to known location (e.g., `~/.hermes/file_cache/`)
|
||||
- Includes `MEDIA:/path` in response text
|
||||
- Existing routing in `_process_message_background()` handles delivery
|
||||
- **Problem:** Agent has to manually do base64 dance + know about MEDIA: convention
|
||||
|
||||
### Option B: Dedicated send_file Tool (Recommended)
|
||||
- New tool that the agent calls with `(file_path, caption?)`
|
||||
- Tool handles the sandbox → host extraction automatically
|
||||
- Returns MEDIA: tag that gets routed through existing pipeline
|
||||
- Much cleaner agent experience
|
||||
|
||||
## Implementation Plan for Option B
|
||||
|
||||
### Files to CREATE:
|
||||
|
||||
1. **`tools/send_file_tool.py`** — The new tool
|
||||
- Accepts: `file_path` (path in sandbox), `caption` (optional)
|
||||
- Detects environment backend from `_active_environments`
|
||||
- Extracts file from sandbox:
|
||||
- **local:** `shutil.copy()` or direct path
|
||||
- **docker:** `docker cp {container_id}:{path} {local_cache}/`
|
||||
- **ssh:** `scp -o ControlPath=... {user}@{host}:{path} {local_cache}/`
|
||||
- **modal:** base64-over-terminal via `env.execute("base64 {path}")`
|
||||
- Saves to `~/.hermes/file_cache/{uuid}_{filename}`
|
||||
- Returns: `MEDIA:/cached/path` in response for gateway to pick up
|
||||
- Register with `registry.register(name="send_file", toolset="file", ...)`
|
||||
|
||||
### Files to MODIFY:
|
||||
|
||||
2. **`gateway/platforms/telegram.py`** — Add missing send methods:
|
||||
```python
|
||||
async def send_document(self, chat_id, file_path, caption=None, file_name=None, reply_to=None):
|
||||
with open(file_path, "rb") as f:
|
||||
msg = await self._bot.send_document(
|
||||
chat_id=int(chat_id), document=f,
|
||||
caption=caption, filename=file_name or os.path.basename(file_path))
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
|
||||
async def send_image_file(self, chat_id, image_path, caption=None, reply_to=None):
|
||||
with open(image_path, "rb") as f:
|
||||
msg = await self._bot.send_photo(chat_id=int(chat_id), photo=f, caption=caption)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
|
||||
async def send_video(self, chat_id, video_path, caption=None, reply_to=None):
|
||||
with open(video_path, "rb") as f:
|
||||
msg = await self._bot.send_video(chat_id=int(chat_id), video=f, caption=caption)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
```
|
||||
|
||||
3. **`gateway/platforms/discord.py`** — Add missing send methods:
|
||||
```python
|
||||
async def send_document(self, chat_id, file_path, caption=None, file_name=None, reply_to=None):
|
||||
channel = self._client.get_channel(int(chat_id)) or await self._client.fetch_channel(int(chat_id))
|
||||
with open(file_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=file_name or os.path.basename(file_path))
|
||||
msg = await channel.send(content=caption, file=file)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
async def send_image_file(self, chat_id, image_path, caption=None, reply_to=None):
|
||||
# Same pattern as send_document with image filename
|
||||
|
||||
async def send_video(self, chat_id, video_path, caption=None, reply_to=None):
|
||||
# Same pattern, discord renders video attachments inline
|
||||
```
|
||||
|
||||
4. **`toolsets.py`** — Add `"send_file"` to `_HERMES_CORE_TOOLS` list
|
||||
|
||||
5. **`agent/prompt_builder.py`** — Update platform hints to mention send_file tool
|
||||
|
||||
### Code that can be REUSED (zero rewrite):
|
||||
|
||||
- `BasePlatformAdapter.extract_media()` — Already extracts MEDIA: tags
|
||||
- `BasePlatformAdapter._process_message_background()` — Already routes by extension
|
||||
- `ToolContext.download_file()` — Base64-over-terminal extraction pattern
|
||||
- `tools/terminal_tool.py` _active_environments dict — Environment access
|
||||
- `tools/registry.py` — Tool registration infrastructure
|
||||
- `gateway/platforms/base.py` send_document/send_image_file/send_video signatures — Already defined
|
||||
|
||||
### Code that needs to be WRITTEN from scratch:
|
||||
|
||||
1. `tools/send_file_tool.py` (~150 lines):
|
||||
- File extraction from each environment backend type
|
||||
- Local file cache management
|
||||
- Registry registration
|
||||
|
||||
2. Telegram `send_document` + `send_image_file` + `send_video` overrides (~40 lines)
|
||||
3. Discord `send_document` + `send_image_file` + `send_video` overrides (~50 lines)
|
||||
|
||||
### Total effort: ~240 lines of new code, ~5 lines of config changes
|
||||
|
||||
## Key Environment-Specific Extract Strategies
|
||||
|
||||
| Backend | Extract Method | Speed | Complexity |
|
||||
|------------|-------------------------------|----------|------------|
|
||||
| local | shutil.copy / direct path | Instant | None |
|
||||
| docker | `docker cp container:path .` | Fast | Low |
|
||||
| docker+vol | Direct host path access | Instant | None |
|
||||
| ssh | `scp -o ControlPath=...` | Fast | Low |
|
||||
| modal | base64-over-terminal | Moderate | Medium |
|
||||
| singularity| Direct path (overlay mount) | Fast | Low |
|
||||
|
||||
## Data Flow Summary
|
||||
|
||||
```
|
||||
Agent calls send_file(file_path="/workspace/output.pdf", caption="Here's the report")
|
||||
│
|
||||
▼
|
||||
send_file_tool.py:
|
||||
1. Get environment from _active_environments[task_id]
|
||||
2. Detect backend type (docker/ssh/modal/local)
|
||||
3. Extract file to ~/.hermes/file_cache/{uuid}_{filename}
|
||||
4. Return: '{"success": true, "media_tag": "MEDIA:/home/user/.hermes/file_cache/abc123_output.pdf"}'
|
||||
│
|
||||
▼
|
||||
LLM includes MEDIA: tag in its response text
|
||||
│
|
||||
▼
|
||||
BasePlatformAdapter._process_message_background():
|
||||
1. extract_media(response) → finds MEDIA:/path
|
||||
2. Checks extension: .pdf → send_document()
|
||||
3. Calls platform-specific send_document(chat_id, file_path, caption)
|
||||
│
|
||||
▼
|
||||
TelegramAdapter.send_document() / DiscordAdapter.send_document():
|
||||
Opens file, sends via platform API as native document attachment
|
||||
User receives downloadable file in chat
|
||||
```
|
||||
@@ -195,8 +195,12 @@ environments/
|
||||
│ └── hermes_swe_env.py
|
||||
│
|
||||
└── benchmarks/ # Evaluation benchmarks
|
||||
└── terminalbench_2/
|
||||
└── terminalbench2_env.py
|
||||
├── terminalbench_2/ # 89 terminal tasks, Modal sandboxes
|
||||
│ └── terminalbench2_env.py
|
||||
├── tblite/ # 100 calibrated tasks (fast TB2 proxy)
|
||||
│ └── tblite_env.py
|
||||
└── yc_bench/ # Long-horizon strategic benchmark
|
||||
└── yc_bench_env.py
|
||||
```
|
||||
|
||||
## Concrete Environments
|
||||
|
||||
115
environments/benchmarks/yc_bench/README.md
Normal file
115
environments/benchmarks/yc_bench/README.md
Normal file
@@ -0,0 +1,115 @@
|
||||
# YC-Bench: Long-Horizon Agent Benchmark
|
||||
|
||||
[YC-Bench](https://github.com/collinear-ai/yc-bench) by [Collinear AI](https://collinear.ai/) is a deterministic, long-horizon benchmark that tests LLM agents' ability to act as a tech startup CEO. The agent manages a simulated company over 1-3 years, making compounding decisions about resource allocation, cash flow, task management, and prestige specialisation across 4 skill domains.
|
||||
|
||||
Unlike TerminalBench2 (which evaluates per-task coding ability with binary pass/fail), YC-Bench measures **long-term strategic coherence** — whether an agent can maintain consistent strategy, manage compounding consequences, and adapt plans over hundreds of turns.
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
# Install yc-bench (optional dependency)
|
||||
pip install "hermes-agent[yc-bench]"
|
||||
|
||||
# Or install from source
|
||||
git clone https://github.com/collinear-ai/yc-bench
|
||||
cd yc-bench && pip install -e .
|
||||
|
||||
# Verify
|
||||
yc-bench --help
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
# From the repo root:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
|
||||
# Or directly:
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
# Override model:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
# Quick single-preset test:
|
||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
--env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
HermesAgentLoop (our agent)
|
||||
-> terminal tool -> subprocess("yc-bench company status") -> JSON output
|
||||
-> terminal tool -> subprocess("yc-bench task accept --task-id X") -> JSON
|
||||
-> terminal tool -> subprocess("yc-bench sim resume") -> JSON (advance time)
|
||||
-> ... (100-500 turns per run)
|
||||
```
|
||||
|
||||
The environment initialises the simulation via `yc-bench sim init` (NOT `yc-bench run`, which would start yc-bench's own built-in agent loop). Our `HermesAgentLoop` then drives all interaction through CLI commands.
|
||||
|
||||
### Simulation Mechanics
|
||||
|
||||
- **4 skill domains**: research, inference, data_environment, training
|
||||
- **Prestige system** (1.0-10.0): Gates access to higher-paying tasks
|
||||
- **Employee management**: Junior/Mid/Senior with domain-specific skill rates
|
||||
- **Throughput splitting**: `effective_rate = base_rate / N` active tasks per employee
|
||||
- **Financial pressure**: Monthly payroll, bankruptcy = game over
|
||||
- **Deterministic**: SHA256-based RNG — same seed + preset = same world
|
||||
|
||||
### Difficulty Presets
|
||||
|
||||
| Preset | Employees | Tasks | Focus |
|
||||
|-----------|-----------|-------|-------|
|
||||
| tutorial | 3 | 50 | Basic loop mechanics |
|
||||
| easy | 5 | 100 | Throughput awareness |
|
||||
| **medium**| 5 | 150 | Prestige climbing + domain specialisation |
|
||||
| **hard** | 7 | 200 | Precise ETA reasoning |
|
||||
| nightmare | 8 | 300 | Sustained perfection under payroll pressure |
|
||||
| fast_test | (varies) | (varies) | Quick validation (~50 turns) |
|
||||
|
||||
Default eval runs **fast_test + medium + hard** × 3 seeds = 9 runs.
|
||||
|
||||
### Scoring
|
||||
|
||||
```
|
||||
composite = 0.5 × survival + 0.5 × normalised_funds
|
||||
```
|
||||
|
||||
- **Survival** (binary): Did the company avoid bankruptcy?
|
||||
- **Normalised funds** (0.0-1.0): Log-scale relative to initial $250K capital
|
||||
|
||||
## Configuration
|
||||
|
||||
Key fields in `default.yaml`:
|
||||
|
||||
| Field | Default | Description |
|
||||
|-------|---------|-------------|
|
||||
| `presets` | `["fast_test", "medium", "hard"]` | Which presets to evaluate |
|
||||
| `seeds` | `[1, 2, 3]` | RNG seeds per preset |
|
||||
| `max_agent_turns` | 200 | Max LLM calls per run |
|
||||
| `run_timeout` | 3600 | Wall-clock timeout per run (seconds) |
|
||||
| `survival_weight` | 0.5 | Weight of survival in composite score |
|
||||
| `funds_weight` | 0.5 | Weight of normalised funds in composite |
|
||||
| `horizon_years` | null | Override horizon (null = auto from preset) |
|
||||
|
||||
## Cost & Time Estimates
|
||||
|
||||
Each run is 100-500 LLM turns. Approximate costs per run at typical API rates:
|
||||
|
||||
| Preset | Turns | Time | Est. Cost |
|
||||
|--------|-------|------|-----------|
|
||||
| fast_test | ~50 | 5-10 min | $1-5 |
|
||||
| medium | ~200 | 20-40 min | $5-15 |
|
||||
| hard | ~300 | 30-60 min | $10-25 |
|
||||
|
||||
Full default eval (9 runs): ~3-6 hours, $50-200 depending on model.
|
||||
|
||||
## References
|
||||
|
||||
- [collinear-ai/yc-bench](https://github.com/collinear-ai/yc-bench) — Official repository
|
||||
- [Collinear AI](https://collinear.ai/) — Company behind yc-bench
|
||||
- [TerminalBench2](../terminalbench_2/) — Per-task coding benchmark (complementary)
|
||||
0
environments/benchmarks/yc_bench/__init__.py
Normal file
0
environments/benchmarks/yc_bench/__init__.py
Normal file
43
environments/benchmarks/yc_bench/default.yaml
Normal file
43
environments/benchmarks/yc_bench/default.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
# YC-Bench Evaluation -- Default Configuration
|
||||
#
|
||||
# Long-horizon agent benchmark: agent plays CEO of an AI startup over
|
||||
# a simulated 1-3 year run, interacting via yc-bench CLI subcommands.
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Usage:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml
|
||||
#
|
||||
# # Override model:
|
||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
# --config environments/benchmarks/yc_bench/default.yaml \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal"]
|
||||
max_agent_turns: 200
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.0
|
||||
terminal_backend: "local"
|
||||
terminal_timeout: 60
|
||||
presets: ["fast_test", "medium", "hard"]
|
||||
seeds: [1, 2, 3]
|
||||
run_timeout: 3600 # 60 min wall-clock per run, auto-FAIL if exceeded
|
||||
survival_weight: 0.5 # weight of binary survival in composite score
|
||||
funds_weight: 0.5 # weight of normalised final funds in composite score
|
||||
db_dir: "/tmp/yc_bench_dbs"
|
||||
company_name: "BenchCo"
|
||||
start_date: "01/01/2025" # MM/DD/YYYY (yc-bench convention)
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "yc-bench"
|
||||
ensure_scores_are_not_same: false
|
||||
data_dir_to_save_evals: "environments/benchmarks/evals/yc-bench"
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-sonnet-4.6"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
34
environments/benchmarks/yc_bench/run_eval.sh
Executable file
34
environments/benchmarks/yc_bench/run_eval.sh
Executable file
@@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
|
||||
# YC-Bench Evaluation
|
||||
#
|
||||
# Requires: pip install "hermes-agent[yc-bench]"
|
||||
#
|
||||
# Run from repo root:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh
|
||||
#
|
||||
# Override model:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
||||
#
|
||||
# Run a single preset:
|
||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
||||
# --env.presets '["fast_test"]' --env.seeds '[1]'
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
mkdir -p logs evals/yc-bench
|
||||
LOG_FILE="logs/yc_bench_$(date +%Y%m%d_%H%M%S).log"
|
||||
|
||||
echo "YC-Bench Evaluation"
|
||||
echo "Log: $LOG_FILE"
|
||||
echo ""
|
||||
|
||||
PYTHONUNBUFFERED=1 LOGLEVEL="${LOGLEVEL:-INFO}" \
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Log saved to: $LOG_FILE"
|
||||
847
environments/benchmarks/yc_bench/yc_bench_env.py
Normal file
847
environments/benchmarks/yc_bench/yc_bench_env.py
Normal file
@@ -0,0 +1,847 @@
|
||||
"""
|
||||
YCBenchEvalEnv -- YC-Bench Long-Horizon Agent Benchmark Environment
|
||||
|
||||
Evaluates agentic LLMs on YC-Bench: a deterministic, long-horizon benchmark
|
||||
where the agent acts as CEO of an AI startup over a simulated 1-3 year run.
|
||||
The agent manages cash flow, employees, tasks, and prestige across 4 domains,
|
||||
interacting exclusively via CLI subprocess calls against a SQLite-backed
|
||||
discrete-event simulation.
|
||||
|
||||
Unlike TerminalBench2 (per-task binary pass/fail), YC-Bench measures sustained
|
||||
multi-turn strategic coherence -- whether an agent can manage compounding
|
||||
decisions over hundreds of turns without going bankrupt.
|
||||
|
||||
This is an eval-only environment. Run via:
|
||||
|
||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
||||
--config environments/benchmarks/yc_bench/default.yaml
|
||||
|
||||
The evaluate flow:
|
||||
1. setup() -- Verifies yc-bench installed, builds eval matrix (preset x seed)
|
||||
2. evaluate() -- Iterates over all runs sequentially through:
|
||||
a. rollout_and_score_eval() -- Per-run agent loop
|
||||
- Initialises a fresh yc-bench simulation via `sim init` (NOT `run`)
|
||||
- Runs HermesAgentLoop with terminal tool only
|
||||
- Reads final SQLite DB to extract score
|
||||
- Returns survival (0/1) + normalised funds score
|
||||
b. Aggregates per-preset and overall metrics
|
||||
c. Logs results via evaluate_log() and wandb
|
||||
|
||||
Key features:
|
||||
- CLI-only interface: agent calls yc-bench subcommands via terminal tool
|
||||
- Deterministic: same seed + preset = same world (SHA256-based RNG)
|
||||
- Multi-dimensional scoring: survival + normalised final funds
|
||||
- Per-preset difficulty breakdown in results
|
||||
- Isolated SQLite DB per run (no cross-run state leakage)
|
||||
|
||||
Requires: pip install hermes-agent[yc-bench]
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# System prompt
|
||||
# =============================================================================
|
||||
|
||||
YC_BENCH_SYSTEM_PROMPT = """\
|
||||
You are the autonomous CEO of an early-stage AI startup in a deterministic
|
||||
business simulation. You manage the company exclusively through the `yc-bench`
|
||||
CLI tool. Your primary goal is to **survive** until the simulation horizon ends
|
||||
without going bankrupt, while **maximising final funds**.
|
||||
|
||||
## Simulation Mechanics
|
||||
|
||||
- **Funds**: You start with $250,000 seed capital. Revenue comes from completing
|
||||
tasks. Rewards scale with your prestige: `base × (1 + scale × (prestige − 1))`.
|
||||
- **Domains**: There are 4 skill domains: **research**, **inference**,
|
||||
**data_environment**, and **training**. Each has its own prestige level
|
||||
(1.0-10.0). Higher prestige unlocks better-paying tasks.
|
||||
- **Employees**: You have employees (Junior/Mid/Senior) with domain-specific
|
||||
skill rates. **Throughput splits**: `effective_rate = base_rate / N` where N
|
||||
is the number of active tasks assigned to that employee. Focus beats breadth.
|
||||
- **Payroll**: Deducted automatically on the first business day of each month.
|
||||
Running out of funds = bankruptcy = game over.
|
||||
- **Time**: The simulation runs on business days (Mon-Fri), 09:00-18:00.
|
||||
Time only advances when you call `yc-bench sim resume`.
|
||||
|
||||
## Task Lifecycle
|
||||
|
||||
1. Browse market tasks with `market browse`
|
||||
2. Accept a task with `task accept` (this sets its deadline)
|
||||
3. Assign employees with `task assign`
|
||||
4. Dispatch with `task dispatch` to start work
|
||||
5. Call `sim resume` to advance time and let employees make progress
|
||||
6. Tasks complete when all domain requirements are fulfilled
|
||||
|
||||
**Penalties for failure vary by difficulty preset.** Completing a task on time
|
||||
earns full reward + prestige gain. Missing a deadline or cancelling a task
|
||||
incurs prestige penalties -- cancelling is always more costly than letting a
|
||||
task fail, so cancel only as a last resort.
|
||||
|
||||
## CLI Commands
|
||||
|
||||
### Observe
|
||||
- `yc-bench company status` -- funds, prestige, runway
|
||||
- `yc-bench employee list` -- skills, salary, active tasks
|
||||
- `yc-bench market browse [--domain D] [--required-prestige-lte N]` -- available tasks
|
||||
- `yc-bench task list [--status active|planned]` -- your tasks
|
||||
- `yc-bench task inspect --task-id UUID` -- progress, deadline, assignments
|
||||
- `yc-bench finance ledger [--category monthly_payroll|task_reward]` -- transaction history
|
||||
- `yc-bench report monthly` -- monthly P&L
|
||||
|
||||
### Act
|
||||
- `yc-bench task accept --task-id UUID` -- accept from market
|
||||
- `yc-bench task assign --task-id UUID --employee-id UUID` -- assign employee
|
||||
- `yc-bench task dispatch --task-id UUID` -- start work (needs >=1 assignment)
|
||||
- `yc-bench task cancel --task-id UUID --reason "text"` -- cancel (prestige penalty)
|
||||
- `yc-bench sim resume` -- advance simulation clock
|
||||
|
||||
### Memory (persists across context truncation)
|
||||
- `yc-bench scratchpad read` -- read your persistent notes
|
||||
- `yc-bench scratchpad write --content "text"` -- overwrite notes
|
||||
- `yc-bench scratchpad append --content "text"` -- append to notes
|
||||
- `yc-bench scratchpad clear` -- clear notes
|
||||
|
||||
## Strategy Guidelines
|
||||
|
||||
1. **Specialise in 2-3 domains** to climb the prestige ladder faster and unlock
|
||||
high-reward tasks. Don't spread thin across all 4 domains early on.
|
||||
2. **Focus employees** -- assigning one employee to many tasks halves their
|
||||
throughput per additional task. Keep assignments concentrated.
|
||||
3. **Use the scratchpad** to track your strategy, upcoming deadlines, and
|
||||
employee assignments. This persists even if conversation context is truncated.
|
||||
4. **Monitor runway** -- always know how many months of payroll you can cover.
|
||||
Accept high-reward tasks before payroll dates.
|
||||
5. **Don't over-accept** -- taking too many tasks and missing deadlines cascades
|
||||
into prestige loss, locking you out of profitable contracts.
|
||||
6. Use `finance ledger` and `report monthly` to track revenue trends.
|
||||
|
||||
## Your Turn
|
||||
|
||||
Each turn:
|
||||
1. Call `yc-bench company status` and `yc-bench task list` to orient yourself.
|
||||
2. Check for completed tasks and pending deadlines.
|
||||
3. Browse market for profitable tasks within your prestige level.
|
||||
4. Accept, assign, and dispatch tasks strategically.
|
||||
5. Call `yc-bench sim resume` to advance time.
|
||||
6. Repeat until the simulation ends.
|
||||
|
||||
Think step by step before acting."""
|
||||
|
||||
# Starting funds in cents ($250,000)
|
||||
INITIAL_FUNDS_CENTS = 25_000_000
|
||||
|
||||
# Default horizon per preset (years)
|
||||
_PRESET_HORIZONS = {
|
||||
"tutorial": 1,
|
||||
"easy": 1,
|
||||
"medium": 1,
|
||||
"hard": 1,
|
||||
"nightmare": 1,
|
||||
"fast_test": 1,
|
||||
"default": 3,
|
||||
"high_reward": 1,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the YC-Bench evaluation environment.
|
||||
|
||||
Extends HermesAgentEnvConfig with YC-Bench-specific settings for
|
||||
preset selection, seed control, scoring, and simulation parameters.
|
||||
"""
|
||||
|
||||
presets: List[str] = Field(
|
||||
default=["fast_test", "medium", "hard"],
|
||||
description="YC-Bench preset names to evaluate.",
|
||||
)
|
||||
seeds: List[int] = Field(
|
||||
default=[1, 2, 3],
|
||||
description="Random seeds -- each preset x seed = one run.",
|
||||
)
|
||||
run_timeout: int = Field(
|
||||
default=3600,
|
||||
description="Maximum wall-clock seconds per run. Default 60 minutes.",
|
||||
)
|
||||
survival_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of survival (0/1) in composite score.",
|
||||
)
|
||||
funds_weight: float = Field(
|
||||
default=0.5,
|
||||
description="Weight of normalised final funds in composite score.",
|
||||
)
|
||||
db_dir: str = Field(
|
||||
default="/tmp/yc_bench_dbs",
|
||||
description="Directory for per-run SQLite databases.",
|
||||
)
|
||||
horizon_years: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Simulation horizon in years. If None (default), inferred from "
|
||||
"preset name (1 year for most, 3 for 'default')."
|
||||
),
|
||||
)
|
||||
company_name: str = Field(
|
||||
default="BenchCo",
|
||||
description="Name of the simulated company.",
|
||||
)
|
||||
start_date: str = Field(
|
||||
default="01/01/2025",
|
||||
description="Simulation start date in MM/DD/YYYY format (yc-bench convention).",
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Scoring helpers
|
||||
# =============================================================================
|
||||
|
||||
def _read_final_score(db_path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Read final game state from a YC-Bench SQLite database.
|
||||
|
||||
Returns dict with final_funds_cents (int), survived (bool),
|
||||
terminal_reason (str).
|
||||
|
||||
Note: yc-bench table names are plural -- 'companies' not 'company',
|
||||
'sim_events' not 'simulation_log'.
|
||||
"""
|
||||
if not os.path.exists(db_path):
|
||||
logger.warning("DB not found at %s", db_path)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": "db_missing",
|
||||
}
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(db_path)
|
||||
cur = conn.cursor()
|
||||
|
||||
# Read final funds from the 'companies' table
|
||||
cur.execute("SELECT funds_cents FROM companies LIMIT 1")
|
||||
row = cur.fetchone()
|
||||
funds = row[0] if row else 0
|
||||
|
||||
# Determine terminal reason from 'sim_events' table
|
||||
terminal_reason = "unknown"
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT event_type FROM sim_events "
|
||||
"WHERE event_type IN ('bankruptcy', 'horizon_end') "
|
||||
"ORDER BY scheduled_at DESC LIMIT 1"
|
||||
)
|
||||
event_row = cur.fetchone()
|
||||
if event_row:
|
||||
terminal_reason = event_row[0]
|
||||
except sqlite3.OperationalError:
|
||||
# Table may not exist if simulation didn't progress
|
||||
pass
|
||||
|
||||
survived = funds >= 0 and terminal_reason != "bankruptcy"
|
||||
return {
|
||||
"final_funds_cents": funds,
|
||||
"survived": survived,
|
||||
"terminal_reason": terminal_reason,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to read DB %s: %s", db_path, e)
|
||||
return {
|
||||
"final_funds_cents": 0,
|
||||
"survived": False,
|
||||
"terminal_reason": f"db_error: {e}",
|
||||
}
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _compute_composite_score(
|
||||
final_funds_cents: int,
|
||||
survived: bool,
|
||||
survival_weight: float = 0.5,
|
||||
funds_weight: float = 0.5,
|
||||
initial_funds_cents: int = INITIAL_FUNDS_CENTS,
|
||||
) -> float:
|
||||
"""
|
||||
Compute composite score from survival and final funds.
|
||||
|
||||
Score = survival_weight * survival_score
|
||||
+ funds_weight * normalised_funds_score
|
||||
|
||||
Normalised funds uses log-scale relative to initial capital:
|
||||
- funds <= 0: 0.0
|
||||
- funds == initial: ~0.15
|
||||
- funds == 10x: ~0.52
|
||||
- funds == 100x: 1.0
|
||||
"""
|
||||
survival_score = 1.0 if survived else 0.0
|
||||
|
||||
if final_funds_cents <= 0:
|
||||
funds_score = 0.0
|
||||
else:
|
||||
max_ratio = 100.0
|
||||
ratio = final_funds_cents / max(initial_funds_cents, 1)
|
||||
funds_score = min(math.log1p(ratio) / math.log1p(max_ratio), 1.0)
|
||||
|
||||
return survival_weight * survival_score + funds_weight * funds_score
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
class YCBenchEvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
YC-Bench long-horizon agent benchmark environment (eval-only).
|
||||
|
||||
Each eval item is a (preset, seed) pair. The environment initialises the
|
||||
simulation via ``yc-bench sim init`` (NOT ``yc-bench run`` which would start
|
||||
a competing built-in agent loop). The HermesAgentLoop then drives the
|
||||
interaction by calling individual yc-bench CLI commands via the terminal tool.
|
||||
|
||||
After the agent loop ends, the SQLite DB is read to extract the final score.
|
||||
|
||||
Scoring:
|
||||
composite = 0.5 * survival + 0.5 * normalised_funds
|
||||
"""
|
||||
|
||||
name = "yc-bench"
|
||||
env_config_cls = YCBenchEvalConfig
|
||||
|
||||
@classmethod
|
||||
def config_init(cls) -> Tuple[YCBenchEvalConfig, List[APIServerConfig]]:
|
||||
env_config = YCBenchEvalConfig(
|
||||
enabled_toolsets=["terminal"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
max_agent_turns=200,
|
||||
max_token_length=32000,
|
||||
agent_temperature=0.0,
|
||||
system_prompt=YC_BENCH_SYSTEM_PROMPT,
|
||||
terminal_backend="local",
|
||||
terminal_timeout=60,
|
||||
presets=["fast_test", "medium", "hard"],
|
||||
seeds=[1, 2, 3],
|
||||
run_timeout=3600,
|
||||
survival_weight=0.5,
|
||||
funds_weight=0.5,
|
||||
db_dir="/tmp/yc_bench_dbs",
|
||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="yc-bench",
|
||||
ensure_scores_are_not_same=False,
|
||||
)
|
||||
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model_name="anthropic/claude-sonnet-4.6",
|
||||
server_type="openai",
|
||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
||||
health_check=False,
|
||||
)
|
||||
]
|
||||
|
||||
return env_config, server_configs
|
||||
|
||||
# =========================================================================
|
||||
# Setup
|
||||
# =========================================================================
|
||||
|
||||
async def setup(self):
|
||||
"""Verify yc-bench is installed and build the eval matrix."""
|
||||
# Verify yc-bench CLI is available
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["yc-bench", "--help"], capture_output=True, text=True, timeout=10
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise FileNotFoundError
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
raise RuntimeError(
|
||||
"yc-bench CLI not found. Install with:\n"
|
||||
' pip install "hermes-agent[yc-bench]"\n'
|
||||
"Or: git clone https://github.com/collinear-ai/yc-bench "
|
||||
"&& cd yc-bench && pip install -e ."
|
||||
)
|
||||
print("yc-bench CLI verified.")
|
||||
|
||||
# Build eval matrix: preset x seed
|
||||
self.all_eval_items = [
|
||||
{"preset": preset, "seed": seed}
|
||||
for preset in self.config.presets
|
||||
for seed in self.config.seeds
|
||||
]
|
||||
self.iter = 0
|
||||
|
||||
os.makedirs(self.config.db_dir, exist_ok=True)
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
# Streaming JSONL log for crash-safe result persistence
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
self._streaming_file = open(self._streaming_path, "w")
|
||||
self._streaming_lock = threading.Lock()
|
||||
|
||||
print(f"\nYC-Bench eval matrix: {len(self.all_eval_items)} runs")
|
||||
for item in self.all_eval_items:
|
||||
print(f" preset={item['preset']!r} seed={item['seed']}")
|
||||
print(f"Streaming results to: {self._streaming_path}\n")
|
||||
|
||||
def _save_result(self, result: Dict[str, Any]):
|
||||
"""Write a single run result to the streaming JSONL file immediately."""
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(
|
||||
json.dumps(result, ensure_ascii=False, default=str) + "\n"
|
||||
)
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
# Training pipeline stubs (eval-only -- not used)
|
||||
# =========================================================================
|
||||
|
||||
async def get_next_item(self):
|
||||
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
|
||||
self.iter += 1
|
||||
return item
|
||||
|
||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
return (
|
||||
f"A new YC-Bench simulation has been initialized "
|
||||
f"(preset='{preset}', seed={seed}).\n"
|
||||
f"Your company '{self.config.company_name}' is ready.\n\n"
|
||||
"Begin by calling:\n"
|
||||
"1. `yc-bench company status` -- see your starting funds and prestige\n"
|
||||
"2. `yc-bench employee list` -- see your team and their skills\n"
|
||||
"3. `yc-bench market browse --required-prestige-lte 1` -- find tasks "
|
||||
"you can take\n\n"
|
||||
"Then accept 2-3 tasks, assign employees, dispatch them, and call "
|
||||
"`yc-bench sim resume` to advance time. Repeat this loop until the "
|
||||
"simulation ends (horizon reached or bankruptcy)."
|
||||
)
|
||||
|
||||
async def compute_reward(self, item, result, ctx) -> float:
|
||||
return 0.0
|
||||
|
||||
async def collect_trajectories(self, item):
|
||||
return None, []
|
||||
|
||||
async def score(self, rollout_group_data):
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Per-run evaluation
|
||||
# =========================================================================
|
||||
|
||||
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
|
||||
"""
|
||||
Evaluate a single (preset, seed) run.
|
||||
|
||||
1. Sets DATABASE_URL and YC_BENCH_EXPERIMENT env vars
|
||||
2. Initialises the simulation via ``yc-bench sim init`` (NOT ``run``)
|
||||
3. Runs HermesAgentLoop with terminal tool
|
||||
4. Reads SQLite DB to compute final score
|
||||
5. Returns result dict with survival, funds, and composite score
|
||||
"""
|
||||
preset = eval_item["preset"]
|
||||
seed = eval_item["seed"]
|
||||
run_id = str(uuid.uuid4())[:8]
|
||||
run_key = f"{preset}_seed{seed}_{run_id}"
|
||||
|
||||
from tqdm import tqdm
|
||||
tqdm.write(f" [START] preset={preset!r} seed={seed} (run_id={run_id})")
|
||||
run_start = time.time()
|
||||
|
||||
# Isolated DB per run -- prevents cross-run state leakage
|
||||
db_path = os.path.join(self.config.db_dir, f"yc_bench_{run_key}.db")
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
||||
os.environ["YC_BENCH_EXPERIMENT"] = preset
|
||||
|
||||
# Determine horizon: explicit config override > preset lookup > default 1
|
||||
horizon = self.config.horizon_years or _PRESET_HORIZONS.get(preset, 1)
|
||||
|
||||
try:
|
||||
# ----------------------------------------------------------
|
||||
# Step 1: Initialise the simulation via CLI
|
||||
# IMPORTANT: We use `sim init`, NOT `yc-bench run`.
|
||||
# `yc-bench run` starts yc-bench's own LLM agent loop (via
|
||||
# LiteLLM), which would compete with our HermesAgentLoop.
|
||||
# `sim init` just sets up the world and returns.
|
||||
# ----------------------------------------------------------
|
||||
init_cmd = [
|
||||
"yc-bench", "sim", "init",
|
||||
"--seed", str(seed),
|
||||
"--start-date", self.config.start_date,
|
||||
"--company-name", self.config.company_name,
|
||||
"--horizon-years", str(horizon),
|
||||
]
|
||||
init_result = subprocess.run(
|
||||
init_cmd, capture_output=True, text=True, timeout=30,
|
||||
)
|
||||
if init_result.returncode != 0:
|
||||
error_msg = (init_result.stderr or init_result.stdout).strip()
|
||||
raise RuntimeError(f"yc-bench sim init failed: {error_msg}")
|
||||
|
||||
tqdm.write(f" Simulation initialized (horizon={horizon}yr)")
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 2: Run the HermesAgentLoop
|
||||
# ----------------------------------------------------------
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
|
||||
messages: List[Dict[str, Any]] = [
|
||||
{"role": "system", "content": YC_BENCH_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": self.format_prompt(eval_item)},
|
||||
]
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=run_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# Step 3: Read final score from the simulation DB
|
||||
# ----------------------------------------------------------
|
||||
score_data = _read_final_score(db_path)
|
||||
final_funds = score_data["final_funds_cents"]
|
||||
survived = score_data["survived"]
|
||||
terminal_reason = score_data["terminal_reason"]
|
||||
|
||||
composite = _compute_composite_score(
|
||||
final_funds_cents=final_funds,
|
||||
survived=survived,
|
||||
survival_weight=self.config.survival_weight,
|
||||
funds_weight=self.config.funds_weight,
|
||||
)
|
||||
|
||||
elapsed = time.time() - run_start
|
||||
status = "SURVIVED" if survived else "BANKRUPT"
|
||||
if final_funds >= 0:
|
||||
funds_str = f"${final_funds / 100:,.0f}"
|
||||
else:
|
||||
funds_str = f"-${abs(final_funds) / 100:,.0f}"
|
||||
|
||||
tqdm.write(
|
||||
f" [{status}] preset={preset!r} seed={seed} "
|
||||
f"funds={funds_str} score={composite:.3f} "
|
||||
f"turns={result.turns_used} ({elapsed:.0f}s)"
|
||||
)
|
||||
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": survived,
|
||||
"final_funds_cents": final_funds,
|
||||
"final_funds_usd": final_funds / 100,
|
||||
"terminal_reason": terminal_reason,
|
||||
"composite_score": composite,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"elapsed_seconds": elapsed,
|
||||
"db_path": db_path,
|
||||
"messages": result.messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
except Exception as e:
|
||||
elapsed = time.time() - run_start
|
||||
logger.error("Run %s failed: %s", run_key, e, exc_info=True)
|
||||
tqdm.write(
|
||||
f" [ERROR] preset={preset!r} seed={seed}: {e} ({elapsed:.0f}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"error: {e}",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": str(e),
|
||||
"elapsed_seconds": elapsed,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
# =========================================================================
|
||||
# Evaluate
|
||||
# =========================================================================
|
||||
|
||||
async def _run_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
||||
"""Wrap a single rollout with a wall-clock timeout."""
|
||||
preset = item["preset"]
|
||||
seed = item["seed"]
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.rollout_and_score_eval(item),
|
||||
timeout=self.config.run_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
tqdm.write(
|
||||
f" [TIMEOUT] preset={preset!r} seed={seed} "
|
||||
f"(exceeded {self.config.run_timeout}s)"
|
||||
)
|
||||
out = {
|
||||
"preset": preset,
|
||||
"seed": seed,
|
||||
"survived": False,
|
||||
"final_funds_cents": 0,
|
||||
"final_funds_usd": 0.0,
|
||||
"terminal_reason": f"timeout ({self.config.run_timeout}s)",
|
||||
"composite_score": 0.0,
|
||||
"turns_used": 0,
|
||||
"error": "timeout",
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Run YC-Bench evaluation over all (preset, seed) combinations.
|
||||
|
||||
Runs sequentially -- each run is 100-500 turns, parallelising would
|
||||
be prohibitively expensive and cause env var conflicts.
|
||||
"""
|
||||
start_time = time.time()
|
||||
from tqdm import tqdm
|
||||
|
||||
# --- tqdm-compatible logging handler (TB2 pattern) ---
|
||||
class _TqdmHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
try:
|
||||
tqdm.write(self.format(record))
|
||||
except Exception:
|
||||
self.handleError(record)
|
||||
|
||||
root = logging.getLogger()
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(
|
||||
logging.Formatter("%(levelname)s %(name)s: %(message)s")
|
||||
)
|
||||
root.handlers = [handler]
|
||||
for noisy in ("httpx", "openai"):
|
||||
logging.getLogger(noisy).setLevel(logging.WARNING)
|
||||
|
||||
# --- Print config summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("Starting YC-Bench Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f" Presets: {self.config.presets}")
|
||||
print(f" Seeds: {self.config.seeds}")
|
||||
print(f" Total runs: {len(self.all_eval_items)}")
|
||||
print(f" Max turns/run: {self.config.max_agent_turns}")
|
||||
print(f" Run timeout: {self.config.run_timeout}s")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
results = []
|
||||
pbar = tqdm(
|
||||
total=len(self.all_eval_items), desc="YC-Bench", dynamic_ncols=True
|
||||
)
|
||||
|
||||
try:
|
||||
for item in self.all_eval_items:
|
||||
result = await self._run_with_timeout(item)
|
||||
results.append(result)
|
||||
survived_count = sum(1 for r in results if r.get("survived"))
|
||||
pbar.set_postfix_str(
|
||||
f"survived={survived_count}/{len(results)}"
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
||||
tqdm.write("\n[INTERRUPTED] Stopping evaluation...")
|
||||
pbar.close()
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
return
|
||||
|
||||
pbar.close()
|
||||
end_time = time.time()
|
||||
|
||||
# --- Compute metrics ---
|
||||
valid = [r for r in results if r is not None]
|
||||
if not valid:
|
||||
print("Warning: No valid results.")
|
||||
return
|
||||
|
||||
total = len(valid)
|
||||
survived_total = sum(1 for r in valid if r.get("survived"))
|
||||
survival_rate = survived_total / total if total else 0.0
|
||||
avg_score = (
|
||||
sum(r.get("composite_score", 0) for r in valid) / total
|
||||
if total
|
||||
else 0.0
|
||||
)
|
||||
|
||||
preset_results: Dict[str, List[Dict]] = defaultdict(list)
|
||||
for r in valid:
|
||||
preset_results[r["preset"]].append(r)
|
||||
|
||||
eval_metrics = {
|
||||
"eval/survival_rate": survival_rate,
|
||||
"eval/avg_composite_score": avg_score,
|
||||
"eval/total_runs": total,
|
||||
"eval/survived_runs": survived_total,
|
||||
"eval/evaluation_time_seconds": end_time - start_time,
|
||||
}
|
||||
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
key = preset.replace("-", "_")
|
||||
eval_metrics[f"eval/survival_rate_{key}"] = ps / pt if pt else 0
|
||||
eval_metrics[f"eval/avg_score_{key}"] = pa
|
||||
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# --- Print summary ---
|
||||
print(f"\n{'='*60}")
|
||||
print("YC-Bench Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(
|
||||
f"Overall survival rate: {survival_rate:.1%} "
|
||||
f"({survived_total}/{total})"
|
||||
)
|
||||
print(f"Average composite score: {avg_score:.4f}")
|
||||
print(f"Evaluation time: {end_time - start_time:.1f}s")
|
||||
|
||||
print("\nPer-preset breakdown:")
|
||||
for preset, items in sorted(preset_results.items()):
|
||||
ps = sum(1 for r in items if r.get("survived"))
|
||||
pt = len(items)
|
||||
pa = (
|
||||
sum(r.get("composite_score", 0) for r in items) / pt
|
||||
if pt
|
||||
else 0
|
||||
)
|
||||
print(f" {preset}: {ps}/{pt} survived avg_score={pa:.4f}")
|
||||
for r in items:
|
||||
status = "SURVIVED" if r.get("survived") else "BANKRUPT"
|
||||
funds = r.get("final_funds_usd", 0)
|
||||
print(
|
||||
f" seed={r['seed']} [{status}] "
|
||||
f"${funds:,.0f} "
|
||||
f"score={r.get('composite_score', 0):.3f}"
|
||||
)
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# --- Log results ---
|
||||
samples = [
|
||||
{k: v for k, v in r.items() if k != "messages"} for r in valid
|
||||
]
|
||||
|
||||
try:
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error logging results: {e}")
|
||||
|
||||
# --- Cleanup (TB2 pattern) ---
|
||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
||||
self._streaming_file.close()
|
||||
print(f"Results saved to: {self._streaming_path}")
|
||||
|
||||
try:
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
cleanup_all_environments()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from environments.agent_loop import _tool_executor
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# =========================================================================
|
||||
# Wandb logging
|
||||
# =========================================================================
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log YC-Bench-specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
for k, v in self.eval_metrics:
|
||||
wandb_metrics[k] = v
|
||||
self.eval_metrics = []
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
YCBenchEvalEnv.cli()
|
||||
@@ -40,8 +40,8 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
logger.warning("Channel directory: failed to build %s: %s", platform.value, e)
|
||||
|
||||
# Telegram & WhatsApp can't enumerate chats -- pull from session history
|
||||
for plat_name in ("telegram", "whatsapp"):
|
||||
# Telegram, WhatsApp & Signal can't enumerate chats -- pull from session history
|
||||
for plat_name in ("telegram", "whatsapp", "signal"):
|
||||
if plat_name not in platforms:
|
||||
platforms[plat_name] = _build_from_sessions(plat_name)
|
||||
|
||||
@@ -52,7 +52,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
|
||||
try:
|
||||
DIRECTORY_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(DIRECTORY_PATH, "w") as f:
|
||||
with open(DIRECTORY_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(directory, f, indent=2, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.warning("Channel directory: failed to write: %s", e)
|
||||
@@ -115,7 +115,7 @@ def _build_from_sessions(platform_name: str) -> List[Dict[str, str]]:
|
||||
|
||||
entries = []
|
||||
try:
|
||||
with open(sessions_path) as f:
|
||||
with open(sessions_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
seen_ids = set()
|
||||
@@ -147,7 +147,7 @@ def load_directory() -> Dict[str, Any]:
|
||||
if not DIRECTORY_PATH.exists():
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
try:
|
||||
with open(DIRECTORY_PATH) as f:
|
||||
with open(DIRECTORY_PATH, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {"updated_at": None, "platforms": {}}
|
||||
|
||||
@@ -26,6 +26,7 @@ class Platform(Enum):
|
||||
DISCORD = "discord"
|
||||
WHATSAPP = "whatsapp"
|
||||
SLACK = "slack"
|
||||
SIGNAL = "signal"
|
||||
HOMEASSISTANT = "homeassistant"
|
||||
|
||||
|
||||
@@ -155,7 +156,16 @@ class GatewayConfig:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
for platform, config in self.platforms.items():
|
||||
if config.enabled and (config.token or config.api_key):
|
||||
if not config.enabled:
|
||||
continue
|
||||
# Platforms that use token/api_key auth
|
||||
if config.token or config.api_key:
|
||||
connected.append(platform)
|
||||
# WhatsApp uses enabled flag only (bridge handles auth)
|
||||
elif platform == Platform.WHATSAPP:
|
||||
connected.append(platform)
|
||||
# Signal uses extra dict for config (http_url + account)
|
||||
elif platform == Platform.SIGNAL and config.extra.get("http_url"):
|
||||
connected.append(platform)
|
||||
return connected
|
||||
|
||||
@@ -379,6 +389,26 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
name=os.getenv("SLACK_HOME_CHANNEL_NAME", ""),
|
||||
)
|
||||
|
||||
# Signal
|
||||
signal_url = os.getenv("SIGNAL_HTTP_URL")
|
||||
signal_account = os.getenv("SIGNAL_ACCOUNT")
|
||||
if signal_url and signal_account:
|
||||
if Platform.SIGNAL not in config.platforms:
|
||||
config.platforms[Platform.SIGNAL] = PlatformConfig()
|
||||
config.platforms[Platform.SIGNAL].enabled = True
|
||||
config.platforms[Platform.SIGNAL].extra.update({
|
||||
"http_url": signal_url,
|
||||
"account": signal_account,
|
||||
"ignore_stories": os.getenv("SIGNAL_IGNORE_STORIES", "true").lower() in ("true", "1", "yes"),
|
||||
})
|
||||
signal_home = os.getenv("SIGNAL_HOME_CHANNEL")
|
||||
if signal_home:
|
||||
config.platforms[Platform.SIGNAL].home_channel = HomeChannel(
|
||||
platform=Platform.SIGNAL,
|
||||
chat_id=signal_home,
|
||||
name=os.getenv("SIGNAL_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Home Assistant
|
||||
hass_token = os.getenv("HASS_TOKEN")
|
||||
if hass_token:
|
||||
|
||||
@@ -73,7 +73,7 @@ def _find_session_id(platform: str, chat_id: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(_SESSIONS_INDEX) as f:
|
||||
with open(_SESSIONS_INDEX, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -103,7 +103,7 @@ def _append_to_jsonl(session_id: str, message: dict) -> None:
|
||||
"""Append a message to the JSONL transcript file."""
|
||||
transcript_path = _SESSIONS_DIR / f"{session_id}.jsonl"
|
||||
try:
|
||||
with open(transcript_path, "a") as f:
|
||||
with open(transcript_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
except Exception as e:
|
||||
logger.debug("Mirror JSONL write failed: %s", e)
|
||||
|
||||
313
gateway/platforms/ADDING_A_PLATFORM.md
Normal file
313
gateway/platforms/ADDING_A_PLATFORM.md
Normal file
@@ -0,0 +1,313 @@
|
||||
# Adding a New Messaging Platform
|
||||
|
||||
Checklist for integrating a new messaging platform into the Hermes gateway.
|
||||
Use this as a reference when building a new adapter — every item here is a
|
||||
real integration point that exists in the codebase. Missing any of them will
|
||||
cause broken functionality, missing features, or inconsistent behavior.
|
||||
|
||||
---
|
||||
|
||||
## 1. Core Adapter (`gateway/platforms/<platform>.py`)
|
||||
|
||||
The adapter is a subclass of `BasePlatformAdapter` from `gateway/platforms/base.py`.
|
||||
|
||||
### Required methods
|
||||
|
||||
| Method | Purpose |
|
||||
|--------|---------|
|
||||
| `__init__(self, config)` | Parse config, init state. Call `super().__init__(config, Platform.YOUR_PLATFORM)` |
|
||||
| `connect() -> bool` | Connect to the platform, start listeners. Return True on success |
|
||||
| `disconnect()` | Stop listeners, close connections, cancel tasks |
|
||||
| `send(chat_id, text, ...) -> SendResult` | Send a text message |
|
||||
| `send_typing(chat_id)` | Send typing indicator |
|
||||
| `send_image(chat_id, image_url, caption) -> SendResult` | Send an image |
|
||||
| `get_chat_info(chat_id) -> dict` | Return `{name, type, chat_id}` for a chat |
|
||||
|
||||
### Optional methods (have default stubs in base)
|
||||
|
||||
| Method | Purpose |
|
||||
|--------|---------|
|
||||
| `send_document(chat_id, path, caption)` | Send a file attachment |
|
||||
| `send_voice(chat_id, path)` | Send a voice message |
|
||||
| `send_video(chat_id, path, caption)` | Send a video |
|
||||
| `send_animation(chat_id, path, caption)` | Send a GIF/animation |
|
||||
| `send_image_file(chat_id, path, caption)` | Send image from local file |
|
||||
|
||||
### Required function
|
||||
|
||||
```python
|
||||
def check_<platform>_requirements() -> bool:
|
||||
"""Check if this platform's dependencies are available."""
|
||||
```
|
||||
|
||||
### Key patterns to follow
|
||||
|
||||
- Use `self.build_source(...)` to construct `SessionSource` objects
|
||||
- Call `self.handle_message(event)` to dispatch inbound messages to the gateway
|
||||
- Use `MessageEvent`, `MessageType`, `SendResult` from base
|
||||
- Use `cache_image_from_bytes`, `cache_audio_from_bytes`, `cache_document_from_bytes` for attachments
|
||||
- Filter self-messages (prevent reply loops)
|
||||
- Filter sync/echo messages if the platform has them
|
||||
- Redact sensitive identifiers (phone numbers, tokens) in all log output
|
||||
- Implement reconnection with exponential backoff + jitter for streaming connections
|
||||
- Set `MAX_MESSAGE_LENGTH` if the platform has message size limits
|
||||
|
||||
---
|
||||
|
||||
## 2. Platform Enum (`gateway/config.py`)
|
||||
|
||||
Add the platform to the `Platform` enum:
|
||||
|
||||
```python
|
||||
class Platform(Enum):
|
||||
...
|
||||
YOUR_PLATFORM = "your_platform"
|
||||
```
|
||||
|
||||
Add env var loading in `_apply_env_overrides()`:
|
||||
|
||||
```python
|
||||
# Your Platform
|
||||
your_token = os.getenv("YOUR_PLATFORM_TOKEN")
|
||||
if your_token:
|
||||
if Platform.YOUR_PLATFORM not in config.platforms:
|
||||
config.platforms[Platform.YOUR_PLATFORM] = PlatformConfig()
|
||||
config.platforms[Platform.YOUR_PLATFORM].enabled = True
|
||||
config.platforms[Platform.YOUR_PLATFORM].token = your_token
|
||||
```
|
||||
|
||||
Update `get_connected_platforms()` if your platform doesn't use token/api_key
|
||||
(e.g., WhatsApp uses `enabled` flag, Signal uses `extra` dict).
|
||||
|
||||
---
|
||||
|
||||
## 3. Adapter Factory (`gateway/run.py`)
|
||||
|
||||
Add to `_create_adapter()`:
|
||||
|
||||
```python
|
||||
elif platform == Platform.YOUR_PLATFORM:
|
||||
from gateway.platforms.your_platform import YourAdapter, check_your_requirements
|
||||
if not check_your_requirements():
|
||||
logger.warning("Your Platform: dependencies not met")
|
||||
return None
|
||||
return YourAdapter(config)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Authorization Maps (`gateway/run.py`)
|
||||
|
||||
Add to BOTH dicts in `_is_user_authorized()`:
|
||||
|
||||
```python
|
||||
platform_env_map = {
|
||||
...
|
||||
Platform.YOUR_PLATFORM: "YOUR_PLATFORM_ALLOWED_USERS",
|
||||
}
|
||||
platform_allow_all_map = {
|
||||
...
|
||||
Platform.YOUR_PLATFORM: "YOUR_PLATFORM_ALLOW_ALL_USERS",
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Session Source (`gateway/session.py`)
|
||||
|
||||
If your platform needs extra identity fields (e.g., Signal's UUID alongside
|
||||
phone number), add them to the `SessionSource` dataclass with `Optional` defaults,
|
||||
and update `to_dict()`, `from_dict()`, and `build_source()` in base.py.
|
||||
|
||||
---
|
||||
|
||||
## 6. System Prompt Hints (`agent/prompt_builder.py`)
|
||||
|
||||
Add a `PLATFORM_HINTS` entry so the agent knows what platform it's on:
|
||||
|
||||
```python
|
||||
PLATFORM_HINTS = {
|
||||
...
|
||||
"your_platform": (
|
||||
"You are on Your Platform. "
|
||||
"Describe formatting capabilities, media support, etc."
|
||||
),
|
||||
}
|
||||
```
|
||||
|
||||
Without this, the agent won't know it's on your platform and may use
|
||||
inappropriate formatting (e.g., markdown on platforms that don't render it).
|
||||
|
||||
---
|
||||
|
||||
## 7. Toolset (`toolsets.py`)
|
||||
|
||||
Add a named toolset for your platform:
|
||||
|
||||
```python
|
||||
"hermes-your-platform": {
|
||||
"description": "Your Platform bot toolset",
|
||||
"tools": _HERMES_CORE_TOOLS,
|
||||
"includes": []
|
||||
},
|
||||
```
|
||||
|
||||
And add it to the `hermes-gateway` composite:
|
||||
|
||||
```python
|
||||
"hermes-gateway": {
|
||||
"includes": [..., "hermes-your-platform"]
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 8. Cron Delivery (`cron/scheduler.py`)
|
||||
|
||||
Add to `platform_map` in `_deliver_result()`:
|
||||
|
||||
```python
|
||||
platform_map = {
|
||||
...
|
||||
"your_platform": Platform.YOUR_PLATFORM,
|
||||
}
|
||||
```
|
||||
|
||||
Without this, `schedule_cronjob(deliver="your_platform")` silently fails.
|
||||
|
||||
---
|
||||
|
||||
## 9. Send Message Tool (`tools/send_message_tool.py`)
|
||||
|
||||
Add to `platform_map` in `send_message_tool()`:
|
||||
|
||||
```python
|
||||
platform_map = {
|
||||
...
|
||||
"your_platform": Platform.YOUR_PLATFORM,
|
||||
}
|
||||
```
|
||||
|
||||
Add routing in `_send_to_platform()`:
|
||||
|
||||
```python
|
||||
elif platform == Platform.YOUR_PLATFORM:
|
||||
return await _send_your_platform(pconfig, chat_id, message)
|
||||
```
|
||||
|
||||
Implement `_send_your_platform()` — a standalone async function that sends
|
||||
a single message without requiring the full adapter (for use by cron jobs
|
||||
and the send_message tool outside the gateway process).
|
||||
|
||||
Update the tool schema `target` description to include your platform example.
|
||||
|
||||
---
|
||||
|
||||
## 10. Cronjob Tool Schema (`tools/cronjob_tools.py`)
|
||||
|
||||
Update the `deliver` parameter description and docstring to mention your
|
||||
platform as a delivery option.
|
||||
|
||||
---
|
||||
|
||||
## 11. Channel Directory (`gateway/channel_directory.py`)
|
||||
|
||||
If your platform can't enumerate chats (most can't), add it to the
|
||||
session-based discovery list:
|
||||
|
||||
```python
|
||||
for plat_name in ("telegram", "whatsapp", "signal", "your_platform"):
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 12. Status Display (`hermes_cli/status.py`)
|
||||
|
||||
Add to the `platforms` dict in the Messaging Platforms section:
|
||||
|
||||
```python
|
||||
platforms = {
|
||||
...
|
||||
"Your Platform": ("YOUR_PLATFORM_TOKEN", "YOUR_PLATFORM_HOME_CHANNEL"),
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 13. Gateway Setup Wizard (`hermes_cli/gateway.py`)
|
||||
|
||||
Add to the `_PLATFORMS` list:
|
||||
|
||||
```python
|
||||
{
|
||||
"key": "your_platform",
|
||||
"label": "Your Platform",
|
||||
"emoji": "📱",
|
||||
"token_var": "YOUR_PLATFORM_TOKEN",
|
||||
"setup_instructions": [...],
|
||||
"vars": [...],
|
||||
}
|
||||
```
|
||||
|
||||
If your platform needs custom setup logic (connectivity testing, QR codes,
|
||||
policy choices), add a `_setup_your_platform()` function and route to it
|
||||
in the platform selection switch.
|
||||
|
||||
Update `_platform_status()` if your platform's "configured" check differs
|
||||
from the standard `bool(get_env_value(token_var))`.
|
||||
|
||||
---
|
||||
|
||||
## 14. Phone/ID Redaction (`agent/redact.py`)
|
||||
|
||||
If your platform uses sensitive identifiers (phone numbers, etc.), add a
|
||||
regex pattern and redaction function to `agent/redact.py`. This ensures
|
||||
identifiers are masked in ALL log output, not just your adapter's logs.
|
||||
|
||||
---
|
||||
|
||||
## 15. Documentation
|
||||
|
||||
| File | What to update |
|
||||
|------|---------------|
|
||||
| `README.md` | Platform list in feature table + documentation table |
|
||||
| `AGENTS.md` | Gateway description + env var config section |
|
||||
| `website/docs/user-guide/messaging/<platform>.md` | **NEW** — Full setup guide (see existing platform docs for template) |
|
||||
| `website/docs/user-guide/messaging/index.md` | Architecture diagram, toolset table, security examples, Next Steps links |
|
||||
| `website/docs/reference/environment-variables.md` | All env vars for the platform |
|
||||
|
||||
---
|
||||
|
||||
## 16. Tests (`tests/gateway/test_<platform>.py`)
|
||||
|
||||
Recommended test coverage:
|
||||
|
||||
- Platform enum exists with correct value
|
||||
- Config loading from env vars via `_apply_env_overrides`
|
||||
- Adapter init (config parsing, allowlist handling, default values)
|
||||
- Helper functions (redaction, parsing, file type detection)
|
||||
- Session source round-trip (to_dict → from_dict)
|
||||
- Authorization integration (platform in allowlist maps)
|
||||
- Send message tool routing (platform in platform_map)
|
||||
|
||||
Optional but valuable:
|
||||
- Async tests for message handling flow (mock the platform API)
|
||||
- SSE/WebSocket reconnection logic
|
||||
- Attachment processing
|
||||
- Group message filtering
|
||||
|
||||
---
|
||||
|
||||
## Quick Verification
|
||||
|
||||
After implementing everything, verify with:
|
||||
|
||||
```bash
|
||||
# All tests pass
|
||||
python -m pytest tests/ -q
|
||||
|
||||
# Grep for your platform name to find any missed integration points
|
||||
grep -r "telegram\|discord\|whatsapp\|slack" gateway/ tools/ agent/ cron/ hermes_cli/ toolsets.py \
|
||||
--include="*.py" -l | sort -u
|
||||
# Check each file in the output — if it mentions other platforms but not yours, you missed it
|
||||
```
|
||||
@@ -701,6 +701,8 @@ class BasePlatformAdapter(ABC):
|
||||
|
||||
# Extract image URLs and send them as native platform attachments
|
||||
images, text_content = self.extract_images(response)
|
||||
if images:
|
||||
logger.info("[%s] extract_images found %d image(s) in response (%d chars)", self.name, len(images), len(response))
|
||||
|
||||
# Send the text portion first (if any remains after extractions)
|
||||
if text_content:
|
||||
@@ -727,10 +729,13 @@ class BasePlatformAdapter(ABC):
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
# Send extracted images as native attachments
|
||||
if images:
|
||||
logger.info("[%s] Extracted %d image(s) to send as attachments", self.name, len(images))
|
||||
for image_url, alt_text in images:
|
||||
if human_delay > 0:
|
||||
await asyncio.sleep(human_delay)
|
||||
try:
|
||||
logger.info("[%s] Sending image: %s (alt=%s)", self.name, image_url[:80], alt_text[:30] if alt_text else "")
|
||||
# Route animated GIFs through send_animation for proper playback
|
||||
if self._is_animation_url(image_url):
|
||||
img_result = await self.send_animation(
|
||||
@@ -745,9 +750,9 @@ class BasePlatformAdapter(ABC):
|
||||
caption=alt_text if alt_text else None,
|
||||
)
|
||||
if not img_result.success:
|
||||
print(f"[{self.name}] Failed to send image: {img_result.error}")
|
||||
logger.error("[%s] Failed to send image: %s", self.name, img_result.error)
|
||||
except Exception as img_err:
|
||||
print(f"[{self.name}] Error sending image: {img_err}")
|
||||
logger.error("[%s] Error sending image: %s", self.name, img_err, exc_info=True)
|
||||
|
||||
# Send extracted media files — route by file type
|
||||
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
||||
@@ -833,6 +838,8 @@ class BasePlatformAdapter(ABC):
|
||||
user_name: Optional[str] = None,
|
||||
thread_id: Optional[str] = None,
|
||||
chat_topic: Optional[str] = None,
|
||||
user_id_alt: Optional[str] = None,
|
||||
chat_id_alt: Optional[str] = None,
|
||||
) -> SessionSource:
|
||||
"""Helper to build a SessionSource for this platform."""
|
||||
# Normalize empty topic to None
|
||||
@@ -847,6 +854,8 @@ class BasePlatformAdapter(ABC):
|
||||
user_name=user_name,
|
||||
thread_id=str(thread_id) if thread_id else None,
|
||||
chat_topic=chat_topic.strip() if chat_topic else None,
|
||||
user_id_alt=user_id_alt,
|
||||
chat_id_alt=chat_id_alt,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -267,6 +267,43 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Failed to send audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Discord file attachment."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import io
|
||||
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
filename = os.path.basename(image_path)
|
||||
|
||||
with open(image_path, "rb") as f:
|
||||
file = discord.File(io.BytesIO(f.read()), filename=filename)
|
||||
msg = await channel.send(
|
||||
content=caption if caption else None,
|
||||
file=file,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.id))
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -555,6 +592,89 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="compress", description="Compress conversation context")
|
||||
async def slash_compress(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/compress")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="title", description="Set or show the session title")
|
||||
@discord.app_commands.describe(name="Session title. Leave empty to show current.")
|
||||
async def slash_title(interaction: discord.Interaction, name: str = ""):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/title {name}".strip())
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="resume", description="Resume a previously-named session")
|
||||
@discord.app_commands.describe(name="Session name to resume. Leave empty to list sessions.")
|
||||
async def slash_resume(interaction: discord.Interaction, name: str = ""):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/resume {name}".strip())
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="usage", description="Show token usage for this session")
|
||||
async def slash_usage(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/usage")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="provider", description="Show available providers")
|
||||
async def slash_provider(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/provider")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="help", description="Show available commands")
|
||||
async def slash_help(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/help")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="insights", description="Show usage insights and analytics")
|
||||
@discord.app_commands.describe(days="Number of days to analyze (default: 7)")
|
||||
async def slash_insights(interaction: discord.Interaction, days: int = 7):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, f"/insights {days}")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="reload-mcp", description="Reload MCP servers from config")
|
||||
async def slash_reload_mcp(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
event = self._build_slash_event(interaction, "/reload-mcp")
|
||||
await self.handle_message(event)
|
||||
try:
|
||||
await interaction.followup.send("Done~", ephemeral=True)
|
||||
except Exception as e:
|
||||
logger.debug("Discord followup failed: %s", e)
|
||||
|
||||
@tree.command(name="update", description="Update Hermes Agent to the latest version")
|
||||
async def slash_update(interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
716
gateway/platforms/signal.py
Normal file
716
gateway/platforms/signal.py
Normal file
@@ -0,0 +1,716 @@
|
||||
"""Signal messenger platform adapter.
|
||||
|
||||
Connects to a signal-cli daemon running in HTTP mode.
|
||||
Inbound messages arrive via SSE (Server-Sent Events) streaming.
|
||||
Outbound messages and actions use JSON-RPC 2.0 over HTTP.
|
||||
|
||||
Based on PR #268 by ibhagwan, rebuilt with bug fixes.
|
||||
|
||||
Requires:
|
||||
- signal-cli installed and running: signal-cli daemon --http 127.0.0.1:8080
|
||||
- SIGNAL_HTTP_URL and SIGNAL_ACCOUNT environment variables set
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from urllib.parse import unquote
|
||||
|
||||
import httpx
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
SendResult,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
cache_image_from_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
SIGNAL_MAX_ATTACHMENT_SIZE = 100 * 1024 * 1024 # 100 MB
|
||||
MAX_MESSAGE_LENGTH = 8000 # Signal message size limit
|
||||
TYPING_INTERVAL = 8.0 # seconds between typing indicator refreshes
|
||||
SSE_RETRY_DELAY_INITIAL = 2.0
|
||||
SSE_RETRY_DELAY_MAX = 60.0
|
||||
HEALTH_CHECK_INTERVAL = 30.0 # seconds between health checks
|
||||
HEALTH_CHECK_STALE_THRESHOLD = 120.0 # seconds without SSE activity before concern
|
||||
|
||||
# E.164 phone number pattern for redaction
|
||||
_PHONE_RE = re.compile(r"\+[1-9]\d{6,14}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _redact_phone(phone: str) -> str:
|
||||
"""Redact a phone number for logging: +15551234567 -> +155****4567."""
|
||||
if not phone:
|
||||
return "<none>"
|
||||
if len(phone) <= 8:
|
||||
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
|
||||
return phone[:4] + "****" + phone[-4:]
|
||||
|
||||
|
||||
def _parse_comma_list(value: str) -> List[str]:
|
||||
"""Split a comma-separated string into a list, stripping whitespace."""
|
||||
return [v.strip() for v in value.split(",") if v.strip()]
|
||||
|
||||
|
||||
def _guess_extension(data: bytes) -> str:
|
||||
"""Guess file extension from magic bytes."""
|
||||
if data[:4] == b"\x89PNG":
|
||||
return ".png"
|
||||
if data[:2] == b"\xff\xd8":
|
||||
return ".jpg"
|
||||
if data[:4] == b"GIF8":
|
||||
return ".gif"
|
||||
if len(data) >= 12 and data[:4] == b"RIFF" and data[8:12] == b"WEBP":
|
||||
return ".webp"
|
||||
if data[:4] == b"%PDF":
|
||||
return ".pdf"
|
||||
if len(data) >= 8 and data[4:8] == b"ftyp":
|
||||
return ".mp4"
|
||||
if data[:4] == b"OggS":
|
||||
return ".ogg"
|
||||
if len(data) >= 2 and data[0] == 0xFF and (data[1] & 0xE0) == 0xE0:
|
||||
return ".mp3"
|
||||
if data[:2] == b"PK":
|
||||
return ".zip"
|
||||
return ".bin"
|
||||
|
||||
|
||||
def _is_image_ext(ext: str) -> bool:
|
||||
return ext.lower() in (".jpg", ".jpeg", ".png", ".gif", ".webp")
|
||||
|
||||
|
||||
def _is_audio_ext(ext: str) -> bool:
|
||||
return ext.lower() in (".mp3", ".wav", ".ogg", ".m4a", ".aac")
|
||||
|
||||
|
||||
def _render_mentions(text: str, mentions: list) -> str:
|
||||
"""Replace Signal mention placeholders (\\uFFFC) with readable @identifiers.
|
||||
|
||||
Signal encodes @mentions as the Unicode object replacement character
|
||||
with out-of-band metadata containing the mentioned user's UUID/number.
|
||||
"""
|
||||
if not mentions or "\uFFFC" not in text:
|
||||
return text
|
||||
# Sort mentions by start position (reverse) to replace from end to start
|
||||
# so indices don't shift as we replace
|
||||
sorted_mentions = sorted(mentions, key=lambda m: m.get("start", 0), reverse=True)
|
||||
for mention in sorted_mentions:
|
||||
start = mention.get("start", 0)
|
||||
length = mention.get("length", 1)
|
||||
# Use the mention's number or UUID as the replacement
|
||||
identifier = mention.get("number") or mention.get("uuid") or "user"
|
||||
replacement = f"@{identifier}"
|
||||
text = text[:start] + replacement + text[start + length:]
|
||||
return text
|
||||
|
||||
|
||||
def check_signal_requirements() -> bool:
|
||||
"""Check if Signal is configured (has URL and account)."""
|
||||
return bool(os.getenv("SIGNAL_HTTP_URL") and os.getenv("SIGNAL_ACCOUNT"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Signal Adapter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class SignalAdapter(BasePlatformAdapter):
|
||||
"""Signal messenger adapter using signal-cli HTTP daemon."""
|
||||
|
||||
platform = Platform.SIGNAL
|
||||
|
||||
def __init__(self, config: PlatformConfig):
|
||||
super().__init__(config, Platform.SIGNAL)
|
||||
|
||||
extra = config.extra or {}
|
||||
self.http_url = extra.get("http_url", "http://127.0.0.1:8080").rstrip("/")
|
||||
self.account = extra.get("account", "")
|
||||
self.ignore_stories = extra.get("ignore_stories", True)
|
||||
|
||||
# Parse allowlists — group policy is derived from presence of group allowlist
|
||||
group_allowed_str = os.getenv("SIGNAL_GROUP_ALLOWED_USERS", "")
|
||||
self.group_allow_from = set(_parse_comma_list(group_allowed_str))
|
||||
|
||||
# HTTP client
|
||||
self.client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
# Background tasks
|
||||
self._sse_task: Optional[asyncio.Task] = None
|
||||
self._health_monitor_task: Optional[asyncio.Task] = None
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._running = False
|
||||
self._last_sse_activity = 0.0
|
||||
self._sse_response: Optional[httpx.Response] = None
|
||||
|
||||
# Normalize account for self-message filtering
|
||||
self._account_normalized = self.account.strip()
|
||||
|
||||
logger.info("Signal adapter initialized: url=%s account=%s groups=%s",
|
||||
self.http_url, _redact_phone(self.account),
|
||||
"enabled" if self.group_allow_from else "disabled")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to signal-cli daemon and start SSE listener."""
|
||||
if not self.http_url or not self.account:
|
||||
logger.error("Signal: SIGNAL_HTTP_URL and SIGNAL_ACCOUNT are required")
|
||||
return False
|
||||
|
||||
self.client = httpx.AsyncClient(timeout=30.0)
|
||||
|
||||
# Health check — verify signal-cli daemon is reachable
|
||||
try:
|
||||
resp = await self.client.get(f"{self.http_url}/api/v1/check", timeout=10.0)
|
||||
if resp.status_code != 200:
|
||||
logger.error("Signal: health check failed (status %d)", resp.status_code)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error("Signal: cannot reach signal-cli at %s: %s", self.http_url, e)
|
||||
return False
|
||||
|
||||
self._running = True
|
||||
self._last_sse_activity = time.time()
|
||||
self._sse_task = asyncio.create_task(self._sse_listener())
|
||||
self._health_monitor_task = asyncio.create_task(self._health_monitor())
|
||||
|
||||
logger.info("Signal: connected to %s", self.http_url)
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop SSE listener and clean up."""
|
||||
self._running = False
|
||||
|
||||
if self._sse_task:
|
||||
self._sse_task.cancel()
|
||||
try:
|
||||
await self._sse_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._health_monitor_task:
|
||||
self._health_monitor_task.cancel()
|
||||
try:
|
||||
await self._health_monitor_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Cancel all typing tasks
|
||||
for task in self._typing_tasks.values():
|
||||
task.cancel()
|
||||
self._typing_tasks.clear()
|
||||
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
self.client = None
|
||||
|
||||
logger.info("Signal: disconnected")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SSE Streaming (inbound messages)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _sse_listener(self) -> None:
|
||||
"""Listen for SSE events from signal-cli daemon."""
|
||||
url = f"{self.http_url}/api/v1/events?account={self.account}"
|
||||
backoff = SSE_RETRY_DELAY_INITIAL
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
logger.debug("Signal SSE: connecting to %s", url)
|
||||
async with self.client.stream(
|
||||
"GET", url,
|
||||
headers={"Accept": "text/event-stream"},
|
||||
timeout=None,
|
||||
) as response:
|
||||
self._sse_response = response
|
||||
backoff = SSE_RETRY_DELAY_INITIAL # Reset on successful connection
|
||||
self._last_sse_activity = time.time()
|
||||
logger.info("Signal SSE: connected")
|
||||
|
||||
buffer = ""
|
||||
async for chunk in response.aiter_text():
|
||||
if not self._running:
|
||||
break
|
||||
buffer += chunk
|
||||
while "\n" in buffer:
|
||||
line, buffer = buffer.split("\n", 1)
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Parse SSE data lines
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
if not data_str:
|
||||
continue
|
||||
self._last_sse_activity = time.time()
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
await self._handle_envelope(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Signal SSE: invalid JSON: %s", data_str[:100])
|
||||
except Exception:
|
||||
logger.exception("Signal SSE: error handling event")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except httpx.HTTPError as e:
|
||||
if self._running:
|
||||
logger.warning("Signal SSE: HTTP error: %s (reconnecting in %.0fs)", e, backoff)
|
||||
except Exception as e:
|
||||
if self._running:
|
||||
logger.warning("Signal SSE: error: %s (reconnecting in %.0fs)", e, backoff)
|
||||
|
||||
if self._running:
|
||||
# Add 20% jitter to prevent thundering herd on reconnection
|
||||
jitter = backoff * 0.2 * random.random()
|
||||
await asyncio.sleep(backoff + jitter)
|
||||
backoff = min(backoff * 2, SSE_RETRY_DELAY_MAX)
|
||||
|
||||
self._sse_response = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Health Monitor
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _health_monitor(self) -> None:
|
||||
"""Monitor SSE connection health and force reconnect if stale."""
|
||||
while self._running:
|
||||
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
elapsed = time.time() - self._last_sse_activity
|
||||
if elapsed > HEALTH_CHECK_STALE_THRESHOLD:
|
||||
logger.warning("Signal: SSE idle for %.0fs, checking daemon health", elapsed)
|
||||
try:
|
||||
resp = await self.client.get(
|
||||
f"{self.http_url}/api/v1/check", timeout=10.0
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
# Daemon is alive but SSE is idle — update activity to
|
||||
# avoid repeated warnings (connection may just be quiet)
|
||||
self._last_sse_activity = time.time()
|
||||
logger.debug("Signal: daemon healthy, SSE idle")
|
||||
else:
|
||||
logger.warning("Signal: health check failed (%d), forcing reconnect", resp.status_code)
|
||||
self._force_reconnect()
|
||||
except Exception as e:
|
||||
logger.warning("Signal: health check error: %s, forcing reconnect", e)
|
||||
self._force_reconnect()
|
||||
|
||||
def _force_reconnect(self) -> None:
|
||||
"""Force SSE reconnection by closing the current response."""
|
||||
if self._sse_response and not self._sse_response.is_stream_consumed:
|
||||
try:
|
||||
asyncio.create_task(self._sse_response.aclose())
|
||||
except Exception:
|
||||
pass
|
||||
self._sse_response = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message Handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_envelope(self, envelope: dict) -> None:
|
||||
"""Process an incoming signal-cli envelope."""
|
||||
# Unwrap nested envelope if present
|
||||
envelope_data = envelope.get("envelope", envelope)
|
||||
|
||||
# Filter syncMessage envelopes (sent transcripts, read receipts, etc.)
|
||||
# signal-cli may set syncMessage to null vs omitting it, so check key existence
|
||||
if "syncMessage" in envelope_data:
|
||||
return
|
||||
|
||||
# Extract sender info
|
||||
sender = (
|
||||
envelope_data.get("sourceNumber")
|
||||
or envelope_data.get("sourceUuid")
|
||||
or envelope_data.get("source")
|
||||
)
|
||||
sender_name = envelope_data.get("sourceName", "")
|
||||
sender_uuid = envelope_data.get("sourceUuid", "")
|
||||
|
||||
if not sender:
|
||||
logger.debug("Signal: ignoring envelope with no sender")
|
||||
return
|
||||
|
||||
# Self-message filtering — prevent reply loops
|
||||
if self._account_normalized and sender == self._account_normalized:
|
||||
return
|
||||
|
||||
# Filter stories
|
||||
if self.ignore_stories and envelope_data.get("storyMessage"):
|
||||
return
|
||||
|
||||
# Get data message — also check editMessage (edited messages contain
|
||||
# their updated dataMessage inside editMessage.dataMessage)
|
||||
data_message = (
|
||||
envelope_data.get("dataMessage")
|
||||
or (envelope_data.get("editMessage") or {}).get("dataMessage")
|
||||
)
|
||||
if not data_message:
|
||||
return
|
||||
|
||||
# Check for group message
|
||||
group_info = data_message.get("groupInfo")
|
||||
group_id = group_info.get("groupId") if group_info else None
|
||||
is_group = bool(group_id)
|
||||
|
||||
# Group message filtering — derived from SIGNAL_GROUP_ALLOWED_USERS:
|
||||
# - No env var set → groups disabled (default safe behavior)
|
||||
# - Env var set with group IDs → only those groups allowed
|
||||
# - Env var set with "*" → all groups allowed
|
||||
# DM auth is fully handled by run.py (_is_user_authorized)
|
||||
if is_group:
|
||||
if not self.group_allow_from:
|
||||
logger.debug("Signal: ignoring group message (no SIGNAL_GROUP_ALLOWED_USERS)")
|
||||
return
|
||||
if "*" not in self.group_allow_from and group_id not in self.group_allow_from:
|
||||
logger.debug("Signal: group %s not in allowlist", group_id[:8] if group_id else "?")
|
||||
return
|
||||
|
||||
# Build chat info
|
||||
chat_id = sender if not is_group else f"group:{group_id}"
|
||||
chat_type = "group" if is_group else "dm"
|
||||
|
||||
# Extract text and render mentions
|
||||
text = data_message.get("message", "")
|
||||
mentions = data_message.get("mentions", [])
|
||||
if text and mentions:
|
||||
text = _render_mentions(text, mentions)
|
||||
|
||||
# Process attachments
|
||||
attachments_data = data_message.get("attachments", [])
|
||||
image_paths = []
|
||||
audio_path = None
|
||||
document_paths = []
|
||||
|
||||
if attachments_data and not getattr(self, "ignore_attachments", False):
|
||||
for att in attachments_data:
|
||||
att_id = att.get("id")
|
||||
att_size = att.get("size", 0)
|
||||
if not att_id:
|
||||
continue
|
||||
if att_size > SIGNAL_MAX_ATTACHMENT_SIZE:
|
||||
logger.warning("Signal: attachment too large (%d bytes), skipping", att_size)
|
||||
continue
|
||||
try:
|
||||
cached_path, ext = await self._fetch_attachment(att_id)
|
||||
if cached_path:
|
||||
if _is_image_ext(ext):
|
||||
image_paths.append(cached_path)
|
||||
elif _is_audio_ext(ext):
|
||||
audio_path = cached_path
|
||||
else:
|
||||
document_paths.append(cached_path)
|
||||
except Exception:
|
||||
logger.exception("Signal: failed to fetch attachment %s", att_id)
|
||||
|
||||
# Build session source
|
||||
source = self.build_source(
|
||||
chat_id=chat_id,
|
||||
chat_name=group_info.get("groupName") if group_info else sender_name,
|
||||
chat_type=chat_type,
|
||||
user_id=sender,
|
||||
user_name=sender_name or sender,
|
||||
user_id_alt=sender_uuid if sender_uuid else None,
|
||||
chat_id_alt=group_id if is_group else None,
|
||||
)
|
||||
|
||||
# Determine message type
|
||||
msg_type = MessageType.TEXT
|
||||
if audio_path:
|
||||
msg_type = MessageType.VOICE
|
||||
elif image_paths:
|
||||
msg_type = MessageType.IMAGE
|
||||
|
||||
# Parse timestamp from envelope data (milliseconds since epoch)
|
||||
ts_ms = envelope_data.get("timestamp", 0)
|
||||
if ts_ms:
|
||||
try:
|
||||
timestamp = datetime.fromtimestamp(ts_ms / 1000, tz=timezone.utc)
|
||||
except (ValueError, OSError):
|
||||
timestamp = datetime.now(tz=timezone.utc)
|
||||
else:
|
||||
timestamp = datetime.now(tz=timezone.utc)
|
||||
|
||||
# Build and dispatch event
|
||||
event = MessageEvent(
|
||||
source=source,
|
||||
text=text or "",
|
||||
message_type=msg_type,
|
||||
image_paths=image_paths,
|
||||
audio_path=audio_path,
|
||||
document_paths=document_paths,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
logger.debug("Signal: message from %s in %s: %s",
|
||||
_redact_phone(sender), chat_id[:20], (text or "")[:50])
|
||||
|
||||
await self.handle_message(event)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Attachment Handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _fetch_attachment(self, attachment_id: str) -> tuple:
|
||||
"""Fetch an attachment via JSON-RPC and cache it. Returns (path, ext)."""
|
||||
result = await self._rpc("getAttachment", {
|
||||
"account": self.account,
|
||||
"attachmentId": attachment_id,
|
||||
})
|
||||
|
||||
if not result:
|
||||
return None, ""
|
||||
|
||||
# Result is base64-encoded file content
|
||||
raw_data = base64.b64decode(result)
|
||||
ext = _guess_extension(raw_data)
|
||||
|
||||
if _is_image_ext(ext):
|
||||
path = cache_image_from_bytes(raw_data, ext)
|
||||
elif _is_audio_ext(ext):
|
||||
path = cache_audio_from_bytes(raw_data, ext)
|
||||
else:
|
||||
path = cache_document_from_bytes(raw_data, ext)
|
||||
|
||||
return path, ext
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# JSON-RPC Communication
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _rpc(self, method: str, params: dict, rpc_id: str = None) -> Any:
|
||||
"""Send a JSON-RPC 2.0 request to signal-cli daemon."""
|
||||
if not self.client:
|
||||
logger.warning("Signal: RPC called but client not connected")
|
||||
return None
|
||||
|
||||
if rpc_id is None:
|
||||
rpc_id = f"{method}_{int(time.time() * 1000)}"
|
||||
|
||||
payload = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
"id": rpc_id,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = await self.client.post(
|
||||
f"{self.http_url}/api/v1/rpc",
|
||||
json=payload,
|
||||
timeout=30.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
if "error" in data:
|
||||
logger.warning("Signal RPC error (%s): %s", method, data["error"])
|
||||
return None
|
||||
|
||||
return data.get("result")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Signal RPC %s failed: %s", method, e)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Sending
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to_message_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a text message."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": text,
|
||||
}
|
||||
|
||||
if chat_id.startswith("group:"):
|
||||
params["groupId"] = chat_id[6:]
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
result = await self._rpc("send", params)
|
||||
|
||||
if result is not None:
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send failed")
|
||||
|
||||
async def send_typing(self, chat_id: str) -> None:
|
||||
"""Send a typing indicator."""
|
||||
params: Dict[str, Any] = {
|
||||
"account": self.account,
|
||||
}
|
||||
|
||||
if chat_id.startswith("group:"):
|
||||
params["groupId"] = chat_id[6:]
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
await self._rpc("sendTyping", params, rpc_id="typing")
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_url: str,
|
||||
caption: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send an image. Supports http(s):// and file:// URLs."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
# Resolve image to local path
|
||||
if image_url.startswith("file://"):
|
||||
file_path = unquote(image_url[7:])
|
||||
else:
|
||||
# Download remote image to cache
|
||||
try:
|
||||
file_path = await cache_image_from_url(image_url)
|
||||
except Exception as e:
|
||||
logger.warning("Signal: failed to download image: %s", e)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
if not file_path or not Path(file_path).exists():
|
||||
return SendResult(success=False, error="Image file not found")
|
||||
|
||||
# Validate size
|
||||
file_size = Path(file_path).stat().st_size
|
||||
if file_size > SIGNAL_MAX_ATTACHMENT_SIZE:
|
||||
return SendResult(success=False, error=f"Image too large ({file_size} bytes)")
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": caption or "",
|
||||
"attachments": [file_path],
|
||||
}
|
||||
|
||||
if chat_id.startswith("group:"):
|
||||
params["groupId"] = chat_id[6:]
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
result = await self._rpc("send", params)
|
||||
if result is not None:
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send with attachment failed")
|
||||
|
||||
async def send_document(
|
||||
self,
|
||||
chat_id: str,
|
||||
file_path: str,
|
||||
caption: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a document/file attachment."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
if not Path(file_path).exists():
|
||||
return SendResult(success=False, error="File not found")
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"account": self.account,
|
||||
"message": caption or "",
|
||||
"attachments": [file_path],
|
||||
}
|
||||
|
||||
if chat_id.startswith("group:"):
|
||||
params["groupId"] = chat_id[6:]
|
||||
else:
|
||||
params["recipient"] = [chat_id]
|
||||
|
||||
result = await self._rpc("send", params)
|
||||
if result is not None:
|
||||
return SendResult(success=True)
|
||||
return SendResult(success=False, error="RPC send document failed")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Typing Indicators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _start_typing_indicator(self, chat_id: str) -> None:
|
||||
"""Start a typing indicator loop for a chat."""
|
||||
if chat_id in self._typing_tasks:
|
||||
return # Already running
|
||||
|
||||
async def _typing_loop():
|
||||
try:
|
||||
while True:
|
||||
await self.send_typing(chat_id)
|
||||
await asyncio.sleep(TYPING_INTERVAL)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(_typing_loop())
|
||||
|
||||
async def _stop_typing_indicator(self, chat_id: str) -> None:
|
||||
"""Stop a typing indicator loop for a chat."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chat Info
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a chat/contact."""
|
||||
if chat_id.startswith("group:"):
|
||||
return {
|
||||
"name": chat_id,
|
||||
"type": "group",
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
|
||||
# Try to resolve contact name
|
||||
result = await self._rpc("getContact", {
|
||||
"account": self.account,
|
||||
"contactAddress": chat_id,
|
||||
})
|
||||
|
||||
name = chat_id
|
||||
if result and isinstance(result, dict):
|
||||
name = result.get("name") or result.get("profileName") or chat_id
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"type": "dm",
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
@@ -179,6 +179,35 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
"""Slack doesn't have a direct typing indicator API for bots."""
|
||||
pass
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file to Slack by uploading it."""
|
||||
if not self._app:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
result = await self._app.client.files_upload_v2(
|
||||
channel=chat_id,
|
||||
file=image_path,
|
||||
filename=os.path.basename(image_path),
|
||||
initial_comment=caption or "",
|
||||
thread_ts=reply_to,
|
||||
)
|
||||
return SendResult(success=True, raw_response=result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
||||
@@ -8,10 +8,13 @@ Uses python-telegram-bot library for:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from telegram import Update, Bot, Message
|
||||
from telegram.ext import (
|
||||
@@ -73,6 +76,19 @@ def _escape_mdv2(text: str) -> str:
|
||||
return _MDV2_ESCAPE_RE.sub(r'\\\1', text)
|
||||
|
||||
|
||||
def _strip_mdv2(text: str) -> str:
|
||||
"""Strip MarkdownV2 escape backslashes to produce clean plain text.
|
||||
|
||||
Also removes MarkdownV2 bold markers (*text* -> text) so the fallback
|
||||
doesn't show stray asterisks from header/bold conversion.
|
||||
"""
|
||||
# Remove escape backslashes before special characters
|
||||
cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text)
|
||||
# Remove MarkdownV2 bold markers that format_message converted from **bold**
|
||||
cleaned = re.sub(r'\*([^*]+)\*', r'\1', cleaned)
|
||||
return cleaned
|
||||
|
||||
|
||||
class TelegramAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
Telegram bot adapter.
|
||||
@@ -139,6 +155,14 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
BotCommand("status", "Show session info"),
|
||||
BotCommand("stop", "Stop the running agent"),
|
||||
BotCommand("sethome", "Set this chat as the home channel"),
|
||||
BotCommand("compress", "Compress conversation context"),
|
||||
BotCommand("title", "Set or show the session title"),
|
||||
BotCommand("resume", "Resume a previously-named session"),
|
||||
BotCommand("usage", "Show token usage for this session"),
|
||||
BotCommand("provider", "Show available providers"),
|
||||
BotCommand("insights", "Show usage insights and analytics"),
|
||||
BotCommand("update", "Update Hermes to the latest version"),
|
||||
BotCommand("reload_mcp", "Reload MCP servers from config"),
|
||||
BotCommand("help", "Show available commands"),
|
||||
])
|
||||
except Exception as e:
|
||||
@@ -199,9 +223,13 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
|
||||
# Strip MDV2 escape backslashes so the user doesn't
|
||||
# see raw backslashes littered through the message.
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
text=plain_chunk,
|
||||
parse_mode=None, # Plain text
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
@@ -286,6 +314,34 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
print(f"[{self.name}] Failed to send voice/audio: {e}")
|
||||
return await super().send_voice(chat_id, audio_path, caption, reply_to)
|
||||
|
||||
async def send_image_file(
|
||||
self,
|
||||
chat_id: str,
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
with open(image_path, "rb") as image_file:
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send local image: {e}")
|
||||
return await super().send_image_file(chat_id, image_path, caption, reply_to)
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -293,12 +349,16 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
) -> SendResult:
|
||||
"""Send an image natively as a Telegram photo."""
|
||||
"""Send an image natively as a Telegram photo.
|
||||
|
||||
Tries URL-based send first (fast, works for <5MB images).
|
||||
Falls back to downloading and uploading as file (supports up to 10MB).
|
||||
"""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
# Telegram can send photos directly from URLs
|
||||
# Telegram can send photos directly from URLs (up to ~5MB)
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_url,
|
||||
@@ -307,9 +367,26 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Failed to send photo, falling back to URL: {e}")
|
||||
# Fallback: send as text link
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
logger.warning("[%s] URL-based send_photo failed (%s), trying file upload", self.name, e)
|
||||
# Fallback: download and upload as file (supports up to 10MB)
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.get(image_url)
|
||||
resp.raise_for_status()
|
||||
image_data = resp.content
|
||||
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_data,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e2:
|
||||
logger.error("[%s] File upload send_photo also failed: %s", self.name, e2)
|
||||
# Final fallback: send URL as text
|
||||
return await super().send_image(chat_id, image_url, caption, reply_to)
|
||||
|
||||
async def send_animation(
|
||||
self,
|
||||
|
||||
745
gateway/run.py
745
gateway/run.py
@@ -75,6 +75,7 @@ if _config_path.exists():
|
||||
"container_memory": "TERMINAL_CONTAINER_MEMORY",
|
||||
"container_disk": "TERMINAL_CONTAINER_DISK",
|
||||
"container_persistent": "TERMINAL_CONTAINER_PERSISTENT",
|
||||
"sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
}
|
||||
for _cfg_key, _env_var in _terminal_env_map.items():
|
||||
if _cfg_key in _terminal_cfg:
|
||||
@@ -85,14 +86,44 @@ if _config_path.exists():
|
||||
"enabled": "CONTEXT_COMPRESSION_ENABLED",
|
||||
"threshold": "CONTEXT_COMPRESSION_THRESHOLD",
|
||||
"summary_model": "CONTEXT_COMPRESSION_MODEL",
|
||||
"summary_provider": "CONTEXT_COMPRESSION_PROVIDER",
|
||||
}
|
||||
for _cfg_key, _env_var in _compression_env_map.items():
|
||||
if _cfg_key in _compression_cfg:
|
||||
os.environ[_env_var] = str(_compression_cfg[_cfg_key])
|
||||
# Auxiliary model overrides (vision, web_extract).
|
||||
# Each task has provider + model; bridge non-default values to env vars.
|
||||
_auxiliary_cfg = _cfg.get("auxiliary", {})
|
||||
if _auxiliary_cfg and isinstance(_auxiliary_cfg, dict):
|
||||
_aux_task_env = {
|
||||
"vision": ("AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL"),
|
||||
"web_extract": ("AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL"),
|
||||
}
|
||||
for _task_key, (_prov_env, _model_env) in _aux_task_env.items():
|
||||
_task_cfg = _auxiliary_cfg.get(_task_key, {})
|
||||
if not isinstance(_task_cfg, dict):
|
||||
continue
|
||||
_prov = str(_task_cfg.get("provider", "")).strip()
|
||||
_model = str(_task_cfg.get("model", "")).strip()
|
||||
if _prov and _prov != "auto":
|
||||
os.environ[_prov_env] = _prov
|
||||
if _model:
|
||||
os.environ[_model_env] = _model
|
||||
_agent_cfg = _cfg.get("agent", {})
|
||||
if _agent_cfg and isinstance(_agent_cfg, dict):
|
||||
if "max_turns" in _agent_cfg:
|
||||
os.environ["HERMES_MAX_ITERATIONS"] = str(_agent_cfg["max_turns"])
|
||||
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
|
||||
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
|
||||
_tz_cfg = _cfg.get("timezone", "")
|
||||
if _tz_cfg and isinstance(_tz_cfg, str) and "HERMES_TIMEZONE" not in os.environ:
|
||||
os.environ["HERMES_TIMEZONE"] = _tz_cfg.strip()
|
||||
# Security settings
|
||||
_security_cfg = _cfg.get("security", {})
|
||||
if isinstance(_security_cfg, dict):
|
||||
_redact = _security_cfg.get("redact_secrets")
|
||||
if _redact is not None:
|
||||
os.environ["HERMES_REDACT_SECRETS"] = str(_redact).lower()
|
||||
except Exception:
|
||||
pass # Non-fatal; gateway can still run with .env values
|
||||
|
||||
@@ -102,11 +133,13 @@ os.environ["HERMES_QUIET"] = "1"
|
||||
# Enable interactive exec approval for dangerous commands on messaging platforms
|
||||
os.environ["HERMES_EXEC_ASK"] = "1"
|
||||
|
||||
# Set terminal working directory for messaging platforms
|
||||
# Uses MESSAGING_CWD if set, otherwise defaults to home directory
|
||||
# This is separate from CLI which uses the directory where `hermes` is run
|
||||
messaging_cwd = os.getenv("MESSAGING_CWD") or str(Path.home())
|
||||
os.environ["TERMINAL_CWD"] = messaging_cwd
|
||||
# Set terminal working directory for messaging platforms.
|
||||
# If the user set an explicit path in config.yaml (not "." or "auto"),
|
||||
# respect it. Otherwise use MESSAGING_CWD or default to home directory.
|
||||
_configured_cwd = os.environ.get("TERMINAL_CWD", "")
|
||||
if not _configured_cwd or _configured_cwd in (".", "auto", "cwd"):
|
||||
messaging_cwd = os.getenv("MESSAGING_CWD") or str(Path.home())
|
||||
os.environ["TERMINAL_CWD"] = messaging_cwd
|
||||
|
||||
from gateway.config import (
|
||||
Platform,
|
||||
@@ -167,13 +200,13 @@ class GatewayRunner:
|
||||
self._ephemeral_system_prompt = self._load_ephemeral_system_prompt()
|
||||
self._reasoning_config = self._load_reasoning_config()
|
||||
self._provider_routing = self._load_provider_routing()
|
||||
self._fallback_model = self._load_fallback_model()
|
||||
|
||||
# Wire process registry into session store for reset protection
|
||||
from tools.process_registry import process_registry
|
||||
self.session_store = SessionStore(
|
||||
self.config.sessions_dir, self.config,
|
||||
has_active_processes_fn=lambda key: process_registry.has_active_for_session(key),
|
||||
on_auto_reset=self._flush_memories_before_reset,
|
||||
)
|
||||
self.delivery_router = DeliveryRouter(self.config)
|
||||
self._running = False
|
||||
@@ -204,15 +237,14 @@ class GatewayRunner:
|
||||
from gateway.hooks import HookRegistry
|
||||
self.hooks = HookRegistry()
|
||||
|
||||
def _flush_memories_before_reset(self, old_entry):
|
||||
"""Prompt the agent to save memories/skills before an auto-reset.
|
||||
|
||||
Called synchronously by SessionStore before destroying an expired session.
|
||||
Loads the transcript, gives the agent a real turn with memory + skills
|
||||
tools, and explicitly asks it to preserve anything worth keeping.
|
||||
def _flush_memories_for_session(self, old_session_id: str):
|
||||
"""Prompt the agent to save memories/skills before context is lost.
|
||||
|
||||
Synchronous worker — meant to be called via run_in_executor from
|
||||
an async context so it doesn't block the event loop.
|
||||
"""
|
||||
try:
|
||||
history = self.session_store.load_transcript(old_entry.session_id)
|
||||
history = self.session_store.load_transcript(old_session_id)
|
||||
if not history or len(history) < 4:
|
||||
return
|
||||
|
||||
@@ -226,7 +258,7 @@ class GatewayRunner:
|
||||
max_iterations=8,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=["memory", "skills"],
|
||||
session_id=old_entry.session_id,
|
||||
session_id=old_session_id,
|
||||
)
|
||||
|
||||
# Build conversation history from transcript
|
||||
@@ -255,9 +287,14 @@ class GatewayRunner:
|
||||
user_message=flush_prompt,
|
||||
conversation_history=msgs,
|
||||
)
|
||||
logger.info("Pre-reset save completed for session %s", old_entry.session_id)
|
||||
logger.info("Pre-reset memory flush completed for session %s", old_session_id)
|
||||
except Exception as e:
|
||||
logger.debug("Pre-reset save failed for session %s: %s", old_entry.session_id, e)
|
||||
logger.debug("Pre-reset memory flush failed for session %s: %s", old_session_id, e)
|
||||
|
||||
async def _async_flush_memories(self, old_session_id: str):
|
||||
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id)
|
||||
|
||||
@staticmethod
|
||||
def _load_prefill_messages() -> List[Dict[str, Any]]:
|
||||
@@ -325,7 +362,7 @@ class GatewayRunner:
|
||||
|
||||
Checks HERMES_REASONING_EFFORT env var first, then agent.reasoning_effort
|
||||
in config.yaml. Valid: "xhigh", "high", "medium", "low", "minimal", "none".
|
||||
Returns None to use default (xhigh).
|
||||
Returns None to use default (medium).
|
||||
"""
|
||||
effort = os.getenv("HERMES_REASONING_EFFORT", "")
|
||||
if not effort:
|
||||
@@ -346,7 +383,7 @@ class GatewayRunner:
|
||||
valid = ("xhigh", "high", "medium", "low", "minimal")
|
||||
if effort in valid:
|
||||
return {"enabled": True, "effort": effort}
|
||||
logger.warning("Unknown reasoning_effort '%s', using default (xhigh)", effort)
|
||||
logger.warning("Unknown reasoning_effort '%s', using default (medium)", effort)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -363,6 +400,26 @@ class GatewayRunner:
|
||||
pass
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _load_fallback_model() -> dict | None:
|
||||
"""Load fallback model config from config.yaml.
|
||||
|
||||
Returns a dict with 'provider' and 'model' keys, or None if
|
||||
not configured / both fields empty.
|
||||
"""
|
||||
try:
|
||||
import yaml as _y
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path) as _f:
|
||||
cfg = _y.safe_load(_f) or {}
|
||||
fb = cfg.get("fallback_model", {}) or {}
|
||||
if fb.get("provider") and fb.get("model"):
|
||||
return fb
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""
|
||||
Start the gateway and all configured platform adapters.
|
||||
@@ -459,10 +516,50 @@ class GatewayRunner:
|
||||
# Check if we're restarting after a /update command
|
||||
await self._send_update_notification()
|
||||
|
||||
# Start background session expiry watcher for proactive memory flushing
|
||||
asyncio.create_task(self._session_expiry_watcher())
|
||||
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
return True
|
||||
|
||||
async def _session_expiry_watcher(self, interval: int = 300):
|
||||
"""Background task that proactively flushes memories for expired sessions.
|
||||
|
||||
Runs every `interval` seconds (default 5 min). For each session that
|
||||
has expired according to its reset policy, flushes memories in a thread
|
||||
pool and marks the session so it won't be flushed again.
|
||||
|
||||
This means memories are already saved by the time the user sends their
|
||||
next message, so there's no blocking delay.
|
||||
"""
|
||||
await asyncio.sleep(60) # initial delay — let the gateway fully start
|
||||
while self._running:
|
||||
try:
|
||||
self.session_store._ensure_loaded()
|
||||
for key, entry in list(self.session_store._entries.items()):
|
||||
if entry.session_id in self.session_store._pre_flushed_sessions:
|
||||
continue # already flushed this session
|
||||
if not self.session_store._is_session_expired(entry):
|
||||
continue # session still active
|
||||
# Session has expired — flush memories in the background
|
||||
logger.info(
|
||||
"Session %s expired (key=%s), flushing memories proactively",
|
||||
entry.session_id, key,
|
||||
)
|
||||
try:
|
||||
await self._async_flush_memories(entry.session_id)
|
||||
self.session_store._pre_flushed_sessions.add(entry.session_id)
|
||||
except Exception as e:
|
||||
logger.debug("Proactive memory flush failed for %s: %s", entry.session_id, e)
|
||||
except Exception as e:
|
||||
logger.debug("Session expiry watcher error: %s", e)
|
||||
# Sleep in small increments so we can stop quickly
|
||||
for _ in range(interval):
|
||||
if not self._running:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the gateway and disconnect all adapters."""
|
||||
logger.info("Stopping gateway...")
|
||||
@@ -521,6 +618,13 @@ class GatewayRunner:
|
||||
return None
|
||||
return SlackAdapter(config)
|
||||
|
||||
elif platform == Platform.SIGNAL:
|
||||
from gateway.platforms.signal import SignalAdapter, check_signal_requirements
|
||||
if not check_signal_requirements():
|
||||
logger.warning("Signal: SIGNAL_HTTP_URL or SIGNAL_ACCOUNT not configured")
|
||||
return None
|
||||
return SignalAdapter(config)
|
||||
|
||||
elif platform == Platform.HOMEASSISTANT:
|
||||
from gateway.platforms.homeassistant import HomeAssistantAdapter, check_ha_requirements
|
||||
if not check_ha_requirements():
|
||||
@@ -556,12 +660,14 @@ class GatewayRunner:
|
||||
Platform.DISCORD: "DISCORD_ALLOWED_USERS",
|
||||
Platform.WHATSAPP: "WHATSAPP_ALLOWED_USERS",
|
||||
Platform.SLACK: "SLACK_ALLOWED_USERS",
|
||||
Platform.SIGNAL: "SIGNAL_ALLOWED_USERS",
|
||||
}
|
||||
platform_allow_all_map = {
|
||||
Platform.TELEGRAM: "TELEGRAM_ALLOW_ALL_USERS",
|
||||
Platform.DISCORD: "DISCORD_ALLOW_ALL_USERS",
|
||||
Platform.WHATSAPP: "WHATSAPP_ALLOW_ALL_USERS",
|
||||
Platform.SLACK: "SLACK_ALLOW_ALL_USERS",
|
||||
Platform.SIGNAL: "SIGNAL_ALLOW_ALL_USERS",
|
||||
}
|
||||
|
||||
# Per-platform allow-all flag (e.g., DISCORD_ALLOW_ALL_USERS=true)
|
||||
@@ -659,7 +765,8 @@ class GatewayRunner:
|
||||
# Emit command:* hook for any recognized slash command
|
||||
_known_commands = {"new", "reset", "help", "status", "stop", "model",
|
||||
"personality", "retry", "undo", "sethome", "set-home",
|
||||
"compress", "usage", "insights", "reload-mcp", "update"}
|
||||
"compress", "usage", "insights", "reload-mcp", "reload_mcp",
|
||||
"update", "title", "resume", "provider"}
|
||||
if command and command in _known_commands:
|
||||
await self.hooks.emit(f"command:{command}", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
@@ -683,6 +790,9 @@ class GatewayRunner:
|
||||
if command == "model":
|
||||
return await self._handle_model_command(event)
|
||||
|
||||
if command == "provider":
|
||||
return await self._handle_provider_command(event)
|
||||
|
||||
if command == "personality":
|
||||
return await self._handle_personality_command(event)
|
||||
|
||||
@@ -704,11 +814,17 @@ class GatewayRunner:
|
||||
if command == "insights":
|
||||
return await self._handle_insights_command(event)
|
||||
|
||||
if command == "reload-mcp":
|
||||
if command in ("reload-mcp", "reload_mcp"):
|
||||
return await self._handle_reload_mcp_command(event)
|
||||
|
||||
if command == "update":
|
||||
return await self._handle_update_command(event)
|
||||
|
||||
if command == "title":
|
||||
return await self._handle_title_command(event)
|
||||
|
||||
if command == "resume":
|
||||
return await self._handle_resume_command(event)
|
||||
|
||||
# Skill slash commands: /skill-name loads the skill and sends to agent
|
||||
if command:
|
||||
@@ -783,6 +899,195 @@ class GatewayRunner:
|
||||
# Load conversation history from transcript
|
||||
history = self.session_store.load_transcript(session_entry.session_id)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Session hygiene: auto-compress pathologically large transcripts
|
||||
#
|
||||
# Long-lived gateway sessions can accumulate enough history that
|
||||
# every new message rehydrates an oversized transcript, causing
|
||||
# repeated truncation/context failures. Detect this early and
|
||||
# compress proactively — before the agent even starts. (#628)
|
||||
#
|
||||
# Thresholds are derived from the SAME compression config the
|
||||
# agent uses (compression.threshold × model context length) so
|
||||
# CLI and messaging platforms behave identically.
|
||||
# -----------------------------------------------------------------
|
||||
if history and len(history) >= 4:
|
||||
from agent.model_metadata import (
|
||||
estimate_messages_tokens_rough,
|
||||
get_model_context_length,
|
||||
)
|
||||
|
||||
# Read model + compression config from config.yaml — same
|
||||
# source of truth the agent itself uses.
|
||||
_hyg_model = "anthropic/claude-sonnet-4.6"
|
||||
_hyg_threshold_pct = 0.85
|
||||
_hyg_compression_enabled = True
|
||||
try:
|
||||
_hyg_cfg_path = _hermes_home / "config.yaml"
|
||||
if _hyg_cfg_path.exists():
|
||||
import yaml as _hyg_yaml
|
||||
with open(_hyg_cfg_path) as _hyg_f:
|
||||
_hyg_data = _hyg_yaml.safe_load(_hyg_f) or {}
|
||||
|
||||
# Resolve model name (same logic as run_sync)
|
||||
_model_cfg = _hyg_data.get("model", {})
|
||||
if isinstance(_model_cfg, str):
|
||||
_hyg_model = _model_cfg
|
||||
elif isinstance(_model_cfg, dict):
|
||||
_hyg_model = _model_cfg.get("default", _hyg_model)
|
||||
|
||||
# Read compression settings
|
||||
_comp_cfg = _hyg_data.get("compression", {})
|
||||
if isinstance(_comp_cfg, dict):
|
||||
_hyg_threshold_pct = float(
|
||||
_comp_cfg.get("threshold", _hyg_threshold_pct)
|
||||
)
|
||||
_hyg_compression_enabled = str(
|
||||
_comp_cfg.get("enabled", True)
|
||||
).lower() in ("true", "1", "yes")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Also check env overrides (same as run_agent.py)
|
||||
_hyg_threshold_pct = float(
|
||||
os.getenv("CONTEXT_COMPRESSION_THRESHOLD", str(_hyg_threshold_pct))
|
||||
)
|
||||
if os.getenv("CONTEXT_COMPRESSION_ENABLED", "").lower() in ("false", "0", "no"):
|
||||
_hyg_compression_enabled = False
|
||||
|
||||
if _hyg_compression_enabled:
|
||||
_hyg_context_length = get_model_context_length(_hyg_model)
|
||||
_compress_token_threshold = int(
|
||||
_hyg_context_length * _hyg_threshold_pct
|
||||
)
|
||||
# Warn if still huge after compression (95% of context)
|
||||
_warn_token_threshold = int(_hyg_context_length * 0.95)
|
||||
|
||||
_msg_count = len(history)
|
||||
_approx_tokens = estimate_messages_tokens_rough(history)
|
||||
|
||||
_needs_compress = _approx_tokens >= _compress_token_threshold
|
||||
|
||||
if _needs_compress:
|
||||
logger.info(
|
||||
"Session hygiene: %s messages, ~%s tokens — auto-compressing "
|
||||
"(threshold: %s%% of %s = %s tokens)",
|
||||
_msg_count, f"{_approx_tokens:,}",
|
||||
int(_hyg_threshold_pct * 100),
|
||||
f"{_hyg_context_length:,}",
|
||||
f"{_compress_token_threshold:,}",
|
||||
)
|
||||
|
||||
_hyg_adapter = self.adapters.get(source.platform)
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
f"🗜️ Session is large ({_msg_count} messages, "
|
||||
f"~{_approx_tokens:,} tokens). Auto-compressing..."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
from run_agent import AIAgent
|
||||
|
||||
_hyg_runtime = _resolve_runtime_agent_kwargs()
|
||||
if _hyg_runtime.get("api_key"):
|
||||
_hyg_msgs = [
|
||||
{"role": m.get("role"), "content": m.get("content")}
|
||||
for m in history
|
||||
if m.get("role") in ("user", "assistant")
|
||||
and m.get("content")
|
||||
]
|
||||
|
||||
if len(_hyg_msgs) >= 4:
|
||||
_hyg_agent = AIAgent(
|
||||
**_hyg_runtime,
|
||||
max_iterations=4,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
_compressed, _ = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: _hyg_agent._compress_context(
|
||||
_hyg_msgs, "",
|
||||
approx_tokens=_approx_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
self.session_store.rewrite_transcript(
|
||||
session_entry.session_id, _compressed
|
||||
)
|
||||
history = _compressed
|
||||
_new_count = len(_compressed)
|
||||
_new_tokens = estimate_messages_tokens_rough(
|
||||
_compressed
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Session hygiene: compressed %s → %s msgs, "
|
||||
"~%s → ~%s tokens",
|
||||
_msg_count, _new_count,
|
||||
f"{_approx_tokens:,}", f"{_new_tokens:,}",
|
||||
)
|
||||
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
f"🗜️ Compressed: {_msg_count} → "
|
||||
f"{_new_count} messages, "
|
||||
f"~{_approx_tokens:,} → "
|
||||
f"~{_new_tokens:,} tokens"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Still too large after compression — warn user
|
||||
if _new_tokens >= _warn_token_threshold:
|
||||
logger.warning(
|
||||
"Session hygiene: still ~%s tokens after "
|
||||
"compression — suggesting /reset",
|
||||
f"{_new_tokens:,}",
|
||||
)
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
"⚠️ Session is still very large "
|
||||
"after compression "
|
||||
f"(~{_new_tokens:,} tokens). "
|
||||
"Consider using /reset to start "
|
||||
"fresh if you experience issues."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Session hygiene auto-compress failed: %s", e
|
||||
)
|
||||
# Compression failed and session is dangerously large
|
||||
if _approx_tokens >= _warn_token_threshold:
|
||||
_hyg_adapter = self.adapters.get(source.platform)
|
||||
if _hyg_adapter:
|
||||
try:
|
||||
await _hyg_adapter.send(
|
||||
source.chat_id,
|
||||
f"⚠️ Session is very large "
|
||||
f"({_msg_count} messages, "
|
||||
f"~{_approx_tokens:,} tokens) and "
|
||||
"auto-compression failed. Consider "
|
||||
"using /compress or /reset to avoid "
|
||||
"issues."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# First-message onboarding -- only on the very first interaction ever
|
||||
if not history and not self.session_store.has_any_sessions():
|
||||
context_prompt += (
|
||||
@@ -1007,33 +1312,12 @@ class GatewayRunner:
|
||||
# Get existing session key
|
||||
session_key = self.session_store._generate_session_key(source)
|
||||
|
||||
# Memory flush before reset: load the old transcript and let a
|
||||
# temporary agent save memories before the session is wiped.
|
||||
# Flush memories in the background (fire-and-forget) so the user
|
||||
# gets the "Session reset!" response immediately.
|
||||
try:
|
||||
old_entry = self.session_store._entries.get(session_key)
|
||||
if old_entry:
|
||||
old_history = self.session_store.load_transcript(old_entry.session_id)
|
||||
if old_history:
|
||||
from run_agent import AIAgent
|
||||
loop = asyncio.get_event_loop()
|
||||
_flush_kwargs = _resolve_runtime_agent_kwargs()
|
||||
def _do_flush():
|
||||
tmp_agent = AIAgent(
|
||||
**_flush_kwargs,
|
||||
max_iterations=5,
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=old_entry.session_id,
|
||||
)
|
||||
# Build simple message list from transcript
|
||||
msgs = []
|
||||
for m in old_history:
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
if role in ("user", "assistant") and content:
|
||||
msgs.append({"role": role, "content": content})
|
||||
tmp_agent.flush_memories(msgs)
|
||||
await loop.run_in_executor(None, _do_flush)
|
||||
asyncio.create_task(self._async_flush_memories(old_entry.session_id))
|
||||
except Exception as e:
|
||||
logger.debug("Gateway memory flush on reset failed: %s", e)
|
||||
|
||||
@@ -1100,12 +1384,15 @@ class GatewayRunner:
|
||||
"`/reset` — Reset conversation history",
|
||||
"`/status` — Show session info",
|
||||
"`/stop` — Interrupt the running agent",
|
||||
"`/model [name]` — Show or change the model",
|
||||
"`/model [provider:model]` — Show/change model (or switch provider)",
|
||||
"`/provider` — Show available providers and auth status",
|
||||
"`/personality [name]` — Set a personality",
|
||||
"`/retry` — Retry your last message",
|
||||
"`/undo` — Remove the last exchange",
|
||||
"`/sethome` — Set this chat as the home channel",
|
||||
"`/compress` — Compress conversation context",
|
||||
"`/title [name]` — Set or show the session title",
|
||||
"`/resume [name]` — Resume a previously-named session",
|
||||
"`/usage` — Show token usage for this session",
|
||||
"`/insights [days]` — Show usage insights and analytics",
|
||||
"`/reload-mcp` — Reload MCP servers from config",
|
||||
@@ -1126,13 +1413,20 @@ class GatewayRunner:
|
||||
async def _handle_model_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /model command - show or change the current model."""
|
||||
import yaml
|
||||
from hermes_cli.models import (
|
||||
parse_model_input,
|
||||
validate_requested_model,
|
||||
curated_models_for_provider,
|
||||
normalize_provider,
|
||||
_PROVIDER_LABELS,
|
||||
)
|
||||
|
||||
args = event.get_command_args().strip()
|
||||
config_path = _hermes_home / 'config.yaml'
|
||||
|
||||
# Resolve current model the same way the agent init does:
|
||||
# env vars first, then config.yaml always overrides.
|
||||
# Resolve current model and provider from config
|
||||
current = os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL") or "anthropic/claude-opus-4.6"
|
||||
current_provider = "openrouter"
|
||||
try:
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
@@ -1142,39 +1436,164 @@ class GatewayRunner:
|
||||
current = model_cfg
|
||||
elif isinstance(model_cfg, dict):
|
||||
current = model_cfg.get("default", current)
|
||||
current_provider = model_cfg.get("provider", current_provider)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Resolve "auto" to the actual provider using credential detection
|
||||
current_provider = normalize_provider(current_provider)
|
||||
if current_provider == "auto":
|
||||
try:
|
||||
from hermes_cli.auth import resolve_provider as _resolve_provider
|
||||
current_provider = _resolve_provider(current_provider)
|
||||
except Exception:
|
||||
current_provider = "openrouter"
|
||||
|
||||
if not args:
|
||||
return f"🤖 **Current model:** `{current}`\n\nTo change: `/model provider/model-name`"
|
||||
provider_label = _PROVIDER_LABELS.get(current_provider, current_provider)
|
||||
lines = [
|
||||
f"🤖 **Current model:** `{current}`",
|
||||
f"**Provider:** {provider_label}",
|
||||
"",
|
||||
]
|
||||
curated = curated_models_for_provider(current_provider)
|
||||
if curated:
|
||||
lines.append(f"**Available models ({provider_label}):**")
|
||||
for mid, desc in curated:
|
||||
marker = " ←" if mid == current else ""
|
||||
label = f" _{desc}_" if desc else ""
|
||||
lines.append(f"• `{mid}`{label}{marker}")
|
||||
lines.append("")
|
||||
lines.append("To change: `/model model-name`")
|
||||
lines.append("Switch provider: `/model provider:model-name`")
|
||||
return "\n".join(lines)
|
||||
|
||||
if "/" not in args:
|
||||
return (
|
||||
f"🤖 Invalid model format: `{args}`\n\n"
|
||||
f"Use `provider/model-name` format, e.g.:\n"
|
||||
f"• `anthropic/claude-sonnet-4`\n"
|
||||
f"• `google/gemini-2.5-pro`\n"
|
||||
f"• `openai/gpt-4o`"
|
||||
)
|
||||
# Parse provider:model syntax
|
||||
target_provider, new_model = parse_model_input(args, current_provider)
|
||||
provider_changed = target_provider != current_provider
|
||||
|
||||
# Write to config.yaml (source of truth), same pattern as CLI save_config_value.
|
||||
# Resolve credentials for the target provider (for API probe)
|
||||
api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
if provider_changed:
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=target_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception as e:
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
return f"⚠️ Could not resolve credentials for provider '{provider_label}': {e}"
|
||||
else:
|
||||
# Use current provider's base_url from config or registry
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=current_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Validate the model against the live API
|
||||
try:
|
||||
validation = validate_requested_model(
|
||||
new_model,
|
||||
target_provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception:
|
||||
validation = {"accepted": True, "persist": True, "recognized": False, "message": None}
|
||||
|
||||
if not validation.get("accepted"):
|
||||
msg = validation.get("message", "Invalid model")
|
||||
tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else ""
|
||||
return f"⚠️ {msg}{tip}"
|
||||
|
||||
# Persist to config only if validation approves
|
||||
if validation.get("persist"):
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||
user_config["model"] = {}
|
||||
user_config["model"]["default"] = new_model
|
||||
if provider_changed:
|
||||
user_config["model"]["provider"] = target_provider
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
except Exception as e:
|
||||
return f"⚠️ Failed to save model change: {e}"
|
||||
|
||||
# Set env vars so the next agent run picks up the change
|
||||
os.environ["HERMES_MODEL"] = new_model
|
||||
if provider_changed:
|
||||
os.environ["HERMES_INFERENCE_PROVIDER"] = target_provider
|
||||
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
provider_note = f"\n**Provider:** {provider_label}" if provider_changed else ""
|
||||
|
||||
warning = ""
|
||||
if validation.get("message"):
|
||||
warning = f"\n⚠️ {validation['message']}"
|
||||
|
||||
if validation.get("persist"):
|
||||
persist_note = "saved to config"
|
||||
else:
|
||||
persist_note = "this session only — will revert on restart"
|
||||
return f"🤖 Model changed to `{new_model}` ({persist_note}){provider_note}{warning}\n_(takes effect on next message)_"
|
||||
|
||||
async def _handle_provider_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /provider command - show available providers."""
|
||||
import yaml
|
||||
from hermes_cli.models import (
|
||||
list_available_providers,
|
||||
normalize_provider,
|
||||
_PROVIDER_LABELS,
|
||||
)
|
||||
|
||||
# Resolve current provider from config
|
||||
current_provider = "openrouter"
|
||||
config_path = _hermes_home / 'config.yaml'
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||
user_config["model"] = {}
|
||||
user_config["model"]["default"] = args
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
except Exception as e:
|
||||
return f"⚠️ Failed to save model change: {e}"
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
model_cfg = cfg.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
current_provider = model_cfg.get("provider", current_provider)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Also set env var so code reading it before the next agent init sees the update.
|
||||
os.environ["HERMES_MODEL"] = args
|
||||
current_provider = normalize_provider(current_provider)
|
||||
if current_provider == "auto":
|
||||
try:
|
||||
from hermes_cli.auth import resolve_provider as _resolve_provider
|
||||
current_provider = _resolve_provider(current_provider)
|
||||
except Exception:
|
||||
current_provider = "openrouter"
|
||||
|
||||
return f"🤖 Model changed to `{args}`\n_(takes effect on next message)_"
|
||||
current_label = _PROVIDER_LABELS.get(current_provider, current_provider)
|
||||
|
||||
lines = [
|
||||
f"🔌 **Current provider:** {current_label} (`{current_provider}`)",
|
||||
"",
|
||||
"**Available providers:**",
|
||||
]
|
||||
|
||||
providers = list_available_providers()
|
||||
for p in providers:
|
||||
marker = " ← active" if p["id"] == current_provider else ""
|
||||
auth = "✅" if p["authenticated"] else "❌"
|
||||
aliases = f" _(also: {', '.join(p['aliases'])})_" if p["aliases"] else ""
|
||||
lines.append(f"{auth} `{p['id']}` — {p['label']}{aliases}{marker}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("Switch: `/model provider:model-name`")
|
||||
lines.append("Setup: `hermes setup`")
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_personality_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /personality command - list or set a personality."""
|
||||
@@ -1364,6 +1783,113 @@ class GatewayRunner:
|
||||
logger.warning("Manual compress failed: %s", e)
|
||||
return f"Compression failed: {e}"
|
||||
|
||||
async def _handle_title_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /title command — set or show the current session's title."""
|
||||
source = event.source
|
||||
session_entry = self.session_store.get_or_create_session(source)
|
||||
session_id = session_entry.session_id
|
||||
|
||||
if not self._session_db:
|
||||
return "Session database not available."
|
||||
|
||||
title_arg = event.get_command_args().strip()
|
||||
if title_arg:
|
||||
# Sanitize the title before setting
|
||||
try:
|
||||
sanitized = self._session_db.sanitize_title(title_arg)
|
||||
except ValueError as e:
|
||||
return f"⚠️ {e}"
|
||||
if not sanitized:
|
||||
return "⚠️ Title is empty after cleanup. Please use printable characters."
|
||||
# Set the title
|
||||
try:
|
||||
if self._session_db.set_session_title(session_id, sanitized):
|
||||
return f"✏️ Session title set: **{sanitized}**"
|
||||
else:
|
||||
return "Session not found in database."
|
||||
except ValueError as e:
|
||||
return f"⚠️ {e}"
|
||||
else:
|
||||
# Show the current title
|
||||
title = self._session_db.get_session_title(session_id)
|
||||
if title:
|
||||
return f"📌 Session title: **{title}**"
|
||||
else:
|
||||
return "No title set. Usage: `/title My Session Name`"
|
||||
|
||||
async def _handle_resume_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /resume command — switch to a previously-named session."""
|
||||
if not self._session_db:
|
||||
return "Session database not available."
|
||||
|
||||
source = event.source
|
||||
session_key = build_session_key(source)
|
||||
name = event.get_command_args().strip()
|
||||
|
||||
if not name:
|
||||
# List recent titled sessions for this user/platform
|
||||
try:
|
||||
user_source = source.platform.value if source.platform else None
|
||||
sessions = self._session_db.list_sessions_rich(
|
||||
source=user_source, limit=10
|
||||
)
|
||||
titled = [s for s in sessions if s.get("title")]
|
||||
if not titled:
|
||||
return (
|
||||
"No named sessions found.\n"
|
||||
"Use `/title My Session` to name your current session, "
|
||||
"then `/resume My Session` to return to it later."
|
||||
)
|
||||
lines = ["📋 **Named Sessions**\n"]
|
||||
for s in titled[:10]:
|
||||
title = s["title"]
|
||||
preview = s.get("preview", "")[:40]
|
||||
preview_part = f" — _{preview}_" if preview else ""
|
||||
lines.append(f"• **{title}**{preview_part}")
|
||||
lines.append("\nUsage: `/resume <session name>`")
|
||||
return "\n".join(lines)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to list titled sessions: %s", e)
|
||||
return f"Could not list sessions: {e}"
|
||||
|
||||
# Resolve the name to a session ID
|
||||
target_id = self._session_db.resolve_session_by_title(name)
|
||||
if not target_id:
|
||||
return (
|
||||
f"No session found matching '**{name}**'.\n"
|
||||
"Use `/resume` with no arguments to see available sessions."
|
||||
)
|
||||
|
||||
# Check if already on that session
|
||||
current_entry = self.session_store.get_or_create_session(source)
|
||||
if current_entry.session_id == target_id:
|
||||
return f"📌 Already on session **{name}**."
|
||||
|
||||
# Flush memories for current session before switching
|
||||
try:
|
||||
asyncio.create_task(self._async_flush_memories(current_entry.session_id))
|
||||
except Exception as e:
|
||||
logger.debug("Memory flush on resume failed: %s", e)
|
||||
|
||||
# Clear any running agent for this session key
|
||||
if session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
|
||||
# Switch the session entry to point at the old session
|
||||
new_entry = self.session_store.switch_session(session_key, target_id)
|
||||
if not new_entry:
|
||||
return "Failed to switch session."
|
||||
|
||||
# Get the title for confirmation
|
||||
title = self._session_db.get_session_title(target_id) or name
|
||||
|
||||
# Count messages for context
|
||||
history = self.session_store.load_transcript(target_id)
|
||||
msg_count = len([m for m in history if m.get("role") == "user"]) if history else 0
|
||||
msg_part = f" ({msg_count} message{'s' if msg_count != 1 else ''})" if msg_count else ""
|
||||
|
||||
return f"↻ Resumed session **{title}**{msg_part}. Conversation restored."
|
||||
|
||||
async def _handle_usage_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /usage command -- show token usage for the session's last agent run."""
|
||||
source = event.source
|
||||
@@ -2092,7 +2618,7 @@ class GatewayRunner:
|
||||
os.environ["HERMES_SESSION_KEY"] = session_key or ""
|
||||
|
||||
# Read from env var or use default (same as CLI)
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
|
||||
# Map platform enum to the platform hint key the agent understands.
|
||||
# Platform.LOCAL ("local") maps to "cli"; others pass through as-is.
|
||||
@@ -2161,6 +2687,7 @@ class GatewayRunner:
|
||||
platform=platform_key,
|
||||
honcho_session_key=session_key,
|
||||
session_db=self._session_db,
|
||||
fallback_model=self._fallback_model,
|
||||
)
|
||||
|
||||
# Store agent reference for interrupt support
|
||||
@@ -2432,34 +2959,84 @@ def _start_cron_ticker(stop_event: threading.Event, adapters=None, interval: int
|
||||
logger.info("Cron ticker stopped")
|
||||
|
||||
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None) -> bool:
|
||||
async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = False) -> bool:
|
||||
"""
|
||||
Start the gateway and run until interrupted.
|
||||
|
||||
This is the main entry point for running the gateway.
|
||||
Returns True if the gateway ran successfully, False if it failed to start.
|
||||
A False return causes a non-zero exit code so systemd can auto-restart.
|
||||
|
||||
Args:
|
||||
config: Optional gateway configuration override.
|
||||
replace: If True, kill any existing gateway instance before starting.
|
||||
Useful for systemd services to avoid restart-loop deadlocks
|
||||
when the previous process hasn't fully exited yet.
|
||||
"""
|
||||
# ── Duplicate-instance guard ──────────────────────────────────────
|
||||
# Prevent two gateways from running under the same HERMES_HOME.
|
||||
# The PID file is scoped to HERMES_HOME, so future multi-profile
|
||||
# setups (each profile using a distinct HERMES_HOME) will naturally
|
||||
# allow concurrent instances without tripping this guard.
|
||||
from gateway.status import get_running_pid
|
||||
import time as _time
|
||||
from gateway.status import get_running_pid, remove_pid_file
|
||||
existing_pid = get_running_pid()
|
||||
if existing_pid is not None and existing_pid != os.getpid():
|
||||
hermes_home = os.getenv("HERMES_HOME", "~/.hermes")
|
||||
logger.error(
|
||||
"Another gateway instance is already running (PID %d, HERMES_HOME=%s). "
|
||||
"Use 'hermes gateway restart' to replace it, or 'hermes gateway stop' first.",
|
||||
existing_pid, hermes_home,
|
||||
)
|
||||
print(
|
||||
f"\n❌ Gateway already running (PID {existing_pid}).\n"
|
||||
f" Use 'hermes gateway restart' to replace it,\n"
|
||||
f" or 'hermes gateway stop' to kill it first.\n"
|
||||
)
|
||||
return False
|
||||
if replace:
|
||||
logger.info(
|
||||
"Replacing existing gateway instance (PID %d) with --replace.",
|
||||
existing_pid,
|
||||
)
|
||||
try:
|
||||
os.kill(existing_pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
pass # Already gone
|
||||
except PermissionError:
|
||||
logger.error(
|
||||
"Permission denied killing PID %d. Cannot replace.",
|
||||
existing_pid,
|
||||
)
|
||||
return False
|
||||
# Wait up to 10 seconds for the old process to exit
|
||||
for _ in range(20):
|
||||
try:
|
||||
os.kill(existing_pid, 0)
|
||||
_time.sleep(0.5)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
break # Process is gone
|
||||
else:
|
||||
# Still alive after 10s — force kill
|
||||
logger.warning(
|
||||
"Old gateway (PID %d) did not exit after SIGTERM, sending SIGKILL.",
|
||||
existing_pid,
|
||||
)
|
||||
try:
|
||||
os.kill(existing_pid, signal.SIGKILL)
|
||||
_time.sleep(0.5)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pass
|
||||
remove_pid_file()
|
||||
else:
|
||||
hermes_home = os.getenv("HERMES_HOME", "~/.hermes")
|
||||
logger.error(
|
||||
"Another gateway instance is already running (PID %d, HERMES_HOME=%s). "
|
||||
"Use 'hermes gateway restart' to replace it, or 'hermes gateway stop' first.",
|
||||
existing_pid, hermes_home,
|
||||
)
|
||||
print(
|
||||
f"\n❌ Gateway already running (PID {existing_pid}).\n"
|
||||
f" Use 'hermes gateway restart' to replace it,\n"
|
||||
f" or 'hermes gateway stop' to kill it first.\n"
|
||||
f" Or use 'hermes gateway run --replace' to auto-replace.\n"
|
||||
)
|
||||
return False
|
||||
|
||||
# Sync bundled skills on gateway start (fast -- skips unchanged)
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
sync_skills(quiet=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Configure rotating file log so gateway output is persisted for debugging
|
||||
log_dir = _hermes_home / 'logs'
|
||||
|
||||
@@ -45,6 +45,8 @@ class SessionSource:
|
||||
user_name: Optional[str] = None
|
||||
thread_id: Optional[str] = None # For forum topics, Discord threads, etc.
|
||||
chat_topic: Optional[str] = None # Channel topic/description (Discord, Slack)
|
||||
user_id_alt: Optional[str] = None # Signal UUID (alternative to phone number)
|
||||
chat_id_alt: Optional[str] = None # Signal group internal ID
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -68,7 +70,7 @@ class SessionSource:
|
||||
return ", ".join(parts)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
d = {
|
||||
"platform": self.platform.value,
|
||||
"chat_id": self.chat_id,
|
||||
"chat_name": self.chat_name,
|
||||
@@ -78,6 +80,11 @@ class SessionSource:
|
||||
"thread_id": self.thread_id,
|
||||
"chat_topic": self.chat_topic,
|
||||
}
|
||||
if self.user_id_alt:
|
||||
d["user_id_alt"] = self.user_id_alt
|
||||
if self.chat_id_alt:
|
||||
d["chat_id_alt"] = self.chat_id_alt
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionSource":
|
||||
@@ -90,6 +97,8 @@ class SessionSource:
|
||||
user_name=data.get("user_name"),
|
||||
thread_id=data.get("thread_id"),
|
||||
chat_topic=data.get("chat_topic"),
|
||||
user_id_alt=data.get("user_id_alt"),
|
||||
chat_id_alt=data.get("chat_id_alt"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -311,7 +320,9 @@ class SessionStore:
|
||||
self._entries: Dict[str, SessionEntry] = {}
|
||||
self._loaded = False
|
||||
self._has_active_processes_fn = has_active_processes_fn
|
||||
self._on_auto_reset = on_auto_reset # callback(old_entry) before auto-reset
|
||||
# on_auto_reset is deprecated — memory flush now runs proactively
|
||||
# via the background session expiry watcher in GatewayRunner.
|
||||
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
|
||||
|
||||
# Initialize SQLite session database
|
||||
self._db = None
|
||||
@@ -331,7 +342,7 @@ class SessionStore:
|
||||
|
||||
if sessions_file.exists():
|
||||
try:
|
||||
with open(sessions_file, "r") as f:
|
||||
with open(sessions_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
for key, entry_data in data.items():
|
||||
self._entries[key] = SessionEntry.from_dict(entry_data)
|
||||
@@ -346,13 +357,51 @@ class SessionStore:
|
||||
sessions_file = self.sessions_dir / "sessions.json"
|
||||
|
||||
data = {key: entry.to_dict() for key, entry in self._entries.items()}
|
||||
with open(sessions_file, "w") as f:
|
||||
with open(sessions_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
def _generate_session_key(self, source: SessionSource) -> str:
|
||||
"""Generate a session key from a source."""
|
||||
return build_session_key(source)
|
||||
|
||||
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
||||
"""Check if a session has expired based on its reset policy.
|
||||
|
||||
Works from the entry alone — no SessionSource needed.
|
||||
Used by the background expiry watcher to proactively flush memories.
|
||||
Sessions with active background processes are never considered expired.
|
||||
"""
|
||||
if self._has_active_processes_fn:
|
||||
if self._has_active_processes_fn(entry.session_key):
|
||||
return False
|
||||
|
||||
policy = self.config.get_reset_policy(
|
||||
platform=entry.platform,
|
||||
session_type=entry.chat_type,
|
||||
)
|
||||
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(
|
||||
hour=policy.at_hour,
|
||||
minute=0, second=0, microsecond=0,
|
||||
)
|
||||
if now.hour < policy.at_hour:
|
||||
today_reset -= timedelta(days=1)
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||
"""
|
||||
Check if a session should be reset based on policy.
|
||||
@@ -439,13 +488,11 @@ class SessionStore:
|
||||
self._save()
|
||||
return entry
|
||||
else:
|
||||
# Session is being auto-reset — flush memories before destroying
|
||||
# Session is being auto-reset. The background expiry watcher
|
||||
# should have already flushed memories proactively; discard
|
||||
# the marker so it doesn't accumulate.
|
||||
was_auto_reset = True
|
||||
if self._on_auto_reset:
|
||||
try:
|
||||
self._on_auto_reset(entry)
|
||||
except Exception as e:
|
||||
logger.debug("Auto-reset callback failed: %s", e)
|
||||
self._pre_flushed_sessions.discard(entry.session_id)
|
||||
if self._db:
|
||||
try:
|
||||
self._db.end_session(entry.session_id, "session_reset")
|
||||
@@ -555,7 +602,49 @@ class SessionStore:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
|
||||
return new_entry
|
||||
|
||||
|
||||
def switch_session(self, session_key: str, target_session_id: str) -> Optional[SessionEntry]:
|
||||
"""Switch a session key to point at an existing session ID.
|
||||
|
||||
Used by ``/resume`` to restore a previously-named session.
|
||||
Ends the current session in SQLite (like reset), but instead of
|
||||
generating a fresh session ID, re-uses ``target_session_id`` so the
|
||||
old transcript is loaded on the next message.
|
||||
"""
|
||||
self._ensure_loaded()
|
||||
|
||||
if session_key not in self._entries:
|
||||
return None
|
||||
|
||||
old_entry = self._entries[session_key]
|
||||
|
||||
# Don't switch if already on that session
|
||||
if old_entry.session_id == target_session_id:
|
||||
return old_entry
|
||||
|
||||
# End the current session in SQLite
|
||||
if self._db:
|
||||
try:
|
||||
self._db.end_session(old_entry.session_id, "session_switch")
|
||||
except Exception as e:
|
||||
logger.debug("Session DB end_session failed: %s", e)
|
||||
|
||||
now = datetime.now()
|
||||
new_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id=target_session_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
origin=old_entry.origin,
|
||||
display_name=old_entry.display_name,
|
||||
platform=old_entry.platform,
|
||||
chat_type=old_entry.chat_type,
|
||||
)
|
||||
|
||||
self._entries[session_key] = new_entry
|
||||
self._save()
|
||||
return new_entry
|
||||
|
||||
def list_sessions(self, active_minutes: Optional[int] = None) -> List[SessionEntry]:
|
||||
"""List all sessions, optionally filtered by activity."""
|
||||
self._ensure_loaded()
|
||||
@@ -592,7 +681,7 @@ class SessionStore:
|
||||
|
||||
# Also write legacy JSONL (keeps existing tooling working during transition)
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "a") as f:
|
||||
with open(transcript_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(message, ensure_ascii=False) + "\n")
|
||||
|
||||
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
|
||||
@@ -619,7 +708,7 @@ class SessionStore:
|
||||
|
||||
# JSONL: overwrite the file
|
||||
transcript_path = self.get_transcript_path(session_id)
|
||||
with open(transcript_path, "w") as f:
|
||||
with open(transcript_path, "w", encoding="utf-8") as f:
|
||||
for msg in messages:
|
||||
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
||||
|
||||
@@ -641,7 +730,7 @@ class SessionStore:
|
||||
return []
|
||||
|
||||
messages = []
|
||||
with open(transcript_path, "r") as f:
|
||||
with open(transcript_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
|
||||
@@ -72,15 +72,19 @@ CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
|
||||
@dataclass
|
||||
class ProviderConfig:
|
||||
"""Describes a known OAuth provider."""
|
||||
"""Describes a known inference provider."""
|
||||
id: str
|
||||
name: str
|
||||
auth_type: str # "oauth_device_code" or "api_key"
|
||||
auth_type: str # "oauth_device_code", "oauth_external", or "api_key"
|
||||
portal_base_url: str = ""
|
||||
inference_base_url: str = ""
|
||||
client_id: str = ""
|
||||
scope: str = ""
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
# For API-key providers: env vars to check (in priority order)
|
||||
api_key_env_vars: tuple = ()
|
||||
# Optional env var for base URL override
|
||||
base_url_env_var: str = ""
|
||||
|
||||
|
||||
PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
@@ -99,9 +103,118 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
auth_type="oauth_external",
|
||||
inference_base_url=DEFAULT_CODEX_BASE_URL,
|
||||
),
|
||||
"zai": ProviderConfig(
|
||||
id="zai",
|
||||
name="Z.AI / GLM",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.z.ai/api/paas/v4",
|
||||
api_key_env_vars=("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
base_url_env_var="GLM_BASE_URL",
|
||||
),
|
||||
"kimi-coding": ProviderConfig(
|
||||
id="kimi-coding",
|
||||
name="Kimi / Moonshot",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.moonshot.ai/v1",
|
||||
api_key_env_vars=("KIMI_API_KEY",),
|
||||
base_url_env_var="KIMI_BASE_URL",
|
||||
),
|
||||
"minimax": ProviderConfig(
|
||||
id="minimax",
|
||||
name="MiniMax",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.minimax.io/v1",
|
||||
api_key_env_vars=("MINIMAX_API_KEY",),
|
||||
base_url_env_var="MINIMAX_BASE_URL",
|
||||
),
|
||||
"minimax-cn": ProviderConfig(
|
||||
id="minimax-cn",
|
||||
name="MiniMax (China)",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.minimaxi.com/v1",
|
||||
api_key_env_vars=("MINIMAX_CN_API_KEY",),
|
||||
base_url_env_var="MINIMAX_CN_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Kimi Code Endpoint Detection
|
||||
# =============================================================================
|
||||
|
||||
# Kimi Code (platform.kimi.ai) issues keys prefixed "sk-kimi-" that only work
|
||||
# on api.kimi.com/coding/v1. Legacy keys from platform.moonshot.ai work on
|
||||
# api.moonshot.ai/v1 (the default). Auto-detect when user hasn't set
|
||||
# KIMI_BASE_URL explicitly.
|
||||
KIMI_CODE_BASE_URL = "https://api.kimi.com/coding/v1"
|
||||
|
||||
|
||||
def _resolve_kimi_base_url(api_key: str, default_url: str, env_override: str) -> str:
|
||||
"""Return the correct Kimi base URL based on the API key prefix.
|
||||
|
||||
If the user has explicitly set KIMI_BASE_URL, that always wins.
|
||||
Otherwise, sk-kimi- prefixed keys route to api.kimi.com/coding/v1.
|
||||
"""
|
||||
if env_override:
|
||||
return env_override
|
||||
if api_key.startswith("sk-kimi-"):
|
||||
return KIMI_CODE_BASE_URL
|
||||
return default_url
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Z.AI Endpoint Detection
|
||||
# =============================================================================
|
||||
|
||||
# Z.AI has separate billing for general vs coding plans, and global vs China
|
||||
# endpoints. A key that works on one may return "Insufficient balance" on
|
||||
# another. We probe at setup time and store the working endpoint.
|
||||
|
||||
ZAI_ENDPOINTS = [
|
||||
# (id, base_url, default_model, label)
|
||||
("global", "https://api.z.ai/api/paas/v4", "glm-5", "Global"),
|
||||
("cn", "https://open.bigmodel.cn/api/paas/v4", "glm-5", "China"),
|
||||
("coding-global", "https://api.z.ai/api/coding/paas/v4", "glm-4.7", "Global (Coding Plan)"),
|
||||
("coding-cn", "https://open.bigmodel.cn/api/coding/paas/v4", "glm-4.7", "China (Coding Plan)"),
|
||||
]
|
||||
|
||||
|
||||
def detect_zai_endpoint(api_key: str, timeout: float = 8.0) -> Optional[Dict[str, str]]:
|
||||
"""Probe z.ai endpoints to find one that accepts this API key.
|
||||
|
||||
Returns {"id": ..., "base_url": ..., "model": ..., "label": ...} for the
|
||||
first working endpoint, or None if all fail.
|
||||
"""
|
||||
for ep_id, base_url, model, label in ZAI_ENDPOINTS:
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": model,
|
||||
"stream": False,
|
||||
"max_tokens": 1,
|
||||
"messages": [{"role": "user", "content": "ping"}],
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
logger.debug("Z.AI endpoint probe: %s (%s) OK", ep_id, base_url)
|
||||
return {
|
||||
"id": ep_id,
|
||||
"base_url": base_url,
|
||||
"model": model,
|
||||
"label": label,
|
||||
}
|
||||
logger.debug("Z.AI endpoint probe: %s returned %s", ep_id, resp.status_code)
|
||||
except Exception as exc:
|
||||
logger.debug("Z.AI endpoint probe: %s failed: %s", ep_id, exc)
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Error Types
|
||||
# =============================================================================
|
||||
@@ -355,10 +468,19 @@ def resolve_provider(
|
||||
1. active_provider in auth.json with valid credentials
|
||||
2. Explicit CLI api_key/base_url -> "openrouter"
|
||||
3. OPENAI_API_KEY or OPENROUTER_API_KEY env vars -> "openrouter"
|
||||
4. Fallback: "openrouter"
|
||||
4. Provider-specific API keys (GLM, Kimi, MiniMax) -> that provider
|
||||
5. Fallback: "openrouter"
|
||||
"""
|
||||
normalized = (requested or "auto").strip().lower()
|
||||
|
||||
# Normalize provider aliases
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
|
||||
"kimi": "kimi-coding", "moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
|
||||
}
|
||||
normalized = _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
if normalized in {"openrouter", "custom"}:
|
||||
return "openrouter"
|
||||
if normalized in PROVIDER_REGISTRY:
|
||||
@@ -387,6 +509,14 @@ def resolve_provider(
|
||||
if os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"):
|
||||
return "openrouter"
|
||||
|
||||
# Auto-detect API-key providers by checking their env vars
|
||||
for pid, pconfig in PROVIDER_REGISTRY.items():
|
||||
if pconfig.auth_type != "api_key":
|
||||
continue
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if os.getenv(env_var, "").strip():
|
||||
return pid
|
||||
|
||||
return "openrouter"
|
||||
|
||||
|
||||
@@ -1230,6 +1360,42 @@ def get_codex_auth_status() -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
|
||||
"""Status snapshot for API-key providers (z.ai, Kimi, MiniMax)."""
|
||||
pconfig = PROVIDER_REGISTRY.get(provider_id)
|
||||
if not pconfig or pconfig.auth_type != "api_key":
|
||||
return {"configured": False}
|
||||
|
||||
api_key = ""
|
||||
key_source = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
key_source = env_var
|
||||
break
|
||||
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
|
||||
if provider_id == "kimi-coding":
|
||||
base_url = _resolve_kimi_base_url(api_key, pconfig.inference_base_url, env_url)
|
||||
elif env_url:
|
||||
base_url = env_url
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
|
||||
return {
|
||||
"configured": bool(api_key),
|
||||
"provider": provider_id,
|
||||
"name": pconfig.name,
|
||||
"key_source": key_source,
|
||||
"base_url": base_url,
|
||||
"logged_in": bool(api_key), # compat with OAuth status shape
|
||||
}
|
||||
|
||||
|
||||
def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Generic auth status dispatcher."""
|
||||
target = provider_id or get_active_provider()
|
||||
@@ -1237,9 +1403,54 @@ def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
return get_nous_auth_status()
|
||||
if target == "openai-codex":
|
||||
return get_codex_auth_status()
|
||||
# API-key providers
|
||||
pconfig = PROVIDER_REGISTRY.get(target)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
return get_api_key_provider_status(target)
|
||||
return {"logged_in": False}
|
||||
|
||||
|
||||
def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
|
||||
"""Resolve API key and base URL for an API-key provider.
|
||||
|
||||
Returns dict with: provider, api_key, base_url, source.
|
||||
"""
|
||||
pconfig = PROVIDER_REGISTRY.get(provider_id)
|
||||
if not pconfig or pconfig.auth_type != "api_key":
|
||||
raise AuthError(
|
||||
f"Provider '{provider_id}' is not an API-key provider.",
|
||||
provider=provider_id,
|
||||
code="invalid_provider",
|
||||
)
|
||||
|
||||
api_key = ""
|
||||
key_source = ""
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
api_key = val
|
||||
key_source = env_var
|
||||
break
|
||||
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
|
||||
|
||||
if provider_id == "kimi-coding":
|
||||
base_url = _resolve_kimi_base_url(api_key, pconfig.inference_base_url, env_url)
|
||||
elif env_url:
|
||||
base_url = env_url.rstrip("/")
|
||||
else:
|
||||
base_url = pconfig.inference_base_url
|
||||
|
||||
return {
|
||||
"provider": provider_id,
|
||||
"api_key": api_key,
|
||||
"base_url": base_url.rstrip("/"),
|
||||
"source": key_source or "default",
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# External credential detection
|
||||
# =============================================================================
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
"""Welcome banner, ASCII art, and skills summary for the CLI.
|
||||
"""Welcome banner, ASCII art, skills summary, and update check for the CLI.
|
||||
|
||||
Pure display functions with no HermesCLI state dependency.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Any
|
||||
from typing import Dict, List, Any, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
@@ -13,6 +18,8 @@ from rich.table import Table
|
||||
from prompt_toolkit import print_formatted_text as _pt_print
|
||||
from prompt_toolkit.formatted_text import ANSI as _PT_ANSI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# ANSI building blocks for conversation display
|
||||
@@ -95,6 +102,72 @@ def get_available_skills() -> Dict[str, List[str]]:
|
||||
return skills_by_category
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Update check
|
||||
# =========================================================================
|
||||
|
||||
# Cache update check results for 6 hours to avoid repeated git fetches
|
||||
_UPDATE_CHECK_CACHE_SECONDS = 6 * 3600
|
||||
|
||||
|
||||
def check_for_updates() -> Optional[int]:
|
||||
"""Check how many commits behind origin/main the local repo is.
|
||||
|
||||
Does a ``git fetch`` at most once every 6 hours (cached to
|
||||
``~/.hermes/.update_check``). Returns the number of commits behind,
|
||||
or ``None`` if the check fails or isn't applicable.
|
||||
"""
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
repo_dir = hermes_home / "hermes-agent"
|
||||
cache_file = hermes_home / ".update_check"
|
||||
|
||||
# Must be a git repo
|
||||
if not (repo_dir / ".git").exists():
|
||||
return None
|
||||
|
||||
# Read cache
|
||||
now = time.time()
|
||||
try:
|
||||
if cache_file.exists():
|
||||
cached = json.loads(cache_file.read_text())
|
||||
if now - cached.get("ts", 0) < _UPDATE_CHECK_CACHE_SECONDS:
|
||||
return cached.get("behind")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fetch latest refs (fast — only downloads ref metadata, no files)
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "fetch", "origin", "--quiet"],
|
||||
capture_output=True, timeout=10,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
except Exception:
|
||||
pass # Offline or timeout — use stale refs, that's fine
|
||||
|
||||
# Count commits behind
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "rev-list", "--count", "HEAD..origin/main"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
cwd=str(repo_dir),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
behind = int(result.stdout.strip())
|
||||
else:
|
||||
behind = None
|
||||
except Exception:
|
||||
behind = None
|
||||
|
||||
# Write cache
|
||||
try:
|
||||
cache_file.write_text(json.dumps({"ts": now, "behind": behind}))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return behind
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Welcome banner
|
||||
# =========================================================================
|
||||
@@ -259,6 +332,18 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
|
||||
summary_parts.append("/help for commands")
|
||||
right_lines.append(f"[dim #B8860B]{' · '.join(summary_parts)}[/]")
|
||||
|
||||
# Update check — show if behind origin/main
|
||||
try:
|
||||
behind = check_for_updates()
|
||||
if behind and behind > 0:
|
||||
commits_word = "commit" if behind == 1 else "commits"
|
||||
right_lines.append(
|
||||
f"[bold yellow]⚠ {behind} {commits_word} behind[/]"
|
||||
f"[dim yellow] — run [bold]hermes update[/bold] to update[/]"
|
||||
)
|
||||
except Exception:
|
||||
pass # Never break the banner over an update check
|
||||
|
||||
right_content = "\n".join(right_lines)
|
||||
layout_table.add_row(left_content, right_content)
|
||||
|
||||
|
||||
@@ -285,8 +285,8 @@ def _convert_to_png(path: Path) -> bool:
|
||||
logger.debug("Pillow BMP→PNG conversion failed: %s", e)
|
||||
|
||||
# Fall back to ImageMagick convert
|
||||
tmp = path.with_suffix(".bmp")
|
||||
try:
|
||||
tmp = path.with_suffix(".bmp")
|
||||
path.rename(tmp)
|
||||
r = subprocess.run(
|
||||
["convert", str(tmp), "png:" + str(path)],
|
||||
@@ -297,8 +297,12 @@ def _convert_to_png(path: Path) -> bool:
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
logger.debug("ImageMagick not installed — cannot convert BMP to PNG")
|
||||
if tmp.exists() and not path.exists():
|
||||
tmp.rename(path)
|
||||
except Exception as e:
|
||||
logger.debug("ImageMagick BMP→PNG conversion failed: %s", e)
|
||||
if tmp.exists() and not path.exists():
|
||||
tmp.rename(path)
|
||||
|
||||
# Can't convert — BMP is still usable as-is for most APIs
|
||||
return path.exists() and path.stat().st_size > 0
|
||||
|
||||
@@ -94,8 +94,6 @@ def _read_cache_models(codex_home: Path) -> List[str]:
|
||||
if not isinstance(slug, str) or not slug.strip():
|
||||
continue
|
||||
slug = slug.strip()
|
||||
if "codex" not in slug.lower():
|
||||
continue
|
||||
if item.get("supported_in_api") is False:
|
||||
continue
|
||||
visibility = item.get("visibility")
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
"""Slash command definitions and autocomplete for the Hermes CLI.
|
||||
|
||||
Contains the COMMANDS dict and the SlashCommandCompleter class.
|
||||
These are pure data/UI with no HermesCLI state dependency.
|
||||
Contains the shared built-in ``COMMANDS`` dict and ``SlashCommandCompleter``.
|
||||
The completer can optionally include dynamic skill slash commands supplied by the
|
||||
interactive CLI.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
||||
|
||||
@@ -12,6 +18,7 @@ COMMANDS = {
|
||||
"/tools": "List available tools",
|
||||
"/toolsets": "List available toolsets",
|
||||
"/model": "Show or change the current model",
|
||||
"/provider": "Show available providers and current provider",
|
||||
"/prompt": "View/set custom system prompt",
|
||||
"/personality": "Set a predefined personality",
|
||||
"/clear": "Clear screen and reset conversation (fresh start)",
|
||||
@@ -27,26 +34,68 @@ COMMANDS = {
|
||||
"/platforms": "Show gateway/messaging platform status",
|
||||
"/verbose": "Cycle tool progress display: off → new → all → verbose",
|
||||
"/compress": "Manually compress conversation context (flush memories + summarize)",
|
||||
"/title": "Set a title for the current session (usage: /title My Session Name)",
|
||||
"/usage": "Show token usage for the current session",
|
||||
"/insights": "Show usage insights and analytics (last 30 days)",
|
||||
"/paste": "Check clipboard for an image and attach it",
|
||||
"/reload-mcp": "Reload MCP servers from config.yaml",
|
||||
"/quit": "Exit the CLI (also: /exit, /q)",
|
||||
}
|
||||
|
||||
|
||||
class SlashCommandCompleter(Completer):
|
||||
"""Autocomplete for /commands in the input area."""
|
||||
"""Autocomplete for built-in slash commands and optional skill commands."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
skill_commands_provider: Callable[[], Mapping[str, dict[str, Any]]] | None = None,
|
||||
) -> None:
|
||||
self._skill_commands_provider = skill_commands_provider
|
||||
|
||||
def _iter_skill_commands(self) -> Mapping[str, dict[str, Any]]:
|
||||
if self._skill_commands_provider is None:
|
||||
return {}
|
||||
try:
|
||||
return self._skill_commands_provider() or {}
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _completion_text(cmd_name: str, word: str) -> str:
|
||||
"""Return replacement text for a completion.
|
||||
|
||||
When the user has already typed the full command exactly (``/help``),
|
||||
returning ``help`` would be a no-op and prompt_toolkit suppresses the
|
||||
menu. Appending a trailing space keeps the dropdown visible and makes
|
||||
backspacing retrigger it naturally.
|
||||
"""
|
||||
return f"{cmd_name} " if cmd_name == word else cmd_name
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
text = document.text_before_cursor
|
||||
if not text.startswith("/"):
|
||||
return
|
||||
|
||||
word = text[1:]
|
||||
|
||||
for cmd, desc in COMMANDS.items():
|
||||
cmd_name = cmd[1:]
|
||||
if cmd_name.startswith(word):
|
||||
yield Completion(
|
||||
cmd_name,
|
||||
self._completion_text(cmd_name, word),
|
||||
start_position=-len(word),
|
||||
display=cmd,
|
||||
display_meta=desc,
|
||||
)
|
||||
|
||||
for cmd, info in self._iter_skill_commands().items():
|
||||
cmd_name = cmd[1:]
|
||||
if cmd_name.startswith(word):
|
||||
description = str(info.get("description", "Skill command"))
|
||||
short_desc = description[:50] + ("..." if len(description) > 50 else "")
|
||||
yield Completion(
|
||||
self._completion_text(cmd_name, word),
|
||||
start_position=-len(word),
|
||||
display=cmd,
|
||||
display_meta=f"⚡ {short_desc}",
|
||||
)
|
||||
|
||||
@@ -81,17 +81,34 @@ DEFAULT_CONFIG = {
|
||||
|
||||
"browser": {
|
||||
"inactivity_timeout": 120,
|
||||
"record_sessions": False, # Auto-record browser sessions as WebM videos
|
||||
},
|
||||
|
||||
"compression": {
|
||||
"enabled": True,
|
||||
"threshold": 0.85,
|
||||
"summary_model": "google/gemini-3-flash-preview",
|
||||
"summary_provider": "auto",
|
||||
},
|
||||
|
||||
# Auxiliary model overrides (advanced). By default Hermes auto-selects
|
||||
# the provider and model for each side task. Set these to override.
|
||||
"auxiliary": {
|
||||
"vision": {
|
||||
"provider": "auto", # auto | openrouter | nous | main
|
||||
"model": "", # e.g. "google/gemini-2.5-flash", "gpt-4o"
|
||||
},
|
||||
"web_extract": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
},
|
||||
},
|
||||
|
||||
"display": {
|
||||
"compact": False,
|
||||
"personality": "kawaii",
|
||||
"resume_display": "full", # "full" (show previous messages) | "minimal" (one-liner only)
|
||||
"bell_on_complete": False, # Play terminal bell (\a) when agent finishes a response
|
||||
},
|
||||
|
||||
# Text-to-speech configuration
|
||||
@@ -141,9 +158,13 @@ DEFAULT_CONFIG = {
|
||||
# (apiKey, workspace, peerName, sessions, enabled) comes from the global config.
|
||||
"honcho": {},
|
||||
|
||||
# IANA timezone (e.g. "Asia/Kolkata", "America/New_York").
|
||||
# Empty string means use server-local time.
|
||||
"timezone": "",
|
||||
|
||||
# Permanently allowed dangerous command patterns (added via "always" approval)
|
||||
"command_allowlist": [],
|
||||
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 5,
|
||||
}
|
||||
@@ -152,6 +173,15 @@ DEFAULT_CONFIG = {
|
||||
# Config Migration System
|
||||
# =============================================================================
|
||||
|
||||
# Track which env vars were introduced in each config version.
|
||||
# Migration only mentions vars new since the user's previous version.
|
||||
ENV_VARS_BY_VERSION: Dict[int, List[str]] = {
|
||||
3: ["FIRECRAWL_API_KEY", "BROWSERBASE_API_KEY", "BROWSERBASE_PROJECT_ID", "FAL_KEY"],
|
||||
4: ["VOICE_TOOLS_OPENAI_KEY", "ELEVENLABS_API_KEY"],
|
||||
5: ["WHATSAPP_ENABLED", "WHATSAPP_MODE", "WHATSAPP_ALLOWED_USERS",
|
||||
"SLACK_BOT_TOKEN", "SLACK_APP_TOKEN", "SLACK_ALLOWED_USERS"],
|
||||
}
|
||||
|
||||
# Required environment variables with metadata for migration prompts.
|
||||
# LLM provider is required but handled in the setup wizard's provider
|
||||
# selection step (Nous Portal / OpenRouter / Custom endpoint), so this
|
||||
@@ -170,6 +200,86 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"GLM_API_KEY": {
|
||||
"description": "Z.AI / GLM API key (also recognized as ZAI_API_KEY / Z_AI_API_KEY)",
|
||||
"prompt": "Z.AI / GLM API key",
|
||||
"url": "https://z.ai/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"ZAI_API_KEY": {
|
||||
"description": "Z.AI API key (alias for GLM_API_KEY)",
|
||||
"prompt": "Z.AI API key",
|
||||
"url": "https://z.ai/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"Z_AI_API_KEY": {
|
||||
"description": "Z.AI API key (alias for GLM_API_KEY)",
|
||||
"prompt": "Z.AI API key",
|
||||
"url": "https://z.ai/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"GLM_BASE_URL": {
|
||||
"description": "Z.AI / GLM base URL override",
|
||||
"prompt": "Z.AI / GLM base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"KIMI_API_KEY": {
|
||||
"description": "Kimi / Moonshot API key",
|
||||
"prompt": "Kimi API key",
|
||||
"url": "https://platform.moonshot.cn/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"KIMI_BASE_URL": {
|
||||
"description": "Kimi / Moonshot base URL override",
|
||||
"prompt": "Kimi base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_API_KEY": {
|
||||
"description": "MiniMax API key (international)",
|
||||
"prompt": "MiniMax API key",
|
||||
"url": "https://www.minimax.io/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_BASE_URL": {
|
||||
"description": "MiniMax base URL override",
|
||||
"prompt": "MiniMax base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_CN_API_KEY": {
|
||||
"description": "MiniMax API key (China endpoint)",
|
||||
"prompt": "MiniMax (China) API key",
|
||||
"url": "https://www.minimaxi.com/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"MINIMAX_CN_BASE_URL": {
|
||||
"description": "MiniMax (China) base URL override",
|
||||
"prompt": "MiniMax (China) base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
|
||||
# ── Tool API keys ──
|
||||
"FIRECRAWL_API_KEY": {
|
||||
@@ -189,7 +299,7 @@ OPTIONAL_ENV_VARS = {
|
||||
"advanced": True,
|
||||
},
|
||||
"BROWSERBASE_API_KEY": {
|
||||
"description": "Browserbase API key for browser automation",
|
||||
"description": "Browserbase API key for cloud browser (optional — local browser works without this)",
|
||||
"prompt": "Browserbase API key",
|
||||
"url": "https://browserbase.com/",
|
||||
"tools": ["browser_navigate", "browser_click"],
|
||||
@@ -197,7 +307,7 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "tool",
|
||||
},
|
||||
"BROWSERBASE_PROJECT_ID": {
|
||||
"description": "Browserbase project ID",
|
||||
"description": "Browserbase project ID (optional — only needed for cloud browser)",
|
||||
"prompt": "Browserbase project ID",
|
||||
"url": "https://browserbase.com/",
|
||||
"tools": ["browser_navigate", "browser_click"],
|
||||
@@ -329,7 +439,7 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "setting",
|
||||
},
|
||||
"HERMES_MAX_ITERATIONS": {
|
||||
"description": "Maximum tool-calling iterations per conversation (default: 60)",
|
||||
"description": "Maximum tool-calling iterations per conversation (default: 90)",
|
||||
"prompt": "Max iterations",
|
||||
"url": None,
|
||||
"password": False,
|
||||
@@ -485,6 +595,22 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
if not quiet:
|
||||
print(f" ✓ Migrated tool progress to config.yaml: {display['tool_progress']}")
|
||||
|
||||
# ── Version 4 → 5: add timezone field ──
|
||||
if current_ver < 5:
|
||||
config = load_config()
|
||||
if "timezone" not in config:
|
||||
old_tz = os.getenv("HERMES_TIMEZONE", "")
|
||||
if old_tz and old_tz.strip():
|
||||
config["timezone"] = old_tz.strip()
|
||||
results["config_added"].append(f"timezone={old_tz.strip()} (from HERMES_TIMEZONE)")
|
||||
else:
|
||||
config["timezone"] = ""
|
||||
results["config_added"].append("timezone= (empty, uses server-local)")
|
||||
save_config(config)
|
||||
if not quiet:
|
||||
tz_display = config["timezone"] or "(server-local)"
|
||||
print(f" ✓ Added timezone to config.yaml: {tz_display}")
|
||||
|
||||
if current_ver < latest_ver and not quiet:
|
||||
print(f"Config version: {current_ver} → {latest_ver}")
|
||||
|
||||
@@ -525,34 +651,47 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
|
||||
if v["name"] not in required_names and not v.get("advanced")
|
||||
]
|
||||
|
||||
if interactive and missing_optional:
|
||||
print(" Would you like to configure any optional keys now?")
|
||||
try:
|
||||
answer = input(" Configure optional keys? [y/N]: ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
answer = "n"
|
||||
|
||||
if answer in ("y", "yes"):
|
||||
# Only offer to configure env vars that are NEW since the user's previous version
|
||||
new_var_names = set()
|
||||
for ver in range(current_ver + 1, latest_ver + 1):
|
||||
new_var_names.update(ENV_VARS_BY_VERSION.get(ver, []))
|
||||
|
||||
if new_var_names and interactive and not quiet:
|
||||
new_and_unset = [
|
||||
(name, OPTIONAL_ENV_VARS[name])
|
||||
for name in sorted(new_var_names)
|
||||
if not get_env_value(name) and name in OPTIONAL_ENV_VARS
|
||||
]
|
||||
if new_and_unset:
|
||||
print(f"\n {len(new_and_unset)} new optional key(s) in this update:")
|
||||
for name, info in new_and_unset:
|
||||
print(f" • {name} — {info.get('description', '')}")
|
||||
print()
|
||||
for var in missing_optional:
|
||||
desc = var.get("description", "")
|
||||
if var.get("url"):
|
||||
print(f" {desc}")
|
||||
print(f" Get your key at: {var['url']}")
|
||||
else:
|
||||
print(f" {desc}")
|
||||
|
||||
if var.get("password"):
|
||||
import getpass
|
||||
value = getpass.getpass(f" {var['prompt']} (Enter to skip): ")
|
||||
else:
|
||||
value = input(f" {var['prompt']} (Enter to skip): ").strip()
|
||||
|
||||
if value:
|
||||
save_env_value(var["name"], value)
|
||||
results["env_added"].append(var["name"])
|
||||
print(f" ✓ Saved {var['name']}")
|
||||
try:
|
||||
answer = input(" Configure new keys? [y/N]: ").strip().lower()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
answer = "n"
|
||||
|
||||
if answer in ("y", "yes"):
|
||||
print()
|
||||
for name, info in new_and_unset:
|
||||
if info.get("url"):
|
||||
print(f" {info.get('description', name)}")
|
||||
print(f" Get your key at: {info['url']}")
|
||||
else:
|
||||
print(f" {info.get('description', name)}")
|
||||
if info.get("password"):
|
||||
import getpass
|
||||
value = getpass.getpass(f" {info.get('prompt', name)} (Enter to skip): ")
|
||||
else:
|
||||
value = input(f" {info.get('prompt', name)} (Enter to skip): ").strip()
|
||||
if value:
|
||||
save_env_value(name, value)
|
||||
results["env_added"].append(name)
|
||||
print(f" ✓ Saved {name}")
|
||||
print()
|
||||
else:
|
||||
print(" Set later with: hermes config set KEY VALUE")
|
||||
|
||||
# Check for missing config fields
|
||||
missing_config = get_missing_config_fields()
|
||||
@@ -620,6 +759,36 @@ def load_config() -> Dict[str, Any]:
|
||||
return config
|
||||
|
||||
|
||||
_COMMENTED_SECTIONS = """
|
||||
# ── Security ──────────────────────────────────────────────────────────
|
||||
# API keys, tokens, and passwords are redacted from tool output by default.
|
||||
# Set to false to see full values (useful for debugging auth issues).
|
||||
#
|
||||
# security:
|
||||
# redact_secrets: false
|
||||
|
||||
# ── Fallback Model ────────────────────────────────────────────────────
|
||||
# Automatic provider failover when primary is unavailable.
|
||||
# Uncomment and configure to enable. Triggers on rate limits (429),
|
||||
# overload (529), service errors (503), or connection failures.
|
||||
#
|
||||
# Supported providers:
|
||||
# openrouter (OPENROUTER_API_KEY) — routes to any model
|
||||
# openai-codex (OAuth — hermes login) — OpenAI Codex
|
||||
# nous (OAuth — hermes login) — Nous Portal
|
||||
# zai (ZAI_API_KEY) — Z.AI / GLM
|
||||
# kimi-coding (KIMI_API_KEY) — Kimi / Moonshot
|
||||
# minimax (MINIMAX_API_KEY) — MiniMax
|
||||
# minimax-cn (MINIMAX_CN_API_KEY) — MiniMax (China)
|
||||
#
|
||||
# For custom OpenAI-compatible endpoints, add base_url and api_key_env.
|
||||
#
|
||||
# fallback_model:
|
||||
# provider: openrouter
|
||||
# model: anthropic/claude-sonnet-4
|
||||
"""
|
||||
|
||||
|
||||
def save_config(config: Dict[str, Any]):
|
||||
"""Save configuration to ~/.hermes/config.yaml."""
|
||||
ensure_hermes_home()
|
||||
@@ -627,6 +796,18 @@ def save_config(config: Dict[str, Any]):
|
||||
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(config, f, default_flow_style=False, sort_keys=False)
|
||||
# Append commented-out sections for features that are off by default
|
||||
# or only relevant when explicitly configured. Skip sections the
|
||||
# user has already uncommented and configured.
|
||||
sections = []
|
||||
sec = config.get("security", {})
|
||||
if not sec or sec.get("redact_secrets") is None:
|
||||
sections.append("security")
|
||||
fb = config.get("fallback_model", {})
|
||||
if not fb or not (fb.get("provider") and fb.get("model")):
|
||||
sections.append("fallback")
|
||||
if sections:
|
||||
f.write(_COMMENTED_SECTIONS)
|
||||
|
||||
|
||||
def load_env() -> Dict[str, str]:
|
||||
@@ -772,6 +953,15 @@ def show_config():
|
||||
print(f" SSH host: {ssh_host or '(not set)'}")
|
||||
print(f" SSH user: {ssh_user or '(not set)'}")
|
||||
|
||||
# Timezone
|
||||
print()
|
||||
print(color("◆ Timezone", Colors.CYAN, Colors.BOLD))
|
||||
tz = config.get('timezone', '')
|
||||
if tz:
|
||||
print(f" Timezone: {tz}")
|
||||
else:
|
||||
print(f" Timezone: {color('(server-local)', Colors.DIM)}")
|
||||
|
||||
# Compression
|
||||
print()
|
||||
print(color("◆ Context Compression", Colors.CYAN, Colors.BOLD))
|
||||
@@ -781,6 +971,31 @@ def show_config():
|
||||
if enabled:
|
||||
print(f" Threshold: {compression.get('threshold', 0.85) * 100:.0f}%")
|
||||
print(f" Model: {compression.get('summary_model', 'google/gemini-3-flash-preview')}")
|
||||
comp_provider = compression.get('summary_provider', 'auto')
|
||||
if comp_provider != 'auto':
|
||||
print(f" Provider: {comp_provider}")
|
||||
|
||||
# Auxiliary models
|
||||
auxiliary = config.get('auxiliary', {})
|
||||
aux_tasks = {
|
||||
"Vision": auxiliary.get('vision', {}),
|
||||
"Web extract": auxiliary.get('web_extract', {}),
|
||||
}
|
||||
has_overrides = any(
|
||||
t.get('provider', 'auto') != 'auto' or t.get('model', '')
|
||||
for t in aux_tasks.values()
|
||||
)
|
||||
if has_overrides:
|
||||
print()
|
||||
print(color("◆ Auxiliary Models (overrides)", Colors.CYAN, Colors.BOLD))
|
||||
for label, task_cfg in aux_tasks.items():
|
||||
prov = task_cfg.get('provider', 'auto')
|
||||
mdl = task_cfg.get('model', '')
|
||||
if prov != 'auto' or mdl:
|
||||
parts = [f"provider={prov}"]
|
||||
if mdl:
|
||||
parts.append(f"model={mdl}")
|
||||
print(f" {label:12s} {', '.join(parts)}")
|
||||
|
||||
# Messaging
|
||||
print()
|
||||
@@ -838,7 +1053,7 @@ def set_config_value(key: str, value: str):
|
||||
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
|
||||
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
|
||||
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',
|
||||
'GITHUB_TOKEN', 'HONCHO_API_KEY', 'NOUS_API_KEY', 'WANDB_API_KEY',
|
||||
'GITHUB_TOKEN', 'HONCHO_API_KEY', 'WANDB_API_KEY',
|
||||
'TINKER_API_KEY',
|
||||
]
|
||||
|
||||
@@ -895,6 +1110,7 @@ def set_config_value(key: str, value: str):
|
||||
"terminal.daytona_image": "TERMINAL_DAYTONA_IMAGE",
|
||||
"terminal.cwd": "TERMINAL_CWD",
|
||||
"terminal.timeout": "TERMINAL_TIMEOUT",
|
||||
"terminal.sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
}
|
||||
if key in _config_to_env_sync:
|
||||
save_env_value(_config_to_env_sync[key], str(value))
|
||||
|
||||
@@ -33,6 +33,26 @@ os.environ.setdefault("MSWEA_SILENT_STARTUP", "1")
|
||||
from hermes_cli.colors import Colors, color
|
||||
from hermes_constants import OPENROUTER_MODELS_URL
|
||||
|
||||
|
||||
_PROVIDER_ENV_HINTS = (
|
||||
"OPENROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"ANTHROPIC_API_KEY",
|
||||
"OPENAI_BASE_URL",
|
||||
"GLM_API_KEY",
|
||||
"ZAI_API_KEY",
|
||||
"Z_AI_API_KEY",
|
||||
"KIMI_API_KEY",
|
||||
"MINIMAX_API_KEY",
|
||||
"MINIMAX_CN_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
def _has_provider_env_config(content: str) -> bool:
|
||||
"""Return True when ~/.hermes/.env contains provider auth/base URL settings."""
|
||||
return any(key in content for key in _PROVIDER_ENV_HINTS)
|
||||
|
||||
|
||||
def check_ok(text: str, detail: str = ""):
|
||||
print(f" {color('✓', Colors.GREEN)} {text}" + (f" {color(detail, Colors.DIM)}" if detail else ""))
|
||||
|
||||
@@ -132,8 +152,8 @@ def run_doctor(args):
|
||||
|
||||
# Check for common issues
|
||||
content = env_path.read_text()
|
||||
if "OPENROUTER_API_KEY" in content or "ANTHROPIC_API_KEY" in content:
|
||||
check_ok("API key configured")
|
||||
if _has_provider_env_config(content):
|
||||
check_ok("API key or custom endpoint configured")
|
||||
else:
|
||||
check_warn("No API key found in ~/.hermes/.env")
|
||||
issues.append("Run 'hermes setup' to configure API keys")
|
||||
@@ -468,7 +488,48 @@ def run_doctor(args):
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} Anthropic API {color(msg, Colors.DIM)} ")
|
||||
except Exception as e:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} Anthropic API {color(f'({e})', Colors.DIM)} ")
|
||||
|
||||
|
||||
# -- API-key providers (Z.AI/GLM, Kimi, MiniMax, MiniMax-CN) --
|
||||
_apikey_providers = [
|
||||
("Z.AI / GLM", ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), "https://api.z.ai/api/paas/v4/models", "GLM_BASE_URL"),
|
||||
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
|
||||
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
|
||||
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
|
||||
]
|
||||
for _pname, _env_vars, _default_url, _base_env in _apikey_providers:
|
||||
_key = ""
|
||||
for _ev in _env_vars:
|
||||
_key = os.getenv(_ev, "")
|
||||
if _key:
|
||||
break
|
||||
if _key:
|
||||
_label = _pname.ljust(20)
|
||||
print(f" Checking {_pname} API...", end="", flush=True)
|
||||
try:
|
||||
import httpx
|
||||
_base = os.getenv(_base_env, "")
|
||||
# Auto-detect Kimi Code keys (sk-kimi-) → api.kimi.com
|
||||
if not _base and _key.startswith("sk-kimi-"):
|
||||
_base = "https://api.kimi.com/coding/v1"
|
||||
_url = (_base.rstrip("/") + "/models") if _base else _default_url
|
||||
_headers = {"Authorization": f"Bearer {_key}"}
|
||||
if "api.kimi.com" in _url.lower():
|
||||
_headers["User-Agent"] = "KimiCLI/1.0"
|
||||
_resp = httpx.get(
|
||||
_url,
|
||||
headers=_headers,
|
||||
timeout=10,
|
||||
)
|
||||
if _resp.status_code == 200:
|
||||
print(f"\r {color('✓', Colors.GREEN)} {_label} ")
|
||||
elif _resp.status_code == 401:
|
||||
print(f"\r {color('✗', Colors.RED)} {_label} {color('(invalid API key)', Colors.DIM)} ")
|
||||
issues.append(f"Check {_env_vars[0]} in .env")
|
||||
else:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} ")
|
||||
except Exception as _e:
|
||||
print(f"\r {color('⚠', Colors.YELLOW)} {_label} {color(f'({_e})', Colors.DIM)} ")
|
||||
|
||||
# =========================================================================
|
||||
# Check: Submodules
|
||||
# =========================================================================
|
||||
|
||||
@@ -154,19 +154,33 @@ def get_hermes_cli_path() -> str:
|
||||
# =============================================================================
|
||||
|
||||
def generate_systemd_unit() -> str:
|
||||
import shutil
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
venv_dir = str(PROJECT_ROOT / "venv")
|
||||
venv_bin = str(PROJECT_ROOT / "venv" / "bin")
|
||||
node_bin = str(PROJECT_ROOT / "node_modules" / ".bin")
|
||||
|
||||
# Build a PATH that includes the venv, node_modules, and standard system dirs
|
||||
sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
|
||||
hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main"
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run --replace
|
||||
ExecStop={hermes_cli} gateway stop
|
||||
WorkingDirectory={working_dir}
|
||||
Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=15
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
@@ -377,8 +391,15 @@ def launchd_status(deep: bool = False):
|
||||
# Gateway Runner
|
||||
# =============================================================================
|
||||
|
||||
def run_gateway(verbose: bool = False):
|
||||
"""Run the gateway in foreground."""
|
||||
def run_gateway(verbose: bool = False, replace: bool = False):
|
||||
"""Run the gateway in foreground.
|
||||
|
||||
Args:
|
||||
verbose: Enable verbose logging output.
|
||||
replace: If True, kill any existing gateway instance before starting.
|
||||
This prevents systemd restart loops when the old process
|
||||
hasn't fully exited yet.
|
||||
"""
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from gateway.run import start_gateway
|
||||
@@ -393,7 +414,7 @@ def run_gateway(verbose: bool = False):
|
||||
|
||||
# Exit with code 1 if gateway fails to connect any platform,
|
||||
# so systemd Restart=on-failure will retry on transient errors
|
||||
success = asyncio.run(start_gateway())
|
||||
success = asyncio.run(start_gateway(replace=replace))
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
|
||||
@@ -486,6 +507,12 @@ _PLATFORMS = [
|
||||
"emoji": "📲",
|
||||
"token_var": "WHATSAPP_ENABLED",
|
||||
},
|
||||
{
|
||||
"key": "signal",
|
||||
"label": "Signal",
|
||||
"emoji": "📡",
|
||||
"token_var": "SIGNAL_HTTP_URL",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -504,6 +531,13 @@ def _platform_status(platform: dict) -> str:
|
||||
return "configured + paired"
|
||||
return "enabled, not paired"
|
||||
return "not configured"
|
||||
if platform.get("key") == "signal":
|
||||
account = get_env_value("SIGNAL_ACCOUNT")
|
||||
if val and account:
|
||||
return "configured"
|
||||
if val or account:
|
||||
return "partially configured"
|
||||
return "not configured"
|
||||
if val:
|
||||
return "configured"
|
||||
return "not configured"
|
||||
@@ -629,6 +663,121 @@ def _is_service_running() -> bool:
|
||||
return len(find_gateway_pids()) > 0
|
||||
|
||||
|
||||
def _setup_signal():
|
||||
"""Interactive setup for Signal messenger."""
|
||||
import shutil
|
||||
|
||||
print()
|
||||
print(color(" ─── 📡 Signal Setup ───", Colors.CYAN))
|
||||
|
||||
existing_url = get_env_value("SIGNAL_HTTP_URL")
|
||||
existing_account = get_env_value("SIGNAL_ACCOUNT")
|
||||
if existing_url and existing_account:
|
||||
print()
|
||||
print_success("Signal is already configured.")
|
||||
if not prompt_yes_no(" Reconfigure Signal?", False):
|
||||
return
|
||||
|
||||
# Check if signal-cli is available
|
||||
print()
|
||||
if shutil.which("signal-cli"):
|
||||
print_success("signal-cli found on PATH.")
|
||||
else:
|
||||
print_warning("signal-cli not found on PATH.")
|
||||
print_info(" Signal requires signal-cli running as an HTTP daemon.")
|
||||
print_info(" Install options:")
|
||||
print_info(" Linux: sudo apt install signal-cli")
|
||||
print_info(" or download from https://github.com/AsamK/signal-cli")
|
||||
print_info(" macOS: brew install signal-cli")
|
||||
print_info(" Docker: bbernhard/signal-cli-rest-api")
|
||||
print()
|
||||
print_info(" After installing, link your account and start the daemon:")
|
||||
print_info(" signal-cli link -n \"HermesAgent\"")
|
||||
print_info(" signal-cli --account +YOURNUMBER daemon --http 127.0.0.1:8080")
|
||||
print()
|
||||
|
||||
# HTTP URL
|
||||
print()
|
||||
print_info(" Enter the URL where signal-cli HTTP daemon is running.")
|
||||
default_url = existing_url or "http://127.0.0.1:8080"
|
||||
try:
|
||||
url = input(f" HTTP URL [{default_url}]: ").strip() or default_url
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n Setup cancelled.")
|
||||
return
|
||||
|
||||
# Test connectivity
|
||||
print_info(" Testing connection...")
|
||||
try:
|
||||
import httpx
|
||||
resp = httpx.get(f"{url.rstrip('/')}/api/v1/check", timeout=10.0)
|
||||
if resp.status_code == 200:
|
||||
print_success(" signal-cli daemon is reachable!")
|
||||
else:
|
||||
print_warning(f" signal-cli responded with status {resp.status_code}.")
|
||||
if not prompt_yes_no(" Continue anyway?", False):
|
||||
return
|
||||
except Exception as e:
|
||||
print_warning(f" Could not reach signal-cli at {url}: {e}")
|
||||
if not prompt_yes_no(" Save this URL anyway? (you can start signal-cli later)", True):
|
||||
return
|
||||
|
||||
save_env_value("SIGNAL_HTTP_URL", url)
|
||||
|
||||
# Account phone number
|
||||
print()
|
||||
print_info(" Enter your Signal account phone number in E.164 format.")
|
||||
print_info(" Example: +15551234567")
|
||||
default_account = existing_account or ""
|
||||
try:
|
||||
account = input(f" Account number{f' [{default_account}]' if default_account else ''}: ").strip()
|
||||
if not account:
|
||||
account = default_account
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n Setup cancelled.")
|
||||
return
|
||||
|
||||
if not account:
|
||||
print_error(" Account number is required.")
|
||||
return
|
||||
|
||||
save_env_value("SIGNAL_ACCOUNT", account)
|
||||
|
||||
# Allowed users
|
||||
print()
|
||||
print_info(" The gateway DENIES all users by default for security.")
|
||||
print_info(" Enter phone numbers or UUIDs of allowed users (comma-separated).")
|
||||
existing_allowed = get_env_value("SIGNAL_ALLOWED_USERS") or ""
|
||||
default_allowed = existing_allowed or account
|
||||
try:
|
||||
allowed = input(f" Allowed users [{default_allowed}]: ").strip() or default_allowed
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n Setup cancelled.")
|
||||
return
|
||||
|
||||
save_env_value("SIGNAL_ALLOWED_USERS", allowed)
|
||||
|
||||
# Group messaging
|
||||
print()
|
||||
if prompt_yes_no(" Enable group messaging? (disabled by default for security)", False):
|
||||
print()
|
||||
print_info(" Enter group IDs to allow, or * for all groups.")
|
||||
existing_groups = get_env_value("SIGNAL_GROUP_ALLOWED_USERS") or ""
|
||||
try:
|
||||
groups = input(f" Group IDs [{existing_groups or '*'}]: ").strip() or existing_groups or "*"
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\n Setup cancelled.")
|
||||
return
|
||||
save_env_value("SIGNAL_GROUP_ALLOWED_USERS", groups)
|
||||
|
||||
print()
|
||||
print_success("Signal configured!")
|
||||
print_info(f" URL: {url}")
|
||||
print_info(f" Account: {account}")
|
||||
print_info(f" DM auth: via SIGNAL_ALLOWED_USERS + DM pairing")
|
||||
print_info(f" Groups: {'enabled' if get_env_value('SIGNAL_GROUP_ALLOWED_USERS') else 'disabled'}")
|
||||
|
||||
|
||||
def gateway_setup():
|
||||
"""Interactive setup for messaging platforms + gateway service."""
|
||||
|
||||
@@ -681,6 +830,8 @@ def gateway_setup():
|
||||
|
||||
if platform["key"] == "whatsapp":
|
||||
_setup_whatsapp()
|
||||
elif platform["key"] == "signal":
|
||||
_setup_signal()
|
||||
else:
|
||||
_setup_standard_platform(platform)
|
||||
|
||||
@@ -765,7 +916,8 @@ def gateway_command(args):
|
||||
# Default to run if no subcommand
|
||||
if subcmd is None or subcmd == "run":
|
||||
verbose = getattr(args, 'verbose', False)
|
||||
run_gateway(verbose)
|
||||
replace = getattr(args, 'replace', False)
|
||||
run_gateway(verbose, replace=replace)
|
||||
return
|
||||
|
||||
if subcmd == "setup":
|
||||
|
||||
@@ -21,6 +21,7 @@ Usage:
|
||||
hermes version # Show version
|
||||
hermes update # Update to latest version
|
||||
hermes uninstall # Uninstall Hermes Agent
|
||||
hermes sessions browse # Interactive session picker with search
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -64,7 +65,13 @@ def _has_any_provider_configured() -> bool:
|
||||
# Check env vars (may be set by .env or shell).
|
||||
# OPENAI_BASE_URL alone counts — local models (vLLM, llama.cpp, etc.)
|
||||
# often don't require an API key.
|
||||
provider_env_vars = ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL")
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
|
||||
# Collect all provider env vars
|
||||
provider_env_vars = {"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL"}
|
||||
for pconfig in PROVIDER_REGISTRY.values():
|
||||
if pconfig.auth_type == "api_key":
|
||||
provider_env_vars.update(pconfig.api_key_env_vars)
|
||||
if any(os.getenv(v) for v in provider_env_vars):
|
||||
return True
|
||||
|
||||
@@ -100,6 +107,279 @@ def _has_any_provider_configured() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _session_browse_picker(sessions: list) -> Optional[str]:
|
||||
"""Interactive curses-based session browser with live search filtering.
|
||||
|
||||
Returns the selected session ID, or None if cancelled.
|
||||
Uses curses (not simple_term_menu) to avoid the ghost-duplication rendering
|
||||
bug in tmux/iTerm when arrow keys are used.
|
||||
"""
|
||||
if not sessions:
|
||||
print("No sessions found.")
|
||||
return None
|
||||
|
||||
# Try curses-based picker first
|
||||
try:
|
||||
import curses
|
||||
import time as _time
|
||||
from datetime import datetime
|
||||
|
||||
result_holder = [None]
|
||||
|
||||
def _relative_time(ts):
|
||||
if not ts:
|
||||
return "?"
|
||||
delta = _time.time() - ts
|
||||
if delta < 60:
|
||||
return "just now"
|
||||
elif delta < 3600:
|
||||
return f"{int(delta / 60)}m ago"
|
||||
elif delta < 86400:
|
||||
return f"{int(delta / 3600)}h ago"
|
||||
elif delta < 172800:
|
||||
return "yesterday"
|
||||
elif delta < 604800:
|
||||
return f"{int(delta / 86400)}d ago"
|
||||
else:
|
||||
return datetime.fromtimestamp(ts).strftime("%Y-%m-%d")
|
||||
|
||||
def _format_row(s, max_x):
|
||||
"""Format a session row for display."""
|
||||
title = (s.get("title") or "").strip()
|
||||
preview = (s.get("preview") or "").strip()
|
||||
source = s.get("source", "")[:6]
|
||||
last_active = _relative_time(s.get("last_active"))
|
||||
sid = s["id"][:18]
|
||||
|
||||
# Adaptive column widths based on terminal width
|
||||
# Layout: [arrow 3] [title/preview flexible] [active 12] [src 6] [id 18]
|
||||
fixed_cols = 3 + 12 + 6 + 18 + 6 # arrow + active + src + id + padding
|
||||
name_width = max(20, max_x - fixed_cols)
|
||||
|
||||
if title:
|
||||
name = title[:name_width]
|
||||
elif preview:
|
||||
name = preview[:name_width]
|
||||
else:
|
||||
name = sid
|
||||
|
||||
return f"{name:<{name_width}} {last_active:<10} {source:<5} {sid}"
|
||||
|
||||
def _match(s, query):
|
||||
"""Check if a session matches the search query (case-insensitive)."""
|
||||
q = query.lower()
|
||||
return (
|
||||
q in (s.get("title") or "").lower()
|
||||
or q in (s.get("preview") or "").lower()
|
||||
or q in s.get("id", "").lower()
|
||||
or q in (s.get("source") or "").lower()
|
||||
)
|
||||
|
||||
def _curses_browse(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1) # selected
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1) # header
|
||||
curses.init_pair(3, curses.COLOR_CYAN, -1) # search
|
||||
curses.init_pair(4, 8, -1) # dim
|
||||
|
||||
cursor = 0
|
||||
scroll_offset = 0
|
||||
search_text = ""
|
||||
filtered = list(sessions)
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
if max_y < 5 or max_x < 40:
|
||||
# Terminal too small
|
||||
try:
|
||||
stdscr.addstr(0, 0, "Terminal too small")
|
||||
except curses.error:
|
||||
pass
|
||||
stdscr.refresh()
|
||||
stdscr.getch()
|
||||
return
|
||||
|
||||
# Header line
|
||||
if search_text:
|
||||
header = f" Browse sessions — filter: {search_text}█"
|
||||
header_attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
header_attr |= curses.color_pair(3)
|
||||
else:
|
||||
header = " Browse sessions — ↑↓ navigate Enter select Type to filter Esc quit"
|
||||
header_attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
header_attr |= curses.color_pair(2)
|
||||
try:
|
||||
stdscr.addnstr(0, 0, header, max_x - 1, header_attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
# Column header line
|
||||
fixed_cols = 3 + 12 + 6 + 18 + 6
|
||||
name_width = max(20, max_x - fixed_cols)
|
||||
col_header = f" {'Title / Preview':<{name_width}} {'Active':<10} {'Src':<5} {'ID'}"
|
||||
try:
|
||||
dim_attr = curses.color_pair(4) if curses.has_colors() else curses.A_DIM
|
||||
stdscr.addnstr(1, 0, col_header, max_x - 1, dim_attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
# Compute visible area
|
||||
visible_rows = max_y - 4 # header + col header + blank + footer
|
||||
if visible_rows < 1:
|
||||
visible_rows = 1
|
||||
|
||||
# Clamp cursor and scroll
|
||||
if not filtered:
|
||||
try:
|
||||
msg = " No sessions match the filter."
|
||||
stdscr.addnstr(3, 0, msg, max_x - 1, curses.A_DIM)
|
||||
except curses.error:
|
||||
pass
|
||||
else:
|
||||
if cursor >= len(filtered):
|
||||
cursor = len(filtered) - 1
|
||||
if cursor < 0:
|
||||
cursor = 0
|
||||
if cursor < scroll_offset:
|
||||
scroll_offset = cursor
|
||||
elif cursor >= scroll_offset + visible_rows:
|
||||
scroll_offset = cursor - visible_rows + 1
|
||||
|
||||
for draw_i, i in enumerate(range(
|
||||
scroll_offset,
|
||||
min(len(filtered), scroll_offset + visible_rows)
|
||||
)):
|
||||
y = draw_i + 3
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
s = filtered[i]
|
||||
arrow = " → " if i == cursor else " "
|
||||
row = arrow + _format_row(s, max_x - 3)
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, row, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
# Footer
|
||||
footer_y = max_y - 1
|
||||
if filtered:
|
||||
footer = f" {cursor + 1}/{len(filtered)} sessions"
|
||||
if len(filtered) < len(sessions):
|
||||
footer += f" (filtered from {len(sessions)})"
|
||||
else:
|
||||
footer = f" 0/{len(sessions)} sessions"
|
||||
try:
|
||||
stdscr.addnstr(footer_y, 0, footer, max_x - 1,
|
||||
curses.color_pair(4) if curses.has_colors() else curses.A_DIM)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ):
|
||||
if filtered:
|
||||
cursor = (cursor - 1) % len(filtered)
|
||||
elif key in (curses.KEY_DOWN, ):
|
||||
if filtered:
|
||||
cursor = (cursor + 1) % len(filtered)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
if filtered:
|
||||
result_holder[0] = filtered[cursor]["id"]
|
||||
return
|
||||
elif key == 27: # Esc
|
||||
if search_text:
|
||||
# First Esc clears the search
|
||||
search_text = ""
|
||||
filtered = list(sessions)
|
||||
cursor = 0
|
||||
scroll_offset = 0
|
||||
else:
|
||||
# Second Esc exits
|
||||
return
|
||||
elif key in (curses.KEY_BACKSPACE, 127, 8):
|
||||
if search_text:
|
||||
search_text = search_text[:-1]
|
||||
if search_text:
|
||||
filtered = [s for s in sessions if _match(s, search_text)]
|
||||
else:
|
||||
filtered = list(sessions)
|
||||
cursor = 0
|
||||
scroll_offset = 0
|
||||
elif key == ord('q') and not search_text:
|
||||
return
|
||||
elif 32 <= key <= 126:
|
||||
# Printable character → add to search filter
|
||||
search_text += chr(key)
|
||||
filtered = [s for s in sessions if _match(s, search_text)]
|
||||
cursor = 0
|
||||
scroll_offset = 0
|
||||
|
||||
curses.wrapper(_curses_browse)
|
||||
return result_holder[0]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: numbered list (Windows without curses, etc.)
|
||||
import time as _time
|
||||
from datetime import datetime
|
||||
|
||||
def _relative_time_fb(ts):
|
||||
if not ts:
|
||||
return "?"
|
||||
delta = _time.time() - ts
|
||||
if delta < 60:
|
||||
return "just now"
|
||||
elif delta < 3600:
|
||||
return f"{int(delta / 60)}m ago"
|
||||
elif delta < 86400:
|
||||
return f"{int(delta / 3600)}h ago"
|
||||
elif delta < 172800:
|
||||
return "yesterday"
|
||||
elif delta < 604800:
|
||||
return f"{int(delta / 86400)}d ago"
|
||||
else:
|
||||
return datetime.fromtimestamp(ts).strftime("%Y-%m-%d")
|
||||
|
||||
print("\n Browse sessions (enter number to resume, q to cancel)\n")
|
||||
for i, s in enumerate(sessions):
|
||||
title = (s.get("title") or "").strip()
|
||||
preview = (s.get("preview") or "").strip()
|
||||
label = title or preview or s["id"]
|
||||
if len(label) > 50:
|
||||
label = label[:47] + "..."
|
||||
last_active = _relative_time_fb(s.get("last_active"))
|
||||
src = s.get("source", "")[:6]
|
||||
print(f" {i + 1:>3}. {label:<50} {last_active:<10} {src}")
|
||||
|
||||
while True:
|
||||
try:
|
||||
val = input(f"\n Select [1-{len(sessions)}]: ").strip()
|
||||
if not val or val.lower() in ("q", "quit", "exit"):
|
||||
return None
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(sessions):
|
||||
return sessions[idx]["id"]
|
||||
print(f" Invalid selection. Enter 1-{len(sessions)} or q to cancel.")
|
||||
except ValueError:
|
||||
print(f" Invalid input. Enter a number or q to cancel.")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_last_cli_session() -> Optional[str]:
|
||||
"""Look up the most recent CLI session ID from SQLite. Returns None if unavailable."""
|
||||
try:
|
||||
@@ -114,16 +394,63 @@ def _resolve_last_cli_session() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_session_by_name_or_id(name_or_id: str) -> Optional[str]:
|
||||
"""Resolve a session name (title) or ID to a session ID.
|
||||
|
||||
- If it looks like a session ID (contains underscore + hex), try direct lookup first.
|
||||
- Otherwise, treat it as a title and use resolve_session_by_title (auto-latest).
|
||||
- Falls back to the other method if the first doesn't match.
|
||||
"""
|
||||
try:
|
||||
from hermes_state import SessionDB
|
||||
db = SessionDB()
|
||||
|
||||
# Try as exact session ID first
|
||||
session = db.get_session(name_or_id)
|
||||
if session:
|
||||
db.close()
|
||||
return session["id"]
|
||||
|
||||
# Try as title (with auto-latest for lineage)
|
||||
session_id = db.resolve_session_by_title(name_or_id)
|
||||
db.close()
|
||||
return session_id
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def cmd_chat(args):
|
||||
"""Run interactive chat CLI."""
|
||||
# Resolve --continue into --resume with the latest CLI session
|
||||
if getattr(args, "continue_last", False) and not getattr(args, "resume", None):
|
||||
last_id = _resolve_last_cli_session()
|
||||
if last_id:
|
||||
args.resume = last_id
|
||||
# Resolve --continue into --resume with the latest CLI session or by name
|
||||
continue_val = getattr(args, "continue_last", None)
|
||||
if continue_val and not getattr(args, "resume", None):
|
||||
if isinstance(continue_val, str):
|
||||
# -c "session name" — resolve by title or ID
|
||||
resolved = _resolve_session_by_name_or_id(continue_val)
|
||||
if resolved:
|
||||
args.resume = resolved
|
||||
else:
|
||||
print(f"No session found matching '{continue_val}'.")
|
||||
print("Use 'hermes sessions list' to see available sessions.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("No previous CLI session found to continue.")
|
||||
sys.exit(1)
|
||||
# -c with no argument — continue the most recent session
|
||||
last_id = _resolve_last_cli_session()
|
||||
if last_id:
|
||||
args.resume = last_id
|
||||
else:
|
||||
print("No previous CLI session found to continue.")
|
||||
sys.exit(1)
|
||||
|
||||
# Resolve --resume by title if it's not a direct session ID
|
||||
resume_val = getattr(args, "resume", None)
|
||||
if resume_val:
|
||||
resolved = _resolve_session_by_name_or_id(resume_val)
|
||||
if resolved:
|
||||
args.resume = resolved
|
||||
# If resolution fails, keep the original value — _init_agent will
|
||||
# report "Session not found" with the original input
|
||||
|
||||
# First-run guard: check if any provider is configured before launching
|
||||
if not _has_any_provider_configured():
|
||||
@@ -143,6 +470,13 @@ def cmd_chat(args):
|
||||
print("You can run 'hermes setup' at any time to configure.")
|
||||
sys.exit(1)
|
||||
|
||||
# Sync bundled skills on every CLI launch (fast -- skips unchanged skills)
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
sync_skills(quiet=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Import and run the CLI
|
||||
from cli import main as cli_main
|
||||
|
||||
@@ -154,6 +488,7 @@ def cmd_chat(args):
|
||||
"verbose": args.verbose,
|
||||
"query": args.query,
|
||||
"resume": getattr(args, "resume", None),
|
||||
"worktree": getattr(args, "worktree", False),
|
||||
}
|
||||
# Filter out None values
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
@@ -404,6 +739,10 @@ def cmd_model(args):
|
||||
"openrouter": "OpenRouter",
|
||||
"nous": "Nous Portal",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"zai": "Z.AI / GLM",
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
active_label = provider_labels.get(active, active)
|
||||
@@ -418,11 +757,16 @@ def cmd_model(args):
|
||||
("openrouter", "OpenRouter (100+ models, pay-per-use)"),
|
||||
("nous", "Nous Portal (Nous Research subscription)"),
|
||||
("openai-codex", "OpenAI Codex"),
|
||||
("zai", "Z.AI / GLM (Zhipu AI direct API)"),
|
||||
("kimi-coding", "Kimi / Moonshot (Moonshot AI direct API)"),
|
||||
("minimax", "MiniMax (global direct API)"),
|
||||
("minimax-cn", "MiniMax China (domestic direct API)"),
|
||||
("custom", "Custom endpoint (self-hosted / VLLM / etc.)"),
|
||||
]
|
||||
|
||||
# Reorder so the active provider is at the top
|
||||
active_key = active if active in ("openrouter", "nous", "openai-codex") else "custom"
|
||||
known_keys = {k for k, _ in providers}
|
||||
active_key = active if active in known_keys else "custom"
|
||||
ordered = []
|
||||
for key, label in providers:
|
||||
if key == active_key:
|
||||
@@ -447,6 +791,8 @@ def cmd_model(args):
|
||||
_model_flow_openai_codex(config, current_model)
|
||||
elif selected_provider == "custom":
|
||||
_model_flow_custom(config)
|
||||
elif selected_provider in ("zai", "kimi-coding", "minimax", "minimax-cn"):
|
||||
_model_flow_api_key_provider(config, selected_provider, current_model)
|
||||
|
||||
|
||||
def _prompt_provider_choice(choices):
|
||||
@@ -716,6 +1062,117 @@ def _model_flow_custom(config):
|
||||
print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.")
|
||||
|
||||
|
||||
# Curated model lists for direct API-key providers
|
||||
_PROVIDER_MODELS = {
|
||||
"zai": [
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"glm-4.5-flash",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-turbo-preview",
|
||||
"kimi-k2-0905-preview",
|
||||
],
|
||||
"minimax": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
"minimax-cn": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
||||
"""Generic flow for API-key providers (z.ai, Kimi, MiniMax)."""
|
||||
from hermes_cli.auth import (
|
||||
PROVIDER_REGISTRY, _prompt_model_selection, _save_model_choice,
|
||||
_update_config_for_provider, deactivate_provider,
|
||||
)
|
||||
from hermes_cli.config import get_env_value, save_env_value, load_config, save_config
|
||||
|
||||
pconfig = PROVIDER_REGISTRY[provider_id]
|
||||
key_env = pconfig.api_key_env_vars[0] if pconfig.api_key_env_vars else ""
|
||||
base_url_env = pconfig.base_url_env_var or ""
|
||||
|
||||
# Check / prompt for API key
|
||||
existing_key = ""
|
||||
for ev in pconfig.api_key_env_vars:
|
||||
existing_key = get_env_value(ev) or os.getenv(ev, "")
|
||||
if existing_key:
|
||||
break
|
||||
|
||||
if not existing_key:
|
||||
print(f"No {pconfig.name} API key configured.")
|
||||
if key_env:
|
||||
try:
|
||||
new_key = input(f"{key_env} (or Enter to cancel): ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return
|
||||
if not new_key:
|
||||
print("Cancelled.")
|
||||
return
|
||||
save_env_value(key_env, new_key)
|
||||
print("API key saved.")
|
||||
print()
|
||||
else:
|
||||
print(f" {pconfig.name} API key: {existing_key[:8]}... ✓")
|
||||
print()
|
||||
|
||||
# Optional base URL override
|
||||
current_base = ""
|
||||
if base_url_env:
|
||||
current_base = get_env_value(base_url_env) or os.getenv(base_url_env, "")
|
||||
effective_base = current_base or pconfig.inference_base_url
|
||||
|
||||
try:
|
||||
override = input(f"Base URL [{effective_base}]: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
override = ""
|
||||
if override and base_url_env:
|
||||
save_env_value(base_url_env, override)
|
||||
effective_base = override
|
||||
|
||||
# Model selection
|
||||
model_list = _PROVIDER_MODELS.get(provider_id, [])
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(model_list, current_model=current_model)
|
||||
else:
|
||||
try:
|
||||
selected = input("Model name: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
selected = None
|
||||
|
||||
if selected:
|
||||
# Clear custom endpoint if set (avoid confusion)
|
||||
if get_env_value("OPENAI_BASE_URL"):
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
|
||||
_save_model_choice(selected)
|
||||
|
||||
# Update config with provider and base URL
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if isinstance(model, dict):
|
||||
model["provider"] = provider_id
|
||||
model["base_url"] = effective_base
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
|
||||
print(f"Default model set to: {selected} (via {pconfig.name})")
|
||||
else:
|
||||
print("No change.")
|
||||
|
||||
|
||||
def cmd_login(args):
|
||||
"""Authenticate Hermes CLI with a provider."""
|
||||
from hermes_cli.auth import login_command
|
||||
@@ -851,11 +1308,17 @@ def _update_via_zip(args):
|
||||
# Sync skills
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
print("→ Checking for new bundled skills...")
|
||||
print("→ Syncing bundled skills...")
|
||||
result = sync_skills(quiet=True)
|
||||
if result["copied"]:
|
||||
print(f" + {len(result['copied'])} new skill(s): {', '.join(result['copied'])}")
|
||||
else:
|
||||
print(f" + {len(result['copied'])} new: {', '.join(result['copied'])}")
|
||||
if result.get("updated"):
|
||||
print(f" ↑ {len(result['updated'])} updated: {', '.join(result['updated'])}")
|
||||
if result.get("user_modified"):
|
||||
print(f" ~ {len(result['user_modified'])} user-modified (kept)")
|
||||
if result.get("cleaned"):
|
||||
print(f" − {len(result['cleaned'])} removed from manifest")
|
||||
if not result["copied"] and not result.get("updated"):
|
||||
print(" ✓ Skills are up to date")
|
||||
except Exception:
|
||||
pass
|
||||
@@ -961,15 +1424,21 @@ def cmd_update(args):
|
||||
print()
|
||||
print("✓ Code updated!")
|
||||
|
||||
# Sync any new bundled skills (manifest-based -- won't overwrite or re-add deleted skills)
|
||||
# Sync bundled skills (copies new, updates changed, respects user deletions)
|
||||
try:
|
||||
from tools.skills_sync import sync_skills
|
||||
print()
|
||||
print("→ Checking for new bundled skills...")
|
||||
print("→ Syncing bundled skills...")
|
||||
result = sync_skills(quiet=True)
|
||||
if result["copied"]:
|
||||
print(f" + {len(result['copied'])} new skill(s): {', '.join(result['copied'])}")
|
||||
else:
|
||||
print(f" + {len(result['copied'])} new: {', '.join(result['copied'])}")
|
||||
if result.get("updated"):
|
||||
print(f" ↑ {len(result['updated'])} updated: {', '.join(result['updated'])}")
|
||||
if result.get("user_modified"):
|
||||
print(f" ~ {len(result['user_modified'])} user-modified (kept)")
|
||||
if result.get("cleaned"):
|
||||
print(f" − {len(result['cleaned'])} removed from manifest")
|
||||
if not result["copied"] and not result.get("updated"):
|
||||
print(" ✓ Skills are up to date")
|
||||
except Exception as e:
|
||||
logger.debug("Skills sync during update failed: %s", e)
|
||||
@@ -1061,8 +1530,9 @@ def main():
|
||||
Examples:
|
||||
hermes Start interactive chat
|
||||
hermes chat -q "Hello" Single query mode
|
||||
hermes --continue Resume the most recent session
|
||||
hermes --resume <session_id> Resume a specific session
|
||||
hermes -c Resume the most recent session
|
||||
hermes -c "my project" Resume a session by name (latest in lineage)
|
||||
hermes --resume <session_id> Resume a specific session by ID
|
||||
hermes setup Run setup wizard
|
||||
hermes logout Clear stored authentication
|
||||
hermes model Select default model
|
||||
@@ -1070,8 +1540,11 @@ Examples:
|
||||
hermes config edit Edit config in $EDITOR
|
||||
hermes config set model gpt-4 Set a config value
|
||||
hermes gateway Run messaging gateway
|
||||
hermes -w Start in isolated git worktree
|
||||
hermes gateway install Install as system service
|
||||
hermes sessions list List past sessions
|
||||
hermes sessions browse Interactive session picker
|
||||
hermes sessions rename ID T Rename/title a session
|
||||
hermes update Update to latest version
|
||||
|
||||
For more help on a command:
|
||||
@@ -1086,16 +1559,24 @@ For more help on a command:
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume", "-r",
|
||||
metavar="SESSION_ID",
|
||||
metavar="SESSION",
|
||||
default=None,
|
||||
help="Resume a previous session by ID (shortcut for: hermes chat --resume ID)"
|
||||
help="Resume a previous session by ID or title"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--continue", "-c",
|
||||
dest="continue_last",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=None,
|
||||
metavar="SESSION_NAME",
|
||||
help="Resume a session by name, or the most recent if no name given"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--worktree", "-w",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Resume the most recent CLI session"
|
||||
help="Run in an isolated git worktree (for parallel agents)"
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", help="Command to run")
|
||||
@@ -1122,7 +1603,7 @@ For more help on a command:
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--provider",
|
||||
choices=["auto", "openrouter", "nous", "openai-codex"],
|
||||
choices=["auto", "openrouter", "nous", "openai-codex", "zai", "kimi-coding", "minimax", "minimax-cn"],
|
||||
default=None,
|
||||
help="Inference provider (default: auto)"
|
||||
)
|
||||
@@ -1139,9 +1620,17 @@ For more help on a command:
|
||||
chat_parser.add_argument(
|
||||
"--continue", "-c",
|
||||
dest="continue_last",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=None,
|
||||
metavar="SESSION_NAME",
|
||||
help="Resume a session by name, or the most recent if no name given"
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--worktree", "-w",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Resume the most recent CLI session"
|
||||
help="Run in an isolated git worktree (for parallel agents on the same repo)"
|
||||
)
|
||||
chat_parser.set_defaults(func=cmd_chat)
|
||||
|
||||
@@ -1168,6 +1657,8 @@ For more help on a command:
|
||||
# gateway run (default)
|
||||
gateway_run = gateway_subparsers.add_parser("run", help="Run gateway in foreground")
|
||||
gateway_run.add_argument("-v", "--verbose", action="store_true")
|
||||
gateway_run.add_argument("--replace", action="store_true",
|
||||
help="Replace any existing gateway instance (useful for systemd)")
|
||||
|
||||
# gateway start
|
||||
gateway_start = gateway_subparsers.add_parser("start", help="Start gateway service")
|
||||
@@ -1200,7 +1691,15 @@ For more help on a command:
|
||||
setup_parser = subparsers.add_parser(
|
||||
"setup",
|
||||
help="Interactive setup wizard",
|
||||
description="Configure Hermes Agent with an interactive wizard"
|
||||
description="Configure Hermes Agent with an interactive wizard. "
|
||||
"Run a specific section: hermes setup model|terminal|gateway|tools|agent"
|
||||
)
|
||||
setup_parser.add_argument(
|
||||
"section",
|
||||
nargs="?",
|
||||
choices=["model", "terminal", "gateway", "tools", "agent"],
|
||||
default=None,
|
||||
help="Run a specific setup section instead of the full wizard"
|
||||
)
|
||||
setup_parser.add_argument(
|
||||
"--non-interactive",
|
||||
@@ -1500,7 +1999,7 @@ For more help on a command:
|
||||
# =========================================================================
|
||||
sessions_parser = subparsers.add_parser(
|
||||
"sessions",
|
||||
help="Manage session history (list, export, prune, delete)",
|
||||
help="Manage session history (list, rename, export, prune, delete)",
|
||||
description="View and manage the SQLite session store"
|
||||
)
|
||||
sessions_subparsers = sessions_parser.add_subparsers(dest="sessions_action")
|
||||
@@ -1525,6 +2024,17 @@ For more help on a command:
|
||||
|
||||
sessions_stats = sessions_subparsers.add_parser("stats", help="Show session store statistics")
|
||||
|
||||
sessions_rename = sessions_subparsers.add_parser("rename", help="Set or change a session's title")
|
||||
sessions_rename.add_argument("session_id", help="Session ID to rename")
|
||||
sessions_rename.add_argument("title", nargs="+", help="New title for the session")
|
||||
|
||||
sessions_browse = sessions_subparsers.add_parser(
|
||||
"browse",
|
||||
help="Interactive session picker — browse, search, and resume sessions",
|
||||
)
|
||||
sessions_browse.add_argument("--source", help="Filter by source (cli, telegram, discord, etc.)")
|
||||
sessions_browse.add_argument("--limit", type=int, default=50, help="Max sessions to load (default: 50)")
|
||||
|
||||
def cmd_sessions(args):
|
||||
import json as _json
|
||||
try:
|
||||
@@ -1537,18 +2047,51 @@ For more help on a command:
|
||||
action = args.sessions_action
|
||||
|
||||
if action == "list":
|
||||
sessions = db.search_sessions(source=args.source, limit=args.limit)
|
||||
sessions = db.list_sessions_rich(source=args.source, limit=args.limit)
|
||||
if not sessions:
|
||||
print("No sessions found.")
|
||||
return
|
||||
print(f"{'ID':<30} {'Source':<12} {'Model':<30} {'Messages':>8} {'Started'}")
|
||||
print("─" * 100)
|
||||
from datetime import datetime
|
||||
import time as _time
|
||||
|
||||
def _relative_time(ts):
|
||||
"""Format a timestamp as relative time (e.g., '2h ago', 'yesterday')."""
|
||||
if not ts:
|
||||
return "?"
|
||||
delta = _time.time() - ts
|
||||
if delta < 60:
|
||||
return "just now"
|
||||
elif delta < 3600:
|
||||
mins = int(delta / 60)
|
||||
return f"{mins}m ago"
|
||||
elif delta < 86400:
|
||||
hours = int(delta / 3600)
|
||||
return f"{hours}h ago"
|
||||
elif delta < 172800:
|
||||
return "yesterday"
|
||||
elif delta < 604800:
|
||||
days = int(delta / 86400)
|
||||
return f"{days}d ago"
|
||||
else:
|
||||
return datetime.fromtimestamp(ts).strftime("%Y-%m-%d")
|
||||
|
||||
has_titles = any(s.get("title") for s in sessions)
|
||||
if has_titles:
|
||||
print(f"{'Title':<22} {'Preview':<40} {'Last Active':<13} {'ID'}")
|
||||
print("─" * 100)
|
||||
else:
|
||||
print(f"{'Preview':<50} {'Last Active':<13} {'Src':<6} {'ID'}")
|
||||
print("─" * 90)
|
||||
for s in sessions:
|
||||
started = datetime.fromtimestamp(s["started_at"]).strftime("%Y-%m-%d %H:%M") if s["started_at"] else "?"
|
||||
model = (s.get("model") or "?")[:28]
|
||||
ended = " (ended)" if s.get("ended_at") else ""
|
||||
print(f"{s['id']:<30} {s['source']:<12} {model:<30} {s['message_count']:>8} {started}{ended}")
|
||||
last_active = _relative_time(s.get("last_active"))
|
||||
preview = s.get("preview", "")[:38] if has_titles else s.get("preview", "")[:48]
|
||||
if has_titles:
|
||||
title = (s.get("title") or "—")[:20]
|
||||
sid = s["id"][:20]
|
||||
print(f"{title:<22} {preview:<40} {last_active:<13} {sid}")
|
||||
else:
|
||||
sid = s["id"][:20]
|
||||
print(f"{preview:<50} {last_active:<13} {s['source']:<6} {sid}")
|
||||
|
||||
elif action == "export":
|
||||
if args.session_id:
|
||||
@@ -1588,6 +2131,44 @@ For more help on a command:
|
||||
count = db.prune_sessions(older_than_days=days, source=args.source)
|
||||
print(f"Pruned {count} session(s).")
|
||||
|
||||
elif action == "rename":
|
||||
title = " ".join(args.title)
|
||||
try:
|
||||
if db.set_session_title(args.session_id, title):
|
||||
print(f"Session '{args.session_id}' renamed to: {title}")
|
||||
else:
|
||||
print(f"Session '{args.session_id}' not found.")
|
||||
except ValueError as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
elif action == "browse":
|
||||
limit = getattr(args, "limit", 50) or 50
|
||||
source = getattr(args, "source", None)
|
||||
sessions = db.list_sessions_rich(source=source, limit=limit)
|
||||
db.close()
|
||||
if not sessions:
|
||||
print("No sessions found.")
|
||||
return
|
||||
|
||||
selected_id = _session_browse_picker(sessions)
|
||||
if not selected_id:
|
||||
print("Cancelled.")
|
||||
return
|
||||
|
||||
# Launch hermes --resume <id> by replacing the current process
|
||||
print(f"Resuming session: {selected_id}")
|
||||
import shutil
|
||||
hermes_bin = shutil.which("hermes")
|
||||
if hermes_bin:
|
||||
os.execvp(hermes_bin, ["hermes", "--resume", selected_id])
|
||||
else:
|
||||
# Fallback: re-invoke via python -m
|
||||
os.execvp(
|
||||
sys.executable,
|
||||
[sys.executable, "-m", "hermes_cli.main", "--resume", selected_id],
|
||||
)
|
||||
return # won't reach here after execvp
|
||||
|
||||
elif action == "stats":
|
||||
total = db.session_count()
|
||||
msgs = db.message_count()
|
||||
@@ -1597,7 +2178,6 @@ For more help on a command:
|
||||
c = db.session_count(source=src)
|
||||
if c > 0:
|
||||
print(f" {src}: {c} sessions")
|
||||
import os
|
||||
db_path = db.db_path
|
||||
if db_path.exists():
|
||||
size_mb = os.path.getsize(db_path) / (1024 * 1024)
|
||||
@@ -1693,6 +2273,8 @@ For more help on a command:
|
||||
args.provider = None
|
||||
args.toolsets = None
|
||||
args.verbose = False
|
||||
if not hasattr(args, "worktree"):
|
||||
args.worktree = False
|
||||
cmd_chat(args)
|
||||
return
|
||||
|
||||
@@ -1704,7 +2286,9 @@ For more help on a command:
|
||||
args.toolsets = None
|
||||
args.verbose = False
|
||||
args.resume = None
|
||||
args.continue_last = False
|
||||
args.continue_last = None
|
||||
if not hasattr(args, "worktree"):
|
||||
args.worktree = False
|
||||
cmd_chat(args)
|
||||
return
|
||||
|
||||
|
||||
@@ -1,10 +1,18 @@
|
||||
"""
|
||||
Canonical list of OpenRouter models offered in CLI and setup wizards.
|
||||
Canonical model catalogs and lightweight validation helpers.
|
||||
|
||||
Add, remove, or reorder entries here — both `hermes setup` and
|
||||
`hermes` provider-selection will pick up the change automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from difflib import get_close_matches
|
||||
from typing import Any, Optional
|
||||
|
||||
# (model_id, display description shown in menus)
|
||||
OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-opus-4.6", "recommended"),
|
||||
@@ -14,17 +22,64 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("minimax/minimax-m2.5", ""),
|
||||
]
|
||||
|
||||
_PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"zai": [
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"glm-4.5",
|
||||
"glm-4.5-flash",
|
||||
],
|
||||
"kimi-coding": [
|
||||
"kimi-k2.5",
|
||||
"kimi-k2-thinking",
|
||||
"kimi-k2-turbo-preview",
|
||||
"kimi-k2-0905-preview",
|
||||
],
|
||||
"minimax": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
"minimax-cn": [
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
],
|
||||
}
|
||||
|
||||
_PROVIDER_LABELS = {
|
||||
"openrouter": "OpenRouter",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"nous": "Nous Portal",
|
||||
"zai": "Z.AI / GLM",
|
||||
"kimi-coding": "Kimi / Moonshot",
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"custom": "custom endpoint",
|
||||
}
|
||||
|
||||
_PROVIDER_ALIASES = {
|
||||
"glm": "zai",
|
||||
"z-ai": "zai",
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
"kimi": "kimi-coding",
|
||||
"moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn",
|
||||
"minimax_cn": "minimax-cn",
|
||||
}
|
||||
|
||||
|
||||
def model_ids() -> list[str]:
|
||||
"""Return just the model-id strings (convenience helper)."""
|
||||
"""Return just the OpenRouter model-id strings."""
|
||||
return [mid for mid, _ in OPENROUTER_MODELS]
|
||||
|
||||
|
||||
@@ -34,3 +89,231 @@ def menu_labels() -> list[str]:
|
||||
for mid, desc in OPENROUTER_MODELS:
|
||||
labels.append(f"{mid} ({desc})" if desc else mid)
|
||||
return labels
|
||||
|
||||
|
||||
# All provider IDs and aliases that are valid for the provider:model syntax.
|
||||
_KNOWN_PROVIDER_NAMES: set[str] = (
|
||||
set(_PROVIDER_LABELS.keys())
|
||||
| set(_PROVIDER_ALIASES.keys())
|
||||
| {"openrouter", "custom"}
|
||||
)
|
||||
|
||||
|
||||
def list_available_providers() -> list[dict[str, str]]:
|
||||
"""Return info about all providers the user could use with ``provider:model``.
|
||||
|
||||
Each dict has ``id``, ``label``, and ``aliases``.
|
||||
Checks which providers have valid credentials configured.
|
||||
"""
|
||||
# Canonical providers in display order
|
||||
_PROVIDER_ORDER = [
|
||||
"openrouter", "nous", "openai-codex",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn",
|
||||
]
|
||||
# Build reverse alias map
|
||||
aliases_for: dict[str, list[str]] = {}
|
||||
for alias, canonical in _PROVIDER_ALIASES.items():
|
||||
aliases_for.setdefault(canonical, []).append(alias)
|
||||
|
||||
result = []
|
||||
for pid in _PROVIDER_ORDER:
|
||||
label = _PROVIDER_LABELS.get(pid, pid)
|
||||
alias_list = aliases_for.get(pid, [])
|
||||
# Check if this provider has credentials available
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=pid)
|
||||
has_creds = bool(runtime.get("api_key"))
|
||||
except Exception:
|
||||
pass
|
||||
result.append({
|
||||
"id": pid,
|
||||
"label": label,
|
||||
"aliases": alias_list,
|
||||
"authenticated": has_creds,
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]:
|
||||
"""Parse ``/model`` input into ``(provider, model)``.
|
||||
|
||||
Supports ``provider:model`` syntax to switch providers at runtime::
|
||||
|
||||
openrouter:anthropic/claude-sonnet-4.5 → ("openrouter", "anthropic/claude-sonnet-4.5")
|
||||
nous:hermes-3 → ("nous", "hermes-3")
|
||||
anthropic/claude-sonnet-4.5 → (current_provider, "anthropic/claude-sonnet-4.5")
|
||||
gpt-5.4 → (current_provider, "gpt-5.4")
|
||||
|
||||
The colon is only treated as a provider delimiter if the left side is a
|
||||
recognized provider name or alias. This avoids misinterpreting model names
|
||||
that happen to contain colons (e.g. ``anthropic/claude-3.5-sonnet:beta``).
|
||||
|
||||
Returns ``(provider, model)`` where *provider* is either the explicit
|
||||
provider from the input or *current_provider* if none was specified.
|
||||
"""
|
||||
stripped = raw.strip()
|
||||
colon = stripped.find(":")
|
||||
if colon > 0:
|
||||
provider_part = stripped[:colon].strip().lower()
|
||||
model_part = stripped[colon + 1:].strip()
|
||||
if provider_part and model_part and provider_part in _KNOWN_PROVIDER_NAMES:
|
||||
return (normalize_provider(provider_part), model_part)
|
||||
return (current_provider, stripped)
|
||||
|
||||
|
||||
def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]:
|
||||
"""Return ``(model_id, description)`` tuples for a provider's curated list."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
return list(OPENROUTER_MODELS)
|
||||
models = _PROVIDER_MODELS.get(normalized, [])
|
||||
return [(m, "") for m in models]
|
||||
|
||||
|
||||
def normalize_provider(provider: Optional[str]) -> str:
|
||||
"""Normalize provider aliases to Hermes' canonical provider ids.
|
||||
|
||||
Note: ``"auto"`` passes through unchanged — use
|
||||
``hermes_cli.auth.resolve_provider()`` to resolve it to a concrete
|
||||
provider based on credentials and environment.
|
||||
"""
|
||||
normalized = (provider or "openrouter").strip().lower()
|
||||
return _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
|
||||
def provider_model_ids(provider: Optional[str]) -> list[str]:
|
||||
"""Return the best known model catalog for a provider."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
return model_ids()
|
||||
if normalized == "openai-codex":
|
||||
from hermes_cli.codex_models import get_codex_model_ids
|
||||
|
||||
return get_codex_model_ids()
|
||||
return list(_PROVIDER_MODELS.get(normalized, []))
|
||||
|
||||
|
||||
def fetch_api_models(
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
timeout: float = 5.0,
|
||||
) -> Optional[list[str]]:
|
||||
"""Fetch the list of available model IDs from the provider's ``/models`` endpoint.
|
||||
|
||||
Returns a list of model ID strings, or ``None`` if the endpoint could not
|
||||
be reached (network error, timeout, auth failure, etc.).
|
||||
"""
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
url = base_url.rstrip("/") + "/models"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
# Standard OpenAI format: {"data": [{"id": "model-name", ...}, ...]}
|
||||
return [m.get("id", "") for m in data.get("data", [])]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def validate_requested_model(
|
||||
model_name: str,
|
||||
provider: Optional[str],
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Validate a ``/model`` value for the active provider.
|
||||
|
||||
Performs format checks first, then probes the live API to confirm
|
||||
the model actually exists.
|
||||
|
||||
Returns a dict with:
|
||||
- accepted: whether the CLI should switch to the requested model now
|
||||
- persist: whether it is safe to save to config
|
||||
- recognized: whether it matched a known provider catalog
|
||||
- message: optional warning / guidance for the user
|
||||
"""
|
||||
requested = (model_name or "").strip()
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter" and base_url and "openrouter.ai" not in base_url:
|
||||
normalized = "custom"
|
||||
|
||||
if not requested:
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": "Model name cannot be empty.",
|
||||
}
|
||||
|
||||
if any(ch.isspace() for ch in requested):
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": "Model names cannot contain spaces.",
|
||||
}
|
||||
|
||||
# Probe the live API to check if the model actually exists
|
||||
api_models = fetch_api_models(api_key, base_url)
|
||||
|
||||
if api_models is not None:
|
||||
if requested in set(api_models):
|
||||
# API confirmed the model exists
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": True,
|
||||
"message": None,
|
||||
}
|
||||
else:
|
||||
# API responded but model is not listed
|
||||
suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5)
|
||||
suggestion_text = ""
|
||||
if suggestions:
|
||||
suggestion_text = "\n Did you mean: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
|
||||
return {
|
||||
"accepted": False,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Error: `{requested}` is not a valid model for this provider."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
# api_models is None — couldn't reach API, fall back to catalog check
|
||||
provider_label = _PROVIDER_LABELS.get(normalized, normalized)
|
||||
known_models = provider_model_ids(normalized)
|
||||
|
||||
if requested in known_models:
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": True,
|
||||
"message": None,
|
||||
}
|
||||
|
||||
# Can't validate — accept for session only
|
||||
suggestion = get_close_matches(requested, known_models, n=1, cutoff=0.6)
|
||||
suggestion_text = f" Did you mean `{suggestion[0]}`?" if suggestion else ""
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": False,
|
||||
"recognized": False,
|
||||
"message": (
|
||||
f"Could not validate `{requested}` against the live {provider_label} API. "
|
||||
"Using it for this session only; config unchanged."
|
||||
f"{suggestion_text}"
|
||||
),
|
||||
}
|
||||
|
||||
@@ -7,10 +7,12 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_cli.auth import (
|
||||
AuthError,
|
||||
PROVIDER_REGISTRY,
|
||||
format_auth_error,
|
||||
resolve_provider,
|
||||
resolve_nous_runtime_credentials,
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_api_key_provider_credentials,
|
||||
)
|
||||
from hermes_cli.config import load_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
@@ -72,12 +74,26 @@ def _resolve_openrouter_runtime(
|
||||
or OPENROUTER_BASE_URL
|
||||
).rstrip("/")
|
||||
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
# Choose API key based on whether the resolved base_url targets OpenRouter.
|
||||
# When hitting OpenRouter, prefer OPENROUTER_API_KEY (issue #289).
|
||||
# When hitting a custom endpoint (e.g. Z.ai, local LLM), prefer
|
||||
# OPENAI_API_KEY so the OpenRouter key doesn't leak to an unrelated
|
||||
# provider (issues #420, #560).
|
||||
_is_openrouter_url = "openrouter.ai" in base_url
|
||||
if _is_openrouter_url:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
else:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or ""
|
||||
)
|
||||
|
||||
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
|
||||
|
||||
@@ -132,6 +148,19 @@ def resolve_runtime_provider(
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
# API-key providers (z.ai/GLM, Kimi, MiniMax, MiniMax-CN)
|
||||
pconfig = PROVIDER_REGISTRY.get(provider)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": creds.get("base_url", "").rstrip("/"),
|
||||
"api_key": creds.get("api_key", ""),
|
||||
"source": creds.get("source", "env"),
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
runtime = _resolve_openrouter_runtime(
|
||||
requested_provider=requested_provider,
|
||||
explicit_api_key=explicit_api_key,
|
||||
|
||||
1679
hermes_cli/setup.py
1679
hermes_cli/setup.py
File diff suppressed because it is too large
Load Diff
@@ -408,10 +408,11 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
|
||||
def do_list(source_filter: str = "all", console: Optional[Console] = None) -> None:
|
||||
"""List installed skills, distinguishing builtins from hub-installed."""
|
||||
from tools.skills_hub import HubLockFile, SKILLS_DIR
|
||||
from tools.skills_hub import HubLockFile, ensure_hub_dirs
|
||||
from tools.skills_tool import _find_all_skills
|
||||
|
||||
c = console or _console
|
||||
ensure_hub_dirs()
|
||||
lock = HubLockFile()
|
||||
hub_installed = {e["name"]: e for e in lock.list_installed()}
|
||||
|
||||
|
||||
@@ -79,8 +79,12 @@ def show_status(args):
|
||||
"OpenRouter": "OPENROUTER_API_KEY",
|
||||
"Anthropic": "ANTHROPIC_API_KEY",
|
||||
"OpenAI": "OPENAI_API_KEY",
|
||||
"Z.AI/GLM": "GLM_API_KEY",
|
||||
"Kimi": "KIMI_API_KEY",
|
||||
"MiniMax": "MINIMAX_API_KEY",
|
||||
"MiniMax-CN": "MINIMAX_CN_API_KEY",
|
||||
"Firecrawl": "FIRECRAWL_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY", # Optional — local browser works without this
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
"WandB": "WANDB_API_KEY",
|
||||
@@ -137,6 +141,28 @@ def show_status(args):
|
||||
if codex_status.get("error") and not codex_logged_in:
|
||||
print(f" Error: {codex_status.get('error')}")
|
||||
|
||||
# =========================================================================
|
||||
# API-Key Providers
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ API-Key Providers", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
apikey_providers = {
|
||||
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
|
||||
"Kimi / Moonshot": ("KIMI_API_KEY",),
|
||||
"MiniMax": ("MINIMAX_API_KEY",),
|
||||
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
|
||||
}
|
||||
for pname, env_vars in apikey_providers.items():
|
||||
key_val = ""
|
||||
for ev in env_vars:
|
||||
key_val = get_env_value(ev) or ""
|
||||
if key_val:
|
||||
break
|
||||
configured = bool(key_val)
|
||||
label = "configured" if configured else "not configured (run: hermes model)"
|
||||
print(f" {pname:<16} {check_mark(configured)} {label}")
|
||||
|
||||
# =========================================================================
|
||||
# Terminal Configuration
|
||||
# =========================================================================
|
||||
@@ -180,6 +206,8 @@ def show_status(args):
|
||||
"Telegram": ("TELEGRAM_BOT_TOKEN", "TELEGRAM_HOME_CHANNEL"),
|
||||
"Discord": ("DISCORD_BOT_TOKEN", "DISCORD_HOME_CHANNEL"),
|
||||
"WhatsApp": ("WHATSAPP_ENABLED", None),
|
||||
"Signal": ("SIGNAL_HTTP_URL", "SIGNAL_HOME_CHANNEL"),
|
||||
"Slack": ("SLACK_BOT_TOKEN", None),
|
||||
}
|
||||
|
||||
for name, (token_var, home_var) in platforms.items():
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""
|
||||
Interactive tool configuration for Hermes Agent.
|
||||
Unified tool configuration for Hermes Agent.
|
||||
|
||||
`hermes tools` and `hermes setup tools` both enter this module.
|
||||
Select a platform → toggle toolsets on/off → for newly enabled tools
|
||||
that need API keys, run through provider-aware configuration.
|
||||
|
||||
`hermes tools` — select a platform, then toggle toolsets on/off via checklist.
|
||||
Saves per-platform tool configuration to ~/.hermes/config.yaml under
|
||||
the `platform_toolsets` key.
|
||||
"""
|
||||
@@ -12,9 +15,63 @@ from typing import Dict, List, Set
|
||||
|
||||
import os
|
||||
|
||||
from hermes_cli.config import load_config, save_config, get_env_value, save_env_value
|
||||
from hermes_cli.config import (
|
||||
load_config, save_config, get_env_value, save_env_value,
|
||||
get_hermes_home,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
|
||||
|
||||
|
||||
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
|
||||
|
||||
def _print_info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def _print_success(text: str):
|
||||
print(color(f"✓ {text}", Colors.GREEN))
|
||||
|
||||
def _print_warning(text: str):
|
||||
print(color(f"⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def _print_error(text: str):
|
||||
print(color(f"✗ {text}", Colors.RED))
|
||||
|
||||
def _prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
if default:
|
||||
display = f"{question} [{default}]: "
|
||||
else:
|
||||
display = f"{question}: "
|
||||
try:
|
||||
if password:
|
||||
import getpass
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default or ""
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default or ""
|
||||
|
||||
def _prompt_yes_no(question: str, default: bool = True) -> bool:
|
||||
default_str = "Y/n" if default else "y/N"
|
||||
while True:
|
||||
try:
|
||||
value = input(color(f"{question} [{default_str}]: ", Colors.YELLOW)).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
if not value:
|
||||
return default
|
||||
if value in ('y', 'yes'):
|
||||
return True
|
||||
if value in ('n', 'no'):
|
||||
return False
|
||||
|
||||
|
||||
# ─── Toolset Registry ─────────────────────────────────────────────────────────
|
||||
|
||||
# Toolsets shown in the configurator, grouped for display.
|
||||
# Each entry: (toolset_name, label, description)
|
||||
# These map to keys in toolsets.py TOOLSETS dict.
|
||||
@@ -39,6 +96,11 @@ CONFIGURABLE_TOOLSETS = [
|
||||
("homeassistant", "🏠 Home Assistant", "smart home device control"),
|
||||
]
|
||||
|
||||
# Toolsets that are OFF by default for new installs.
|
||||
# They're still in _HERMES_CORE_TOOLS (available at runtime if enabled),
|
||||
# but the setup checklist won't pre-select them for first-time users.
|
||||
_DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "rl"}
|
||||
|
||||
# Platform display config
|
||||
PLATFORMS = {
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
@@ -49,6 +111,189 @@ PLATFORMS = {
|
||||
}
|
||||
|
||||
|
||||
# ─── Tool Categories (provider-aware configuration) ──────────────────────────
|
||||
# Maps toolset keys to their provider options. When a toolset is newly enabled,
|
||||
# we use this to show provider selection and prompt for the right API keys.
|
||||
# Toolsets not in this map either need no config or use the simple fallback.
|
||||
|
||||
TOOL_CATEGORIES = {
|
||||
"tts": {
|
||||
"name": "Text-to-Speech",
|
||||
"icon": "🔊",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Microsoft Edge TTS",
|
||||
"tag": "Free - no API key needed",
|
||||
"env_vars": [],
|
||||
"tts_provider": "edge",
|
||||
},
|
||||
{
|
||||
"name": "OpenAI TTS",
|
||||
"tag": "Premium - high quality voices",
|
||||
"env_vars": [
|
||||
{"key": "VOICE_TOOLS_OPENAI_KEY", "prompt": "OpenAI API key", "url": "https://platform.openai.com/api-keys"},
|
||||
],
|
||||
"tts_provider": "openai",
|
||||
},
|
||||
{
|
||||
"name": "ElevenLabs",
|
||||
"tag": "Premium - most natural voices",
|
||||
"env_vars": [
|
||||
{"key": "ELEVENLABS_API_KEY", "prompt": "ElevenLabs API key", "url": "https://elevenlabs.io/app/settings/api-keys"},
|
||||
],
|
||||
"tts_provider": "elevenlabs",
|
||||
},
|
||||
],
|
||||
},
|
||||
"web": {
|
||||
"name": "Web Search & Extract",
|
||||
"setup_title": "Select Search Provider",
|
||||
"setup_note": "A free DuckDuckGo search skill is also included — skip this if you don't need Firecrawl.",
|
||||
"icon": "🔍",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Firecrawl Cloud",
|
||||
"tag": "Recommended - hosted service",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Firecrawl Self-Hosted",
|
||||
"tag": "Free - run your own instance",
|
||||
"env_vars": [
|
||||
{"key": "FIRECRAWL_API_URL", "prompt": "Your Firecrawl instance URL (e.g., http://localhost:3002)"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
"image_gen": {
|
||||
"name": "Image Generation",
|
||||
"icon": "🎨",
|
||||
"providers": [
|
||||
{
|
||||
"name": "FAL.ai",
|
||||
"tag": "FLUX 2 Pro with auto-upscaling",
|
||||
"env_vars": [
|
||||
{"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
"browser": {
|
||||
"name": "Browser Automation",
|
||||
"icon": "🌐",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Local Browser",
|
||||
"tag": "Free headless Chromium (no API key needed)",
|
||||
"env_vars": [],
|
||||
"post_setup": "browserbase", # Same npm install for agent-browser
|
||||
},
|
||||
{
|
||||
"name": "Browserbase",
|
||||
"tag": "Cloud browser with stealth & proxies",
|
||||
"env_vars": [
|
||||
{"key": "BROWSERBASE_API_KEY", "prompt": "Browserbase API key", "url": "https://browserbase.com"},
|
||||
{"key": "BROWSERBASE_PROJECT_ID", "prompt": "Browserbase project ID"},
|
||||
],
|
||||
"post_setup": "browserbase",
|
||||
},
|
||||
],
|
||||
},
|
||||
"homeassistant": {
|
||||
"name": "Smart Home",
|
||||
"icon": "🏠",
|
||||
"providers": [
|
||||
{
|
||||
"name": "Home Assistant",
|
||||
"tag": "REST API integration",
|
||||
"env_vars": [
|
||||
{"key": "HASS_TOKEN", "prompt": "Home Assistant Long-Lived Access Token"},
|
||||
{"key": "HASS_URL", "prompt": "Home Assistant URL", "default": "http://homeassistant.local:8123"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
"rl": {
|
||||
"name": "RL Training",
|
||||
"icon": "🧪",
|
||||
"requires_python": (3, 11),
|
||||
"providers": [
|
||||
{
|
||||
"name": "Tinker / Atropos",
|
||||
"tag": "RL training platform",
|
||||
"env_vars": [
|
||||
{"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"},
|
||||
{"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"},
|
||||
],
|
||||
"post_setup": "rl_training",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Simple env-var requirements for toolsets NOT in TOOL_CATEGORIES.
|
||||
# Used as a fallback for tools like vision/moa that just need an API key.
|
||||
TOOLSET_ENV_REQUIREMENTS = {
|
||||
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
}
|
||||
|
||||
|
||||
# ─── Post-Setup Hooks ─────────────────────────────────────────────────────────
|
||||
|
||||
def _run_post_setup(post_setup_key: str):
|
||||
"""Run post-setup hooks for tools that need extra installation steps."""
|
||||
import shutil
|
||||
if post_setup_key == "browserbase":
|
||||
node_modules = PROJECT_ROOT / "node_modules" / "agent-browser"
|
||||
if not node_modules.exists() and shutil.which("npm"):
|
||||
_print_info(" Installing Node.js dependencies for browser tools...")
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["npm", "install", "--silent"],
|
||||
capture_output=True, text=True, cwd=str(PROJECT_ROOT)
|
||||
)
|
||||
if result.returncode == 0:
|
||||
_print_success(" Node.js dependencies installed")
|
||||
else:
|
||||
_print_warning(" npm install failed - run manually: cd ~/.hermes/hermes-agent && npm install")
|
||||
elif not node_modules.exists():
|
||||
_print_warning(" Node.js not found - browser tools require: npm install (in hermes-agent directory)")
|
||||
|
||||
elif post_setup_key == "rl_training":
|
||||
try:
|
||||
__import__("tinker_atropos")
|
||||
except ImportError:
|
||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
||||
_print_info(" Installing tinker-atropos submodule...")
|
||||
import subprocess
|
||||
uv_bin = shutil.which("uv")
|
||||
if uv_bin:
|
||||
result = subprocess.run(
|
||||
[uv_bin, "pip", "install", "--python", sys.executable, "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.returncode == 0:
|
||||
_print_success(" tinker-atropos installed")
|
||||
else:
|
||||
_print_warning(" tinker-atropos install failed - run manually:")
|
||||
_print_info(' uv pip install -e "./tinker-atropos"')
|
||||
else:
|
||||
_print_warning(" tinker-atropos submodule not found - run:")
|
||||
_print_info(" git submodule update --init --recursive")
|
||||
_print_info(' uv pip install -e "./tinker-atropos"')
|
||||
|
||||
|
||||
# ─── Platform / Toolset Helpers ───────────────────────────────────────────────
|
||||
|
||||
def _get_enabled_platforms() -> List[str]:
|
||||
"""Return platform keys that are configured (have tokens or are CLI)."""
|
||||
enabled = ["cli"]
|
||||
@@ -70,7 +315,7 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
platform_toolsets = config.get("platform_toolsets", {})
|
||||
toolset_names = platform_toolsets.get(platform)
|
||||
|
||||
if not toolset_names or not isinstance(toolset_names, list):
|
||||
if toolset_names is None or not isinstance(toolset_names, list):
|
||||
default_ts = PLATFORMS[platform]["default_toolset"]
|
||||
toolset_names = [default_ts]
|
||||
|
||||
@@ -97,61 +342,117 @@ def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys)."""
|
||||
print(color(question, Colors.YELLOW))
|
||||
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
menu = TerminalMenu(
|
||||
[f" {c}" for c in choices],
|
||||
cursor_index=default,
|
||||
menu_cursor="→ ",
|
||||
menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True,
|
||||
clear_screen=False,
|
||||
)
|
||||
idx = menu.show()
|
||||
if idx is None:
|
||||
sys.exit(0)
|
||||
print()
|
||||
return idx
|
||||
except (ImportError, NotImplementedError):
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {c}", style) if style else f" {marker} {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return idx
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def _toolset_has_keys(ts_key: str) -> bool:
|
||||
"""Check if a toolset's required API keys are configured."""
|
||||
# Check TOOL_CATEGORIES first (provider-aware)
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
if cat:
|
||||
for provider in cat["providers"]:
|
||||
env_vars = provider.get("env_vars", [])
|
||||
if not env_vars:
|
||||
return True # Free provider (e.g., Edge TTS)
|
||||
if all(get_env_value(v["key"]) for v in env_vars):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Fallback to simple requirements
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
return True
|
||||
return all(get_env_value(var) for var, _ in requirements)
|
||||
|
||||
|
||||
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
|
||||
|
||||
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu (arrow keys). Uses curses to avoid simple_term_menu
|
||||
rendering bugs in tmux, iTerm, and other non-standard terminals."""
|
||||
|
||||
# Curses-based single-select — works in tmux, iTerm, and standard terminals
|
||||
try:
|
||||
import curses
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
try:
|
||||
stdscr.addnstr(0, 0, question, max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0))
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, c in enumerate(choices):
|
||||
y = i + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {c}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
|
||||
if key in (curses.KEY_UP, ord('k')):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord('j')):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord('q')):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
return result_holder[0]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback: numbered input (Windows without curses, etc.)
|
||||
print(color(question, Colors.YELLOW))
|
||||
for i, c in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
style = Colors.GREEN if i == default else ""
|
||||
print(color(f" {marker} {i+1}. {c}", style) if style else f" {marker} {i+1}. {c}")
|
||||
while True:
|
||||
try:
|
||||
val = input(color(f" Select [1-{len(choices)}] ({default + 1}): ", Colors.DIM))
|
||||
if not val:
|
||||
return default
|
||||
idx = int(val) - 1
|
||||
if 0 <= idx < len(choices):
|
||||
return idx
|
||||
except (ValueError, KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
|
||||
|
||||
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
|
||||
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
|
||||
import platform as _platform
|
||||
|
||||
labels = []
|
||||
for ts_key, ts_label, ts_desc in CONFIGURABLE_TOOLSETS:
|
||||
suffix = ""
|
||||
if not _toolset_has_keys(ts_key) and TOOLSET_ENV_REQUIREMENTS.get(ts_key):
|
||||
suffix = " ⚠ no API key"
|
||||
if not _toolset_has_keys(ts_key) and (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
|
||||
suffix = " [no API key]"
|
||||
labels.append(f"{ts_label} ({ts_desc}){suffix}")
|
||||
|
||||
pre_selected_indices = [
|
||||
@@ -159,48 +460,8 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
if ts_key in enabled
|
||||
]
|
||||
|
||||
# simple_term_menu multi-select has rendering bugs on macOS terminals,
|
||||
# so we use a curses-based fallback there.
|
||||
use_term_menu = _platform.system() != "Darwin"
|
||||
|
||||
if use_term_menu:
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
|
||||
print(color(f"Tools for {platform_label}", Colors.YELLOW))
|
||||
print(color(" SPACE to toggle, ENTER to confirm.", Colors.DIM))
|
||||
print()
|
||||
|
||||
menu_items = [f" {label}" for label in labels]
|
||||
menu = TerminalMenu(
|
||||
menu_items,
|
||||
multi_select=True,
|
||||
show_multi_select_hint=False,
|
||||
multi_select_cursor="[✓] ",
|
||||
multi_select_select_on_accept=False,
|
||||
multi_select_empty_ok=True,
|
||||
preselected_entries=pre_selected_indices if pre_selected_indices else None,
|
||||
menu_cursor="→ ",
|
||||
menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True,
|
||||
clear_screen=False,
|
||||
clear_menu_on_exit=False,
|
||||
)
|
||||
|
||||
menu.show()
|
||||
|
||||
if menu.chosen_menu_entries is None:
|
||||
return enabled
|
||||
|
||||
selected_indices = list(menu.chosen_menu_indices or [])
|
||||
return {CONFIGURABLE_TOOLSETS[i][0] for i in selected_indices}
|
||||
|
||||
except (ImportError, NotImplementedError):
|
||||
pass # fall through to curses/numbered fallback
|
||||
|
||||
# Curses-based multi-select — arrow keys + space to toggle + enter to confirm.
|
||||
# Used on macOS (where simple_term_menu ghosts) and as a fallback.
|
||||
# simple_term_menu has rendering bugs in tmux, iTerm, and other terminals.
|
||||
try:
|
||||
import curses
|
||||
selected = set(pre_selected_indices)
|
||||
@@ -302,100 +563,400 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
return {CONFIGURABLE_TOOLSETS[i][0] for i in selected}
|
||||
|
||||
|
||||
# Map toolset keys to the env vars they require and where to get them
|
||||
TOOLSET_ENV_REQUIREMENTS = {
|
||||
"web": [("FIRECRAWL_API_KEY", "https://firecrawl.dev/")],
|
||||
"browser": [("BROWSERBASE_API_KEY", "https://browserbase.com/"),
|
||||
("BROWSERBASE_PROJECT_ID", None)],
|
||||
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"image_gen": [("FAL_KEY", "https://fal.ai/")],
|
||||
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
|
||||
"tts": [], # Edge TTS is free, no key needed
|
||||
"rl": [("TINKER_API_KEY", "https://tinker-console.thinkingmachines.ai/keys"),
|
||||
("WANDB_API_KEY", "https://wandb.ai/authorize")],
|
||||
"homeassistant": [("HASS_TOKEN", "Home Assistant > Profile > Long-Lived Access Tokens"),
|
||||
("HASS_URL", None)],
|
||||
}
|
||||
# ─── Provider-Aware Configuration ────────────────────────────────────────────
|
||||
|
||||
def _configure_toolset(ts_key: str, config: dict):
|
||||
"""Configure a toolset - provider selection + API keys.
|
||||
|
||||
Uses TOOL_CATEGORIES for provider-aware config, falls back to simple
|
||||
env var prompts for toolsets not in TOOL_CATEGORIES.
|
||||
"""
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
|
||||
if cat:
|
||||
_configure_tool_category(ts_key, cat, config)
|
||||
else:
|
||||
# Simple fallback for vision, moa, etc.
|
||||
_configure_simple_requirements(ts_key)
|
||||
|
||||
|
||||
def _check_and_prompt_requirements(newly_enabled: Set[str]):
|
||||
"""Check if newly enabled toolsets have missing API keys and offer to set them up."""
|
||||
for ts_key in sorted(newly_enabled):
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
continue
|
||||
def _configure_tool_category(ts_key: str, cat: dict, config: dict):
|
||||
"""Configure a tool category with provider selection."""
|
||||
icon = cat.get("icon", "")
|
||||
name = cat["name"]
|
||||
providers = cat["providers"]
|
||||
|
||||
missing = [(var, url) for var, url in requirements if not get_env_value(var)]
|
||||
if not missing:
|
||||
continue
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" ⚠ {ts_label} requires configuration:", Colors.YELLOW))
|
||||
|
||||
for var, url in missing:
|
||||
if url:
|
||||
print(color(f" {var}", Colors.CYAN) + color(f" ({url})", Colors.DIM))
|
||||
else:
|
||||
print(color(f" {var}", Colors.CYAN))
|
||||
|
||||
print()
|
||||
try:
|
||||
response = input(color(" Set up now? [Y/n] ", Colors.YELLOW)).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Check Python version requirement
|
||||
if cat.get("requires_python"):
|
||||
req = cat["requires_python"]
|
||||
if sys.version_info < req:
|
||||
print()
|
||||
continue
|
||||
_print_error(f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})")
|
||||
_print_info(" Upgrade Python and reinstall to enable this tool.")
|
||||
return
|
||||
|
||||
if response in ("", "y", "yes"):
|
||||
for var, url in missing:
|
||||
if url:
|
||||
print(color(f" Get key at: {url}", Colors.DIM))
|
||||
try:
|
||||
import getpass
|
||||
value = getpass.getpass(color(f" {var}: ", Colors.YELLOW))
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
break
|
||||
if value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
print(color(f" ✓ Saved", Colors.GREEN))
|
||||
if len(providers) == 1:
|
||||
# Single provider - configure directly
|
||||
provider = providers[0]
|
||||
print()
|
||||
print(color(f" --- {icon} {name} ({provider['name']}) ---", Colors.CYAN))
|
||||
if provider.get("tag"):
|
||||
_print_info(f" {provider['tag']}")
|
||||
# For single-provider tools, show a note if available
|
||||
if cat.get("setup_note"):
|
||||
_print_info(f" {cat['setup_note']}")
|
||||
_configure_provider(provider, config)
|
||||
else:
|
||||
# Multiple providers - let user choose
|
||||
print()
|
||||
# Use custom title if provided (e.g. "Select Search Provider")
|
||||
title = cat.get("setup_title", f"Choose a provider")
|
||||
print(color(f" --- {icon} {name} - {title} ---", Colors.CYAN))
|
||||
if cat.get("setup_note"):
|
||||
_print_info(f" {cat['setup_note']}")
|
||||
print()
|
||||
|
||||
# Plain text labels only (no ANSI codes in menu items)
|
||||
provider_choices = []
|
||||
for p in providers:
|
||||
tag = f" ({p['tag']})" if p.get("tag") else ""
|
||||
configured = ""
|
||||
env_vars = p.get("env_vars", [])
|
||||
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
configured = " [active]"
|
||||
elif not env_vars:
|
||||
configured = " [active]" if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") else ""
|
||||
else:
|
||||
print(color(f" Skipped", Colors.DIM))
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
|
||||
# Add skip option
|
||||
provider_choices.append("Skip — keep defaults / configure later")
|
||||
|
||||
# Detect current provider as default
|
||||
default_idx = 0
|
||||
for i, p in enumerate(providers):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
default_idx = i
|
||||
break
|
||||
env_vars = p.get("env_vars", [])
|
||||
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
|
||||
default_idx = i
|
||||
break
|
||||
|
||||
provider_idx = _prompt_choice(f" {title}:", provider_choices, default_idx)
|
||||
|
||||
# Skip selected
|
||||
if provider_idx >= len(providers):
|
||||
_print_info(f" Skipped {name}")
|
||||
return
|
||||
|
||||
_configure_provider(providers[provider_idx], config)
|
||||
|
||||
|
||||
def _configure_provider(provider: dict, config: dict):
|
||||
"""Configure a single provider - prompt for API keys and set config."""
|
||||
env_vars = provider.get("env_vars", [])
|
||||
|
||||
# Set TTS provider in config if applicable
|
||||
if provider.get("tts_provider"):
|
||||
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
|
||||
|
||||
if not env_vars:
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
return
|
||||
|
||||
# Prompt for each required env var
|
||||
all_configured = True
|
||||
for var in env_vars:
|
||||
existing = get_env_value(var["key"])
|
||||
if existing:
|
||||
_print_success(f" {var['key']}: already configured")
|
||||
# Don't ask to update - this is a new enable flow.
|
||||
# Reconfigure is handled separately.
|
||||
else:
|
||||
print(color(" Skipped — configure later with 'hermes setup'", Colors.DIM))
|
||||
url = var.get("url", "")
|
||||
if url:
|
||||
_print_info(f" Get yours at: {url}")
|
||||
|
||||
default_val = var.get("default", "")
|
||||
if default_val:
|
||||
value = _prompt(f" {var.get('prompt', var['key'])}", default_val)
|
||||
else:
|
||||
value = _prompt(f" {var.get('prompt', var['key'])}", password=True)
|
||||
|
||||
if value:
|
||||
save_env_value(var["key"], value)
|
||||
_print_success(f" Saved")
|
||||
else:
|
||||
_print_warning(f" Skipped")
|
||||
all_configured = False
|
||||
|
||||
# Run post-setup hooks if needed
|
||||
if provider.get("post_setup") and all_configured:
|
||||
_run_post_setup(provider["post_setup"])
|
||||
|
||||
if all_configured:
|
||||
_print_success(f" {provider['name']} configured!")
|
||||
|
||||
|
||||
def tools_command(args):
|
||||
"""Entry point for `hermes tools`."""
|
||||
config = load_config()
|
||||
def _configure_simple_requirements(ts_key: str):
|
||||
"""Simple fallback for toolsets that just need env vars (no provider selection)."""
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
missing = [(var, url) for var, url in requirements if not get_env_value(var)]
|
||||
if not missing:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label} requires configuration:", Colors.YELLOW))
|
||||
|
||||
for var, url in missing:
|
||||
if url:
|
||||
_print_info(f" Get key at: {url}")
|
||||
value = _prompt(f" {var}", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(f" Saved")
|
||||
else:
|
||||
_print_warning(f" Skipped")
|
||||
|
||||
|
||||
def _reconfigure_tool(config: dict):
|
||||
"""Let user reconfigure an existing tool's provider or API key."""
|
||||
# Build list of configurable tools that are currently set up
|
||||
configurable = []
|
||||
for ts_key, ts_label, _ in CONFIGURABLE_TOOLSETS:
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
reqs = TOOLSET_ENV_REQUIREMENTS.get(ts_key)
|
||||
if cat or reqs:
|
||||
if _toolset_has_keys(ts_key):
|
||||
configurable.append((ts_key, ts_label))
|
||||
|
||||
if not configurable:
|
||||
_print_info("No configured tools to reconfigure.")
|
||||
return
|
||||
|
||||
choices = [label for _, label in configurable]
|
||||
choices.append("Cancel")
|
||||
|
||||
idx = _prompt_choice(" Which tool would you like to reconfigure?", choices, len(choices) - 1)
|
||||
|
||||
if idx >= len(configurable):
|
||||
return # Cancel
|
||||
|
||||
ts_key, ts_label = configurable[idx]
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
|
||||
if cat:
|
||||
_configure_tool_category_for_reconfig(ts_key, cat, config)
|
||||
else:
|
||||
_reconfigure_simple_requirements(ts_key)
|
||||
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _configure_tool_category_for_reconfig(ts_key: str, cat: dict, config: dict):
|
||||
"""Reconfigure a tool category - provider selection + API key update."""
|
||||
icon = cat.get("icon", "")
|
||||
name = cat["name"]
|
||||
providers = cat["providers"]
|
||||
|
||||
if len(providers) == 1:
|
||||
provider = providers[0]
|
||||
print()
|
||||
print(color(f" --- {icon} {name} ({provider['name']}) ---", Colors.CYAN))
|
||||
_reconfigure_provider(provider, config)
|
||||
else:
|
||||
print()
|
||||
print(color(f" --- {icon} {name} - Choose a provider ---", Colors.CYAN))
|
||||
print()
|
||||
|
||||
provider_choices = []
|
||||
for p in providers:
|
||||
tag = f" ({p['tag']})" if p.get("tag") else ""
|
||||
configured = ""
|
||||
env_vars = p.get("env_vars", [])
|
||||
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
configured = " [active]"
|
||||
elif not env_vars:
|
||||
configured = ""
|
||||
else:
|
||||
configured = " [configured]"
|
||||
provider_choices.append(f"{p['name']}{tag}{configured}")
|
||||
|
||||
default_idx = 0
|
||||
for i, p in enumerate(providers):
|
||||
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
|
||||
default_idx = i
|
||||
break
|
||||
env_vars = p.get("env_vars", [])
|
||||
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
|
||||
default_idx = i
|
||||
break
|
||||
|
||||
provider_idx = _prompt_choice(" Select provider:", provider_choices, default_idx)
|
||||
_reconfigure_provider(providers[provider_idx], config)
|
||||
|
||||
|
||||
def _reconfigure_provider(provider: dict, config: dict):
|
||||
"""Reconfigure a provider - update API keys."""
|
||||
env_vars = provider.get("env_vars", [])
|
||||
|
||||
if provider.get("tts_provider"):
|
||||
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
|
||||
_print_success(f" TTS provider set to: {provider['tts_provider']}")
|
||||
|
||||
if not env_vars:
|
||||
_print_success(f" {provider['name']} - no configuration needed!")
|
||||
return
|
||||
|
||||
for var in env_vars:
|
||||
existing = get_env_value(var["key"])
|
||||
if existing:
|
||||
_print_info(f" {var['key']}: configured ({existing[:8]}...)")
|
||||
url = var.get("url", "")
|
||||
if url:
|
||||
_print_info(f" Get yours at: {url}")
|
||||
default_val = var.get("default", "")
|
||||
value = _prompt(f" {var.get('prompt', var['key'])} (Enter to keep current)", password=not default_val)
|
||||
if value and value.strip():
|
||||
save_env_value(var["key"], value.strip())
|
||||
_print_success(f" Updated")
|
||||
else:
|
||||
_print_info(f" Kept current")
|
||||
|
||||
|
||||
def _reconfigure_simple_requirements(ts_key: str):
|
||||
"""Reconfigure simple env var requirements."""
|
||||
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label}:", Colors.CYAN))
|
||||
|
||||
for var, url in requirements:
|
||||
existing = get_env_value(var)
|
||||
if existing:
|
||||
_print_info(f" {var}: configured ({existing[:8]}...)")
|
||||
if url:
|
||||
_print_info(f" Get key at: {url}")
|
||||
value = _prompt(f" {var} (Enter to keep current)", password=True)
|
||||
if value and value.strip():
|
||||
save_env_value(var, value.strip())
|
||||
_print_success(f" Updated")
|
||||
else:
|
||||
_print_info(f" Kept current")
|
||||
|
||||
|
||||
# ─── Main Entry Point ─────────────────────────────────────────────────────────
|
||||
|
||||
def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
"""Entry point for `hermes tools` and `hermes setup tools`.
|
||||
|
||||
Args:
|
||||
first_install: When True (set by the setup wizard on fresh installs),
|
||||
skip the platform menu, go straight to the CLI checklist, and
|
||||
prompt for API keys on all enabled tools that need them.
|
||||
config: Optional config dict to use. When called from the setup
|
||||
wizard, the wizard passes its own dict so that platform_toolsets
|
||||
are written into it and survive the wizard's final save_config().
|
||||
"""
|
||||
if config is None:
|
||||
config = load_config()
|
||||
enabled_platforms = _get_enabled_platforms()
|
||||
|
||||
print()
|
||||
print(color("⚕ Hermes Tool Configuration", Colors.CYAN, Colors.BOLD))
|
||||
print(color(" Enable or disable tools per platform.", Colors.DIM))
|
||||
print(color(" Tools that need API keys will be configured when enabled.", Colors.DIM))
|
||||
print()
|
||||
|
||||
# ── First-time install: linear flow, no platform menu ──
|
||||
if first_install:
|
||||
for pkey in enabled_platforms:
|
||||
pinfo = PLATFORMS[pkey]
|
||||
current_enabled = _get_platform_tools(config, pkey)
|
||||
|
||||
# Uncheck toolsets that should be off by default
|
||||
checklist_preselected = current_enabled - _DEFAULT_OFF_TOOLSETS
|
||||
|
||||
# Show checklist
|
||||
new_enabled = _prompt_toolset_checklist(pinfo["label"], checklist_preselected)
|
||||
|
||||
added = new_enabled - current_enabled
|
||||
removed = current_enabled - new_enabled
|
||||
if added:
|
||||
for ts in sorted(added):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
print(color(f" + {label}", Colors.GREEN))
|
||||
if removed:
|
||||
for ts in sorted(removed):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
|
||||
# Walk through ALL selected tools that have provider options or
|
||||
# need API keys. This ensures browser (Local vs Browserbase),
|
||||
# TTS (Edge vs OpenAI vs ElevenLabs), etc. are shown even when
|
||||
# a free provider exists.
|
||||
to_configure = [
|
||||
ts_key for ts_key in sorted(new_enabled)
|
||||
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)
|
||||
]
|
||||
|
||||
if to_configure:
|
||||
print()
|
||||
print(color(f" Configuring {len(to_configure)} tool(s):", Colors.YELLOW))
|
||||
for ts_key in to_configure:
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
print(color(f" • {label}", Colors.DIM))
|
||||
print(color(" You can skip any tool you don't need right now.", Colors.DIM))
|
||||
print()
|
||||
for ts_key in to_configure:
|
||||
_configure_toolset(ts_key, config)
|
||||
|
||||
_save_platform_tools(config, pkey, new_enabled)
|
||||
save_config(config)
|
||||
print(color(f" ✓ Saved {pinfo['label']} tool configuration", Colors.GREEN))
|
||||
print()
|
||||
|
||||
return
|
||||
|
||||
# ── Returning user: platform menu loop ──
|
||||
# Build platform choices
|
||||
platform_choices = []
|
||||
platform_keys = []
|
||||
for pkey in enabled_platforms:
|
||||
pinfo = PLATFORMS[pkey]
|
||||
# Count currently enabled toolsets
|
||||
current = _get_platform_tools(config, pkey)
|
||||
count = len(current)
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
platform_choices.append(f"Configure {pinfo['label']} ({count}/{total} enabled)")
|
||||
platform_keys.append(pkey)
|
||||
|
||||
platform_choices.append("Done — save and exit")
|
||||
platform_choices.append("Reconfigure an existing tool's provider or API key")
|
||||
platform_choices.append("Done")
|
||||
|
||||
while True:
|
||||
idx = _prompt_choice("Select a platform to configure:", platform_choices, default=0)
|
||||
idx = _prompt_choice("Select an option:", platform_choices, default=0)
|
||||
|
||||
# "Done" selected
|
||||
if idx == len(platform_keys):
|
||||
if idx == len(platform_keys) + 1:
|
||||
break
|
||||
|
||||
# "Reconfigure" selected
|
||||
if idx == len(platform_keys):
|
||||
_reconfigure_tool(config)
|
||||
print()
|
||||
continue
|
||||
|
||||
pkey = platform_keys[idx]
|
||||
pinfo = PLATFORMS[pkey]
|
||||
|
||||
@@ -418,11 +979,14 @@ def tools_command(args):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
|
||||
# Prompt for missing API keys on newly enabled toolsets
|
||||
if added:
|
||||
_check_and_prompt_requirements(added)
|
||||
# Configure newly enabled toolsets that need API keys
|
||||
for ts_key in sorted(added):
|
||||
if (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
|
||||
if not _toolset_has_keys(ts_key):
|
||||
_configure_toolset(ts_key, config)
|
||||
|
||||
_save_platform_tools(config, pkey, new_enabled)
|
||||
save_config(config)
|
||||
print(color(f" ✓ Saved {pinfo['label']} configuration", Colors.GREEN))
|
||||
else:
|
||||
print(color(f" No changes to {pinfo['label']}", Colors.DIM))
|
||||
|
||||
233
hermes_state.py
233
hermes_state.py
@@ -24,7 +24,7 @@ from typing import Dict, Any, List, Optional
|
||||
|
||||
DEFAULT_DB_PATH = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "state.db"
|
||||
|
||||
SCHEMA_VERSION = 2
|
||||
SCHEMA_VERSION = 4
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
@@ -46,6 +46,7 @@ CREATE TABLE IF NOT EXISTS sessions (
|
||||
tool_call_count INTEGER DEFAULT 0,
|
||||
input_tokens INTEGER DEFAULT 0,
|
||||
output_tokens INTEGER DEFAULT 0,
|
||||
title TEXT,
|
||||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||||
);
|
||||
|
||||
@@ -133,7 +134,33 @@ class SessionDB:
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 2")
|
||||
if current_version < 3:
|
||||
# v3: add title column to sessions
|
||||
try:
|
||||
cursor.execute("ALTER TABLE sessions ADD COLUMN title TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 3")
|
||||
if current_version < 4:
|
||||
# v4: add unique index on title (NULLs allowed, only non-NULL must be unique)
|
||||
try:
|
||||
cursor.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
|
||||
"ON sessions(title) WHERE title IS NOT NULL"
|
||||
)
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 4")
|
||||
|
||||
# Unique title index — always ensure it exists (safe to run after migrations
|
||||
# since the title column is guaranteed to exist at this point)
|
||||
try:
|
||||
cursor.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_title_unique "
|
||||
"ON sessions(title) WHERE title IS NOT NULL"
|
||||
)
|
||||
except sqlite3.OperationalError:
|
||||
pass # Index already exists
|
||||
|
||||
# FTS5 setup (separate because CREATE VIRTUAL TABLE can't be in executescript with IF NOT EXISTS reliably)
|
||||
try:
|
||||
@@ -219,6 +246,210 @@ class SessionDB:
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
# Maximum length for session titles
|
||||
MAX_TITLE_LENGTH = 100
|
||||
|
||||
@staticmethod
|
||||
def sanitize_title(title: Optional[str]) -> Optional[str]:
|
||||
"""Validate and sanitize a session title.
|
||||
|
||||
- Strips leading/trailing whitespace
|
||||
- Removes ASCII control characters (0x00-0x1F, 0x7F) and problematic
|
||||
Unicode control chars (zero-width, RTL/LTR overrides, etc.)
|
||||
- Collapses internal whitespace runs to single spaces
|
||||
- Normalizes empty/whitespace-only strings to None
|
||||
- Enforces MAX_TITLE_LENGTH
|
||||
|
||||
Returns the cleaned title string or None.
|
||||
Raises ValueError if the title exceeds MAX_TITLE_LENGTH after cleaning.
|
||||
"""
|
||||
if not title:
|
||||
return None
|
||||
|
||||
import re
|
||||
|
||||
# Remove ASCII control characters (0x00-0x1F, 0x7F) but keep
|
||||
# whitespace chars (\t=0x09, \n=0x0A, \r=0x0D) so they can be
|
||||
# normalized to spaces by the whitespace collapsing step below
|
||||
cleaned = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', title)
|
||||
|
||||
# Remove problematic Unicode control characters:
|
||||
# - Zero-width chars (U+200B-U+200F, U+FEFF)
|
||||
# - Directional overrides (U+202A-U+202E, U+2066-U+2069)
|
||||
# - Object replacement (U+FFFC), interlinear annotation (U+FFF9-U+FFFB)
|
||||
cleaned = re.sub(
|
||||
r'[\u200b-\u200f\u2028-\u202e\u2060-\u2069\ufeff\ufffc\ufff9-\ufffb]',
|
||||
'', cleaned,
|
||||
)
|
||||
|
||||
# Collapse internal whitespace runs and strip
|
||||
cleaned = re.sub(r'\s+', ' ', cleaned).strip()
|
||||
|
||||
if not cleaned:
|
||||
return None
|
||||
|
||||
if len(cleaned) > SessionDB.MAX_TITLE_LENGTH:
|
||||
raise ValueError(
|
||||
f"Title too long ({len(cleaned)} chars, max {SessionDB.MAX_TITLE_LENGTH})"
|
||||
)
|
||||
|
||||
return cleaned
|
||||
|
||||
def set_session_title(self, session_id: str, title: str) -> bool:
|
||||
"""Set or update a session's title.
|
||||
|
||||
Returns True if session was found and title was set.
|
||||
Raises ValueError if title is already in use by another session,
|
||||
or if the title fails validation (too long, invalid characters).
|
||||
Empty/whitespace-only strings are normalized to None (clearing the title).
|
||||
"""
|
||||
title = self.sanitize_title(title)
|
||||
if title:
|
||||
# Check uniqueness (allow the same session to keep its own title)
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE title = ? AND id != ?",
|
||||
(title, session_id),
|
||||
)
|
||||
conflict = cursor.fetchone()
|
||||
if conflict:
|
||||
raise ValueError(
|
||||
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_session_title(self, session_id: str) -> Optional[str]:
|
||||
"""Get the title for a session, or None."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return row["title"] if row else None
|
||||
|
||||
def get_session_by_title(self, title: str) -> Optional[Dict[str, Any]]:
|
||||
"""Look up a session by exact title. Returns session dict or None."""
|
||||
cursor = self._conn.execute(
|
||||
"SELECT * FROM sessions WHERE title = ?", (title,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def resolve_session_by_title(self, title: str) -> Optional[str]:
|
||||
"""Resolve a title to a session ID, preferring the latest in a lineage.
|
||||
|
||||
If the exact title exists, returns that session's ID.
|
||||
If not, searches for "title #N" variants and returns the latest one.
|
||||
If the exact title exists AND numbered variants exist, returns the
|
||||
latest numbered variant (the most recent continuation).
|
||||
"""
|
||||
# First try exact match
|
||||
exact = self.get_session_by_title(title)
|
||||
|
||||
# Also search for numbered variants: "title #2", "title #3", etc.
|
||||
# Escape SQL LIKE wildcards (%, _) in the title to prevent false matches
|
||||
escaped = title.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id, title, started_at FROM sessions "
|
||||
"WHERE title LIKE ? ESCAPE '\\' ORDER BY started_at DESC",
|
||||
(f"{escaped} #%",),
|
||||
)
|
||||
numbered = cursor.fetchall()
|
||||
|
||||
if numbered:
|
||||
# Return the most recent numbered variant
|
||||
return numbered[0]["id"]
|
||||
elif exact:
|
||||
return exact["id"]
|
||||
return None
|
||||
|
||||
def get_next_title_in_lineage(self, base_title: str) -> str:
|
||||
"""Generate the next title in a lineage (e.g., "my session" → "my session #2").
|
||||
|
||||
Strips any existing " #N" suffix to find the base name, then finds
|
||||
the highest existing number and increments.
|
||||
"""
|
||||
import re
|
||||
# Strip existing #N suffix to find the true base
|
||||
match = re.match(r'^(.*?) #(\d+)$', base_title)
|
||||
if match:
|
||||
base = match.group(1)
|
||||
else:
|
||||
base = base_title
|
||||
|
||||
# Find all existing numbered variants
|
||||
# Escape SQL LIKE wildcards (%, _) in the base to prevent false matches
|
||||
escaped = base.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
cursor = self._conn.execute(
|
||||
"SELECT title FROM sessions WHERE title = ? OR title LIKE ? ESCAPE '\\'",
|
||||
(base, f"{escaped} #%"),
|
||||
)
|
||||
existing = [row["title"] for row in cursor.fetchall()]
|
||||
|
||||
if not existing:
|
||||
return base # No conflict, use the base name as-is
|
||||
|
||||
# Find the highest number
|
||||
max_num = 1 # The unnumbered original counts as #1
|
||||
for t in existing:
|
||||
m = re.match(r'^.* #(\d+)$', t)
|
||||
if m:
|
||||
max_num = max(max_num, int(m.group(1)))
|
||||
|
||||
return f"{base} #{max_num + 1}"
|
||||
|
||||
def list_sessions_rich(
|
||||
self,
|
||||
source: str = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List sessions with preview (first user message) and last active timestamp.
|
||||
|
||||
Returns dicts with keys: id, source, model, title, started_at, ended_at,
|
||||
message_count, preview (first 60 chars of first user message),
|
||||
last_active (timestamp of last message).
|
||||
|
||||
Uses a single query with correlated subqueries instead of N+2 queries.
|
||||
"""
|
||||
source_clause = "WHERE s.source = ?" if source else ""
|
||||
query = f"""
|
||||
SELECT s.*,
|
||||
COALESCE(
|
||||
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
|
||||
FROM messages m
|
||||
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
|
||||
ORDER BY m.timestamp, m.id LIMIT 1),
|
||||
''
|
||||
) AS _preview_raw,
|
||||
COALESCE(
|
||||
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
|
||||
s.started_at
|
||||
) AS last_active
|
||||
FROM sessions s
|
||||
{source_clause}
|
||||
ORDER BY s.started_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
params = (source, limit, offset) if source else (limit, offset)
|
||||
cursor = self._conn.execute(query, params)
|
||||
sessions = []
|
||||
for row in cursor.fetchall():
|
||||
s = dict(row)
|
||||
# Build the preview from the raw substring
|
||||
raw = s.pop("_preview_raw", "").strip()
|
||||
if raw:
|
||||
text = raw[:60]
|
||||
s["preview"] = text + ("..." if len(raw) > 60 else "")
|
||||
else:
|
||||
s["preview"] = ""
|
||||
sessions.append(s)
|
||||
|
||||
return sessions
|
||||
|
||||
# =========================================================================
|
||||
# Message storage
|
||||
# =========================================================================
|
||||
|
||||
119
hermes_time.py
Normal file
119
hermes_time.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Timezone-aware clock for Hermes.
|
||||
|
||||
Provides a single ``now()`` helper that returns a timezone-aware datetime
|
||||
based on the user's configured IANA timezone (e.g. ``Asia/Kolkata``).
|
||||
|
||||
Resolution order:
|
||||
1. ``HERMES_TIMEZONE`` environment variable
|
||||
2. ``timezone`` key in ``~/.hermes/config.yaml``
|
||||
3. Falls back to the server's local time (``datetime.now().astimezone()``)
|
||||
|
||||
Invalid timezone values log a warning and fall back safely — Hermes never
|
||||
crashes due to a bad timezone string.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone as _tz
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
except ImportError:
|
||||
# Python 3.8 fallback (shouldn't be needed — Hermes requires 3.9+)
|
||||
from backports.zoneinfo import ZoneInfo # type: ignore[no-redef]
|
||||
|
||||
# Cached state — resolved once, reused on every call.
|
||||
# Call reset_cache() to force re-resolution (e.g. after config changes).
|
||||
_cached_tz: Optional[ZoneInfo] = None
|
||||
_cached_tz_name: Optional[str] = None
|
||||
_cache_resolved: bool = False
|
||||
|
||||
|
||||
def _resolve_timezone_name() -> str:
|
||||
"""Read the configured IANA timezone string (or empty string).
|
||||
|
||||
This does file I/O when falling through to config.yaml, so callers
|
||||
should cache the result rather than calling on every ``now()``.
|
||||
"""
|
||||
# 1. Environment variable (highest priority — set by Supervisor, etc.)
|
||||
tz_env = os.getenv("HERMES_TIMEZONE", "").strip()
|
||||
if tz_env:
|
||||
return tz_env
|
||||
|
||||
# 2. config.yaml ``timezone`` key
|
||||
try:
|
||||
import yaml
|
||||
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
config_path = hermes_home / "config.yaml"
|
||||
if config_path.exists():
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
tz_cfg = cfg.get("timezone", "")
|
||||
if isinstance(tz_cfg, str) and tz_cfg.strip():
|
||||
return tz_cfg.strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _get_zoneinfo(name: str) -> Optional[ZoneInfo]:
|
||||
"""Validate and return a ZoneInfo, or None if invalid."""
|
||||
if not name:
|
||||
return None
|
||||
try:
|
||||
return ZoneInfo(name)
|
||||
except (KeyError, Exception) as exc:
|
||||
logger.warning(
|
||||
"Invalid timezone '%s': %s. Falling back to server local time.",
|
||||
name, exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def get_timezone() -> Optional[ZoneInfo]:
|
||||
"""Return the user's configured ZoneInfo, or None (meaning server-local).
|
||||
|
||||
Resolved once and cached. Call ``reset_cache()`` after config changes.
|
||||
"""
|
||||
global _cached_tz, _cached_tz_name, _cache_resolved
|
||||
if not _cache_resolved:
|
||||
_cached_tz_name = _resolve_timezone_name()
|
||||
_cached_tz = _get_zoneinfo(_cached_tz_name)
|
||||
_cache_resolved = True
|
||||
return _cached_tz
|
||||
|
||||
|
||||
def get_timezone_name() -> str:
|
||||
"""Return the IANA name of the configured timezone, or empty string."""
|
||||
global _cached_tz_name, _cache_resolved
|
||||
if not _cache_resolved:
|
||||
get_timezone() # populates cache
|
||||
return _cached_tz_name or ""
|
||||
|
||||
|
||||
def now() -> datetime:
|
||||
"""
|
||||
Return the current time as a timezone-aware datetime.
|
||||
|
||||
If a valid timezone is configured, returns wall-clock time in that zone.
|
||||
Otherwise returns the server's local time (via ``astimezone()``).
|
||||
"""
|
||||
tz = get_timezone()
|
||||
if tz is not None:
|
||||
return datetime.now(tz)
|
||||
# No timezone configured — use server-local (still tz-aware)
|
||||
return datetime.now().astimezone()
|
||||
|
||||
|
||||
def reset_cache() -> None:
|
||||
"""Clear the cached timezone. Used by tests and after config changes."""
|
||||
global _cached_tz, _cached_tz_name, _cache_resolved
|
||||
_cached_tz = None
|
||||
_cached_tz_name = None
|
||||
_cache_resolved = False
|
||||
@@ -149,7 +149,7 @@ class MiniSWERunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic/claude-sonnet-4-20250514",
|
||||
model: str = "anthropic/claude-sonnet-4.6",
|
||||
base_url: str = None,
|
||||
api_key: str = None,
|
||||
env_type: str = "local",
|
||||
@@ -200,13 +200,7 @@ class MiniSWERunner:
|
||||
else:
|
||||
client_kwargs["base_url"] = "https://openrouter.ai/api/v1"
|
||||
|
||||
if base_url and "api.anthropic.com" in base_url.strip().lower():
|
||||
raise ValueError(
|
||||
"Anthropic's native /v1/messages API is not supported yet (planned for a future release). "
|
||||
"Hermes currently requires OpenAI-compatible /chat/completions endpoints. "
|
||||
"To use Claude models now, route through OpenRouter (OPENROUTER_API_KEY) "
|
||||
"or any OpenAI-compatible proxy that wraps the Anthropic API."
|
||||
)
|
||||
|
||||
|
||||
# Handle API key - OpenRouter is the primary provider
|
||||
if api_key:
|
||||
|
||||
@@ -225,6 +225,18 @@ def get_tool_definitions(
|
||||
# Ask the registry for schemas (only returns tools whose check_fn passes)
|
||||
filtered_tools = registry.get_definitions(tools_to_include, quiet=quiet_mode)
|
||||
|
||||
# Rebuild execute_code schema to only list sandbox tools that are actually
|
||||
# enabled. Without this, the model sees "web_search is available in
|
||||
# execute_code" even when the user disabled the web toolset (#560-discord).
|
||||
if "execute_code" in tools_to_include:
|
||||
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
|
||||
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
|
||||
dynamic_schema = build_execute_code_schema(sandbox_enabled)
|
||||
for i, td in enumerate(filtered_tools):
|
||||
if td.get("function", {}).get("name") == "execute_code":
|
||||
filtered_tools[i] = {"type": "function", "function": dynamic_schema}
|
||||
break
|
||||
|
||||
if not quiet_mode:
|
||||
if filtered_tools:
|
||||
tool_names = [t["function"]["name"] for t in filtered_tools]
|
||||
|
||||
207
optional-skills/blockchain/solana/SKILL.md
Normal file
207
optional-skills/blockchain/solana/SKILL.md
Normal file
@@ -0,0 +1,207 @@
|
||||
---
|
||||
name: solana
|
||||
description: Query Solana blockchain data with USD pricing — wallet balances, token portfolios with values, transaction details, NFTs, whale detection, and live network stats. Uses Solana RPC + CoinGecko. No API key required.
|
||||
version: 0.2.0
|
||||
author: Deniz Alagoz (gizdusum), enhanced by Hermes Agent
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Solana, Blockchain, Crypto, Web3, RPC, DeFi, NFT]
|
||||
related_skills: []
|
||||
---
|
||||
|
||||
# Solana Blockchain Skill
|
||||
|
||||
Query Solana on-chain data enriched with USD pricing via CoinGecko.
|
||||
8 commands: wallet portfolio, token info, transactions, activity, NFTs,
|
||||
whale detection, network stats, and price lookup.
|
||||
|
||||
No API key needed. Uses only Python standard library (urllib, json, argparse).
|
||||
|
||||
---
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks for a Solana wallet balance, token holdings, or portfolio value
|
||||
- User wants to inspect a specific transaction by signature
|
||||
- User wants SPL token metadata, price, supply, or top holders
|
||||
- User wants recent transaction history for an address
|
||||
- User wants NFTs owned by a wallet
|
||||
- User wants to find large SOL transfers (whale detection)
|
||||
- User wants Solana network health, TPS, epoch, or SOL price
|
||||
- User asks "what's the price of BONK/JUP/SOL?"
|
||||
|
||||
---
|
||||
|
||||
## Prerequisites
|
||||
|
||||
The helper script uses only Python standard library (urllib, json, argparse).
|
||||
No external packages required.
|
||||
|
||||
Pricing data comes from CoinGecko's free API (no key needed, rate-limited
|
||||
to ~10-30 requests/minute). For faster lookups, use `--no-prices` flag.
|
||||
|
||||
---
|
||||
|
||||
## Quick Reference
|
||||
|
||||
RPC endpoint (default): https://api.mainnet-beta.solana.com
|
||||
Override: export SOLANA_RPC_URL=https://your-private-rpc.com
|
||||
|
||||
Helper script path: ~/.hermes/skills/blockchain/solana/scripts/solana_client.py
|
||||
|
||||
```
|
||||
python3 solana_client.py wallet <address> [--limit N] [--all] [--no-prices]
|
||||
python3 solana_client.py tx <signature>
|
||||
python3 solana_client.py token <mint_address>
|
||||
python3 solana_client.py activity <address> [--limit N]
|
||||
python3 solana_client.py nft <address>
|
||||
python3 solana_client.py whales [--min-sol N]
|
||||
python3 solana_client.py stats
|
||||
python3 solana_client.py price <mint_or_symbol>
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Procedure
|
||||
|
||||
### 0. Setup Check
|
||||
|
||||
```bash
|
||||
python3 --version
|
||||
|
||||
# Optional: set a private RPC for better rate limits
|
||||
export SOLANA_RPC_URL="https://api.mainnet-beta.solana.com"
|
||||
|
||||
# Confirm connectivity
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py stats
|
||||
```
|
||||
|
||||
### 1. Wallet Portfolio
|
||||
|
||||
Get SOL balance, SPL token holdings with USD values, NFT count, and
|
||||
portfolio total. Tokens sorted by value, dust filtered, known tokens
|
||||
labeled by name (BONK, JUP, USDC, etc.).
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py \
|
||||
wallet 9WzDXwBbmkg8ZTbNMqUxvQRAyrZzDsGYdLVL9zYtAWWM
|
||||
```
|
||||
|
||||
Flags:
|
||||
- `--limit N` — show top N tokens (default: 20)
|
||||
- `--all` — show all tokens, no dust filter, no limit
|
||||
- `--no-prices` — skip CoinGecko price lookups (faster, RPC-only)
|
||||
|
||||
Output includes: SOL balance + USD value, token list with prices sorted
|
||||
by value, dust count, NFT summary, total portfolio value in USD.
|
||||
|
||||
### 2. Transaction Details
|
||||
|
||||
Inspect a full transaction by its base58 signature. Shows balance changes
|
||||
in both SOL and USD.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py \
|
||||
tx 5j7s8K...your_signature_here
|
||||
```
|
||||
|
||||
Output: slot, timestamp, fee, status, balance changes (SOL + USD),
|
||||
program invocations.
|
||||
|
||||
### 3. Token Info
|
||||
|
||||
Get SPL token metadata, current price, market cap, supply, decimals,
|
||||
mint/freeze authorities, and top 5 holders.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py \
|
||||
token DezXAZ8z7PnrnRJjz3wXBoRgixCa6xjnB7YaB1pPB263
|
||||
```
|
||||
|
||||
Output: name, symbol, decimals, supply, price, market cap, top 5
|
||||
holders with percentages.
|
||||
|
||||
### 4. Recent Activity
|
||||
|
||||
List recent transactions for an address (default: last 10, max: 25).
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py \
|
||||
activity 9WzDXwBbmkg8ZTbNMqUxvQRAyrZzDsGYdLVL9zYtAWWM --limit 25
|
||||
```
|
||||
|
||||
### 5. NFT Portfolio
|
||||
|
||||
List NFTs owned by a wallet (heuristic: SPL tokens with amount=1, decimals=0).
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py \
|
||||
nft 9WzDXwBbmkg8ZTbNMqUxvQRAyrZzDsGYdLVL9zYtAWWM
|
||||
```
|
||||
|
||||
Note: Compressed NFTs (cNFTs) are not detected by this heuristic.
|
||||
|
||||
### 6. Whale Detector
|
||||
|
||||
Scan the most recent block for large SOL transfers with USD values.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py \
|
||||
whales --min-sol 500
|
||||
```
|
||||
|
||||
Note: scans the latest block only — point-in-time snapshot, not historical.
|
||||
|
||||
### 7. Network Stats
|
||||
|
||||
Live Solana network health: current slot, epoch, TPS, supply, validator
|
||||
version, SOL price, and market cap.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py stats
|
||||
```
|
||||
|
||||
### 8. Price Lookup
|
||||
|
||||
Quick price check for any token by mint address or known symbol.
|
||||
|
||||
```bash
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py price BONK
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py price JUP
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py price SOL
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py price DezXAZ8z7PnrnRJjz3wXBoRgixCa6xjnB7YaB1pPB263
|
||||
```
|
||||
|
||||
Known symbols: SOL, USDC, USDT, BONK, JUP, WETH, JTO, mSOL, stSOL,
|
||||
PYTH, HNT, RNDR, WEN, W, TNSR, DRIFT, bSOL, JLP, WIF, MEW, BOME, PENGU.
|
||||
|
||||
---
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- **CoinGecko rate-limits** — free tier allows ~10-30 requests/minute.
|
||||
Price lookups use 1 request per token. Wallets with many tokens may
|
||||
not get prices for all of them. Use `--no-prices` for speed.
|
||||
- **Public RPC rate-limits** — Solana mainnet public RPC limits requests.
|
||||
For production use, set SOLANA_RPC_URL to a private endpoint
|
||||
(Helius, QuickNode, Triton).
|
||||
- **NFT detection is heuristic** — amount=1 + decimals=0. Compressed
|
||||
NFTs (cNFTs) and Token-2022 NFTs won't appear.
|
||||
- **Whale detector scans latest block only** — not historical. Results
|
||||
vary by the moment you query.
|
||||
- **Transaction history** — public RPC keeps ~2 days. Older transactions
|
||||
may not be available.
|
||||
- **Token names** — ~25 well-known tokens are labeled by name. Others
|
||||
show abbreviated mint addresses. Use the `token` command for full info.
|
||||
- **Retry on 429** — both RPC and CoinGecko calls retry up to 2 times
|
||||
with exponential backoff on rate-limit errors.
|
||||
|
||||
---
|
||||
|
||||
## Verification
|
||||
|
||||
```bash
|
||||
# Should print current Solana slot, TPS, and SOL price
|
||||
python3 ~/.hermes/skills/blockchain/solana/scripts/solana_client.py stats
|
||||
```
|
||||
698
optional-skills/blockchain/solana/scripts/solana_client.py
Normal file
698
optional-skills/blockchain/solana/scripts/solana_client.py
Normal file
@@ -0,0 +1,698 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Solana Blockchain CLI Tool for Hermes Agent
|
||||
--------------------------------------------
|
||||
Queries the Solana JSON-RPC API and CoinGecko for enriched on-chain data.
|
||||
Uses only Python standard library — no external packages required.
|
||||
|
||||
Usage:
|
||||
python3 solana_client.py stats
|
||||
python3 solana_client.py wallet <address> [--limit N] [--all] [--no-prices]
|
||||
python3 solana_client.py tx <signature>
|
||||
python3 solana_client.py token <mint_address>
|
||||
python3 solana_client.py activity <address> [--limit N]
|
||||
python3 solana_client.py nft <address>
|
||||
python3 solana_client.py whales [--min-sol N]
|
||||
python3 solana_client.py price <mint_address_or_symbol>
|
||||
|
||||
Environment:
|
||||
SOLANA_RPC_URL Override the default RPC endpoint (default: mainnet-beta public)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
RPC_URL = os.environ.get(
|
||||
"SOLANA_RPC_URL",
|
||||
"https://api.mainnet-beta.solana.com",
|
||||
)
|
||||
|
||||
LAMPORTS_PER_SOL = 1_000_000_000
|
||||
|
||||
# Well-known Solana token names — avoids API calls for common tokens.
|
||||
# Maps mint address → (symbol, name).
|
||||
KNOWN_TOKENS: Dict[str, tuple] = {
|
||||
"So11111111111111111111111111111111111111112": ("SOL", "Solana"),
|
||||
"EPjFWdd5AufqSSqeM2qN1xzybapC8G4wEGGkZwyTDt1v": ("USDC", "USD Coin"),
|
||||
"Es9vMFrzaCERmJfrF4H2FYD4KCoNkY11McCe8BenwNYB": ("USDT", "Tether"),
|
||||
"DezXAZ8z7PnrnRJjz3wXBoRgixCa6xjnB7YaB1pPB263": ("BONK", "Bonk"),
|
||||
"JUPyiwrYJFskUPiHa7hkeR8VUtAeFoSYbKedZNsDvCN": ("JUP", "Jupiter"),
|
||||
"7vfCXTUXx5WJV5JADk17DUJ4ksgau7utNKj4b963voxs": ("WETH", "Wrapped Ether"),
|
||||
"jtojtomepa8beP8AuQc6eXt5FriJwfFMwQx2v2f9mCL": ("JTO", "Jito"),
|
||||
"mSoLzYCxHdYgdzU16g5QSh3i5K3z3KZK7ytfqcJm7So": ("mSOL", "Marinade Staked SOL"),
|
||||
"7dHbWXmci3dT8UFYWYZweBLXgycu7Y3iL6trKn1Y7ARj": ("stSOL", "Lido Staked SOL"),
|
||||
"HZ1JovNiVvGrGNiiYvEozEVgZ58xaU3RKwX8eACQBCt3": ("PYTH", "Pyth Network"),
|
||||
"RLBxxFkseAZ4RgJH3Sqn8jXxhmGoz9jWxDNJMh8pL7a": ("RLBB", "Rollbit"),
|
||||
"hntyVP6YFm1Hg25TN9WGLqM12b8TQmcknKrdu1oxWux": ("HNT", "Helium"),
|
||||
"rndrizKT3MK1iimdxRdWabcF7Zg7AR5T4nud4EkHBof": ("RNDR", "Render"),
|
||||
"WENWENvqqNya429ubCdR81ZmD69brwQaaBYY6p91oHQQ": ("WEN", "Wen"),
|
||||
"85VBFQZC9TZkfaptBWjvUw7YbZjy52A6mjtPGjstQAmQ": ("W", "Wormhole"),
|
||||
"TNSRxcUxoT9xBG3de7PiJyTDYu7kskLqcpddxnEJAS6": ("TNSR", "Tensor"),
|
||||
"DriFtupJYLTosbwoN8koMbEYSx54aFAVLddWsbksjwg7": ("DRIFT", "Drift"),
|
||||
"bSo13r4TkiE4KumL71LsHTPpL2euBYLFx6h9HP3piy1": ("bSOL", "BlazeStake Staked SOL"),
|
||||
"27G8MtK7VtTcCHkpASjSDdkWWYfoqT6ggEuKidVJidD4": ("JLP", "Jupiter LP"),
|
||||
"EKpQGSJtjMFqKZ9KQanSqYXRcF8fBopzLHYxdM65zcjm": ("WIF", "dogwifhat"),
|
||||
"MEW1gQWJ3nEXg2qgERiKu7FAFj79PHvQVREQUzScPP5": ("MEW", "cat in a dogs world"),
|
||||
"ukHH6c7mMyiWCf1b9pnWe25TSpkDDt3H5pQZgZ74J82": ("BOME", "Book of Meme"),
|
||||
"A8C3xuqscfmyLrte3VwJvtPHXvcSN3FjDbUaSMAkQrCS": ("PENGU", "Pudgy Penguins"),
|
||||
}
|
||||
|
||||
# Reverse lookup: symbol → mint (for the `price` command).
|
||||
_SYMBOL_TO_MINT = {v[0].upper(): k for k, v in KNOWN_TOKENS.items()}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP / RPC helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _http_get_json(url: str, timeout: int = 10, retries: int = 2) -> Any:
|
||||
"""GET JSON from a URL with retry on 429 rate-limit. Returns parsed JSON or None."""
|
||||
for attempt in range(retries + 1):
|
||||
req = urllib.request.Request(
|
||||
url, headers={"Accept": "application/json", "User-Agent": "HermesAgent/1.0"},
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
return json.load(resp)
|
||||
except urllib.error.HTTPError as exc:
|
||||
if exc.code == 429 and attempt < retries:
|
||||
time.sleep(2.0 * (attempt + 1))
|
||||
continue
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _rpc_call(method: str, params: list = None, retries: int = 2) -> Any:
|
||||
"""Send a JSON-RPC request with retry on 429 rate-limit."""
|
||||
payload = json.dumps({
|
||||
"jsonrpc": "2.0", "id": 1,
|
||||
"method": method, "params": params or [],
|
||||
}).encode()
|
||||
|
||||
for attempt in range(retries + 1):
|
||||
req = urllib.request.Request(
|
||||
RPC_URL, data=payload,
|
||||
headers={"Content-Type": "application/json"}, method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=20) as resp:
|
||||
body = json.load(resp)
|
||||
if "error" in body:
|
||||
err = body["error"]
|
||||
# Rate-limit: retry after delay
|
||||
if isinstance(err, dict) and err.get("code") == 429:
|
||||
if attempt < retries:
|
||||
time.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
sys.exit(f"RPC error: {err}")
|
||||
return body.get("result")
|
||||
except urllib.error.HTTPError as exc:
|
||||
if exc.code == 429 and attempt < retries:
|
||||
time.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
sys.exit(f"RPC HTTP error: {exc}")
|
||||
except urllib.error.URLError as exc:
|
||||
sys.exit(f"RPC connection error: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
# Keep backward compat — the rest of the code uses `rpc()`.
|
||||
rpc = _rpc_call
|
||||
|
||||
|
||||
def rpc_batch(calls: list) -> list:
|
||||
"""Send a batch of JSON-RPC requests (with retry on 429)."""
|
||||
payload = json.dumps([
|
||||
{"jsonrpc": "2.0", "id": i, "method": c["method"], "params": c.get("params", [])}
|
||||
for i, c in enumerate(calls)
|
||||
]).encode()
|
||||
|
||||
for attempt in range(3):
|
||||
req = urllib.request.Request(
|
||||
RPC_URL, data=payload,
|
||||
headers={"Content-Type": "application/json"}, method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=20) as resp:
|
||||
return json.load(resp)
|
||||
except urllib.error.HTTPError as exc:
|
||||
if exc.code == 429 and attempt < 2:
|
||||
time.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
sys.exit(f"RPC batch HTTP error: {exc}")
|
||||
except urllib.error.URLError as exc:
|
||||
sys.exit(f"RPC batch error: {exc}")
|
||||
return []
|
||||
|
||||
|
||||
def lamports_to_sol(lamports: int) -> float:
|
||||
return lamports / LAMPORTS_PER_SOL
|
||||
|
||||
|
||||
def print_json(obj: Any) -> None:
|
||||
print(json.dumps(obj, indent=2))
|
||||
|
||||
|
||||
def _short_mint(mint: str) -> str:
|
||||
"""Abbreviate a mint address for display: first 4 + last 4."""
|
||||
if len(mint) <= 12:
|
||||
return mint
|
||||
return f"{mint[:4]}...{mint[-4:]}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Price & token name helpers (CoinGecko — free, no API key)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def fetch_prices(mints: List[str], max_lookups: int = 20) -> Dict[str, float]:
|
||||
"""Fetch USD prices for mint addresses via CoinGecko (one per request).
|
||||
|
||||
CoinGecko free tier doesn't support batch Solana token lookups,
|
||||
so we do individual calls — capped at *max_lookups* to stay within
|
||||
rate limits. Returns {mint: usd_price}.
|
||||
"""
|
||||
prices: Dict[str, float] = {}
|
||||
for i, mint in enumerate(mints[:max_lookups]):
|
||||
url = (
|
||||
f"https://api.coingecko.com/api/v3/simple/token_price/solana"
|
||||
f"?contract_addresses={mint}&vs_currencies=usd"
|
||||
)
|
||||
data = _http_get_json(url, timeout=10)
|
||||
if data and isinstance(data, dict):
|
||||
for addr, info in data.items():
|
||||
if isinstance(info, dict) and "usd" in info:
|
||||
prices[mint] = info["usd"]
|
||||
break
|
||||
# Pause between calls to respect CoinGecko free-tier rate-limits
|
||||
if i < len(mints[:max_lookups]) - 1:
|
||||
time.sleep(1.0)
|
||||
return prices
|
||||
|
||||
|
||||
def fetch_sol_price() -> Optional[float]:
|
||||
"""Fetch current SOL price in USD via CoinGecko."""
|
||||
data = _http_get_json(
|
||||
"https://api.coingecko.com/api/v3/simple/price?ids=solana&vs_currencies=usd"
|
||||
)
|
||||
if data and "solana" in data:
|
||||
return data["solana"].get("usd")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_token_name(mint: str) -> Optional[Dict[str, str]]:
|
||||
"""Look up token name and symbol from CoinGecko by mint address.
|
||||
|
||||
Returns {"name": ..., "symbol": ...} or None.
|
||||
"""
|
||||
if mint in KNOWN_TOKENS:
|
||||
sym, name = KNOWN_TOKENS[mint]
|
||||
return {"symbol": sym, "name": name}
|
||||
url = f"https://api.coingecko.com/api/v3/coins/solana/contract/{mint}"
|
||||
data = _http_get_json(url, timeout=10)
|
||||
if data and "symbol" in data:
|
||||
return {"symbol": data["symbol"].upper(), "name": data.get("name", "")}
|
||||
return None
|
||||
|
||||
|
||||
def _token_label(mint: str) -> str:
|
||||
"""Return a human-readable label for a mint: symbol if known, else abbreviated address."""
|
||||
if mint in KNOWN_TOKENS:
|
||||
return KNOWN_TOKENS[mint][0]
|
||||
return _short_mint(mint)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Network Stats
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_stats(_args):
|
||||
"""Live Solana network: slot, epoch, TPS, supply, version, SOL price."""
|
||||
results = rpc_batch([
|
||||
{"method": "getSlot"},
|
||||
{"method": "getEpochInfo"},
|
||||
{"method": "getRecentPerformanceSamples", "params": [1]},
|
||||
{"method": "getSupply"},
|
||||
{"method": "getVersion"},
|
||||
])
|
||||
|
||||
by_id = {r["id"]: r.get("result") for r in results}
|
||||
|
||||
slot = by_id.get(0)
|
||||
epoch_info = by_id.get(1)
|
||||
perf_samples = by_id.get(2)
|
||||
supply = by_id.get(3)
|
||||
version = by_id.get(4)
|
||||
|
||||
tps = None
|
||||
if perf_samples:
|
||||
s = perf_samples[0]
|
||||
tps = round(s["numTransactions"] / s["samplePeriodSecs"], 1)
|
||||
|
||||
total_supply = lamports_to_sol(supply["value"]["total"]) if supply else None
|
||||
circ_supply = lamports_to_sol(supply["value"]["circulating"]) if supply else None
|
||||
|
||||
sol_price = fetch_sol_price()
|
||||
|
||||
out = {
|
||||
"slot": slot,
|
||||
"epoch": epoch_info.get("epoch") if epoch_info else None,
|
||||
"slot_in_epoch": epoch_info.get("slotIndex") if epoch_info else None,
|
||||
"tps": tps,
|
||||
"total_supply_SOL": round(total_supply, 2) if total_supply else None,
|
||||
"circulating_supply_SOL": round(circ_supply, 2) if circ_supply else None,
|
||||
"validator_version": version.get("solana-core") if version else None,
|
||||
}
|
||||
if sol_price is not None:
|
||||
out["sol_price_usd"] = sol_price
|
||||
if circ_supply:
|
||||
out["market_cap_usd"] = round(sol_price * circ_supply, 0)
|
||||
print_json(out)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Wallet Info (enhanced with prices, sorting, filtering)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_wallet(args):
|
||||
"""SOL balance + SPL token holdings with USD values."""
|
||||
address = args.address
|
||||
show_all = getattr(args, "all", False)
|
||||
limit = getattr(args, "limit", 20) or 20
|
||||
skip_prices = getattr(args, "no_prices", False)
|
||||
|
||||
# Fetch SOL balance
|
||||
balance_result = rpc("getBalance", [address])
|
||||
sol_balance = lamports_to_sol(balance_result["value"])
|
||||
|
||||
# Fetch all SPL token accounts
|
||||
token_result = rpc("getTokenAccountsByOwner", [
|
||||
address,
|
||||
{"programId": "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA"},
|
||||
{"encoding": "jsonParsed"},
|
||||
])
|
||||
|
||||
raw_tokens = []
|
||||
for acct in (token_result.get("value") or []):
|
||||
info = acct["account"]["data"]["parsed"]["info"]
|
||||
ta = info["tokenAmount"]
|
||||
amount = float(ta.get("uiAmountString") or 0)
|
||||
if amount > 0:
|
||||
raw_tokens.append({
|
||||
"mint": info["mint"],
|
||||
"amount": amount,
|
||||
"decimals": ta["decimals"],
|
||||
})
|
||||
|
||||
# Separate NFTs (amount=1, decimals=0) from fungible tokens
|
||||
nfts = [t for t in raw_tokens if t["decimals"] == 0 and t["amount"] == 1]
|
||||
fungible = [t for t in raw_tokens if not (t["decimals"] == 0 and t["amount"] == 1)]
|
||||
|
||||
# Fetch prices for fungible tokens (cap lookups to avoid API abuse)
|
||||
sol_price = None
|
||||
prices: Dict[str, float] = {}
|
||||
if not skip_prices and fungible:
|
||||
sol_price = fetch_sol_price()
|
||||
# Prioritize known tokens, then a small sample of unknowns.
|
||||
# CoinGecko free tier = 1 request per mint, so we cap lookups.
|
||||
known_mints = [t["mint"] for t in fungible if t["mint"] in KNOWN_TOKENS]
|
||||
other_mints = [t["mint"] for t in fungible if t["mint"] not in KNOWN_TOKENS][:15]
|
||||
mints_to_price = known_mints + other_mints
|
||||
if mints_to_price:
|
||||
prices = fetch_prices(mints_to_price, max_lookups=30)
|
||||
|
||||
# Enrich tokens with labels and USD values
|
||||
enriched = []
|
||||
dust_count = 0
|
||||
dust_value = 0.0
|
||||
for t in fungible:
|
||||
mint = t["mint"]
|
||||
label = _token_label(mint)
|
||||
usd_price = prices.get(mint)
|
||||
usd_value = round(usd_price * t["amount"], 2) if usd_price else None
|
||||
|
||||
# Filter dust (< $0.01) unless --all
|
||||
if not show_all and usd_value is not None and usd_value < 0.01:
|
||||
dust_count += 1
|
||||
dust_value += usd_value
|
||||
continue
|
||||
|
||||
entry = {"token": label, "mint": mint, "amount": t["amount"]}
|
||||
if usd_price is not None:
|
||||
entry["price_usd"] = usd_price
|
||||
entry["value_usd"] = usd_value
|
||||
enriched.append(entry)
|
||||
|
||||
# Sort: tokens with known USD value first (highest→lowest), then unknowns
|
||||
enriched.sort(key=lambda x: (x.get("value_usd") is not None, x.get("value_usd") or 0), reverse=True)
|
||||
|
||||
# Apply limit unless --all
|
||||
total_tokens = len(enriched)
|
||||
if not show_all and len(enriched) > limit:
|
||||
enriched = enriched[:limit]
|
||||
|
||||
# Compute portfolio total
|
||||
total_usd = sum(t.get("value_usd", 0) for t in enriched)
|
||||
sol_value_usd = round(sol_price * sol_balance, 2) if sol_price else None
|
||||
if sol_value_usd:
|
||||
total_usd += sol_value_usd
|
||||
total_usd += dust_value
|
||||
|
||||
output = {
|
||||
"address": address,
|
||||
"sol_balance": round(sol_balance, 9),
|
||||
}
|
||||
if sol_price:
|
||||
output["sol_price_usd"] = sol_price
|
||||
output["sol_value_usd"] = sol_value_usd
|
||||
output["tokens_shown"] = len(enriched)
|
||||
if total_tokens > len(enriched):
|
||||
output["tokens_hidden"] = total_tokens - len(enriched)
|
||||
output["spl_tokens"] = enriched
|
||||
if dust_count > 0:
|
||||
output["dust_filtered"] = {"count": dust_count, "total_value_usd": round(dust_value, 4)}
|
||||
output["nft_count"] = len(nfts)
|
||||
if nfts:
|
||||
output["nfts"] = [_token_label(n["mint"]) + f" ({_short_mint(n['mint'])})" for n in nfts[:10]]
|
||||
if len(nfts) > 10:
|
||||
output["nfts"].append(f"... and {len(nfts) - 10} more")
|
||||
if total_usd > 0:
|
||||
output["portfolio_total_usd"] = round(total_usd, 2)
|
||||
|
||||
print_json(output)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Transaction Details
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_tx(args):
|
||||
"""Full transaction details by signature."""
|
||||
result = rpc("getTransaction", [
|
||||
args.signature,
|
||||
{"encoding": "jsonParsed", "maxSupportedTransactionVersion": 0},
|
||||
])
|
||||
|
||||
if result is None:
|
||||
sys.exit("Transaction not found (may be too old for public RPC history).")
|
||||
|
||||
meta = result.get("meta", {}) or {}
|
||||
msg = result.get("transaction", {}).get("message", {})
|
||||
account_keys = msg.get("accountKeys", [])
|
||||
|
||||
pre = meta.get("preBalances", [])
|
||||
post = meta.get("postBalances", [])
|
||||
|
||||
balance_changes = []
|
||||
for i, key in enumerate(account_keys):
|
||||
acct_key = key["pubkey"] if isinstance(key, dict) else key
|
||||
if i < len(pre) and i < len(post):
|
||||
change = lamports_to_sol(post[i] - pre[i])
|
||||
if change != 0:
|
||||
balance_changes.append({"account": acct_key, "change_SOL": round(change, 9)})
|
||||
|
||||
programs = []
|
||||
for ix in msg.get("instructions", []):
|
||||
prog = ix.get("programId")
|
||||
if prog is None and "programIdIndex" in ix:
|
||||
k = account_keys[ix["programIdIndex"]]
|
||||
prog = k["pubkey"] if isinstance(k, dict) else k
|
||||
if prog:
|
||||
programs.append(prog)
|
||||
|
||||
# Add USD value for SOL changes
|
||||
sol_price = fetch_sol_price()
|
||||
if sol_price and balance_changes:
|
||||
for bc in balance_changes:
|
||||
bc["change_USD"] = round(bc["change_SOL"] * sol_price, 2)
|
||||
|
||||
print_json({
|
||||
"signature": args.signature,
|
||||
"slot": result.get("slot"),
|
||||
"block_time": result.get("blockTime"),
|
||||
"fee_SOL": lamports_to_sol(meta.get("fee", 0)),
|
||||
"status": "success" if meta.get("err") is None else "failed",
|
||||
"balance_changes": balance_changes,
|
||||
"programs_invoked": list(dict.fromkeys(programs)),
|
||||
})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Token Info (enhanced with name + price)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_token(args):
|
||||
"""SPL token metadata, supply, decimals, price, top holders."""
|
||||
mint = args.mint
|
||||
|
||||
mint_info = rpc("getAccountInfo", [mint, {"encoding": "jsonParsed"}])
|
||||
if mint_info is None or mint_info.get("value") is None:
|
||||
sys.exit("Mint account not found.")
|
||||
|
||||
parsed = mint_info["value"]["data"]["parsed"]["info"]
|
||||
decimals = parsed.get("decimals", 0)
|
||||
supply_raw = int(parsed.get("supply", 0))
|
||||
supply_human = supply_raw / (10 ** decimals) if decimals else supply_raw
|
||||
|
||||
largest = rpc("getTokenLargestAccounts", [mint])
|
||||
holders = []
|
||||
for acct in (largest.get("value") or [])[:5]:
|
||||
amount = float(acct.get("uiAmountString") or 0)
|
||||
pct = round((amount / supply_human * 100), 4) if supply_human > 0 else 0
|
||||
holders.append({
|
||||
"account": acct["address"],
|
||||
"amount": amount,
|
||||
"percent": pct,
|
||||
})
|
||||
|
||||
# Resolve name + price
|
||||
token_meta = resolve_token_name(mint)
|
||||
price_data = fetch_prices([mint])
|
||||
|
||||
out = {"mint": mint}
|
||||
if token_meta:
|
||||
out["name"] = token_meta["name"]
|
||||
out["symbol"] = token_meta["symbol"]
|
||||
out["decimals"] = decimals
|
||||
out["supply"] = round(supply_human, min(decimals, 6))
|
||||
out["mint_authority"] = parsed.get("mintAuthority")
|
||||
out["freeze_authority"] = parsed.get("freezeAuthority")
|
||||
if mint in price_data:
|
||||
out["price_usd"] = price_data[mint]
|
||||
out["market_cap_usd"] = round(price_data[mint] * supply_human, 0)
|
||||
out["top_5_holders"] = holders
|
||||
|
||||
print_json(out)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Recent Activity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_activity(args):
|
||||
"""Recent transaction signatures for an address."""
|
||||
limit = min(args.limit, 25)
|
||||
result = rpc("getSignaturesForAddress", [args.address, {"limit": limit}])
|
||||
|
||||
txs = [
|
||||
{
|
||||
"signature": item["signature"],
|
||||
"slot": item.get("slot"),
|
||||
"block_time": item.get("blockTime"),
|
||||
"err": item.get("err"),
|
||||
}
|
||||
for item in (result or [])
|
||||
]
|
||||
|
||||
print_json({"address": args.address, "transactions": txs})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. NFT Portfolio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_nft(args):
|
||||
"""NFTs owned by a wallet (amount=1 && decimals=0 heuristic)."""
|
||||
result = rpc("getTokenAccountsByOwner", [
|
||||
args.address,
|
||||
{"programId": "TokenkegQfeZyiNwAJbNbGKPFXCWuBvf9Ss623VQ5DA"},
|
||||
{"encoding": "jsonParsed"},
|
||||
])
|
||||
|
||||
nfts = [
|
||||
acct["account"]["data"]["parsed"]["info"]["mint"]
|
||||
for acct in (result.get("value") or [])
|
||||
if acct["account"]["data"]["parsed"]["info"]["tokenAmount"]["decimals"] == 0
|
||||
and int(acct["account"]["data"]["parsed"]["info"]["tokenAmount"]["amount"]) == 1
|
||||
]
|
||||
|
||||
print_json({
|
||||
"address": args.address,
|
||||
"nft_count": len(nfts),
|
||||
"nfts": nfts,
|
||||
"note": "Heuristic only. Compressed NFTs (cNFTs) are not detected.",
|
||||
})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Whale Detector (enhanced with USD values)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_whales(args):
|
||||
"""Scan the latest block for large SOL transfers."""
|
||||
min_lamports = int(args.min_sol * LAMPORTS_PER_SOL)
|
||||
|
||||
slot = rpc("getSlot")
|
||||
block = rpc("getBlock", [
|
||||
slot,
|
||||
{
|
||||
"encoding": "jsonParsed",
|
||||
"transactionDetails": "full",
|
||||
"maxSupportedTransactionVersion": 0,
|
||||
"rewards": False,
|
||||
},
|
||||
])
|
||||
|
||||
if block is None:
|
||||
sys.exit("Could not retrieve latest block.")
|
||||
|
||||
sol_price = fetch_sol_price()
|
||||
|
||||
whales = []
|
||||
for tx in (block.get("transactions") or []):
|
||||
meta = tx.get("meta", {}) or {}
|
||||
if meta.get("err") is not None:
|
||||
continue
|
||||
|
||||
msg = tx["transaction"].get("message", {})
|
||||
account_keys = msg.get("accountKeys", [])
|
||||
pre = meta.get("preBalances", [])
|
||||
post = meta.get("postBalances", [])
|
||||
|
||||
for i in range(len(pre)):
|
||||
change = post[i] - pre[i]
|
||||
if change >= min_lamports:
|
||||
k = account_keys[i]
|
||||
receiver = k["pubkey"] if isinstance(k, dict) else k
|
||||
sender = None
|
||||
for j in range(len(pre)):
|
||||
if pre[j] - post[j] >= min_lamports:
|
||||
sk = account_keys[j]
|
||||
sender = sk["pubkey"] if isinstance(sk, dict) else sk
|
||||
break
|
||||
entry = {
|
||||
"sender": sender,
|
||||
"receiver": receiver,
|
||||
"amount_SOL": round(lamports_to_sol(change), 4),
|
||||
}
|
||||
if sol_price:
|
||||
entry["amount_USD"] = round(lamports_to_sol(change) * sol_price, 2)
|
||||
whales.append(entry)
|
||||
|
||||
out = {
|
||||
"slot": slot,
|
||||
"min_threshold_SOL": args.min_sol,
|
||||
"large_transfers": whales,
|
||||
"note": "Scans latest block only — point-in-time snapshot.",
|
||||
}
|
||||
if sol_price:
|
||||
out["sol_price_usd"] = sol_price
|
||||
print_json(out)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Price Lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_price(args):
|
||||
"""Quick price lookup for a token by mint address or known symbol."""
|
||||
query = args.token
|
||||
|
||||
# Check if it's a known symbol
|
||||
mint = _SYMBOL_TO_MINT.get(query.upper(), query)
|
||||
|
||||
# Try to resolve name
|
||||
token_meta = resolve_token_name(mint)
|
||||
|
||||
# Fetch price
|
||||
prices = fetch_prices([mint])
|
||||
|
||||
out = {"query": query, "mint": mint}
|
||||
if token_meta:
|
||||
out["name"] = token_meta["name"]
|
||||
out["symbol"] = token_meta["symbol"]
|
||||
if mint in prices:
|
||||
out["price_usd"] = prices[mint]
|
||||
else:
|
||||
out["price_usd"] = None
|
||||
out["note"] = "Price not available — token may not be listed on CoinGecko."
|
||||
print_json(out)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="solana_client.py",
|
||||
description="Solana blockchain query tool for Hermes Agent",
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
sub.add_parser("stats", help="Network stats: slot, epoch, TPS, supply, SOL price")
|
||||
|
||||
p_wallet = sub.add_parser("wallet", help="SOL balance + SPL tokens with USD values")
|
||||
p_wallet.add_argument("address")
|
||||
p_wallet.add_argument("--limit", type=int, default=20,
|
||||
help="Max tokens to display (default: 20)")
|
||||
p_wallet.add_argument("--all", action="store_true",
|
||||
help="Show all tokens (no limit, no dust filter)")
|
||||
p_wallet.add_argument("--no-prices", action="store_true",
|
||||
help="Skip price lookups (faster, RPC-only)")
|
||||
|
||||
p_tx = sub.add_parser("tx", help="Transaction details by signature")
|
||||
p_tx.add_argument("signature")
|
||||
|
||||
p_token = sub.add_parser("token", help="SPL token metadata, price, and top holders")
|
||||
p_token.add_argument("mint")
|
||||
|
||||
p_activity = sub.add_parser("activity", help="Recent transactions for an address")
|
||||
p_activity.add_argument("address")
|
||||
p_activity.add_argument("--limit", type=int, default=10,
|
||||
help="Number of transactions (max 25, default 10)")
|
||||
|
||||
p_nft = sub.add_parser("nft", help="NFT portfolio for a wallet")
|
||||
p_nft.add_argument("address")
|
||||
|
||||
p_whales = sub.add_parser("whales", help="Large SOL transfers in the latest block")
|
||||
p_whales.add_argument("--min-sol", type=float, default=1000.0,
|
||||
help="Minimum SOL transfer size (default: 1000)")
|
||||
|
||||
p_price = sub.add_parser("price", help="Quick price lookup by mint or symbol")
|
||||
p_price.add_argument("token", help="Mint address or known symbol (SOL, BONK, JUP, ...)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dispatch = {
|
||||
"stats": cmd_stats,
|
||||
"wallet": cmd_wallet,
|
||||
"tx": cmd_tx,
|
||||
"token": cmd_token,
|
||||
"activity": cmd_activity,
|
||||
"nft": cmd_nft,
|
||||
"whales": cmd_whales,
|
||||
"price": cmd_price,
|
||||
}
|
||||
dispatch[args.command](args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
125
optional-skills/email/agentmail/SKILL.md
Normal file
125
optional-skills/email/agentmail/SKILL.md
Normal file
@@ -0,0 +1,125 @@
|
||||
---
|
||||
name: agentmail
|
||||
description: Give the agent its own dedicated email inbox via AgentMail. Send, receive, and manage email autonomously using agent-owned email addresses (e.g. hermes-agent@agentmail.to).
|
||||
version: 1.0.0
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [email, communication, agentmail, mcp]
|
||||
category: email
|
||||
---
|
||||
|
||||
# AgentMail — Agent-Owned Email Inboxes
|
||||
|
||||
## Requirements
|
||||
|
||||
- **AgentMail API key** (required) — sign up at https://console.agentmail.to (free tier: 3 inboxes, 3,000 emails/month; paid plans from $20/mo)
|
||||
- Node.js 18+ (for the MCP server)
|
||||
|
||||
## When to Use
|
||||
Use this skill when you need to:
|
||||
- Give the agent its own dedicated email address
|
||||
- Send emails autonomously on behalf of the agent
|
||||
- Receive and read incoming emails
|
||||
- Manage email threads and conversations
|
||||
- Sign up for services or authenticate via email
|
||||
- Communicate with other agents or humans via email
|
||||
|
||||
This is NOT for reading the user's personal email (use himalaya or Gmail for that).
|
||||
AgentMail gives the agent its own identity and inbox.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Get an API Key
|
||||
- Go to https://console.agentmail.to
|
||||
- Create an account and generate an API key (starts with `am_`)
|
||||
|
||||
### 2. Configure MCP Server
|
||||
Add to `~/.hermes/config.yaml` (paste your actual key — MCP env vars are not expanded from .env):
|
||||
```yaml
|
||||
mcp_servers:
|
||||
agentmail:
|
||||
command: "npx"
|
||||
args: ["-y", "agentmail-mcp"]
|
||||
env:
|
||||
AGENTMAIL_API_KEY: "am_your_key_here"
|
||||
```
|
||||
|
||||
### 3. Restart Hermes
|
||||
```bash
|
||||
hermes
|
||||
```
|
||||
All 11 AgentMail tools are now available automatically.
|
||||
|
||||
## Available Tools (via MCP)
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `list_inboxes` | List all agent inboxes |
|
||||
| `get_inbox` | Get details of a specific inbox |
|
||||
| `create_inbox` | Create a new inbox (gets a real email address) |
|
||||
| `delete_inbox` | Delete an inbox |
|
||||
| `list_threads` | List email threads in an inbox |
|
||||
| `get_thread` | Get a specific email thread |
|
||||
| `send_message` | Send a new email |
|
||||
| `reply_to_message` | Reply to an existing email |
|
||||
| `forward_message` | Forward an email |
|
||||
| `update_message` | Update message labels/status |
|
||||
| `get_attachment` | Download an email attachment |
|
||||
|
||||
## Procedure
|
||||
|
||||
### Create an inbox and send an email
|
||||
1. Create a dedicated inbox:
|
||||
- Use `create_inbox` with a username (e.g. `hermes-agent`)
|
||||
- The agent gets address: `hermes-agent@agentmail.to`
|
||||
2. Send an email:
|
||||
- Use `send_message` with `inbox_id`, `to`, `subject`, `text`
|
||||
3. Check for replies:
|
||||
- Use `list_threads` to see incoming conversations
|
||||
- Use `get_thread` to read a specific thread
|
||||
|
||||
### Check incoming email
|
||||
1. Use `list_inboxes` to find your inbox ID
|
||||
2. Use `list_threads` with the inbox ID to see conversations
|
||||
3. Use `get_thread` to read a thread and its messages
|
||||
|
||||
### Reply to an email
|
||||
1. Get the thread with `get_thread`
|
||||
2. Use `reply_to_message` with the message ID and your reply text
|
||||
|
||||
## Example Workflows
|
||||
|
||||
**Sign up for a service:**
|
||||
```
|
||||
1. create_inbox (username: "signup-bot")
|
||||
2. Use the inbox address to register on the service
|
||||
3. list_threads to check for verification email
|
||||
4. get_thread to read the verification code
|
||||
```
|
||||
|
||||
**Agent-to-human outreach:**
|
||||
```
|
||||
1. create_inbox (username: "hermes-outreach")
|
||||
2. send_message (to: user@example.com, subject: "Hello", text: "...")
|
||||
3. list_threads to check for replies
|
||||
```
|
||||
|
||||
## Pitfalls
|
||||
- Free tier limited to 3 inboxes and 3,000 emails/month
|
||||
- Emails come from `@agentmail.to` domain on free tier (custom domains on paid plans)
|
||||
- Node.js (18+) is required for the MCP server (`npx -y agentmail-mcp`)
|
||||
- The `mcp` Python package must be installed: `pip install mcp`
|
||||
- Real-time inbound email (webhooks) requires a public server — use `list_threads` polling via cronjob instead for personal use
|
||||
|
||||
## Verification
|
||||
After setup, test with:
|
||||
```
|
||||
hermes --toolsets mcp -q "Create an AgentMail inbox called test-agent and tell me its email address"
|
||||
```
|
||||
You should see the new inbox address returned.
|
||||
|
||||
## References
|
||||
- AgentMail docs: https://docs.agentmail.to/
|
||||
- AgentMail console: https://console.agentmail.to
|
||||
- AgentMail MCP repo: https://github.com/agentmail-to/agentmail-mcp
|
||||
- Pricing: https://www.agentmail.to/pricing
|
||||
441
optional-skills/research/qmd/SKILL.md
Normal file
441
optional-skills/research/qmd/SKILL.md
Normal file
@@ -0,0 +1,441 @@
|
||||
---
|
||||
name: qmd
|
||||
description: Search personal knowledge bases, notes, docs, and meeting transcripts locally using qmd — a hybrid retrieval engine with BM25, vector search, and LLM reranking. Supports CLI and MCP integration.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent + Teknium
|
||||
license: MIT
|
||||
platforms: [macos, linux]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Search, Knowledge-Base, RAG, Notes, MCP, Local-AI]
|
||||
related_skills: [obsidian, native-mcp, arxiv]
|
||||
---
|
||||
|
||||
# QMD — Query Markup Documents
|
||||
|
||||
Local, on-device search engine for personal knowledge bases. Indexes markdown
|
||||
notes, meeting transcripts, documentation, and any text-based files, then
|
||||
provides hybrid search combining keyword matching, semantic understanding, and
|
||||
LLM-powered reranking — all running locally with no cloud dependencies.
|
||||
|
||||
Created by [Tobi Lütke](https://github.com/tobi/qmd). MIT licensed.
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks to search their notes, docs, knowledge base, or meeting transcripts
|
||||
- User wants to find something across a large collection of markdown/text files
|
||||
- User wants semantic search ("find notes about X concept") not just keyword grep
|
||||
- User has already set up qmd collections and wants to query them
|
||||
- User asks to set up a local knowledge base or document search system
|
||||
- Keywords: "search my notes", "find in my docs", "knowledge base", "qmd"
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### Node.js >= 22 (required)
|
||||
|
||||
```bash
|
||||
# Check version
|
||||
node --version # must be >= 22
|
||||
|
||||
# macOS — install or upgrade via Homebrew
|
||||
brew install node@22
|
||||
|
||||
# Linux — use NodeSource or nvm
|
||||
curl -fsSL https://deb.nodesource.com/setup_22.x | sudo -E bash -
|
||||
sudo apt-get install -y nodejs
|
||||
# or with nvm:
|
||||
nvm install 22 && nvm use 22
|
||||
```
|
||||
|
||||
### SQLite with Extension Support (macOS only)
|
||||
|
||||
macOS system SQLite lacks extension loading. Install via Homebrew:
|
||||
|
||||
```bash
|
||||
brew install sqlite
|
||||
```
|
||||
|
||||
### Install qmd
|
||||
|
||||
```bash
|
||||
npm install -g @tobilu/qmd
|
||||
# or with Bun:
|
||||
bun install -g @tobilu/qmd
|
||||
```
|
||||
|
||||
First run auto-downloads 3 local GGUF models (~2GB total):
|
||||
|
||||
| Model | Purpose | Size |
|
||||
|-------|---------|------|
|
||||
| embeddinggemma-300M-Q8_0 | Vector embeddings | ~300MB |
|
||||
| qwen3-reranker-0.6b-q8_0 | Result reranking | ~640MB |
|
||||
| qmd-query-expansion-1.7B | Query expansion | ~1.1GB |
|
||||
|
||||
### Verify Installation
|
||||
|
||||
```bash
|
||||
qmd --version
|
||||
qmd status
|
||||
```
|
||||
|
||||
## Quick Reference
|
||||
|
||||
| Command | What It Does | Speed |
|
||||
|---------|-------------|-------|
|
||||
| `qmd search "query"` | BM25 keyword search (no models) | ~0.2s |
|
||||
| `qmd vsearch "query"` | Semantic vector search (1 model) | ~3s |
|
||||
| `qmd query "query"` | Hybrid + reranking (all 3 models) | ~2-3s warm, ~19s cold |
|
||||
| `qmd get <docid>` | Retrieve full document content | instant |
|
||||
| `qmd multi-get "glob"` | Retrieve multiple files | instant |
|
||||
| `qmd collection add <path> --name <n>` | Add a directory as a collection | instant |
|
||||
| `qmd context add <path> "description"` | Add context metadata to improve retrieval | instant |
|
||||
| `qmd embed` | Generate/update vector embeddings | varies |
|
||||
| `qmd status` | Show index health and collection info | instant |
|
||||
| `qmd mcp` | Start MCP server (stdio) | persistent |
|
||||
| `qmd mcp --http --daemon` | Start MCP server (HTTP, warm models) | persistent |
|
||||
|
||||
## Setup Workflow
|
||||
|
||||
### 1. Add Collections
|
||||
|
||||
Point qmd at directories containing your documents:
|
||||
|
||||
```bash
|
||||
# Add a notes directory
|
||||
qmd collection add ~/notes --name notes
|
||||
|
||||
# Add project docs
|
||||
qmd collection add ~/projects/myproject/docs --name project-docs
|
||||
|
||||
# Add meeting transcripts
|
||||
qmd collection add ~/meetings --name meetings
|
||||
|
||||
# List all collections
|
||||
qmd collection list
|
||||
```
|
||||
|
||||
### 2. Add Context Descriptions
|
||||
|
||||
Context metadata helps the search engine understand what each collection
|
||||
contains. This significantly improves retrieval quality:
|
||||
|
||||
```bash
|
||||
qmd context add qmd://notes "Personal notes, ideas, and journal entries"
|
||||
qmd context add qmd://project-docs "Technical documentation for the main project"
|
||||
qmd context add qmd://meetings "Meeting transcripts and action items from team syncs"
|
||||
```
|
||||
|
||||
### 3. Generate Embeddings
|
||||
|
||||
```bash
|
||||
qmd embed
|
||||
```
|
||||
|
||||
This processes all documents in all collections and generates vector
|
||||
embeddings. Re-run after adding new documents or collections.
|
||||
|
||||
### 4. Verify
|
||||
|
||||
```bash
|
||||
qmd status # shows index health, collection stats, model info
|
||||
```
|
||||
|
||||
## Search Patterns
|
||||
|
||||
### Fast Keyword Search (BM25)
|
||||
|
||||
Best for: exact terms, code identifiers, names, known phrases.
|
||||
No models loaded — near-instant results.
|
||||
|
||||
```bash
|
||||
qmd search "authentication middleware"
|
||||
qmd search "handleError async"
|
||||
```
|
||||
|
||||
### Semantic Vector Search
|
||||
|
||||
Best for: natural language questions, conceptual queries.
|
||||
Loads embedding model (~3s first query).
|
||||
|
||||
```bash
|
||||
qmd vsearch "how does the rate limiter handle burst traffic"
|
||||
qmd vsearch "ideas for improving onboarding flow"
|
||||
```
|
||||
|
||||
### Hybrid Search with Reranking (Best Quality)
|
||||
|
||||
Best for: important queries where quality matters most.
|
||||
Uses all 3 models — query expansion, parallel BM25+vector, reranking.
|
||||
|
||||
```bash
|
||||
qmd query "what decisions were made about the database migration"
|
||||
```
|
||||
|
||||
### Structured Multi-Mode Queries
|
||||
|
||||
Combine different search types in a single query for precision:
|
||||
|
||||
```bash
|
||||
# BM25 for exact term + vector for concept
|
||||
qmd query $'lex: rate limiter\nvec: how does throttling work under load'
|
||||
|
||||
# With query expansion
|
||||
qmd query $'expand: database migration plan\nlex: "schema change"'
|
||||
```
|
||||
|
||||
### Query Syntax (lex/BM25 mode)
|
||||
|
||||
| Syntax | Effect | Example |
|
||||
|--------|--------|---------|
|
||||
| `term` | Prefix match | `perf` matches "performance" |
|
||||
| `"phrase"` | Exact phrase | `"rate limiter"` |
|
||||
| `-term` | Exclude term | `performance -sports` |
|
||||
|
||||
### HyDE (Hypothetical Document Embeddings)
|
||||
|
||||
For complex topics, write what you expect the answer to look like:
|
||||
|
||||
```bash
|
||||
qmd query $'hyde: The migration plan involves three phases. First, we add the new columns without dropping the old ones. Then we backfill data. Finally we cut over and remove legacy columns.'
|
||||
```
|
||||
|
||||
### Scoping to Collections
|
||||
|
||||
```bash
|
||||
qmd search "query" --collection notes
|
||||
qmd query "query" --collection project-docs
|
||||
```
|
||||
|
||||
### Output Formats
|
||||
|
||||
```bash
|
||||
qmd search "query" --json # JSON output (best for parsing)
|
||||
qmd search "query" --limit 5 # Limit results
|
||||
qmd get "#abc123" # Get by document ID
|
||||
qmd get "path/to/file.md" # Get by file path
|
||||
qmd get "file.md:50" -l 100 # Get specific line range
|
||||
qmd multi-get "journals/*.md" --json # Batch retrieve by glob
|
||||
```
|
||||
|
||||
## MCP Integration (Recommended)
|
||||
|
||||
qmd exposes an MCP server that provides search tools directly to
|
||||
Hermes Agent via the native MCP client. This is the preferred
|
||||
integration — once configured, the agent gets qmd tools automatically
|
||||
without needing to load this skill.
|
||||
|
||||
### Option A: Stdio Mode (Simple)
|
||||
|
||||
Add to `~/.hermes/config.yaml`:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
qmd:
|
||||
command: "qmd"
|
||||
args: ["mcp"]
|
||||
timeout: 30
|
||||
connect_timeout: 45
|
||||
```
|
||||
|
||||
This registers tools: `mcp_qmd_search`, `mcp_qmd_vsearch`,
|
||||
`mcp_qmd_deep_search`, `mcp_qmd_get`, `mcp_qmd_status`.
|
||||
|
||||
**Tradeoff:** Models load on first search call (~19s cold start),
|
||||
then stay warm for the session. Acceptable for occasional use.
|
||||
|
||||
### Option B: HTTP Daemon Mode (Fast, Recommended for Heavy Use)
|
||||
|
||||
Start the qmd daemon separately — it keeps models warm in memory:
|
||||
|
||||
```bash
|
||||
# Start daemon (persists across agent restarts)
|
||||
qmd mcp --http --daemon
|
||||
|
||||
# Runs on http://localhost:8181 by default
|
||||
```
|
||||
|
||||
Then configure Hermes Agent to connect via HTTP:
|
||||
|
||||
```yaml
|
||||
mcp_servers:
|
||||
qmd:
|
||||
url: "http://localhost:8181/mcp"
|
||||
timeout: 30
|
||||
```
|
||||
|
||||
**Tradeoff:** Uses ~2GB RAM while running, but every query is fast
|
||||
(~2-3s). Best for users who search frequently.
|
||||
|
||||
### Keeping the Daemon Running
|
||||
|
||||
#### macOS (launchd)
|
||||
|
||||
```bash
|
||||
cat > ~/Library/LaunchAgents/com.qmd.daemon.plist << 'EOF'
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN"
|
||||
"http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>com.qmd.daemon</string>
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>qmd</string>
|
||||
<string>mcp</string>
|
||||
<string>--http</string>
|
||||
<string>--daemon</string>
|
||||
</array>
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
<key>KeepAlive</key>
|
||||
<true/>
|
||||
<key>StandardOutPath</key>
|
||||
<string>/tmp/qmd-daemon.log</string>
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/tmp/qmd-daemon.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
EOF
|
||||
|
||||
launchctl load ~/Library/LaunchAgents/com.qmd.daemon.plist
|
||||
```
|
||||
|
||||
#### Linux (systemd user service)
|
||||
|
||||
```bash
|
||||
mkdir -p ~/.config/systemd/user
|
||||
|
||||
cat > ~/.config/systemd/user/qmd-daemon.service << 'EOF'
|
||||
[Unit]
|
||||
Description=QMD MCP Daemon
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
ExecStart=qmd mcp --http --daemon
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
Environment=PATH=/usr/local/bin:/usr/bin:/bin
|
||||
|
||||
[Install]
|
||||
WantedBy=default.target
|
||||
EOF
|
||||
|
||||
systemctl --user daemon-reload
|
||||
systemctl --user enable --now qmd-daemon
|
||||
systemctl --user status qmd-daemon
|
||||
```
|
||||
|
||||
### MCP Tools Reference
|
||||
|
||||
Once connected, these tools are available as `mcp_qmd_*`:
|
||||
|
||||
| MCP Tool | Maps To | Description |
|
||||
|----------|---------|-------------|
|
||||
| `mcp_qmd_search` | `qmd search` | BM25 keyword search |
|
||||
| `mcp_qmd_vsearch` | `qmd vsearch` | Semantic vector search |
|
||||
| `mcp_qmd_deep_search` | `qmd query` | Hybrid search + reranking |
|
||||
| `mcp_qmd_get` | `qmd get` | Retrieve document by ID or path |
|
||||
| `mcp_qmd_status` | `qmd status` | Index health and stats |
|
||||
|
||||
The MCP tools accept structured JSON queries for multi-mode search:
|
||||
|
||||
```json
|
||||
{
|
||||
"searches": [
|
||||
{"type": "lex", "query": "authentication middleware"},
|
||||
{"type": "vec", "query": "how user login is verified"}
|
||||
],
|
||||
"collections": ["project-docs"],
|
||||
"limit": 10
|
||||
}
|
||||
```
|
||||
|
||||
## CLI Usage (Without MCP)
|
||||
|
||||
When MCP is not configured, use qmd directly via terminal:
|
||||
|
||||
```
|
||||
terminal(command="qmd query 'what was decided about the API redesign' --json", timeout=30)
|
||||
```
|
||||
|
||||
For setup and management tasks, always use terminal:
|
||||
|
||||
```
|
||||
terminal(command="qmd collection add ~/Documents/notes --name notes")
|
||||
terminal(command="qmd context add qmd://notes 'Personal research notes and ideas'")
|
||||
terminal(command="qmd embed")
|
||||
terminal(command="qmd status")
|
||||
```
|
||||
|
||||
## How the Search Pipeline Works
|
||||
|
||||
Understanding the internals helps choose the right search mode:
|
||||
|
||||
1. **Query Expansion** — A fine-tuned 1.7B model generates 2 alternative
|
||||
queries. The original gets 2x weight in fusion.
|
||||
2. **Parallel Retrieval** — BM25 (SQLite FTS5) and vector search run
|
||||
simultaneously across all query variants.
|
||||
3. **RRF Fusion** — Reciprocal Rank Fusion (k=60) merges results.
|
||||
Top-rank bonus: #1 gets +0.05, #2-3 get +0.02.
|
||||
4. **LLM Reranking** — qwen3-reranker scores top 30 candidates (0.0-1.0).
|
||||
5. **Position-Aware Blending** — Ranks 1-3: 75% retrieval / 25% reranker.
|
||||
Ranks 4-10: 60/40. Ranks 11+: 40/60 (trusts reranker more for long tail).
|
||||
|
||||
**Smart Chunking:** Documents are split at natural break points (headings,
|
||||
code blocks, blank lines) targeting ~900 tokens with 15% overlap. Code
|
||||
blocks are never split mid-block.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Always add context descriptions** — `qmd context add` dramatically
|
||||
improves retrieval accuracy. Describe what each collection contains.
|
||||
2. **Re-embed after adding documents** — `qmd embed` must be re-run when
|
||||
new files are added to collections.
|
||||
3. **Use `qmd search` for speed** — when you need fast keyword lookup
|
||||
(code identifiers, exact names), BM25 is instant and needs no models.
|
||||
4. **Use `qmd query` for quality** — when the question is conceptual or
|
||||
the user needs the best possible results, use hybrid search.
|
||||
5. **Prefer MCP integration** — once configured, the agent gets native
|
||||
tools without needing to load this skill each time.
|
||||
6. **Daemon mode for frequent users** — if the user searches their
|
||||
knowledge base regularly, recommend the HTTP daemon setup.
|
||||
7. **First query in structured search gets 2x weight** — put the most
|
||||
important/certain query first when combining lex and vec.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Models downloading on first run"
|
||||
Normal — qmd auto-downloads ~2GB of GGUF models on first use.
|
||||
This is a one-time operation.
|
||||
|
||||
### Cold start latency (~19s)
|
||||
This happens when models aren't loaded in memory. Solutions:
|
||||
- Use HTTP daemon mode (`qmd mcp --http --daemon`) to keep warm
|
||||
- Use `qmd search` (BM25 only) when models aren't needed
|
||||
- MCP stdio mode loads models on first search, stays warm for session
|
||||
|
||||
### macOS: "unable to load extension"
|
||||
Install Homebrew SQLite: `brew install sqlite`
|
||||
Then ensure it's on PATH before system SQLite.
|
||||
|
||||
### "No collections found"
|
||||
Run `qmd collection add <path> --name <name>` to add directories,
|
||||
then `qmd embed` to index them.
|
||||
|
||||
### Embedding model override (CJK/multilingual)
|
||||
Set `QMD_EMBED_MODEL` environment variable for non-English content:
|
||||
```bash
|
||||
export QMD_EMBED_MODEL="your-multilingual-model"
|
||||
```
|
||||
|
||||
## Data Storage
|
||||
|
||||
- **Index & vectors:** `~/.cache/qmd/index.sqlite`
|
||||
- **Models:** Auto-downloaded to local cache on first run
|
||||
- **No cloud dependencies** — everything runs locally
|
||||
|
||||
## References
|
||||
|
||||
- [GitHub: tobi/qmd](https://github.com/tobi/qmd)
|
||||
- [QMD Changelog](https://github.com/tobi/qmd/blob/main/CHANGELOG.md)
|
||||
@@ -50,6 +50,7 @@ pty = ["ptyprocess>=0.7.0"]
|
||||
honcho = ["honcho-ai>=2.0.1"]
|
||||
mcp = ["mcp>=1.2.0"]
|
||||
homeassistant = ["aiohttp>=3.9.0"]
|
||||
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git"]
|
||||
all = [
|
||||
"hermes-agent[modal]",
|
||||
"hermes-agent[daytona]",
|
||||
|
||||
471
run_agent.py
471
run_agent.py
@@ -99,6 +99,46 @@ from agent.trajectory import (
|
||||
)
|
||||
|
||||
|
||||
class IterationBudget:
|
||||
"""Thread-safe shared iteration counter for parent and child agents.
|
||||
|
||||
Tracks total LLM-call iterations consumed across a parent agent and all
|
||||
its subagents. A single ``IterationBudget`` is created by the parent
|
||||
and passed to every child so they share the same cap.
|
||||
|
||||
``execute_code`` (programmatic tool calling) iterations are refunded via
|
||||
:meth:`refund` so they don't eat into the budget.
|
||||
"""
|
||||
|
||||
def __init__(self, max_total: int):
|
||||
self.max_total = max_total
|
||||
self._used = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def consume(self) -> bool:
|
||||
"""Try to consume one iteration. Returns True if allowed."""
|
||||
with self._lock:
|
||||
if self._used >= self.max_total:
|
||||
return False
|
||||
self._used += 1
|
||||
return True
|
||||
|
||||
def refund(self) -> None:
|
||||
"""Give back one iteration (e.g. for execute_code turns)."""
|
||||
with self._lock:
|
||||
if self._used > 0:
|
||||
self._used -= 1
|
||||
|
||||
@property
|
||||
def used(self) -> int:
|
||||
return self._used
|
||||
|
||||
@property
|
||||
def remaining(self) -> int:
|
||||
with self._lock:
|
||||
return max(0, self.max_total - self._used)
|
||||
|
||||
|
||||
class AIAgent:
|
||||
"""
|
||||
AI Agent with tool calling capabilities.
|
||||
@@ -114,7 +154,7 @@ class AIAgent:
|
||||
provider: str = None,
|
||||
api_mode: str = None,
|
||||
model: str = "anthropic/claude-opus-4.6", # OpenRouter format
|
||||
max_iterations: int = 60, # Default tool-calling iterations
|
||||
max_iterations: int = 90, # Default tool-calling iterations (shared with subagents)
|
||||
tool_delay: float = 1.0,
|
||||
enabled_toolsets: List[str] = None,
|
||||
disabled_toolsets: List[str] = None,
|
||||
@@ -142,6 +182,8 @@ class AIAgent:
|
||||
skip_memory: bool = False,
|
||||
session_db=None,
|
||||
honcho_session_key: str = None,
|
||||
iteration_budget: "IterationBudget" = None,
|
||||
fallback_model: Dict[str, Any] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the AI Agent.
|
||||
@@ -152,7 +194,7 @@ class AIAgent:
|
||||
provider (str): Provider identifier (optional; used for telemetry/routing hints)
|
||||
api_mode (str): API mode override: "chat_completions" or "codex_responses"
|
||||
model (str): Model name to use (default: "anthropic/claude-opus-4.6")
|
||||
max_iterations (int): Maximum number of tool calling iterations (default: 60)
|
||||
max_iterations (int): Maximum number of tool calling iterations (default: 90)
|
||||
tool_delay (float): Delay between tool calls in seconds (default: 1.0)
|
||||
enabled_toolsets (List[str]): Only enable tools from these toolsets (optional)
|
||||
disabled_toolsets (List[str]): Disable tools from these toolsets (optional)
|
||||
@@ -172,7 +214,7 @@ class AIAgent:
|
||||
Provided by the platform layer (CLI or gateway). If None, the clarify tool returns an error.
|
||||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_config (Dict): OpenRouter reasoning configuration override (e.g. {"effort": "none"} to disable thinking).
|
||||
If None, defaults to {"enabled": True, "effort": "xhigh"} for OpenRouter. Set to disable/customize reasoning.
|
||||
If None, defaults to {"enabled": True, "effort": "medium"} for OpenRouter. Set to disable/customize reasoning.
|
||||
prefill_messages (List[Dict]): Messages to prepend to conversation history as prefilled context.
|
||||
Useful for injecting a few-shot example or priming the model's response style.
|
||||
Example: [{"role": "user", "content": "Hi!"}, {"role": "assistant", "content": "Hello!"}]
|
||||
@@ -186,6 +228,9 @@ class AIAgent:
|
||||
"""
|
||||
self.model = model
|
||||
self.max_iterations = max_iterations
|
||||
# Shared iteration budget — parent creates, children inherit.
|
||||
# Consumed by every LLM turn across parent + all subagents.
|
||||
self.iteration_budget = iteration_budget or IterationBudget(max_iterations)
|
||||
self.tool_delay = tool_delay
|
||||
self.save_trajectories = save_trajectories
|
||||
self.verbose_logging = verbose_logging
|
||||
@@ -209,13 +254,7 @@ class AIAgent:
|
||||
self.provider = "openai-codex"
|
||||
else:
|
||||
self.api_mode = "chat_completions"
|
||||
if base_url and "api.anthropic.com" in base_url.strip().lower():
|
||||
raise ValueError(
|
||||
"Anthropic's native /v1/messages API is not supported yet (planned for a future release). "
|
||||
"Hermes currently requires OpenAI-compatible /chat/completions endpoints. "
|
||||
"To use Claude models now, route through OpenRouter (OPENROUTER_API_KEY) "
|
||||
"or any OpenAI-compatible proxy that wraps the Anthropic API."
|
||||
)
|
||||
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
@@ -243,7 +282,7 @@ class AIAgent:
|
||||
|
||||
# Model response configuration
|
||||
self.max_tokens = max_tokens # None = use model default
|
||||
self.reasoning_config = reasoning_config # None = use default (xhigh for OpenRouter)
|
||||
self.reasoning_config = reasoning_config # None = use default (medium for OpenRouter)
|
||||
self.prefill_messages = prefill_messages or [] # Prefilled conversation turns
|
||||
|
||||
# Anthropic prompt caching: auto-enabled for Claude models via OpenRouter.
|
||||
@@ -345,6 +384,12 @@ class AIAgent:
|
||||
"X-OpenRouter-Title": "Hermes Agent",
|
||||
"X-OpenRouter-Categories": "productivity,cli-agent",
|
||||
}
|
||||
elif "api.kimi.com" in effective_base.lower():
|
||||
# Kimi Code API requires a recognized coding-agent User-Agent
|
||||
# (see https://github.com/MoonshotAI/kimi-cli)
|
||||
client_kwargs["default_headers"] = {
|
||||
"User-Agent": "KimiCLI/1.0",
|
||||
}
|
||||
|
||||
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
|
||||
try:
|
||||
@@ -362,6 +407,17 @@ class AIAgent:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initialize OpenAI client: {e}")
|
||||
|
||||
# Provider fallback — a single backup model/provider tried when the
|
||||
# primary is exhausted (rate-limit, overload, connection failure).
|
||||
# Config shape: {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}
|
||||
self._fallback_model = fallback_model if isinstance(fallback_model, dict) else None
|
||||
self._fallback_activated = False
|
||||
if self._fallback_model:
|
||||
fb_p = self._fallback_model.get("provider", "")
|
||||
fb_m = self._fallback_model.get("model", "")
|
||||
if fb_p and fb_m and not self.quiet_mode:
|
||||
print(f"🔄 Fallback model: {fb_m} ({fb_p})")
|
||||
|
||||
# Get available tools with filtering
|
||||
self.tools = get_tool_definitions(
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
@@ -1363,7 +1419,8 @@ class AIAgent:
|
||||
if context_files_prompt:
|
||||
prompt_parts.append(context_files_prompt)
|
||||
|
||||
now = datetime.now()
|
||||
from hermes_time import now as _hermes_now
|
||||
now = _hermes_now()
|
||||
prompt_parts.append(
|
||||
f"Conversation started: {now.strftime('%A, %B %d, %Y %I:%M %p')}"
|
||||
)
|
||||
@@ -2018,6 +2075,49 @@ class AIAgent:
|
||||
|
||||
return True
|
||||
|
||||
def _try_refresh_nous_client_credentials(self, *, force: bool = True) -> bool:
|
||||
if self.api_mode != "chat_completions" or self.provider != "nous":
|
||||
return False
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import resolve_nous_runtime_credentials
|
||||
|
||||
creds = resolve_nous_runtime_credentials(
|
||||
min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))),
|
||||
timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")),
|
||||
force_mint=force,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("Nous credential refresh failed: %s", exc)
|
||||
return False
|
||||
|
||||
api_key = creds.get("api_key")
|
||||
base_url = creds.get("base_url")
|
||||
if not isinstance(api_key, str) or not api_key.strip():
|
||||
return False
|
||||
if not isinstance(base_url, str) or not base_url.strip():
|
||||
return False
|
||||
|
||||
self.api_key = api_key.strip()
|
||||
self.base_url = base_url.strip().rstrip("/")
|
||||
self._client_kwargs["api_key"] = self.api_key
|
||||
self._client_kwargs["base_url"] = self.base_url
|
||||
# Nous requests should not inherit OpenRouter-only attribution headers.
|
||||
self._client_kwargs.pop("default_headers", None)
|
||||
|
||||
try:
|
||||
self.client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
self.client = OpenAI(**self._client_kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to rebuild OpenAI client after Nous refresh: %s", exc)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _interruptible_api_call(self, api_kwargs: dict):
|
||||
"""
|
||||
Run the API call in a background thread so the main conversation loop
|
||||
@@ -2058,6 +2158,141 @@ class AIAgent:
|
||||
raise result["error"]
|
||||
return result["response"]
|
||||
|
||||
# ── Provider fallback ──────────────────────────────────────────────────
|
||||
|
||||
# API-key providers: provider → (base_url, [env_var_names])
|
||||
_FALLBACK_API_KEY_PROVIDERS = {
|
||||
"openrouter": (OPENROUTER_BASE_URL, ["OPENROUTER_API_KEY"]),
|
||||
"zai": ("https://api.z.ai/api/paas/v4", ["ZAI_API_KEY", "Z_AI_API_KEY"]),
|
||||
"kimi-coding": ("https://api.moonshot.ai/v1", ["KIMI_API_KEY"]),
|
||||
"minimax": ("https://api.minimax.io/v1", ["MINIMAX_API_KEY"]),
|
||||
"minimax-cn": ("https://api.minimaxi.com/v1", ["MINIMAX_CN_API_KEY"]),
|
||||
}
|
||||
|
||||
# OAuth providers: provider → (resolver_import_path, api_mode)
|
||||
# Each resolver returns {"api_key": ..., "base_url": ...}.
|
||||
_FALLBACK_OAUTH_PROVIDERS = {
|
||||
"openai-codex": ("resolve_codex_runtime_credentials", "codex_responses"),
|
||||
"nous": ("resolve_nous_runtime_credentials", "chat_completions"),
|
||||
}
|
||||
|
||||
def _resolve_fallback_credentials(
|
||||
self, fb_provider: str, fb_config: dict
|
||||
) -> Optional[tuple]:
|
||||
"""Resolve credentials for a fallback provider.
|
||||
|
||||
Returns (api_key, base_url, api_mode) on success, or None on failure.
|
||||
Handles three cases:
|
||||
1. OAuth providers (openai-codex, nous) — call credential resolver
|
||||
2. API-key providers (openrouter, zai, etc.) — read env var
|
||||
3. Custom endpoints — use base_url + api_key_env from config
|
||||
"""
|
||||
# ── 1. OAuth providers ────────────────────────────────────────
|
||||
if fb_provider in self._FALLBACK_OAUTH_PROVIDERS:
|
||||
resolver_name, api_mode = self._FALLBACK_OAUTH_PROVIDERS[fb_provider]
|
||||
try:
|
||||
import hermes_cli.auth as _auth
|
||||
resolver = getattr(_auth, resolver_name)
|
||||
creds = resolver()
|
||||
return creds["api_key"], creds["base_url"], api_mode
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
"Fallback to %s failed (credential resolution): %s",
|
||||
fb_provider, e,
|
||||
)
|
||||
return None
|
||||
|
||||
# ── 2. API-key providers ──────────────────────────────────────
|
||||
fb_key = (fb_config.get("api_key") or "").strip()
|
||||
if not fb_key:
|
||||
key_env = (fb_config.get("api_key_env") or "").strip()
|
||||
if key_env:
|
||||
fb_key = os.getenv(key_env, "")
|
||||
elif fb_provider in self._FALLBACK_API_KEY_PROVIDERS:
|
||||
for env_var in self._FALLBACK_API_KEY_PROVIDERS[fb_provider][1]:
|
||||
fb_key = os.getenv(env_var, "")
|
||||
if fb_key:
|
||||
break
|
||||
if not fb_key:
|
||||
logging.warning(
|
||||
"Fallback model configured but no API key found for provider '%s'",
|
||||
fb_provider,
|
||||
)
|
||||
return None
|
||||
|
||||
# ── 3. Resolve base URL ───────────────────────────────────────
|
||||
fb_base_url = (fb_config.get("base_url") or "").strip()
|
||||
if not fb_base_url and fb_provider in self._FALLBACK_API_KEY_PROVIDERS:
|
||||
fb_base_url = self._FALLBACK_API_KEY_PROVIDERS[fb_provider][0]
|
||||
if not fb_base_url:
|
||||
fb_base_url = OPENROUTER_BASE_URL
|
||||
|
||||
return fb_key, fb_base_url, "chat_completions"
|
||||
|
||||
def _try_activate_fallback(self) -> bool:
|
||||
"""Switch to the configured fallback model/provider.
|
||||
|
||||
Called when the primary model is failing after retries. Swaps the
|
||||
OpenAI client, model slug, and provider in-place so the retry loop
|
||||
can continue with the new backend. One-shot: returns False if
|
||||
already activated or not configured.
|
||||
"""
|
||||
if self._fallback_activated or not self._fallback_model:
|
||||
return False
|
||||
|
||||
fb = self._fallback_model
|
||||
fb_provider = (fb.get("provider") or "").strip().lower()
|
||||
fb_model = (fb.get("model") or "").strip()
|
||||
if not fb_provider or not fb_model:
|
||||
return False
|
||||
|
||||
resolved = self._resolve_fallback_credentials(fb_provider, fb)
|
||||
if resolved is None:
|
||||
return False
|
||||
fb_key, fb_base_url, fb_api_mode = resolved
|
||||
|
||||
# Build new client
|
||||
try:
|
||||
client_kwargs = {"api_key": fb_key, "base_url": fb_base_url}
|
||||
if "openrouter" in fb_base_url.lower():
|
||||
client_kwargs["default_headers"] = {
|
||||
"HTTP-Referer": "https://github.com/NousResearch/hermes-agent",
|
||||
"X-OpenRouter-Title": "Hermes Agent",
|
||||
"X-OpenRouter-Categories": "productivity,cli-agent",
|
||||
}
|
||||
elif "api.kimi.com" in fb_base_url.lower():
|
||||
client_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
|
||||
self.client = OpenAI(**client_kwargs)
|
||||
self._client_kwargs = client_kwargs
|
||||
old_model = self.model
|
||||
self.model = fb_model
|
||||
self.provider = fb_provider
|
||||
self.base_url = fb_base_url
|
||||
self.api_mode = fb_api_mode
|
||||
self._fallback_activated = True
|
||||
|
||||
# Re-evaluate prompt caching for the new provider/model
|
||||
self._use_prompt_caching = (
|
||||
"openrouter" in fb_base_url.lower()
|
||||
and "claude" in fb_model.lower()
|
||||
)
|
||||
|
||||
print(
|
||||
f"{self.log_prefix}🔄 Primary model failed — switching to fallback: "
|
||||
f"{fb_model} via {fb_provider}"
|
||||
)
|
||||
logging.info(
|
||||
"Fallback activated: %s → %s (%s)",
|
||||
old_model, fb_model, fb_provider,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error("Failed to activate fallback model: %s", e)
|
||||
return False
|
||||
|
||||
# ── End provider fallback ──────────────────────────────────────────────
|
||||
|
||||
def _build_api_kwargs(self, api_messages: list) -> dict:
|
||||
"""Build the keyword arguments dict for the active API mode."""
|
||||
if self.api_mode == "codex_responses":
|
||||
@@ -2069,8 +2304,8 @@ class AIAgent:
|
||||
if not instructions:
|
||||
instructions = DEFAULT_AGENT_IDENTITY
|
||||
|
||||
# Resolve reasoning effort: config > default (xhigh)
|
||||
reasoning_effort = "xhigh"
|
||||
# Resolve reasoning effort: config > default (medium)
|
||||
reasoning_effort = "medium"
|
||||
reasoning_enabled = True
|
||||
if self.reasoning_config and isinstance(self.reasoning_config, dict):
|
||||
if self.reasoning_config.get("enabled") is False:
|
||||
@@ -2136,7 +2371,7 @@ class AIAgent:
|
||||
else:
|
||||
extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
"effort": "medium"
|
||||
}
|
||||
|
||||
# Nous Portal product attribution
|
||||
@@ -2396,6 +2631,8 @@ class AIAgent:
|
||||
|
||||
if self._session_db:
|
||||
try:
|
||||
# Propagate title to the new session with auto-numbering
|
||||
old_title = self._session_db.get_session_title(self.session_id)
|
||||
self._session_db.end_session(self.session_id, "compression")
|
||||
old_session_id = self.session_id
|
||||
self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
|
||||
@@ -2405,6 +2642,13 @@ class AIAgent:
|
||||
model=self.model,
|
||||
parent_session_id=old_session_id,
|
||||
)
|
||||
# Auto-number the title for the continuation session
|
||||
if old_title:
|
||||
try:
|
||||
new_title = self._session_db.get_next_title_in_lineage(old_title)
|
||||
self._session_db.set_session_title(self.session_id, new_title)
|
||||
except (ValueError, Exception) as e:
|
||||
logger.debug("Could not propagate title on compression: %s", e)
|
||||
self._session_db.update_system_prompt(self.session_id, new_system_prompt)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB compression split failed: %s", e)
|
||||
@@ -2422,9 +2666,10 @@ class AIAgent:
|
||||
if remaining_calls:
|
||||
print(f"{self.log_prefix}⚡ Interrupt: skipping {len(remaining_calls)} tool call(s)")
|
||||
for skipped_tc in remaining_calls:
|
||||
skipped_name = skipped_tc.function.name
|
||||
skip_msg = {
|
||||
"role": "tool",
|
||||
"content": "[Tool execution cancelled - user interrupted]",
|
||||
"content": f"[Tool execution cancelled — {skipped_name} was skipped due to user interrupt]",
|
||||
"tool_call_id": skipped_tc.id,
|
||||
}
|
||||
messages.append(skip_msg)
|
||||
@@ -2531,7 +2776,6 @@ class AIAgent:
|
||||
context=function_args.get("context"),
|
||||
toolsets=function_args.get("toolsets"),
|
||||
tasks=tasks_arg,
|
||||
model=function_args.get("model"),
|
||||
max_iterations=function_args.get("max_iterations"),
|
||||
parent_agent=self,
|
||||
)
|
||||
@@ -2628,9 +2872,10 @@ class AIAgent:
|
||||
remaining = len(assistant_message.tool_calls) - i
|
||||
print(f"{self.log_prefix}⚡ Interrupt: skipping {remaining} remaining tool call(s)")
|
||||
for skipped_tc in assistant_message.tool_calls[i:]:
|
||||
skipped_name = skipped_tc.function.name
|
||||
skip_msg = {
|
||||
"role": "tool",
|
||||
"content": "[Tool execution skipped - user sent a new message]",
|
||||
"content": f"[Tool execution skipped — {skipped_name} was not started. User sent a new message]",
|
||||
"tool_call_id": skipped_tc.id
|
||||
}
|
||||
messages.append(skip_msg)
|
||||
@@ -2680,7 +2925,7 @@ class AIAgent:
|
||||
else:
|
||||
summary_extra_body["reasoning"] = {
|
||||
"enabled": True,
|
||||
"effort": "xhigh"
|
||||
"effort": "medium"
|
||||
}
|
||||
if _is_nous:
|
||||
summary_extra_body["tags"] = ["product=hermes-agent"]
|
||||
@@ -2792,13 +3037,15 @@ class AIAgent:
|
||||
# Generate unique task_id if not provided to isolate VMs between concurrent tasks
|
||||
effective_task_id = task_id or str(uuid.uuid4())
|
||||
|
||||
# Reset retry counters at the start of each conversation to prevent state leakage
|
||||
# Reset retry counters and iteration budget at the start of each turn
|
||||
# so subagent usage from a previous turn doesn't eat into the next one.
|
||||
self._invalid_tool_retries = 0
|
||||
self._invalid_json_retries = 0
|
||||
self._empty_content_retries = 0
|
||||
self._last_content_with_tools = None
|
||||
self._turns_since_memory = 0
|
||||
self._iters_since_skill = 0
|
||||
self.iteration_budget = IterationBudget(self.max_iterations)
|
||||
|
||||
# Initialize conversation (copy to avoid mutating the caller's list)
|
||||
messages = list(conversation_history) if conversation_history else []
|
||||
@@ -2845,9 +3092,14 @@ class AIAgent:
|
||||
)
|
||||
self._iters_since_skill = 0
|
||||
|
||||
# Honcho prefetch: retrieve user context for system prompt injection
|
||||
# Honcho prefetch: retrieve user context for system prompt injection.
|
||||
# Only on the FIRST turn of a session (empty history). On subsequent
|
||||
# turns the model already has all prior context in its conversation
|
||||
# history, and the Honcho context is baked into the stored system
|
||||
# prompt — re-fetching it would change the system message and break
|
||||
# Anthropic prompt caching.
|
||||
self._honcho_context = ""
|
||||
if self._honcho and self._honcho_session_key:
|
||||
if self._honcho and self._honcho_session_key and not conversation_history:
|
||||
try:
|
||||
self._honcho_context = self._honcho_prefetch(user_message)
|
||||
except Exception as e:
|
||||
@@ -2865,14 +3117,42 @@ class AIAgent:
|
||||
# Built once on first call, reused for all subsequent calls.
|
||||
# Only rebuilt after context compression events (which invalidate
|
||||
# the cache and reload memory from disk).
|
||||
#
|
||||
# For continuing sessions (gateway creates a fresh AIAgent per
|
||||
# message), we load the stored system prompt from the session DB
|
||||
# instead of rebuilding. Rebuilding would pick up memory changes
|
||||
# from disk that the model already knows about (it wrote them!),
|
||||
# producing a different system prompt and breaking the Anthropic
|
||||
# prefix cache.
|
||||
if self._cached_system_prompt is None:
|
||||
self._cached_system_prompt = self._build_system_prompt(system_message)
|
||||
# Store the system prompt snapshot in SQLite
|
||||
if self._session_db:
|
||||
stored_prompt = None
|
||||
if conversation_history and self._session_db:
|
||||
try:
|
||||
self._session_db.update_system_prompt(self.session_id, self._cached_system_prompt)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB update_system_prompt failed: %s", e)
|
||||
session_row = self._session_db.get_session(self.session_id)
|
||||
if session_row:
|
||||
stored_prompt = session_row.get("system_prompt") or None
|
||||
except Exception:
|
||||
pass # Fall through to build fresh
|
||||
|
||||
if stored_prompt:
|
||||
# Continuing session — reuse the exact system prompt from
|
||||
# the previous turn so the Anthropic cache prefix matches.
|
||||
self._cached_system_prompt = stored_prompt
|
||||
else:
|
||||
# First turn of a new session — build from scratch.
|
||||
self._cached_system_prompt = self._build_system_prompt(system_message)
|
||||
# Bake Honcho context into the prompt so it's stable for
|
||||
# the entire session (not re-fetched per turn).
|
||||
if self._honcho_context:
|
||||
self._cached_system_prompt = (
|
||||
self._cached_system_prompt + "\n\n" + self._honcho_context
|
||||
).strip()
|
||||
# Store the system prompt snapshot in SQLite
|
||||
if self._session_db:
|
||||
try:
|
||||
self._session_db.update_system_prompt(self.session_id, self._cached_system_prompt)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB update_system_prompt failed: %s", e)
|
||||
|
||||
active_system_prompt = self._cached_system_prompt
|
||||
|
||||
@@ -2930,7 +3210,7 @@ class AIAgent:
|
||||
# Clear any stale interrupt state at start
|
||||
self.clear_interrupt()
|
||||
|
||||
while api_call_count < self.max_iterations:
|
||||
while api_call_count < self.max_iterations and self.iteration_budget.remaining > 0:
|
||||
# Check for interrupt request (e.g., user sent new message)
|
||||
if self._interrupt_requested:
|
||||
interrupted = True
|
||||
@@ -2939,6 +3219,10 @@ class AIAgent:
|
||||
break
|
||||
|
||||
api_call_count += 1
|
||||
if not self.iteration_budget.consume():
|
||||
if not self.quiet_mode:
|
||||
print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.max_total} total across agent + subagents)")
|
||||
break
|
||||
|
||||
# Fire step_callback for gateway hooks (agent:step event)
|
||||
if self.step_callback is not None:
|
||||
@@ -2993,11 +3277,13 @@ class AIAgent:
|
||||
# Build the final system message: cached prompt + ephemeral system prompt.
|
||||
# The ephemeral part is appended here (not baked into the cached prompt)
|
||||
# so it stays out of the session DB and logs.
|
||||
# Note: Honcho context is baked into _cached_system_prompt on the first
|
||||
# turn and stored in the session DB, so it does NOT need to be injected
|
||||
# here. This keeps the system message identical across all turns in a
|
||||
# session, maximizing Anthropic prompt cache hits.
|
||||
effective_system = active_system_prompt or ""
|
||||
if self.ephemeral_system_prompt:
|
||||
effective_system = (effective_system + "\n\n" + self.ephemeral_system_prompt).strip()
|
||||
if self._honcho_context:
|
||||
effective_system = (effective_system + "\n\n" + self._honcho_context).strip()
|
||||
if effective_system:
|
||||
api_messages = [{"role": "system", "content": effective_system}] + api_messages
|
||||
|
||||
@@ -3015,6 +3301,13 @@ class AIAgent:
|
||||
if self._use_prompt_caching:
|
||||
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
|
||||
|
||||
# Safety net: strip orphaned tool results / add stubs for missing
|
||||
# results before sending to the API. The compressor handles this
|
||||
# during compression, but orphans can also sneak in from session
|
||||
# loading or manual message manipulation.
|
||||
if hasattr(self, 'context_compressor') and self.context_compressor:
|
||||
api_messages = self.context_compressor._sanitize_tool_pairs(api_messages)
|
||||
|
||||
# Calculate approximate request size for logging
|
||||
total_chars = sum(len(str(msg)) for msg in api_messages)
|
||||
approx_tokens = total_chars // 4 # Rough estimate: 4 chars per token
|
||||
@@ -3043,9 +3336,13 @@ class AIAgent:
|
||||
api_start_time = time.time()
|
||||
retry_count = 0
|
||||
max_retries = 6 # Increased to allow longer backoff periods
|
||||
compression_attempts = 0
|
||||
max_compression_attempts = 3
|
||||
codex_auth_retry_attempted = False
|
||||
nous_auth_retry_attempted = False
|
||||
|
||||
finish_reason = "stop"
|
||||
response = None # Guard against UnboundLocalError if all retries fail
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
@@ -3137,6 +3434,10 @@ class AIAgent:
|
||||
print(f"{self.log_prefix} ⏱️ Response time: {api_duration:.2f}s (fast response often indicates rate limiting)")
|
||||
|
||||
if retry_count >= max_retries:
|
||||
# Try fallback before giving up
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
print(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded for invalid responses. Giving up.")
|
||||
logging.error(f"{self.log_prefix}Invalid API response after {max_retries} retries.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
@@ -3161,7 +3462,7 @@ class AIAgent:
|
||||
self._persist_session(messages, conversation_history)
|
||||
self.clear_interrupt()
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"final_response": f"Operation interrupted: retrying API call after rate limit (retry {retry_count}/{max_retries}).",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
@@ -3270,10 +3571,11 @@ class AIAgent:
|
||||
if thinking_spinner:
|
||||
thinking_spinner.stop("")
|
||||
thinking_spinner = None
|
||||
api_elapsed = time.time() - api_start_time
|
||||
print(f"{self.log_prefix}⚡ Interrupted during API call.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
interrupted = True
|
||||
final_response = "Operation interrupted."
|
||||
final_response = f"Operation interrupted: waiting for model response ({api_elapsed:.1f}s elapsed)."
|
||||
break
|
||||
|
||||
except Exception as api_error:
|
||||
@@ -3293,6 +3595,16 @@ class AIAgent:
|
||||
if self._try_refresh_codex_client_credentials(force=True):
|
||||
print(f"{self.log_prefix}🔐 Codex auth refreshed after 401. Retrying request...")
|
||||
continue
|
||||
if (
|
||||
self.api_mode == "chat_completions"
|
||||
and self.provider == "nous"
|
||||
and status_code == 401
|
||||
and not nous_auth_retry_attempted
|
||||
):
|
||||
nous_auth_retry_attempted = True
|
||||
if self._try_refresh_nous_client_credentials(force=True):
|
||||
print(f"{self.log_prefix}🔐 Nous agent key refreshed after 401. Retrying request...")
|
||||
continue
|
||||
|
||||
retry_count += 1
|
||||
elapsed_time = time.time() - api_start_time
|
||||
@@ -3312,7 +3624,7 @@ class AIAgent:
|
||||
self._persist_session(messages, conversation_history)
|
||||
self.clear_interrupt()
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"final_response": f"Operation interrupted: handling API error ({error_type}: {str(api_error)[:80]}).",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
@@ -3331,7 +3643,19 @@ class AIAgent:
|
||||
)
|
||||
|
||||
if is_payload_too_large:
|
||||
print(f"{self.log_prefix}⚠️ Request payload too large (413) - attempting compression...")
|
||||
compression_attempts += 1
|
||||
if compression_attempts > max_compression_attempts:
|
||||
print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached for payload-too-large error.")
|
||||
logging.error(f"{self.log_prefix}413 compression failed after {max_compression_attempts} attempts.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"messages": messages,
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Request payload too large: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
}
|
||||
print(f"{self.log_prefix}⚠️ Request payload too large (413) — compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
|
||||
original_len = len(messages)
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
@@ -3340,6 +3664,7 @@ class AIAgent:
|
||||
|
||||
if len(messages) < original_len:
|
||||
print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
time.sleep(2) # Brief pause between compression retries
|
||||
continue # Retry with compressed messages
|
||||
else:
|
||||
print(f"{self.log_prefix}❌ Payload too large and cannot compress further.")
|
||||
@@ -3385,6 +3710,20 @@ class AIAgent:
|
||||
else:
|
||||
print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...")
|
||||
|
||||
compression_attempts += 1
|
||||
if compression_attempts > max_compression_attempts:
|
||||
print(f"{self.log_prefix}❌ Max compression attempts ({max_compression_attempts}) reached.")
|
||||
logging.error(f"{self.log_prefix}Context compression failed after {max_compression_attempts} attempts.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"messages": messages,
|
||||
"completed": False,
|
||||
"api_calls": api_call_count,
|
||||
"error": f"Context length exceeded: max compression attempts ({max_compression_attempts}) reached.",
|
||||
"partial": True
|
||||
}
|
||||
print(f"{self.log_prefix} 🗜️ Context compression attempt {compression_attempts}/{max_compression_attempts}...")
|
||||
|
||||
original_len = len(messages)
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
messages, system_message, approx_tokens=approx_tokens
|
||||
@@ -3393,6 +3732,7 @@ class AIAgent:
|
||||
if len(messages) < original_len or new_ctx and new_ctx < old_ctx:
|
||||
if len(messages) < original_len:
|
||||
print(f"{self.log_prefix} 🗜️ Compressed {original_len} → {len(messages)} messages, retrying...")
|
||||
time.sleep(2) # Brief pause between compression retries
|
||||
continue # Retry with compressed messages or new tier
|
||||
else:
|
||||
# Can't compress further and already at minimum tier
|
||||
@@ -3422,6 +3762,11 @@ class AIAgent:
|
||||
])) and not is_context_length_error
|
||||
|
||||
if is_client_error:
|
||||
# Try fallback before aborting — a different provider
|
||||
# may not have the same issue (rate limit, auth, etc.)
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
self._dump_api_request_debug(
|
||||
api_kwargs, reason="non_retryable_client_error", error=api_error,
|
||||
)
|
||||
@@ -3439,6 +3784,10 @@ class AIAgent:
|
||||
}
|
||||
|
||||
if retry_count >= max_retries:
|
||||
# Try fallback before giving up entirely
|
||||
if self._try_activate_fallback():
|
||||
retry_count = 0
|
||||
continue
|
||||
print(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.")
|
||||
logging.error(f"{self.log_prefix}API call failed after {max_retries} retries. Last error: {api_error}")
|
||||
logging.error(f"{self.log_prefix}Request details - Messages: {len(api_messages)}, Approx tokens: {approx_tokens:,}")
|
||||
@@ -3459,7 +3808,7 @@ class AIAgent:
|
||||
self._persist_session(messages, conversation_history)
|
||||
self.clear_interrupt()
|
||||
return {
|
||||
"final_response": "Operation interrupted.",
|
||||
"final_response": f"Operation interrupted: retrying API call after error (retry {retry_count}/{max_retries}).",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
@@ -3471,6 +3820,14 @@ class AIAgent:
|
||||
if interrupted:
|
||||
break
|
||||
|
||||
# Guard: if all retries exhausted without a successful response
|
||||
# (e.g. repeated context-length errors that exhausted retry_count),
|
||||
# the `response` variable is still None. Break out cleanly.
|
||||
if response is None:
|
||||
print(f"{self.log_prefix}❌ All API retries exhausted with no successful response.")
|
||||
self._persist_session(messages, conversation_history)
|
||||
break
|
||||
|
||||
try:
|
||||
if self.api_mode == "codex_responses":
|
||||
assistant_message, finish_reason = self._normalize_codex_response(response)
|
||||
@@ -3687,6 +4044,13 @@ class AIAgent:
|
||||
self._log_msg_to_db(assistant_msg)
|
||||
|
||||
self._execute_tool_calls(assistant_message, messages, effective_task_id)
|
||||
|
||||
# Refund the iteration if the ONLY tool(s) called were
|
||||
# execute_code (programmatic tool calling). These are
|
||||
# cheap RPC-style calls that shouldn't eat the budget.
|
||||
_tc_names = {tc.function.name for tc in assistant_message.tool_calls}
|
||||
if _tc_names == {"execute_code"}:
|
||||
self.iteration_budget.refund()
|
||||
|
||||
if self.compression_enabled and self.context_compressor.should_compress():
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
@@ -3707,13 +4071,33 @@ class AIAgent:
|
||||
|
||||
# Check if response only has think block with no actual content after it
|
||||
if not self._has_content_after_think_block(final_response):
|
||||
# Track retries for empty-after-think responses
|
||||
# If the previous turn already delivered real content alongside
|
||||
# tool calls (e.g. "You're welcome!" + memory save), the model
|
||||
# has nothing more to say. Use the earlier content immediately
|
||||
# instead of wasting API calls on retries that won't help.
|
||||
fallback = getattr(self, '_last_content_with_tools', None)
|
||||
if fallback:
|
||||
logger.debug("Empty follow-up after tool calls — using prior turn content as final response")
|
||||
self._last_content_with_tools = None
|
||||
self._empty_content_retries = 0
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
msg = messages[i]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
tool_names = []
|
||||
for tc in msg["tool_calls"]:
|
||||
fn = tc.get("function", {})
|
||||
tool_names.append(fn.get("name", "unknown"))
|
||||
msg["content"] = f"Calling the {', '.join(tool_names)} tool{'s' if len(tool_names) > 1 else ''}..."
|
||||
break
|
||||
final_response = self._strip_think_blocks(fallback).strip()
|
||||
break
|
||||
|
||||
# No fallback available — this is a genuine empty response.
|
||||
# Retry in case the model just had a bad generation.
|
||||
if not hasattr(self, '_empty_content_retries'):
|
||||
self._empty_content_retries = 0
|
||||
self._empty_content_retries += 1
|
||||
|
||||
# Show the reasoning/thinking content so the user can see
|
||||
# what the model was thinking even though content is empty
|
||||
reasoning_text = self._extract_reasoning(assistant_message)
|
||||
print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it")
|
||||
if reasoning_text:
|
||||
@@ -3869,7 +4253,12 @@ class AIAgent:
|
||||
final_response = f"I apologize, but I encountered repeated errors: {error_msg}"
|
||||
break
|
||||
|
||||
if api_call_count >= self.max_iterations and final_response is None:
|
||||
if final_response is None and (
|
||||
api_call_count >= self.max_iterations
|
||||
or self.iteration_budget.remaining <= 0
|
||||
):
|
||||
if self.iteration_budget.remaining <= 0 and not self.quiet_mode:
|
||||
print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.used}/{self.iteration_budget.max_total} used, including subagents)")
|
||||
final_response = self._handle_max_iterations(messages, api_call_count)
|
||||
|
||||
# Determine if conversation completed successfully
|
||||
@@ -3940,7 +4329,7 @@ def main(
|
||||
|
||||
Args:
|
||||
query (str): Natural language query for the agent. Defaults to Python 3.13 example.
|
||||
model (str): Model name to use (OpenRouter format: provider/model). Defaults to anthropic/claude-sonnet-4-20250514.
|
||||
model (str): Model name to use (OpenRouter format: provider/model). Defaults to anthropic/claude-sonnet-4.6.
|
||||
api_key (str): API key for authentication. Uses OPENROUTER_API_KEY env var if not provided.
|
||||
base_url (str): Base URL for the model API. Defaults to https://openrouter.ai/api/v1
|
||||
max_turns (int): Maximum number of API call iterations. Defaults to 10.
|
||||
|
||||
@@ -492,9 +492,23 @@ install_system_packages() {
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
elif [ -e /dev/tty ]; then
|
||||
# Non-interactive (e.g. curl | bash) but a terminal is available.
|
||||
# Read the prompt from /dev/tty (same approach the setup wizard uses).
|
||||
echo ""
|
||||
log_info "Installing ${description} requires sudo."
|
||||
read -p "Install? [Y/n] " -n 1 -r < /dev/tty
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
if sudo DEBIAN_FRONTEND=noninteractive NEEDRESTART_MODE=a $install_cmd < /dev/tty; then
|
||||
[ "$need_ripgrep" = true ] && HAS_RIPGREP=true && log_success "ripgrep installed"
|
||||
[ "$need_ffmpeg" = true ] && HAS_FFMPEG=true && log_success "ffmpeg installed"
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
else
|
||||
log_warn "Non-interactive mode: cannot prompt for sudo password"
|
||||
log_info "Install missing packages manually: sudo $install_cmd"
|
||||
log_warn "Non-interactive mode and no terminal available — cannot install system packages"
|
||||
log_info "Install manually after setup completes: sudo $install_cmd"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
@@ -829,6 +843,33 @@ install_node_deps() {
|
||||
log_warn "npm install failed (browser tools may not work)"
|
||||
}
|
||||
log_success "Node.js dependencies installed"
|
||||
|
||||
# Install Playwright browser + system dependencies.
|
||||
# Playwright's install-deps only supports apt/dnf/zypper natively.
|
||||
# For Arch/Manjaro we install the system libs via pacman first.
|
||||
log_info "Installing browser engine (Playwright Chromium)..."
|
||||
case "$DISTRO" in
|
||||
arch|manjaro)
|
||||
if command -v pacman &> /dev/null; then
|
||||
log_info "Arch/Manjaro detected — installing Chromium system dependencies via pacman..."
|
||||
if command -v sudo &> /dev/null && sudo -n true 2>/dev/null; then
|
||||
sudo NEEDRESTART_MODE=a pacman -S --noconfirm --needed \
|
||||
nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib >/dev/null 2>&1 || true
|
||||
elif [ "$(id -u)" -eq 0 ]; then
|
||||
pacman -S --noconfirm --needed \
|
||||
nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib >/dev/null 2>&1 || true
|
||||
else
|
||||
log_warn "Cannot install browser deps without sudo. Run manually:"
|
||||
log_warn " sudo pacman -S nss atk at-spi2-core cups libdrm libxkbcommon mesa pango cairo alsa-lib"
|
||||
fi
|
||||
fi
|
||||
cd "$INSTALL_DIR" && npx playwright install chromium 2>/dev/null || true
|
||||
;;
|
||||
*)
|
||||
cd "$INSTALL_DIR" && npx playwright install --with-deps chromium 2>/dev/null || true
|
||||
;;
|
||||
esac
|
||||
log_success "Browser engine installed"
|
||||
fi
|
||||
|
||||
# Install WhatsApp bridge dependencies
|
||||
|
||||
3
skills/apple/DESCRIPTION.md
Normal file
3
skills/apple/DESCRIPTION.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
description: Apple/macOS-specific skills — iMessage, Reminders, Notes, FindMy, and macOS automation. These skills only load on macOS systems.
|
||||
---
|
||||
88
skills/apple/apple-notes/SKILL.md
Normal file
88
skills/apple/apple-notes/SKILL.md
Normal file
@@ -0,0 +1,88 @@
|
||||
---
|
||||
name: apple-notes
|
||||
description: Manage Apple Notes via the memo CLI on macOS (create, view, search, edit).
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Notes, Apple, macOS, note-taking]
|
||||
related_skills: [obsidian]
|
||||
---
|
||||
|
||||
# Apple Notes
|
||||
|
||||
Use `memo` to manage Apple Notes directly from the terminal. Notes sync across all Apple devices via iCloud.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Notes.app
|
||||
- Install: `brew tap antoniorodr/memo && brew install antoniorodr/memo/memo`
|
||||
- Grant Automation access to Notes.app when prompted (System Settings → Privacy → Automation)
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks to create, view, or search Apple Notes
|
||||
- Saving information to Notes.app for cross-device access
|
||||
- Organizing notes into folders
|
||||
- Exporting notes to Markdown/HTML
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
- Obsidian vault management → use the `obsidian` skill
|
||||
- Bear Notes → separate app (not supported here)
|
||||
- Quick agent-only notes → use the `memory` tool instead
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### View Notes
|
||||
|
||||
```bash
|
||||
memo notes # List all notes
|
||||
memo notes -f "Folder Name" # Filter by folder
|
||||
memo notes -s "query" # Search notes (fuzzy)
|
||||
```
|
||||
|
||||
### Create Notes
|
||||
|
||||
```bash
|
||||
memo notes -a # Interactive editor
|
||||
memo notes -a "Note Title" # Quick add with title
|
||||
```
|
||||
|
||||
### Edit Notes
|
||||
|
||||
```bash
|
||||
memo notes -e # Interactive selection to edit
|
||||
```
|
||||
|
||||
### Delete Notes
|
||||
|
||||
```bash
|
||||
memo notes -d # Interactive selection to delete
|
||||
```
|
||||
|
||||
### Move Notes
|
||||
|
||||
```bash
|
||||
memo notes -m # Move note to folder (interactive)
|
||||
```
|
||||
|
||||
### Export Notes
|
||||
|
||||
```bash
|
||||
memo notes -ex # Export to HTML/Markdown
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- Cannot edit notes containing images or attachments
|
||||
- Interactive prompts require terminal access (use pty=true if needed)
|
||||
- macOS only — requires Apple Notes.app
|
||||
|
||||
## Rules
|
||||
|
||||
1. Prefer Apple Notes when user wants cross-device sync (iPhone/iPad/Mac)
|
||||
2. Use the `memory` tool for agent-internal notes that don't need to sync
|
||||
3. Use the `obsidian` skill for Markdown-native knowledge management
|
||||
96
skills/apple/apple-reminders/SKILL.md
Normal file
96
skills/apple/apple-reminders/SKILL.md
Normal file
@@ -0,0 +1,96 @@
|
||||
---
|
||||
name: apple-reminders
|
||||
description: Manage Apple Reminders via remindctl CLI (list, add, complete, delete).
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Reminders, tasks, todo, macOS, Apple]
|
||||
---
|
||||
|
||||
# Apple Reminders
|
||||
|
||||
Use `remindctl` to manage Apple Reminders directly from the terminal. Tasks sync across all Apple devices via iCloud.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Reminders.app
|
||||
- Install: `brew install steipete/tap/remindctl`
|
||||
- Grant Reminders permission when prompted
|
||||
- Check: `remindctl status` / Request: `remindctl authorize`
|
||||
|
||||
## When to Use
|
||||
|
||||
- User mentions "reminder" or "Reminders app"
|
||||
- Creating personal to-dos with due dates that sync to iOS
|
||||
- Managing Apple Reminders lists
|
||||
- User wants tasks to appear on their iPhone/iPad
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
- Scheduling agent alerts → use the cronjob tool instead
|
||||
- Calendar events → use Apple Calendar or Google Calendar
|
||||
- Project task management → use GitHub Issues, Notion, etc.
|
||||
- If user says "remind me" but means an agent alert → clarify first
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### View Reminders
|
||||
|
||||
```bash
|
||||
remindctl # Today's reminders
|
||||
remindctl today # Today
|
||||
remindctl tomorrow # Tomorrow
|
||||
remindctl week # This week
|
||||
remindctl overdue # Past due
|
||||
remindctl all # Everything
|
||||
remindctl 2026-01-04 # Specific date
|
||||
```
|
||||
|
||||
### Manage Lists
|
||||
|
||||
```bash
|
||||
remindctl list # List all lists
|
||||
remindctl list Work # Show specific list
|
||||
remindctl list Projects --create # Create list
|
||||
remindctl list Work --delete # Delete list
|
||||
```
|
||||
|
||||
### Create Reminders
|
||||
|
||||
```bash
|
||||
remindctl add "Buy milk"
|
||||
remindctl add --title "Call mom" --list Personal --due tomorrow
|
||||
remindctl add --title "Meeting prep" --due "2026-02-15 09:00"
|
||||
```
|
||||
|
||||
### Complete / Delete
|
||||
|
||||
```bash
|
||||
remindctl complete 1 2 3 # Complete by ID
|
||||
remindctl delete 4A83 --force # Delete by ID
|
||||
```
|
||||
|
||||
### Output Formats
|
||||
|
||||
```bash
|
||||
remindctl today --json # JSON for scripting
|
||||
remindctl today --plain # TSV format
|
||||
remindctl today --quiet # Counts only
|
||||
```
|
||||
|
||||
## Date Formats
|
||||
|
||||
Accepted by `--due` and date filters:
|
||||
- `today`, `tomorrow`, `yesterday`
|
||||
- `YYYY-MM-DD`
|
||||
- `YYYY-MM-DD HH:mm`
|
||||
- ISO 8601 (`2026-01-04T12:34:56Z`)
|
||||
|
||||
## Rules
|
||||
|
||||
1. When user says "remind me", clarify: Apple Reminders (syncs to phone) vs agent cronjob alert
|
||||
2. Always confirm reminder content and due date before creating
|
||||
3. Use `--json` for programmatic parsing
|
||||
131
skills/apple/findmy/SKILL.md
Normal file
131
skills/apple/findmy/SKILL.md
Normal file
@@ -0,0 +1,131 @@
|
||||
---
|
||||
name: findmy
|
||||
description: Track Apple devices and AirTags via FindMy.app on macOS using AppleScript and screen capture.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [FindMy, AirTag, location, tracking, macOS, Apple]
|
||||
---
|
||||
|
||||
# Find My (Apple)
|
||||
|
||||
Track Apple devices and AirTags via the FindMy.app on macOS. Since Apple doesn't
|
||||
provide a CLI for FindMy, this skill uses AppleScript to open the app and
|
||||
screen capture to read device locations.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Find My app and iCloud signed in
|
||||
- Devices/AirTags already registered in Find My
|
||||
- Screen Recording permission for terminal (System Settings → Privacy → Screen Recording)
|
||||
- **Optional but recommended**: Install `peekaboo` for better UI automation:
|
||||
`brew install steipete/tap/peekaboo`
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks "where is my [device/cat/keys/bag]?"
|
||||
- Tracking AirTag locations
|
||||
- Checking device locations (iPhone, iPad, Mac, AirPods)
|
||||
- Monitoring pet or item movement over time (AirTag patrol routes)
|
||||
|
||||
## Method 1: AppleScript + Screenshot (Basic)
|
||||
|
||||
### Open FindMy and Navigate
|
||||
|
||||
```bash
|
||||
# Open Find My app
|
||||
osascript -e 'tell application "FindMy" to activate'
|
||||
|
||||
# Wait for it to load
|
||||
sleep 3
|
||||
|
||||
# Take a screenshot of the Find My window
|
||||
screencapture -w -o /tmp/findmy.png
|
||||
```
|
||||
|
||||
Then use `vision_analyze` to read the screenshot:
|
||||
```
|
||||
vision_analyze(image_url="/tmp/findmy.png", question="What devices/items are shown and what are their locations?")
|
||||
```
|
||||
|
||||
### Switch Between Tabs
|
||||
|
||||
```bash
|
||||
# Switch to Devices tab
|
||||
osascript -e '
|
||||
tell application "System Events"
|
||||
tell process "FindMy"
|
||||
click button "Devices" of toolbar 1 of window 1
|
||||
end tell
|
||||
end tell'
|
||||
|
||||
# Switch to Items tab (AirTags)
|
||||
osascript -e '
|
||||
tell application "System Events"
|
||||
tell process "FindMy"
|
||||
click button "Items" of toolbar 1 of window 1
|
||||
end tell
|
||||
end tell'
|
||||
```
|
||||
|
||||
## Method 2: Peekaboo UI Automation (Recommended)
|
||||
|
||||
If `peekaboo` is installed, use it for more reliable UI interaction:
|
||||
|
||||
```bash
|
||||
# Open Find My
|
||||
osascript -e 'tell application "FindMy" to activate'
|
||||
sleep 3
|
||||
|
||||
# Capture and annotate the UI
|
||||
peekaboo see --app "FindMy" --annotate --path /tmp/findmy-ui.png
|
||||
|
||||
# Click on a specific device/item by element ID
|
||||
peekaboo click --on B3 --app "FindMy"
|
||||
|
||||
# Capture the detail view
|
||||
peekaboo image --app "FindMy" --path /tmp/findmy-detail.png
|
||||
```
|
||||
|
||||
Then analyze with vision:
|
||||
```
|
||||
vision_analyze(image_url="/tmp/findmy-detail.png", question="What is the location shown for this device/item? Include address and coordinates if visible.")
|
||||
```
|
||||
|
||||
## Workflow: Track AirTag Location Over Time
|
||||
|
||||
For monitoring an AirTag (e.g., tracking a cat's patrol route):
|
||||
|
||||
```bash
|
||||
# 1. Open FindMy to Items tab
|
||||
osascript -e 'tell application "FindMy" to activate'
|
||||
sleep 3
|
||||
|
||||
# 2. Click on the AirTag item (stay on page — AirTag only updates when page is open)
|
||||
|
||||
# 3. Periodically capture location
|
||||
while true; do
|
||||
screencapture -w -o /tmp/findmy-$(date +%H%M%S).png
|
||||
sleep 300 # Every 5 minutes
|
||||
done
|
||||
```
|
||||
|
||||
Analyze each screenshot with vision to extract coordinates, then compile a route.
|
||||
|
||||
## Limitations
|
||||
|
||||
- FindMy has **no CLI or API** — must use UI automation
|
||||
- AirTags only update location while the FindMy page is actively displayed
|
||||
- Location accuracy depends on nearby Apple devices in the FindMy network
|
||||
- Screen Recording permission required for screenshots
|
||||
- AppleScript UI automation may break across macOS versions
|
||||
|
||||
## Rules
|
||||
|
||||
1. Keep FindMy app in the foreground when tracking AirTags (updates stop when minimized)
|
||||
2. Use `vision_analyze` to read screenshot content — don't try to parse pixels
|
||||
3. For ongoing tracking, use a cronjob to periodically capture and log locations
|
||||
4. Respect privacy — only track devices/items the user owns
|
||||
100
skills/apple/imessage/SKILL.md
Normal file
100
skills/apple/imessage/SKILL.md
Normal file
@@ -0,0 +1,100 @@
|
||||
---
|
||||
name: imessage
|
||||
description: Send and receive iMessages/SMS via the imsg CLI on macOS.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent
|
||||
license: MIT
|
||||
platforms: [macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [iMessage, SMS, messaging, macOS, Apple]
|
||||
---
|
||||
|
||||
# iMessage
|
||||
|
||||
Use `imsg` to read and send iMessage/SMS via macOS Messages.app.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **macOS** with Messages.app signed in
|
||||
- Install: `brew install steipete/tap/imsg`
|
||||
- Grant Full Disk Access for terminal (System Settings → Privacy → Full Disk Access)
|
||||
- Grant Automation permission for Messages.app when prompted
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks to send an iMessage or text message
|
||||
- Reading iMessage conversation history
|
||||
- Checking recent Messages.app chats
|
||||
- Sending to phone numbers or Apple IDs
|
||||
|
||||
## When NOT to Use
|
||||
|
||||
- Telegram/Discord/Slack/WhatsApp messages → use the appropriate gateway channel
|
||||
- Group chat management (adding/removing members) → not supported
|
||||
- Bulk/mass messaging → always confirm with user first
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### List Chats
|
||||
|
||||
```bash
|
||||
imsg chats --limit 10 --json
|
||||
```
|
||||
|
||||
### View History
|
||||
|
||||
```bash
|
||||
# By chat ID
|
||||
imsg history --chat-id 1 --limit 20 --json
|
||||
|
||||
# With attachments info
|
||||
imsg history --chat-id 1 --limit 20 --attachments --json
|
||||
```
|
||||
|
||||
### Send Messages
|
||||
|
||||
```bash
|
||||
# Text only
|
||||
imsg send --to "+14155551212" --text "Hello!"
|
||||
|
||||
# With attachment
|
||||
imsg send --to "+14155551212" --text "Check this out" --file /path/to/image.jpg
|
||||
|
||||
# Force iMessage or SMS
|
||||
imsg send --to "+14155551212" --text "Hi" --service imessage
|
||||
imsg send --to "+14155551212" --text "Hi" --service sms
|
||||
```
|
||||
|
||||
### Watch for New Messages
|
||||
|
||||
```bash
|
||||
imsg watch --chat-id 1 --attachments
|
||||
```
|
||||
|
||||
## Service Options
|
||||
|
||||
- `--service imessage` — Force iMessage (requires recipient has iMessage)
|
||||
- `--service sms` — Force SMS (green bubble)
|
||||
- `--service auto` — Let Messages.app decide (default)
|
||||
|
||||
## Rules
|
||||
|
||||
1. **Always confirm recipient and message content** before sending
|
||||
2. **Never send to unknown numbers** without explicit user approval
|
||||
3. **Verify file paths** exist before attaching
|
||||
4. **Don't spam** — rate-limit yourself
|
||||
|
||||
## Example Workflow
|
||||
|
||||
User: "Text mom that I'll be late"
|
||||
|
||||
```bash
|
||||
# 1. Find mom's chat
|
||||
imsg chats --limit 20 --json | jq '.[] | select(.displayName | contains("Mom"))'
|
||||
|
||||
# 2. Confirm with user: "Found Mom at +1555123456. Send 'I'll be late' via iMessage?"
|
||||
|
||||
# 3. Send after confirmation
|
||||
imsg send --to "+1555123456" --text "I'll be late"
|
||||
```
|
||||
@@ -1,7 +1,7 @@
|
||||
---
|
||||
name: ascii-art
|
||||
description: Generate ASCII art using pyfiglet (571 fonts), cowsay, boxes, toilet, image-to-ascii conversion, and search curated art from emojicombos.com and asciiart.eu (11,000+ artworks). Falls back to LLM-generated art.
|
||||
version: 3.1.0
|
||||
description: Generate ASCII art using pyfiglet (571 fonts), cowsay, boxes, toilet, image-to-ascii, remote APIs (asciified, ascii.co.uk), and LLM fallback. No API keys required.
|
||||
version: 4.0.0
|
||||
author: 0xbyt4, Hermes Agent
|
||||
license: MIT
|
||||
dependencies: []
|
||||
@@ -14,9 +14,9 @@ metadata:
|
||||
|
||||
# ASCII Art Skill
|
||||
|
||||
Multiple tools for different ASCII art needs. All tools are local CLI programs — no API keys required.
|
||||
Multiple tools for different ASCII art needs. All tools are local CLI programs or free REST APIs — no API keys required.
|
||||
|
||||
## Tool 1: Text Banners (pyfiglet)
|
||||
## Tool 1: Text Banners (pyfiglet — local)
|
||||
|
||||
Render text as large ASCII art banners. 571 built-in fonts.
|
||||
|
||||
@@ -53,7 +53,35 @@ python3 -m pyfiglet --list_fonts # List all 571 fonts
|
||||
- Short text (1-8 chars) works best with detailed fonts like `doom` or `block`
|
||||
- Long text works better with compact fonts like `small` or `mini`
|
||||
|
||||
## Tool 2: Cowsay (Message Art)
|
||||
## Tool 2: Text Banners (asciified API — remote, no install)
|
||||
|
||||
Free REST API that converts text to ASCII art. 250+ FIGlet fonts. Returns plain text directly — no parsing needed. Use this when pyfiglet is not installed or as a quick alternative.
|
||||
|
||||
### Usage (via terminal curl)
|
||||
|
||||
```bash
|
||||
# Basic text banner (default font)
|
||||
curl -s "https://asciified.thelicato.io/api/v2/ascii?text=Hello+World"
|
||||
|
||||
# With a specific font
|
||||
curl -s "https://asciified.thelicato.io/api/v2/ascii?text=Hello&font=Slant"
|
||||
curl -s "https://asciified.thelicato.io/api/v2/ascii?text=Hello&font=Doom"
|
||||
curl -s "https://asciified.thelicato.io/api/v2/ascii?text=Hello&font=Star+Wars"
|
||||
curl -s "https://asciified.thelicato.io/api/v2/ascii?text=Hello&font=3-D"
|
||||
curl -s "https://asciified.thelicato.io/api/v2/ascii?text=Hello&font=Banner3"
|
||||
|
||||
# List all available fonts (returns JSON array)
|
||||
curl -s "https://asciified.thelicato.io/api/v2/fonts"
|
||||
```
|
||||
|
||||
### Tips
|
||||
|
||||
- URL-encode spaces as `+` in the text parameter
|
||||
- The response is plain text ASCII art — no JSON wrapping, ready to display
|
||||
- Font names are case-sensitive; use the fonts endpoint to get exact names
|
||||
- Works from any terminal with curl — no Python or pip needed
|
||||
|
||||
## Tool 3: Cowsay (Message Art)
|
||||
|
||||
Classic tool that wraps text in a speech bubble with an ASCII character.
|
||||
|
||||
@@ -97,7 +125,7 @@ cowsay -e "OO" "Msg" # Custom eyes
|
||||
cowsay -T "U " "Msg" # Custom tongue
|
||||
```
|
||||
|
||||
## Tool 3: Boxes (Decorative Borders)
|
||||
## Tool 4: Boxes (Decorative Borders)
|
||||
|
||||
Draw decorative ASCII art borders/frames around any text. 70+ built-in designs.
|
||||
|
||||
@@ -124,13 +152,15 @@ echo "Hello World" | boxes -a c # Center text
|
||||
boxes -l # List all 70+ designs
|
||||
```
|
||||
|
||||
### Combine with pyfiglet
|
||||
### Combine with pyfiglet or asciified
|
||||
|
||||
```bash
|
||||
python3 -m pyfiglet "HERMES" -f slant | boxes -d stone
|
||||
# Or without pyfiglet installed:
|
||||
curl -s "https://asciified.thelicato.io/api/v2/ascii?text=HERMES&font=Slant" | boxes -d stone
|
||||
```
|
||||
|
||||
## Tool 4: TOIlet (Colored Text Art)
|
||||
## Tool 5: TOIlet (Colored Text Art)
|
||||
|
||||
Like pyfiglet but with ANSI color effects and visual filters. Great for terminal eye candy.
|
||||
|
||||
@@ -160,14 +190,14 @@ toilet -F list # List available filters
|
||||
|
||||
**Note**: toilet outputs ANSI escape codes for colors — works in terminals but may not render in all contexts (e.g., plain text files, some chat platforms).
|
||||
|
||||
## Tool 5: Image to ASCII Art
|
||||
## Tool 6: Image to ASCII Art
|
||||
|
||||
Convert images (PNG, JPEG, GIF, WEBP) to ASCII art.
|
||||
|
||||
### Option A: ascii-image-converter (recommended, modern)
|
||||
|
||||
```bash
|
||||
# Install via snap or Go
|
||||
# Install
|
||||
sudo snap install ascii-image-converter
|
||||
# OR: go install github.com/TheZoraiz/ascii-image-converter@latest
|
||||
```
|
||||
@@ -190,63 +220,77 @@ jp2a --width=80 image.jpg
|
||||
jp2a --colors image.jpg # Colorized
|
||||
```
|
||||
|
||||
## Tool 6: Search Pre-Made ASCII Art (Web APIs)
|
||||
## Tool 7: Search Pre-Made ASCII Art
|
||||
|
||||
Search curated ASCII art databases via `web_extract`. No API keys needed.
|
||||
Search curated ASCII art from the web. Use `terminal` with `curl`.
|
||||
|
||||
### Source A: emojicombos.com (recommended first)
|
||||
### Source A: ascii.co.uk (recommended for pre-made art)
|
||||
|
||||
Huge collection of ASCII art, dot art, kaomoji, and emoji combos. Modern, meme-aware, user-submitted content. Great for pop culture, animals, objects, aesthetics.
|
||||
Large collection of classic ASCII art organized by subject. Art is inside HTML `<pre>` tags. Fetch the page with curl, then extract art with a small Python snippet.
|
||||
|
||||
**URL pattern:** `https://emojicombos.com/{term}-ascii-art`
|
||||
**URL pattern:** `https://ascii.co.uk/art/{subject}`
|
||||
|
||||
**Step 1 — Fetch the page:**
|
||||
|
||||
```bash
|
||||
curl -s 'https://ascii.co.uk/art/cat' -o /tmp/ascii_art.html
|
||||
```
|
||||
web_extract(urls=["https://emojicombos.com/cat-ascii-art"])
|
||||
web_extract(urls=["https://emojicombos.com/rocket-ascii-art"])
|
||||
web_extract(urls=["https://emojicombos.com/dragon-ascii-art"])
|
||||
web_extract(urls=["https://emojicombos.com/skull-ascii-art"])
|
||||
web_extract(urls=["https://emojicombos.com/heart-ascii-art"])
|
||||
|
||||
**Step 2 — Extract art from pre tags:**
|
||||
|
||||
```python
|
||||
import re, html
|
||||
with open('/tmp/ascii_art.html') as f:
|
||||
text = f.read()
|
||||
arts = re.findall(r'<pre[^>]*>(.*?)</pre>', text, re.DOTALL)
|
||||
for art in arts:
|
||||
clean = re.sub(r'<[^>]+>', '', art)
|
||||
clean = html.unescape(clean).strip()
|
||||
if len(clean) > 30:
|
||||
print(clean)
|
||||
print('\n---\n')
|
||||
```
|
||||
|
||||
**Available subjects** (use as URL path):
|
||||
- Animals: `cat`, `dog`, `horse`, `bird`, `fish`, `dragon`, `snake`, `rabbit`, `elephant`, `dolphin`, `butterfly`, `owl`, `wolf`, `bear`, `penguin`, `turtle`
|
||||
- Objects: `car`, `ship`, `airplane`, `rocket`, `guitar`, `computer`, `coffee`, `beer`, `cake`, `house`, `castle`, `sword`, `crown`, `key`
|
||||
- Nature: `tree`, `flower`, `sun`, `moon`, `star`, `mountain`, `ocean`, `rainbow`
|
||||
- Characters: `skull`, `robot`, `angel`, `wizard`, `pirate`, `ninja`, `alien`
|
||||
- Holidays: `christmas`, `halloween`, `valentine`
|
||||
|
||||
**Tips:**
|
||||
- Use hyphenated search terms: `hello-kitty-ascii-art`, `star-wars-ascii-art`
|
||||
- Returns a mix of classic ASCII, Braille dot art, and kaomoji — pick the best style for the user
|
||||
- Includes modern meme art and pop culture references
|
||||
- Great for kaomoji/emoticons too: `https://emojicombos.com/cat-kaomoji`
|
||||
- Preserve artist signatures/initials — important etiquette
|
||||
- Multiple art pieces per page — pick the best one for the user
|
||||
- Works reliably via curl, no JavaScript needed
|
||||
|
||||
### Source B: asciiart.eu (classic archive)
|
||||
### Source B: GitHub Octocat API (fun easter egg)
|
||||
|
||||
11,000+ classic ASCII artworks organized by category. More traditional/vintage art.
|
||||
|
||||
**Browse by category** (use as URL paths):
|
||||
- `animals/cats`, `animals/dogs`, `animals/birds`, `animals/horses`
|
||||
- `animals/dolphins`, `animals/dragons`, `animals/insects`
|
||||
- `space/rockets`, `space/stars`, `space/planets`
|
||||
- `vehicles/cars`, `vehicles/ships`, `vehicles/airplanes`
|
||||
- `food-and-drinks/coffee`, `food-and-drinks/beer`
|
||||
- `computers/computers`, `electronics/robots`
|
||||
- `art-and-design/hearts`, `art-and-design/skulls`
|
||||
- `plants/flowers`, `plants/trees`
|
||||
- `mythology/dragons`, `mythology/unicorns`
|
||||
|
||||
```
|
||||
web_extract(urls=["https://www.asciiart.eu/animals/cats"])
|
||||
web_extract(urls=["https://www.asciiart.eu/search?q=rocket"])
|
||||
```
|
||||
|
||||
**Tips:**
|
||||
- Preserve artist initials/signatures (e.g., `jgs`, `hjw`) — this is important etiquette
|
||||
- Better for classic/vintage ASCII art style
|
||||
|
||||
### Source C: GitHub Octocat API (fun easter egg)
|
||||
|
||||
Returns a random GitHub Octocat with a quote. No auth needed.
|
||||
Returns a random GitHub Octocat with a wise quote. No auth needed.
|
||||
|
||||
```bash
|
||||
curl -s https://api.github.com/octocat
|
||||
```
|
||||
|
||||
## Tool 7: LLM-Generated Custom Art (Fallback)
|
||||
## Tool 8: Fun ASCII Utilities (via curl)
|
||||
|
||||
These free services return ASCII art directly — great for fun extras.
|
||||
|
||||
### QR Codes as ASCII Art
|
||||
|
||||
```bash
|
||||
curl -s "qrenco.de/Hello+World"
|
||||
curl -s "qrenco.de/https://example.com"
|
||||
```
|
||||
|
||||
### Weather as ASCII Art
|
||||
|
||||
```bash
|
||||
curl -s "wttr.in/London" # Full weather report with ASCII graphics
|
||||
curl -s "wttr.in/Moon" # Moon phase in ASCII art
|
||||
curl -s "v2.wttr.in/London" # Detailed version
|
||||
```
|
||||
|
||||
## Tool 9: LLM-Generated Custom Art (Fallback)
|
||||
|
||||
When tools above don't have what's needed, generate ASCII art directly using these Unicode characters:
|
||||
|
||||
@@ -264,28 +308,14 @@ When tools above don't have what's needed, generate ASCII art directly using the
|
||||
- Max height: 15 lines for banners, 25 for scenes
|
||||
- Monospace only: output must render correctly in fixed-width fonts
|
||||
|
||||
## Fun Extras
|
||||
|
||||
### Star Wars in ASCII (via telnet)
|
||||
|
||||
```bash
|
||||
telnet towel.blinkenlights.nl
|
||||
```
|
||||
|
||||
### Useful Resources
|
||||
|
||||
- [asciiart.eu](https://www.asciiart.eu/) — 11,000+ artworks, searchable
|
||||
- [patorjk.com/software/taag](http://patorjk.com/software/taag/) — Web-based text-to-ASCII with font preview
|
||||
- [asciiflow.com](http://asciiflow.com/) — Interactive ASCII diagram editor (browser)
|
||||
- [awesome-ascii-art](https://github.com/moul/awesome-ascii-art) — Curated resource list
|
||||
|
||||
## Decision Flow
|
||||
|
||||
1. **Text as a banner** → pyfiglet (or toilet for colored output)
|
||||
1. **Text as a banner** → pyfiglet if installed, otherwise asciified API via curl
|
||||
2. **Wrap a message in fun character art** → cowsay
|
||||
3. **Add decorative border/frame** → boxes (can combine with pyfiglet)
|
||||
4. **Art of a thing** (cat, rocket, dragon) → emojicombos.com first, then asciiart.eu
|
||||
5. **Kaomoji / emoticons** → emojicombos.com (`{term}-kaomoji`)
|
||||
6. **Convert an image to ASCII** → ascii-image-converter or jp2a
|
||||
7. **Something custom/creative** → LLM generation with Unicode palette
|
||||
8. **Any tool not installed** → install it, or fall back to next option
|
||||
3. **Add decorative border/frame** → boxes (can combine with pyfiglet/asciified)
|
||||
4. **Art of a specific thing** (cat, rocket, dragon) → ascii.co.uk via curl + parsing
|
||||
5. **Convert an image to ASCII** → ascii-image-converter or jp2a
|
||||
6. **QR code** → qrenco.de via curl
|
||||
7. **Weather/moon art** → wttr.in via curl
|
||||
8. **Something custom/creative** → LLM generation with Unicode palette
|
||||
9. **Any tool not installed** → install it, or fall back to next option
|
||||
|
||||
162
skills/dogfood/SKILL.md
Normal file
162
skills/dogfood/SKILL.md
Normal file
@@ -0,0 +1,162 @@
|
||||
---
|
||||
name: dogfood
|
||||
description: Systematic exploratory QA testing of web applications — find bugs, capture evidence, and generate structured reports
|
||||
version: 1.0.0
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [qa, testing, browser, web, dogfood]
|
||||
related_skills: []
|
||||
---
|
||||
|
||||
# Dogfood: Systematic Web Application QA Testing
|
||||
|
||||
## Overview
|
||||
|
||||
This skill guides you through systematic exploratory QA testing of web applications using the browser toolset. You will navigate the application, interact with elements, capture evidence of issues, and produce a structured bug report.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Browser toolset must be available (`browser_navigate`, `browser_snapshot`, `browser_click`, `browser_type`, `browser_vision`, `browser_console`, `browser_scroll`, `browser_back`, `browser_press`, `browser_close`)
|
||||
- A target URL and testing scope from the user
|
||||
|
||||
## Inputs
|
||||
|
||||
The user provides:
|
||||
1. **Target URL** — the entry point for testing
|
||||
2. **Scope** — what areas/features to focus on (or "full site" for comprehensive testing)
|
||||
3. **Output directory** (optional) — where to save screenshots and the report (default: `./dogfood-output`)
|
||||
|
||||
## Workflow
|
||||
|
||||
Follow this 5-phase systematic workflow:
|
||||
|
||||
### Phase 1: Plan
|
||||
|
||||
1. Create the output directory structure:
|
||||
```
|
||||
{output_dir}/
|
||||
├── screenshots/ # Evidence screenshots
|
||||
└── report.md # Final report (generated in Phase 5)
|
||||
```
|
||||
2. Identify the testing scope based on user input.
|
||||
3. Build a rough sitemap by planning which pages and features to test:
|
||||
- Landing/home page
|
||||
- Navigation links (header, footer, sidebar)
|
||||
- Key user flows (sign up, login, search, checkout, etc.)
|
||||
- Forms and interactive elements
|
||||
- Edge cases (empty states, error pages, 404s)
|
||||
|
||||
### Phase 2: Explore
|
||||
|
||||
For each page or feature in your plan:
|
||||
|
||||
1. **Navigate** to the page:
|
||||
```
|
||||
browser_navigate(url="https://example.com/page")
|
||||
```
|
||||
|
||||
2. **Take a snapshot** to understand the DOM structure:
|
||||
```
|
||||
browser_snapshot()
|
||||
```
|
||||
|
||||
3. **Check the console** for JavaScript errors:
|
||||
```
|
||||
browser_console(clear=true)
|
||||
```
|
||||
Do this after every navigation and after every significant interaction. Silent JS errors are high-value findings.
|
||||
|
||||
4. **Take an annotated screenshot** to visually assess the page and identify interactive elements:
|
||||
```
|
||||
browser_vision(question="Describe the page layout, identify any visual issues, broken elements, or accessibility concerns", annotate=true)
|
||||
```
|
||||
The `annotate=true` flag overlays numbered `[N]` labels on interactive elements. Each `[N]` maps to ref `@eN` for subsequent browser commands.
|
||||
|
||||
5. **Test interactive elements** systematically:
|
||||
- Click buttons and links: `browser_click(ref="@eN")`
|
||||
- Fill forms: `browser_type(ref="@eN", text="test input")`
|
||||
- Test keyboard navigation: `browser_press(key="Tab")`, `browser_press(key="Enter")`
|
||||
- Scroll through content: `browser_scroll(direction="down")`
|
||||
- Test form validation with invalid inputs
|
||||
- Test empty submissions
|
||||
|
||||
6. **After each interaction**, check for:
|
||||
- Console errors: `browser_console()`
|
||||
- Visual changes: `browser_vision(question="What changed after the interaction?")`
|
||||
- Expected vs actual behavior
|
||||
|
||||
### Phase 3: Collect Evidence
|
||||
|
||||
For every issue found:
|
||||
|
||||
1. **Take a screenshot** showing the issue:
|
||||
```
|
||||
browser_vision(question="Capture and describe the issue visible on this page", annotate=false)
|
||||
```
|
||||
Save the `screenshot_path` from the response — you will reference it in the report.
|
||||
|
||||
2. **Record the details**:
|
||||
- URL where the issue occurs
|
||||
- Steps to reproduce
|
||||
- Expected behavior
|
||||
- Actual behavior
|
||||
- Console errors (if any)
|
||||
- Screenshot path
|
||||
|
||||
3. **Classify the issue** using the issue taxonomy (see `references/issue-taxonomy.md`):
|
||||
- Severity: Critical / High / Medium / Low
|
||||
- Category: Functional / Visual / Accessibility / Console / UX / Content
|
||||
|
||||
### Phase 4: Categorize
|
||||
|
||||
1. Review all collected issues.
|
||||
2. De-duplicate — merge issues that are the same bug manifesting in different places.
|
||||
3. Assign final severity and category to each issue.
|
||||
4. Sort by severity (Critical first, then High, Medium, Low).
|
||||
5. Count issues by severity and category for the executive summary.
|
||||
|
||||
### Phase 5: Report
|
||||
|
||||
Generate the final report using the template at `templates/dogfood-report-template.md`.
|
||||
|
||||
The report must include:
|
||||
1. **Executive summary** with total issue count, breakdown by severity, and testing scope
|
||||
2. **Per-issue sections** with:
|
||||
- Issue number and title
|
||||
- Severity and category badges
|
||||
- URL where observed
|
||||
- Description of the issue
|
||||
- Steps to reproduce
|
||||
- Expected vs actual behavior
|
||||
- Screenshot references (use `MEDIA:<screenshot_path>` for inline images)
|
||||
- Console errors if relevant
|
||||
3. **Summary table** of all issues
|
||||
4. **Testing notes** — what was tested, what was not, any blockers
|
||||
|
||||
Save the report to `{output_dir}/report.md`.
|
||||
|
||||
## Tools Reference
|
||||
|
||||
| Tool | Purpose |
|
||||
|------|---------|
|
||||
| `browser_navigate` | Go to a URL |
|
||||
| `browser_snapshot` | Get DOM text snapshot (accessibility tree) |
|
||||
| `browser_click` | Click an element by ref (`@eN`) or text |
|
||||
| `browser_type` | Type into an input field |
|
||||
| `browser_scroll` | Scroll up/down on the page |
|
||||
| `browser_back` | Go back in browser history |
|
||||
| `browser_press` | Press a keyboard key |
|
||||
| `browser_vision` | Screenshot + AI analysis; use `annotate=true` for element labels |
|
||||
| `browser_console` | Get JS console output and errors |
|
||||
| `browser_close` | Close the browser session |
|
||||
|
||||
## Tips
|
||||
|
||||
- **Always check `browser_console()` after navigating and after significant interactions.** Silent JS errors are among the most valuable findings.
|
||||
- **Use `annotate=true` with `browser_vision`** when you need to reason about interactive element positions or when the snapshot refs are unclear.
|
||||
- **Test with both valid and invalid inputs** — form validation bugs are common.
|
||||
- **Scroll through long pages** — content below the fold may have rendering issues.
|
||||
- **Test navigation flows** — click through multi-step processes end-to-end.
|
||||
- **Check responsive behavior** by noting any layout issues visible in screenshots.
|
||||
- **Don't forget edge cases**: empty states, very long text, special characters, rapid clicking.
|
||||
- When reporting screenshots to the user, include `MEDIA:<screenshot_path>` so they can see the evidence inline.
|
||||
109
skills/dogfood/references/issue-taxonomy.md
Normal file
109
skills/dogfood/references/issue-taxonomy.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# Issue Taxonomy
|
||||
|
||||
Use this taxonomy to classify issues found during dogfood QA testing.
|
||||
|
||||
## Severity Levels
|
||||
|
||||
### Critical
|
||||
The issue makes a core feature completely unusable or causes data loss.
|
||||
|
||||
**Examples:**
|
||||
- Application crashes or shows a blank white page
|
||||
- Form submission silently loses user data
|
||||
- Authentication is completely broken (can't log in at all)
|
||||
- Payment flow fails and charges the user without completing the order
|
||||
- Security vulnerability (e.g., XSS, exposed credentials in console)
|
||||
|
||||
### High
|
||||
The issue significantly impairs functionality but a workaround may exist.
|
||||
|
||||
**Examples:**
|
||||
- A key button does nothing when clicked (but refreshing fixes it)
|
||||
- Search returns no results for valid queries
|
||||
- Form validation rejects valid input
|
||||
- Page loads but critical content is missing or garbled
|
||||
- Navigation link leads to a 404 or wrong page
|
||||
- Uncaught JavaScript exceptions in the console on core pages
|
||||
|
||||
### Medium
|
||||
The issue is noticeable and affects user experience but doesn't block core functionality.
|
||||
|
||||
**Examples:**
|
||||
- Layout is misaligned or overlapping on certain screen sections
|
||||
- Images fail to load (broken image icons)
|
||||
- Slow performance (visible loading delays > 3 seconds)
|
||||
- Form field lacks proper validation feedback (no error message on bad input)
|
||||
- Console warnings that suggest deprecated or misconfigured features
|
||||
- Inconsistent styling between similar pages
|
||||
|
||||
### Low
|
||||
Minor polish issues that don't affect functionality.
|
||||
|
||||
**Examples:**
|
||||
- Typos or grammatical errors in text content
|
||||
- Minor spacing or alignment inconsistencies
|
||||
- Placeholder text left in production ("Lorem ipsum")
|
||||
- Favicon missing
|
||||
- Console info/debug messages that shouldn't be in production
|
||||
- Subtle color contrast issues that don't fail WCAG requirements
|
||||
|
||||
## Categories
|
||||
|
||||
### Functional
|
||||
Issues where features don't work as expected.
|
||||
|
||||
- Buttons/links that don't respond
|
||||
- Forms that don't submit or submit incorrectly
|
||||
- Broken user flows (can't complete a multi-step process)
|
||||
- Incorrect data displayed
|
||||
- Features that work partially
|
||||
|
||||
### Visual
|
||||
Issues with the visual presentation of the page.
|
||||
|
||||
- Layout problems (overlapping elements, broken grids)
|
||||
- Broken images or missing media
|
||||
- Styling inconsistencies
|
||||
- Responsive design failures
|
||||
- Z-index issues (elements hidden behind others)
|
||||
- Text overflow or truncation
|
||||
|
||||
### Accessibility
|
||||
Issues that prevent or hinder access for users with disabilities.
|
||||
|
||||
- Missing alt text on meaningful images
|
||||
- Poor color contrast (fails WCAG AA)
|
||||
- Elements not reachable via keyboard navigation
|
||||
- Missing form labels or ARIA attributes
|
||||
- Focus indicators missing or unclear
|
||||
- Screen reader incompatible content
|
||||
|
||||
### Console
|
||||
Issues detected through JavaScript console output.
|
||||
|
||||
- Uncaught exceptions and unhandled promise rejections
|
||||
- Failed network requests (4xx, 5xx errors in console)
|
||||
- Deprecation warnings
|
||||
- CORS errors
|
||||
- Mixed content warnings (HTTP resources on HTTPS page)
|
||||
- Excessive console.log output left from development
|
||||
|
||||
### UX (User Experience)
|
||||
Issues where functionality works but the experience is poor.
|
||||
|
||||
- Confusing navigation or information architecture
|
||||
- Missing loading indicators (user doesn't know something is happening)
|
||||
- No feedback after user actions (e.g., button click with no visible result)
|
||||
- Inconsistent interaction patterns
|
||||
- Missing confirmation dialogs for destructive actions
|
||||
- Poor error messages that don't help the user recover
|
||||
|
||||
### Content
|
||||
Issues with the text, media, or information on the page.
|
||||
|
||||
- Typos and grammatical errors
|
||||
- Placeholder/dummy content in production
|
||||
- Outdated information
|
||||
- Missing content (empty sections)
|
||||
- Broken or dead links to external resources
|
||||
- Incorrect or misleading labels
|
||||
86
skills/dogfood/templates/dogfood-report-template.md
Normal file
86
skills/dogfood/templates/dogfood-report-template.md
Normal file
@@ -0,0 +1,86 @@
|
||||
# Dogfood QA Report
|
||||
|
||||
**Target:** {target_url}
|
||||
**Date:** {date}
|
||||
**Scope:** {scope_description}
|
||||
**Tester:** Hermes Agent (automated exploratory QA)
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
| Severity | Count |
|
||||
|----------|-------|
|
||||
| 🔴 Critical | {critical_count} |
|
||||
| 🟠 High | {high_count} |
|
||||
| 🟡 Medium | {medium_count} |
|
||||
| 🔵 Low | {low_count} |
|
||||
| **Total** | **{total_count}** |
|
||||
|
||||
**Overall Assessment:** {one_sentence_assessment}
|
||||
|
||||
---
|
||||
|
||||
## Issues
|
||||
|
||||
<!-- Repeat this section for each issue found, sorted by severity (Critical first) -->
|
||||
|
||||
### Issue #{issue_number}: {issue_title}
|
||||
|
||||
| Field | Value |
|
||||
|-------|-------|
|
||||
| **Severity** | {severity} |
|
||||
| **Category** | {category} |
|
||||
| **URL** | {url_where_found} |
|
||||
|
||||
**Description:**
|
||||
{detailed_description_of_the_issue}
|
||||
|
||||
**Steps to Reproduce:**
|
||||
1. {step_1}
|
||||
2. {step_2}
|
||||
3. {step_3}
|
||||
|
||||
**Expected Behavior:**
|
||||
{what_should_happen}
|
||||
|
||||
**Actual Behavior:**
|
||||
{what_actually_happens}
|
||||
|
||||
**Screenshot:**
|
||||
MEDIA:{screenshot_path}
|
||||
|
||||
**Console Errors** (if applicable):
|
||||
```
|
||||
{console_error_output}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
<!-- End of per-issue section -->
|
||||
|
||||
## Issues Summary Table
|
||||
|
||||
| # | Title | Severity | Category | URL |
|
||||
|---|-------|----------|----------|-----|
|
||||
| {n} | {title} | {severity} | {category} | {url} |
|
||||
|
||||
## Testing Coverage
|
||||
|
||||
### Pages Tested
|
||||
- {list_of_pages_visited}
|
||||
|
||||
### Features Tested
|
||||
- {list_of_features_exercised}
|
||||
|
||||
### Not Tested / Out of Scope
|
||||
- {areas_not_covered_and_why}
|
||||
|
||||
### Blockers
|
||||
- {any_issues_that_prevented_testing_certain_areas}
|
||||
|
||||
---
|
||||
|
||||
## Notes
|
||||
|
||||
{any_additional_observations_or_recommendations}
|
||||
76
skills/market-data/polymarket/SKILL.md
Normal file
76
skills/market-data/polymarket/SKILL.md
Normal file
@@ -0,0 +1,76 @@
|
||||
---
|
||||
name: polymarket
|
||||
description: Query Polymarket prediction market data — search markets, get prices, orderbooks, and price history. Read-only via public REST APIs, no API key needed.
|
||||
version: 1.0.0
|
||||
author: Hermes Agent + Teknium
|
||||
tags: [polymarket, prediction-markets, market-data, trading]
|
||||
---
|
||||
|
||||
# Polymarket — Prediction Market Data
|
||||
|
||||
Query prediction market data from Polymarket using their public REST APIs.
|
||||
All endpoints are read-only and require zero authentication.
|
||||
|
||||
See `references/api-endpoints.md` for the full endpoint reference with curl examples.
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks about prediction markets, betting odds, or event probabilities
|
||||
- User wants to know "what are the odds of X happening?"
|
||||
- User asks about Polymarket specifically
|
||||
- User wants market prices, orderbook data, or price history
|
||||
- User asks to monitor or track prediction market movements
|
||||
|
||||
## Key Concepts
|
||||
|
||||
- **Events** contain one or more **Markets** (1:many relationship)
|
||||
- **Markets** are binary outcomes with Yes/No prices between 0.00 and 1.00
|
||||
- Prices ARE probabilities: price 0.65 means the market thinks 65% likely
|
||||
- `outcomePrices` field: JSON-encoded array like `["0.80", "0.20"]`
|
||||
- `clobTokenIds` field: JSON-encoded array of two token IDs [Yes, No] for price/book queries
|
||||
- `conditionId` field: hex string used for price history queries
|
||||
- Volume is in USDC (US dollars)
|
||||
|
||||
## Three Public APIs
|
||||
|
||||
1. **Gamma API** at `gamma-api.polymarket.com` — Discovery, search, browsing
|
||||
2. **CLOB API** at `clob.polymarket.com` — Real-time prices, orderbooks, history
|
||||
3. **Data API** at `data-api.polymarket.com` — Trades, open interest
|
||||
|
||||
## Typical Workflow
|
||||
|
||||
When a user asks about prediction market odds:
|
||||
|
||||
1. **Search** using the Gamma API public-search endpoint with their query
|
||||
2. **Parse** the response — extract events and their nested markets
|
||||
3. **Present** market question, current prices as percentages, and volume
|
||||
4. **Deep dive** if asked — use clobTokenIds for orderbook, conditionId for history
|
||||
|
||||
## Presenting Results
|
||||
|
||||
Format prices as percentages for readability:
|
||||
- outcomePrices `["0.652", "0.348"]` becomes "Yes: 65.2%, No: 34.8%"
|
||||
- Always show the market question and probability
|
||||
- Include volume when available
|
||||
|
||||
Example: `"Will X happen?" — 65.2% Yes ($1.2M volume)`
|
||||
|
||||
## Parsing Double-Encoded Fields
|
||||
|
||||
The Gamma API returns `outcomePrices`, `outcomes`, and `clobTokenIds` as JSON strings
|
||||
inside JSON responses (double-encoded). When processing with Python, parse them with
|
||||
`json.loads(market['outcomePrices'])` to get the actual array.
|
||||
|
||||
## Rate Limits
|
||||
|
||||
Generous — unlikely to hit for normal usage:
|
||||
- Gamma: 4,000 requests per 10 seconds (general)
|
||||
- CLOB: 9,000 requests per 10 seconds (general)
|
||||
- Data: 1,000 requests per 10 seconds (general)
|
||||
|
||||
## Limitations
|
||||
|
||||
- This skill is read-only — it does not support placing trades
|
||||
- Trading requires wallet-based crypto authentication (EIP-712 signatures)
|
||||
- Some new markets may have empty price history
|
||||
- Geographic restrictions apply to trading but read-only data is globally accessible
|
||||
220
skills/market-data/polymarket/references/api-endpoints.md
Normal file
220
skills/market-data/polymarket/references/api-endpoints.md
Normal file
@@ -0,0 +1,220 @@
|
||||
# Polymarket API Endpoints Reference
|
||||
|
||||
All endpoints are public REST (GET), return JSON, and need no authentication.
|
||||
|
||||
## Gamma API — gamma-api.polymarket.com
|
||||
|
||||
### Search Markets
|
||||
|
||||
```
|
||||
GET /public-search?q=QUERY
|
||||
```
|
||||
|
||||
Response structure:
|
||||
```json
|
||||
{
|
||||
"events": [
|
||||
{
|
||||
"id": "12345",
|
||||
"title": "Event title",
|
||||
"slug": "event-slug",
|
||||
"volume": 1234567.89,
|
||||
"markets": [
|
||||
{
|
||||
"question": "Will X happen?",
|
||||
"outcomePrices": "[\"0.65\", \"0.35\"]",
|
||||
"outcomes": "[\"Yes\", \"No\"]",
|
||||
"clobTokenIds": "[\"TOKEN_YES\", \"TOKEN_NO\"]",
|
||||
"conditionId": "0xabc...",
|
||||
"volume": 500000
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"pagination": {"hasMore": true, "totalResults": 100}
|
||||
}
|
||||
```
|
||||
|
||||
### List Events
|
||||
|
||||
```
|
||||
GET /events?limit=N&active=true&closed=false&order=volume&ascending=false
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `limit` — max results (default varies)
|
||||
- `offset` — pagination offset
|
||||
- `active` — true/false
|
||||
- `closed` — true/false
|
||||
- `order` — sort field: `volume`, `createdAt`, `updatedAt`
|
||||
- `ascending` — true/false
|
||||
- `tag` — filter by tag slug
|
||||
- `slug` — get specific event by slug
|
||||
|
||||
Response: array of event objects. Each event includes a `markets` array.
|
||||
|
||||
Event fields: `id`, `title`, `slug`, `description`, `volume`, `liquidity`,
|
||||
`openInterest`, `active`, `closed`, `category`, `startDate`, `endDate`,
|
||||
`markets` (array of market objects).
|
||||
|
||||
### List Markets
|
||||
|
||||
```
|
||||
GET /markets?limit=N&active=true&closed=false&order=volume&ascending=false
|
||||
```
|
||||
|
||||
Same filter parameters as events, plus:
|
||||
- `slug` — get specific market by slug
|
||||
|
||||
Market fields: `id`, `question`, `conditionId`, `slug`, `description`,
|
||||
`outcomes`, `outcomePrices`, `volume`, `liquidity`, `active`, `closed`,
|
||||
`marketType`, `clobTokenIds`, `endDate`, `category`, `createdAt`.
|
||||
|
||||
Important: `outcomePrices`, `outcomes`, and `clobTokenIds` are JSON strings
|
||||
(double-encoded). Parse with json.loads() in Python.
|
||||
|
||||
### List Tags
|
||||
|
||||
```
|
||||
GET /tags
|
||||
```
|
||||
|
||||
Returns array of tag objects: `id`, `label`, `slug`.
|
||||
Use the `slug` value when filtering events/markets by tag.
|
||||
|
||||
---
|
||||
|
||||
## CLOB API — clob.polymarket.com
|
||||
|
||||
All CLOB price endpoints use `token_id` from the market's `clobTokenIds` field.
|
||||
Index 0 = Yes outcome, Index 1 = No outcome.
|
||||
|
||||
### Current Price
|
||||
|
||||
```
|
||||
GET /price?token_id=TOKEN_ID&side=buy
|
||||
```
|
||||
|
||||
Response: `{"price": "0.650"}`
|
||||
|
||||
The `side` parameter: `buy` or `sell`.
|
||||
|
||||
### Midpoint Price
|
||||
|
||||
```
|
||||
GET /midpoint?token_id=TOKEN_ID
|
||||
```
|
||||
|
||||
Response: `{"mid": "0.645"}`
|
||||
|
||||
### Spread
|
||||
|
||||
```
|
||||
GET /spread?token_id=TOKEN_ID
|
||||
```
|
||||
|
||||
Response: `{"spread": "0.02"}`
|
||||
|
||||
### Orderbook
|
||||
|
||||
```
|
||||
GET /book?token_id=TOKEN_ID
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"market": "condition_id",
|
||||
"asset_id": "token_id",
|
||||
"bids": [{"price": "0.64", "size": "500"}, ...],
|
||||
"asks": [{"price": "0.66", "size": "300"}, ...],
|
||||
"min_order_size": "5",
|
||||
"tick_size": "0.01",
|
||||
"last_trade_price": "0.65"
|
||||
}
|
||||
```
|
||||
|
||||
Bids and asks are sorted by price. Size is in shares (USDC-denominated).
|
||||
|
||||
### Price History
|
||||
|
||||
```
|
||||
GET /prices-history?market=CONDITION_ID&interval=INTERVAL&fidelity=N
|
||||
```
|
||||
|
||||
Parameters:
|
||||
- `market` — the conditionId (hex string with 0x prefix)
|
||||
- `interval` — time range: `all`, `1d`, `1w`, `1m`, `3m`, `6m`, `1y`
|
||||
- `fidelity` — number of data points to return
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"history": [
|
||||
{"t": 1709000000, "p": "0.55"},
|
||||
{"t": 1709100000, "p": "0.58"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
`t` is Unix timestamp, `p` is price (probability).
|
||||
|
||||
Note: Very new markets may return empty history.
|
||||
|
||||
### CLOB Markets List
|
||||
|
||||
```
|
||||
GET /markets?limit=N
|
||||
```
|
||||
|
||||
Response:
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"condition_id": "0xabc...",
|
||||
"question": "Will X?",
|
||||
"tokens": [
|
||||
{"token_id": "123...", "outcome": "Yes", "price": 0.65},
|
||||
{"token_id": "456...", "outcome": "No", "price": 0.35}
|
||||
],
|
||||
"active": true,
|
||||
"closed": false
|
||||
}
|
||||
],
|
||||
"next_cursor": "cursor_string",
|
||||
"limit": 100,
|
||||
"count": 1000
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data API — data-api.polymarket.com
|
||||
|
||||
### Recent Trades
|
||||
|
||||
```
|
||||
GET /trades?limit=N
|
||||
GET /trades?market=CONDITION_ID&limit=N
|
||||
```
|
||||
|
||||
Trade fields: `side` (BUY/SELL), `size`, `price`, `timestamp`,
|
||||
`title`, `slug`, `outcome`, `transactionHash`, `conditionId`.
|
||||
|
||||
### Open Interest
|
||||
|
||||
```
|
||||
GET /oi?market=CONDITION_ID
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Field Cross-Reference
|
||||
|
||||
To go from a Gamma market to CLOB data:
|
||||
|
||||
1. Get market from Gamma: has `clobTokenIds` and `conditionId`
|
||||
2. Parse `clobTokenIds` (JSON string): `["YES_TOKEN", "NO_TOKEN"]`
|
||||
3. Use YES_TOKEN with `/price`, `/book`, `/midpoint`, `/spread`
|
||||
4. Use `conditionId` with `/prices-history` and Data API endpoints
|
||||
284
skills/market-data/polymarket/scripts/polymarket.py
Normal file
284
skills/market-data/polymarket/scripts/polymarket.py
Normal file
@@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Polymarket CLI helper — query prediction market data.
|
||||
|
||||
Usage:
|
||||
python3 polymarket.py search "bitcoin"
|
||||
python3 polymarket.py trending [--limit 10]
|
||||
python3 polymarket.py market <slug>
|
||||
python3 polymarket.py event <slug>
|
||||
python3 polymarket.py price <token_id>
|
||||
python3 polymarket.py book <token_id>
|
||||
python3 polymarket.py history <condition_id> [--interval all] [--fidelity 50]
|
||||
python3 polymarket.py trades [--limit 10] [--market CONDITION_ID]
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
import urllib.error
|
||||
|
||||
GAMMA = "https://gamma-api.polymarket.com"
|
||||
CLOB = "https://clob.polymarket.com"
|
||||
DATA = "https://data-api.polymarket.com"
|
||||
|
||||
|
||||
def _get(url: str) -> dict | list:
|
||||
"""GET request, return parsed JSON."""
|
||||
req = urllib.request.Request(url, headers={"User-Agent": "hermes-agent/1.0"})
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
return json.loads(resp.read().decode())
|
||||
except urllib.error.HTTPError as e:
|
||||
print(f"HTTP {e.code}: {e.reason}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except urllib.error.URLError as e:
|
||||
print(f"Connection error: {e.reason}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _parse_json_field(val):
|
||||
"""Parse double-encoded JSON fields (outcomePrices, outcomes, clobTokenIds)."""
|
||||
if isinstance(val, str):
|
||||
try:
|
||||
return json.loads(val)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return val
|
||||
return val
|
||||
|
||||
|
||||
def _fmt_pct(price_str: str) -> str:
|
||||
"""Format price string as percentage."""
|
||||
try:
|
||||
return f"{float(price_str) * 100:.1f}%"
|
||||
except (ValueError, TypeError):
|
||||
return price_str
|
||||
|
||||
|
||||
def _fmt_volume(vol) -> str:
|
||||
"""Format volume as human-readable."""
|
||||
try:
|
||||
v = float(vol)
|
||||
if v >= 1_000_000:
|
||||
return f"${v / 1_000_000:.1f}M"
|
||||
if v >= 1_000:
|
||||
return f"${v / 1_000:.1f}K"
|
||||
return f"${v:.0f}"
|
||||
except (ValueError, TypeError):
|
||||
return str(vol)
|
||||
|
||||
|
||||
def _print_market(m: dict, indent: str = ""):
|
||||
"""Print a market summary."""
|
||||
question = m.get("question", "?")
|
||||
prices = _parse_json_field(m.get("outcomePrices", "[]"))
|
||||
outcomes = _parse_json_field(m.get("outcomes", "[]"))
|
||||
vol = _fmt_volume(m.get("volume", 0))
|
||||
closed = m.get("closed", False)
|
||||
status = " [CLOSED]" if closed else ""
|
||||
|
||||
if isinstance(prices, list) and len(prices) >= 2:
|
||||
outcome_labels = outcomes if isinstance(outcomes, list) else ["Yes", "No"]
|
||||
price_str = " / ".join(
|
||||
f"{outcome_labels[i]}: {_fmt_pct(prices[i])}"
|
||||
for i in range(min(len(prices), len(outcome_labels)))
|
||||
)
|
||||
print(f"{indent}{question}{status}")
|
||||
print(f"{indent} {price_str} | Volume: {vol}")
|
||||
else:
|
||||
print(f"{indent}{question}{status} | Volume: {vol}")
|
||||
|
||||
slug = m.get("slug", "")
|
||||
if slug:
|
||||
print(f"{indent} slug: {slug}")
|
||||
|
||||
|
||||
def cmd_search(query: str):
|
||||
"""Search for markets."""
|
||||
q = urllib.parse.quote(query)
|
||||
data = _get(f"{GAMMA}/public-search?q={q}")
|
||||
events = data.get("events", [])
|
||||
total = data.get("pagination", {}).get("totalResults", len(events))
|
||||
print(f"Found {total} results for \"{query}\":\n")
|
||||
for evt in events[:10]:
|
||||
print(f"=== {evt['title']} ===")
|
||||
print(f" Volume: {_fmt_volume(evt.get('volume', 0))} | slug: {evt.get('slug', '')}")
|
||||
markets = evt.get("markets", [])
|
||||
for m in markets[:5]:
|
||||
_print_market(m, indent=" ")
|
||||
if len(markets) > 5:
|
||||
print(f" ... and {len(markets) - 5} more markets")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_trending(limit: int = 10):
|
||||
"""Show trending events by volume."""
|
||||
events = _get(f"{GAMMA}/events?limit={limit}&active=true&closed=false&order=volume&ascending=false")
|
||||
print(f"Top {len(events)} trending events:\n")
|
||||
for i, evt in enumerate(events, 1):
|
||||
print(f"{i}. {evt['title']}")
|
||||
print(f" Volume: {_fmt_volume(evt.get('volume', 0))} | Markets: {len(evt.get('markets', []))}")
|
||||
print(f" slug: {evt.get('slug', '')}")
|
||||
markets = evt.get("markets", [])
|
||||
for m in markets[:3]:
|
||||
_print_market(m, indent=" ")
|
||||
if len(markets) > 3:
|
||||
print(f" ... and {len(markets) - 3} more markets")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_market(slug: str):
|
||||
"""Get market details by slug."""
|
||||
markets = _get(f"{GAMMA}/markets?slug={urllib.parse.quote(slug)}")
|
||||
if not markets:
|
||||
print(f"No market found with slug: {slug}")
|
||||
return
|
||||
m = markets[0]
|
||||
print(f"Market: {m.get('question', '?')}")
|
||||
print(f"Status: {'CLOSED' if m.get('closed') else 'ACTIVE'}")
|
||||
_print_market(m)
|
||||
print(f"\n conditionId: {m.get('conditionId', 'N/A')}")
|
||||
tokens = _parse_json_field(m.get("clobTokenIds", "[]"))
|
||||
if isinstance(tokens, list):
|
||||
outcomes = _parse_json_field(m.get("outcomes", "[]"))
|
||||
for i, t in enumerate(tokens):
|
||||
label = outcomes[i] if isinstance(outcomes, list) and i < len(outcomes) else f"Outcome {i}"
|
||||
print(f" token ({label}): {t}")
|
||||
desc = m.get("description", "")
|
||||
if desc:
|
||||
print(f"\n Description: {desc[:500]}")
|
||||
|
||||
|
||||
def cmd_event(slug: str):
|
||||
"""Get event details by slug."""
|
||||
events = _get(f"{GAMMA}/events?slug={urllib.parse.quote(slug)}")
|
||||
if not events:
|
||||
print(f"No event found with slug: {slug}")
|
||||
return
|
||||
evt = events[0]
|
||||
print(f"Event: {evt['title']}")
|
||||
print(f"Volume: {_fmt_volume(evt.get('volume', 0))}")
|
||||
print(f"Status: {'CLOSED' if evt.get('closed') else 'ACTIVE'}")
|
||||
print(f"Markets: {len(evt.get('markets', []))}\n")
|
||||
for m in evt.get("markets", []):
|
||||
_print_market(m, indent=" ")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_price(token_id: str):
|
||||
"""Get current price for a token."""
|
||||
buy = _get(f"{CLOB}/price?token_id={token_id}&side=buy")
|
||||
mid = _get(f"{CLOB}/midpoint?token_id={token_id}")
|
||||
spread = _get(f"{CLOB}/spread?token_id={token_id}")
|
||||
print(f"Token: {token_id[:30]}...")
|
||||
print(f" Buy price: {_fmt_pct(buy.get('price', '?'))}")
|
||||
print(f" Midpoint: {_fmt_pct(mid.get('mid', '?'))}")
|
||||
print(f" Spread: {spread.get('spread', '?')}")
|
||||
|
||||
|
||||
def cmd_book(token_id: str):
|
||||
"""Get orderbook for a token."""
|
||||
book = _get(f"{CLOB}/book?token_id={token_id}")
|
||||
bids = book.get("bids", [])
|
||||
asks = book.get("asks", [])
|
||||
last = book.get("last_trade_price", "?")
|
||||
print(f"Orderbook for {token_id[:30]}...")
|
||||
print(f"Last trade: {_fmt_pct(last)} | Tick size: {book.get('tick_size', '?')}")
|
||||
print(f"\n Top bids ({len(bids)} total):")
|
||||
# Show bids sorted by price descending (best bids first)
|
||||
sorted_bids = sorted(bids, key=lambda x: float(x.get("price", 0)), reverse=True)
|
||||
for b in sorted_bids[:10]:
|
||||
print(f" {_fmt_pct(b['price']):>7} | Size: {float(b['size']):>10.2f}")
|
||||
print(f"\n Top asks ({len(asks)} total):")
|
||||
sorted_asks = sorted(asks, key=lambda x: float(x.get("price", 0)))
|
||||
for a in sorted_asks[:10]:
|
||||
print(f" {_fmt_pct(a['price']):>7} | Size: {float(a['size']):>10.2f}")
|
||||
|
||||
|
||||
def cmd_history(condition_id: str, interval: str = "all", fidelity: int = 50):
|
||||
"""Get price history for a market."""
|
||||
data = _get(f"{CLOB}/prices-history?market={condition_id}&interval={interval}&fidelity={fidelity}")
|
||||
history = data.get("history", [])
|
||||
if not history:
|
||||
print("No price history available for this market.")
|
||||
return
|
||||
print(f"Price history ({len(history)} points, interval={interval}):\n")
|
||||
from datetime import datetime, timezone
|
||||
for pt in history:
|
||||
ts = datetime.fromtimestamp(pt["t"], tz=timezone.utc).strftime("%Y-%m-%d %H:%M")
|
||||
price = _fmt_pct(pt["p"])
|
||||
bar = "█" * int(float(pt["p"]) * 40)
|
||||
print(f" {ts} {price:>7} {bar}")
|
||||
|
||||
|
||||
def cmd_trades(limit: int = 10, market: str = None):
|
||||
"""Get recent trades."""
|
||||
url = f"{DATA}/trades?limit={limit}"
|
||||
if market:
|
||||
url += f"&market={market}"
|
||||
trades = _get(url)
|
||||
if not isinstance(trades, list):
|
||||
print(f"Unexpected response: {trades}")
|
||||
return
|
||||
print(f"Recent trades ({len(trades)}):\n")
|
||||
for t in trades:
|
||||
side = t.get("side", "?")
|
||||
price = _fmt_pct(t.get("price", "?"))
|
||||
size = t.get("size", "?")
|
||||
outcome = t.get("outcome", "?")
|
||||
title = t.get("title", "?")[:50]
|
||||
ts = t.get("timestamp", "")
|
||||
print(f" {side:4} {price:>7} x{float(size):>8.2f} [{outcome}] {title}")
|
||||
|
||||
|
||||
def main():
|
||||
args = sys.argv[1:]
|
||||
if not args or args[0] in ("-h", "--help", "help"):
|
||||
print(__doc__)
|
||||
return
|
||||
|
||||
cmd = args[0]
|
||||
|
||||
if cmd == "search" and len(args) >= 2:
|
||||
cmd_search(" ".join(args[1:]))
|
||||
elif cmd == "trending":
|
||||
limit = 10
|
||||
if "--limit" in args:
|
||||
idx = args.index("--limit")
|
||||
limit = int(args[idx + 1]) if idx + 1 < len(args) else 10
|
||||
cmd_trending(limit)
|
||||
elif cmd == "market" and len(args) >= 2:
|
||||
cmd_market(args[1])
|
||||
elif cmd == "event" and len(args) >= 2:
|
||||
cmd_event(args[1])
|
||||
elif cmd == "price" and len(args) >= 2:
|
||||
cmd_price(args[1])
|
||||
elif cmd == "book" and len(args) >= 2:
|
||||
cmd_book(args[1])
|
||||
elif cmd == "history" and len(args) >= 2:
|
||||
interval = "all"
|
||||
fidelity = 50
|
||||
if "--interval" in args:
|
||||
idx = args.index("--interval")
|
||||
interval = args[idx + 1] if idx + 1 < len(args) else "all"
|
||||
if "--fidelity" in args:
|
||||
idx = args.index("--fidelity")
|
||||
fidelity = int(args[idx + 1]) if idx + 1 < len(args) else 50
|
||||
cmd_history(args[1], interval, fidelity)
|
||||
elif cmd == "trades":
|
||||
limit = 10
|
||||
market = None
|
||||
if "--limit" in args:
|
||||
idx = args.index("--limit")
|
||||
limit = int(args[idx + 1]) if idx + 1 < len(args) else 10
|
||||
if "--market" in args:
|
||||
idx = args.index("--market")
|
||||
market = args[idx + 1] if idx + 1 < len(args) else None
|
||||
cmd_trades(limit, market)
|
||||
else:
|
||||
print(f"Unknown command: {cmd}")
|
||||
print(__doc__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
335
skills/mlops/accelerate/SKILL.md
Normal file
335
skills/mlops/accelerate/SKILL.md
Normal file
@@ -0,0 +1,335 @@
|
||||
---
|
||||
name: huggingface-accelerate
|
||||
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [accelerate, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
|
||||
|
||||
---
|
||||
|
||||
# HuggingFace Accelerate - Unified Distributed Training
|
||||
|
||||
## Quick start
|
||||
|
||||
Accelerate simplifies distributed training to 4 lines of code.
|
||||
|
||||
**Installation**:
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
**Convert PyTorch script** (4 lines):
|
||||
```python
|
||||
import torch
|
||||
+ from accelerate import Accelerator
|
||||
|
||||
+ accelerator = Accelerator()
|
||||
|
||||
model = torch.nn.Transformer()
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset)
|
||||
|
||||
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
- loss.backward()
|
||||
+ accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Run** (single command):
|
||||
```bash
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: From single GPU to multi-GPU
|
||||
|
||||
**Original script**:
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
|
||||
model = torch.nn.Linear(10, 2).to('cuda')
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
batch = batch.to('cuda')
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**With Accelerate** (4 lines added):
|
||||
```python
|
||||
# train.py
|
||||
import torch
|
||||
from accelerate import Accelerator # +1
|
||||
|
||||
accelerator = Accelerator() # +2
|
||||
|
||||
model = torch.nn.Linear(10, 2)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
|
||||
|
||||
for epoch in range(10):
|
||||
for batch in dataloader:
|
||||
# No .to('cuda') needed - automatic!
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch).mean()
|
||||
accelerator.backward(loss) # +4
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Configure** (interactive):
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
- Which machine? (single/multi GPU/TPU/CPU)
|
||||
- How many machines? (1)
|
||||
- Mixed precision? (no/fp16/bf16/fp8)
|
||||
- DeepSpeed? (no/yes)
|
||||
|
||||
**Launch** (works on any setup):
|
||||
```bash
|
||||
# Single GPU
|
||||
accelerate launch train.py
|
||||
|
||||
# Multi-GPU (8 GPUs)
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
train.py
|
||||
```
|
||||
|
||||
### Workflow 2: Mixed precision training
|
||||
|
||||
**Enable FP16/BF16**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# FP16 (with gradient scaling)
|
||||
accelerator = Accelerator(mixed_precision='fp16')
|
||||
|
||||
# BF16 (no scaling, more stable)
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# FP8 (H100+)
|
||||
accelerator = Accelerator(mixed_precision='fp8')
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
# Everything else is automatic!
|
||||
for batch in dataloader:
|
||||
with accelerator.autocast(): # Optional, done automatically
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
```
|
||||
|
||||
### Workflow 3: DeepSpeed ZeRO integration
|
||||
|
||||
**Enable DeepSpeed ZeRO-2**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
deepspeed_plugin={
|
||||
"zero_stage": 2, # ZeRO-2
|
||||
"offload_optimizer": False,
|
||||
"gradient_accumulation_steps": 4
|
||||
}
|
||||
)
|
||||
|
||||
# Same code as before!
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: DeepSpeed → ZeRO-2
|
||||
```
|
||||
|
||||
**deepspeed_config.json**:
|
||||
```json
|
||||
{
|
||||
"fp16": {"enabled": false},
|
||||
"bf16": {"enabled": true},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {"device": "cpu"},
|
||||
"allgather_bucket_size": 5e8,
|
||||
"reduce_bucket_size": 5e8
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Launch**:
|
||||
```bash
|
||||
accelerate launch --config_file deepspeed_config.json train.py
|
||||
```
|
||||
|
||||
### Workflow 4: FSDP (Fully Sharded Data Parallel)
|
||||
|
||||
**Enable FSDP**:
|
||||
```python
|
||||
from accelerate import Accelerator, FullyShardedDataParallelPlugin
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
|
||||
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
|
||||
cpu_offload=False
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
fsdp_plugin=fsdp_plugin
|
||||
)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
```
|
||||
|
||||
**Or via config**:
|
||||
```bash
|
||||
accelerate config
|
||||
# Select: FSDP → Full Shard → No CPU Offload
|
||||
```
|
||||
|
||||
### Workflow 5: Gradient accumulation
|
||||
|
||||
**Accumulate gradients**:
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=4)
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation
|
||||
optimizer.zero_grad()
|
||||
loss = model(batch)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Accelerate when**:
|
||||
- Want simplest distributed training
|
||||
- Need single script for any hardware
|
||||
- Use HuggingFace ecosystem
|
||||
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
|
||||
- Need quick prototyping
|
||||
|
||||
**Key advantages**:
|
||||
- **4 lines**: Minimal code changes
|
||||
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
|
||||
- **Automatic**: Device placement, mixed precision, sharding
|
||||
- **Interactive config**: No manual launcher setup
|
||||
- **Single launch**: Works everywhere
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **PyTorch Lightning**: Need callbacks, high-level abstractions
|
||||
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
|
||||
- **DeepSpeed**: Direct API control, advanced features
|
||||
- **Raw DDP**: Maximum control, minimal abstraction
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: Wrong device placement**
|
||||
|
||||
Don't manually move to device:
|
||||
```python
|
||||
# WRONG
|
||||
batch = batch.to('cuda')
|
||||
|
||||
# CORRECT
|
||||
# Accelerate handles it automatically after prepare()
|
||||
```
|
||||
|
||||
**Issue: Gradient accumulation not working**
|
||||
|
||||
Use context manager:
|
||||
```python
|
||||
# CORRECT
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**Issue: Checkpointing in distributed**
|
||||
|
||||
Use accelerator methods:
|
||||
```python
|
||||
# Save only on main process
|
||||
if accelerator.is_main_process:
|
||||
accelerator.save_state('checkpoint/')
|
||||
|
||||
# Load on all processes
|
||||
accelerator.load_state('checkpoint/')
|
||||
```
|
||||
|
||||
**Issue: Different results with FSDP**
|
||||
|
||||
Ensure same random seed:
|
||||
```python
|
||||
from accelerate.utils import set_seed
|
||||
set_seed(42)
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
|
||||
|
||||
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
|
||||
|
||||
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **CPU**: Works (slow)
|
||||
- **Single GPU**: Works
|
||||
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
|
||||
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
|
||||
- **TPU**: Supported
|
||||
- **Apple MPS**: Supported
|
||||
|
||||
**Launcher requirements**:
|
||||
- **DDP**: `torch.distributed.run` (built-in)
|
||||
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
|
||||
- **FSDP**: PyTorch 1.12+ (built-in)
|
||||
- **Megatron**: Custom setup
|
||||
|
||||
## Resources
|
||||
|
||||
- Docs: https://huggingface.co/docs/accelerate
|
||||
- GitHub: https://github.com/huggingface/accelerate
|
||||
- Version: 1.11.0+
|
||||
- Tutorial: "Accelerate your scripts"
|
||||
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
|
||||
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries
|
||||
|
||||
|
||||
|
||||
453
skills/mlops/accelerate/references/custom-plugins.md
Normal file
453
skills/mlops/accelerate/references/custom-plugins.md
Normal file
@@ -0,0 +1,453 @@
|
||||
# Custom Plugins for Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
|
||||
|
||||
## Plugin Architecture
|
||||
|
||||
### Base Plugin Structure
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
"""Custom training plugin."""
|
||||
|
||||
# Plugin configuration
|
||||
param1: int = 1
|
||||
param2: str = "default"
|
||||
|
||||
def __post_init__(self):
|
||||
# Validation logic
|
||||
if self.param1 < 1:
|
||||
raise ValueError("param1 must be >= 1")
|
||||
```
|
||||
|
||||
### Using Custom Plugin
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
# Create plugin
|
||||
custom_plugin = CustomPlugin(param1=4, param2="value")
|
||||
|
||||
# Pass to Accelerator
|
||||
accelerator = Accelerator(
|
||||
custom_plugin=custom_plugin # Not a real parameter, example only
|
||||
)
|
||||
```
|
||||
|
||||
## Built-In Plugin Examples
|
||||
|
||||
### 1. GradScalerKwargs (FP16 Configuration)
|
||||
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
# Configure gradient scaler for FP16
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16, # Initial loss scale
|
||||
growth_factor=2.0, # Scale growth rate
|
||||
backoff_factor=0.5, # Scale backoff rate
|
||||
growth_interval=2000, # Steps between scale increases
|
||||
enabled=True # Enable scaler
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Fine-tune FP16 gradient scaling behavior
|
||||
|
||||
### 2. DistributedDataParallelKwargs
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
# Configure DDP behavior
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Gradient bucketing size
|
||||
find_unused_parameters=False, # Find unused params (slower)
|
||||
check_reduction=False, # Check gradient reduction
|
||||
gradient_as_bucket_view=True, # Memory optimization
|
||||
static_graph=False # Static computation graph
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
kwargs_handlers=[ddp_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Optimize DDP performance for specific models
|
||||
|
||||
### 3. FP8RecipeKwargs (H100 FP8)
|
||||
|
||||
```python
|
||||
from accelerate.utils import FP8RecipeKwargs
|
||||
|
||||
# Configure FP8 training (H100)
|
||||
fp8_recipe = FP8RecipeKwargs(
|
||||
backend="te", # TransformerEngine backend
|
||||
margin=0, # Scaling margin
|
||||
interval=1, # Scaling interval
|
||||
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
|
||||
amax_history_len=1024, # AMAX history length
|
||||
amax_compute_algo="max" # AMAX computation algorithm
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp8',
|
||||
kwargs_handlers=[fp8_recipe]
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Ultra-fast training on H100 GPUs
|
||||
|
||||
## Custom DeepSpeed Configuration
|
||||
|
||||
### ZeRO-3 with CPU Offload
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
# Custom DeepSpeed config
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3, # ZeRO-3
|
||||
offload_optimizer_device="cpu", # CPU offload optimizer
|
||||
offload_param_device="cpu", # CPU offload parameters
|
||||
zero3_init_flag=True, # ZeRO-3 initialization
|
||||
zero3_save_16bit_model=True, # Save FP16 weights
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### ZeRO-2 with NVMe Offload
|
||||
|
||||
```python
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=2,
|
||||
offload_optimizer_device="nvme", # NVMe offload
|
||||
offload_param_device="nvme",
|
||||
nvme_path="/local_nvme", # NVMe mount path
|
||||
)
|
||||
```
|
||||
|
||||
### Custom JSON Config
|
||||
|
||||
```python
|
||||
import json
|
||||
|
||||
# Load custom DeepSpeed config
|
||||
with open('deepspeed_config.json', 'r') as f:
|
||||
ds_config = json.load(f)
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
|
||||
|
||||
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
|
||||
```
|
||||
|
||||
**Example config** (`deepspeed_config.json`):
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 1.0,
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 1e9,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"stage3_prefetch_bucket_size": 5e8,
|
||||
"stage3_param_persistence_threshold": 1e6,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": true
|
||||
},
|
||||
"steps_per_print": 100,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
```
|
||||
|
||||
## Custom FSDP Configuration
|
||||
|
||||
### FSDP with Custom Auto-Wrap Policy
|
||||
|
||||
```python
|
||||
from accelerate.utils import FullyShardedDataParallelPlugin
|
||||
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
|
||||
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
|
||||
import functools
|
||||
|
||||
# Custom wrap policy (size-based)
|
||||
wrap_policy = functools.partial(
|
||||
size_based_auto_wrap_policy,
|
||||
min_num_params=1e6 # Wrap layers with 1M+ params
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
|
||||
mixed_precision_policy=None, # Use Accelerator's mixed precision
|
||||
auto_wrap_policy=wrap_policy, # Custom wrapping
|
||||
cpu_offload=False,
|
||||
ignored_modules=None, # Modules to not wrap
|
||||
state_dict_type="FULL_STATE_DICT", # Save format
|
||||
optim_state_dict_config=None,
|
||||
limit_all_gathers=False,
|
||||
use_orig_params=True, # Use original param shapes
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
fsdp_plugin=fsdp_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
### FSDP with Transformer Auto-Wrap
|
||||
|
||||
```python
|
||||
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
|
||||
|
||||
# Wrap at transformer block level
|
||||
wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
|
||||
)
|
||||
|
||||
fsdp_plugin = FullyShardedDataParallelPlugin(
|
||||
auto_wrap_policy=wrap_policy
|
||||
)
|
||||
```
|
||||
|
||||
## Creating Custom Training Strategy
|
||||
|
||||
### Example: Custom Gradient Accumulation
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
|
||||
class CustomGradientAccumulation:
|
||||
def __init__(self, steps=4, adaptive=False):
|
||||
self.steps = steps
|
||||
self.adaptive = adaptive
|
||||
self.current_step = 0
|
||||
|
||||
def should_sync(self, loss):
|
||||
"""Decide whether to sync gradients."""
|
||||
self.current_step += 1
|
||||
|
||||
# Adaptive: sync on high loss
|
||||
if self.adaptive and loss > threshold:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
# Regular: sync every N steps
|
||||
if self.current_step >= self.steps:
|
||||
self.current_step = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Usage
|
||||
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
|
||||
accelerator = Accelerator()
|
||||
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
|
||||
# Scale loss
|
||||
loss = loss / custom_accum.steps
|
||||
accelerator.backward(loss)
|
||||
|
||||
# Conditional sync
|
||||
if custom_accum.should_sync(loss.item()):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
### Example: Custom Mixed Precision
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class CustomMixedPrecision:
|
||||
"""Custom mixed precision with dynamic loss scaling."""
|
||||
|
||||
def __init__(self, init_scale=2**16, scale_window=2000):
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
init_scale=init_scale,
|
||||
growth_interval=scale_window
|
||||
)
|
||||
self.scale_history = []
|
||||
|
||||
def scale_loss(self, loss):
|
||||
"""Scale loss for backward."""
|
||||
return self.scaler.scale(loss)
|
||||
|
||||
def unscale_and_clip(self, optimizer, max_norm=1.0):
|
||||
"""Unscale gradients and clip."""
|
||||
self.scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
optimizer.param_groups[0]['params'],
|
||||
max_norm
|
||||
)
|
||||
|
||||
def step(self, optimizer):
|
||||
"""Optimizer step with scaler update."""
|
||||
scale_before = self.scaler.get_scale()
|
||||
self.scaler.step(optimizer)
|
||||
self.scaler.update()
|
||||
scale_after = self.scaler.get_scale()
|
||||
|
||||
# Track scale changes
|
||||
if scale_before != scale_after:
|
||||
self.scale_history.append(scale_after)
|
||||
|
||||
# Usage
|
||||
custom_mp = CustomMixedPrecision()
|
||||
|
||||
for batch in dataloader:
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16):
|
||||
loss = model(**batch).loss
|
||||
|
||||
scaled_loss = custom_mp.scale_loss(loss)
|
||||
scaled_loss.backward()
|
||||
|
||||
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
|
||||
custom_mp.step(optimizer)
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Advanced: Custom Distributed Backend
|
||||
|
||||
### Custom AllReduce Strategy
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
|
||||
class CustomAllReduce:
|
||||
"""Custom all-reduce with compression."""
|
||||
|
||||
def __init__(self, compression_ratio=0.1):
|
||||
self.compression_ratio = compression_ratio
|
||||
|
||||
def compress_gradients(self, tensor):
|
||||
"""Top-k gradient compression."""
|
||||
k = int(tensor.numel() * self.compression_ratio)
|
||||
values, indices = torch.topk(tensor.abs().view(-1), k)
|
||||
return values, indices
|
||||
|
||||
def all_reduce_compressed(self, tensor):
|
||||
"""All-reduce with gradient compression."""
|
||||
# Compress
|
||||
values, indices = self.compress_gradients(tensor)
|
||||
|
||||
# All-reduce compressed gradients
|
||||
dist.all_reduce(values, op=dist.ReduceOp.SUM)
|
||||
|
||||
# Decompress
|
||||
tensor_compressed = torch.zeros_like(tensor).view(-1)
|
||||
tensor_compressed[indices] = values / dist.get_world_size()
|
||||
|
||||
return tensor_compressed.view_as(tensor)
|
||||
|
||||
# Usage in training loop
|
||||
custom_ar = CustomAllReduce(compression_ratio=0.1)
|
||||
|
||||
for batch in dataloader:
|
||||
loss = model(**batch).loss
|
||||
loss.backward()
|
||||
|
||||
# Custom all-reduce
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
## Plugin Best Practices
|
||||
|
||||
### 1. Validation in `__post_init__`
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
learning_rate: float = 1e-3
|
||||
warmup_steps: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
# Validate parameters
|
||||
if self.learning_rate <= 0:
|
||||
raise ValueError("learning_rate must be positive")
|
||||
if self.warmup_steps < 0:
|
||||
raise ValueError("warmup_steps must be non-negative")
|
||||
|
||||
# Compute derived values
|
||||
self.min_lr = self.learning_rate * 0.1
|
||||
```
|
||||
|
||||
### 2. Compatibility Checks
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
feature_enabled: bool = True
|
||||
|
||||
def is_compatible(self, accelerator):
|
||||
"""Check if plugin is compatible with accelerator config."""
|
||||
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
|
||||
raise ValueError("Custom plugin not compatible with FP8")
|
||||
return True
|
||||
```
|
||||
|
||||
### 3. State Management
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CustomPlugin:
|
||||
counter: int = 0
|
||||
history: list = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.history is None:
|
||||
self.history = []
|
||||
|
||||
def update_state(self, value):
|
||||
"""Update plugin state during training."""
|
||||
self.counter += 1
|
||||
self.history.append(value)
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
|
||||
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
|
||||
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
|
||||
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu
|
||||
489
skills/mlops/accelerate/references/megatron-integration.md
Normal file
489
skills/mlops/accelerate/references/megatron-integration.md
Normal file
@@ -0,0 +1,489 @@
|
||||
# Megatron Integration with Accelerate
|
||||
|
||||
## Overview
|
||||
|
||||
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
|
||||
|
||||
**Megatron capabilities**:
|
||||
- **Tensor Parallelism (TP)**: Split layers across GPUs
|
||||
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
|
||||
- **Data Parallelism (DP)**: Replicate model across GPU groups
|
||||
- **Sequence Parallelism**: Split sequences for long contexts
|
||||
|
||||
## Setup
|
||||
|
||||
### Install Megatron-LM
|
||||
|
||||
```bash
|
||||
# Clone Megatron-LM repository
|
||||
git clone https://github.com/NVIDIA/Megatron-LM.git
|
||||
cd Megatron-LM
|
||||
pip install -e .
|
||||
|
||||
# Install Apex (NVIDIA optimizations)
|
||||
git clone https://github.com/NVIDIA/apex
|
||||
cd apex
|
||||
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
|
||||
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
```
|
||||
|
||||
### Accelerate Configuration
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
**Questions**:
|
||||
```
|
||||
In which compute environment are you running?
|
||||
> This machine
|
||||
|
||||
Which type of machine are you using?
|
||||
> Multi-GPU
|
||||
|
||||
How many different machines will you use?
|
||||
> 1
|
||||
|
||||
Do you want to use DeepSpeed/FSDP?
|
||||
> No
|
||||
|
||||
Do you want to use Megatron-LM?
|
||||
> Yes
|
||||
|
||||
What is the Tensor Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Do you want to enable Sequence Parallelism?
|
||||
> No
|
||||
|
||||
What is the Pipeline Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
What is the Data Parallelism degree? [1-8]
|
||||
> 2
|
||||
|
||||
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
|
||||
> SELECTIVE
|
||||
|
||||
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
|
||||
> SEQUENTIAL
|
||||
```
|
||||
|
||||
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: MEGATRON_LM
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
megatron_lm_config:
|
||||
megatron_lm_gradient_clipping: 1.0
|
||||
megatron_lm_learning_rate_decay_iters: 320000
|
||||
megatron_lm_num_micro_batches: 1
|
||||
megatron_lm_pp_degree: 2
|
||||
megatron_lm_recompute_activations: true
|
||||
megatron_lm_sequence_parallelism: false
|
||||
megatron_lm_tp_degree: 2
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
## Parallelism Strategies
|
||||
|
||||
### Tensor Parallelism (TP)
|
||||
|
||||
**Splits each transformer layer across GPUs**:
|
||||
|
||||
```python
|
||||
# Layer split across 2 GPUs
|
||||
# GPU 0: First half of attention heads
|
||||
# GPU 1: Second half of attention heads
|
||||
|
||||
# Each GPU computes partial outputs
|
||||
# All-reduce combines results
|
||||
```
|
||||
|
||||
**TP degree recommendations**:
|
||||
- **TP=1**: No tensor parallelism (single GPU per layer)
|
||||
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
|
||||
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
|
||||
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Reduces memory per GPU
|
||||
- All-reduce communication (fast)
|
||||
|
||||
**Drawbacks**:
|
||||
- Requires fast inter-GPU bandwidth (NVLink)
|
||||
- Communication overhead per layer
|
||||
|
||||
### Pipeline Parallelism (PP)
|
||||
|
||||
**Splits model depth across GPUs**:
|
||||
|
||||
```python
|
||||
# 12-layer model, PP=4
|
||||
# GPU 0: Layers 0-2
|
||||
# GPU 1: Layers 3-5
|
||||
# GPU 2: Layers 6-8
|
||||
# GPU 3: Layers 9-11
|
||||
```
|
||||
|
||||
**PP degree recommendations**:
|
||||
- **PP=1**: No pipeline parallelism
|
||||
- **PP=2**: 2 pipeline stages (good for 20-40B models)
|
||||
- **PP=4**: 4 pipeline stages (good for 70B+ models)
|
||||
- **PP=8**: 8 pipeline stages (good for 175B+ models)
|
||||
|
||||
**Benefits**:
|
||||
- Linear memory reduction (4× PP = 4× less memory)
|
||||
- Works across nodes (slower interconnect OK)
|
||||
|
||||
**Drawbacks**:
|
||||
- Pipeline bubbles (idle time)
|
||||
- Requires micro-batching
|
||||
|
||||
### Data Parallelism (DP)
|
||||
|
||||
**Replicates model across GPU groups**:
|
||||
|
||||
```python
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
# Group 0 (GPUs 0-3): Full model replica
|
||||
# Group 1 (GPUs 4-7): Full model replica
|
||||
```
|
||||
|
||||
**DP degree**:
|
||||
- `DP = total_gpus / (TP × PP)`
|
||||
- Example: 8 GPUs, TP=2, PP=2 → DP=2
|
||||
|
||||
**Benefits**:
|
||||
- Increases throughput
|
||||
- Scales batch size
|
||||
|
||||
### Sequence Parallelism
|
||||
|
||||
**Splits long sequences across GPUs** (extends TP):
|
||||
|
||||
```python
|
||||
# 8K sequence, TP=2, Sequence Parallel=True
|
||||
# GPU 0: Tokens 0-4095
|
||||
# GPU 1: Tokens 4096-8191
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Enables very long sequences (100K+ tokens)
|
||||
- Reduces activation memory
|
||||
|
||||
**Requirements**:
|
||||
- Must use with TP > 1
|
||||
- RoPE/ALiBi position encodings work best
|
||||
|
||||
## Accelerate Code Example
|
||||
|
||||
### Basic Setup
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
|
||||
# Configure Megatron
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2, # Tensor parallelism degree
|
||||
pp_degree=2, # Pipeline parallelism degree
|
||||
num_micro_batches=4, # Micro-batches for pipeline
|
||||
gradient_clipping=1.0, # Gradient clipping value
|
||||
sequence_parallelism=False, # Enable sequence parallelism
|
||||
recompute_activations=True, # Activation checkpointing
|
||||
use_distributed_optimizer=True, # Distributed optimizer
|
||||
custom_prepare_model_function=None, # Custom model prep
|
||||
)
|
||||
|
||||
# Initialize accelerator
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Prepare model and optimizer
|
||||
model, optimizer, train_dataloader = accelerator.prepare(
|
||||
model, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
# Training loop (same as DDP!)
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
### Full Training Script
|
||||
|
||||
```python
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import MegatronLMPlugin
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
def main():
|
||||
# Megatron configuration
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=2,
|
||||
pp_degree=2,
|
||||
num_micro_batches=4,
|
||||
gradient_clipping=1.0,
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='bf16',
|
||||
gradient_accumulation_steps=8,
|
||||
megatron_lm_plugin=megatron_plugin
|
||||
)
|
||||
|
||||
# Model
|
||||
config = GPT2Config(
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
n_embd=1024,
|
||||
)
|
||||
model = GPT2LMHeadModel(config)
|
||||
|
||||
# Optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
|
||||
|
||||
# Prepare
|
||||
model, optimizer, train_loader = accelerator.prepare(
|
||||
model, optimizer, train_loader
|
||||
)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(num_epochs):
|
||||
for batch in train_loader:
|
||||
with accelerator.accumulate(model):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Save checkpoint
|
||||
accelerator.wait_for_everyone()
|
||||
accelerator.save_state(f'checkpoint-epoch-{epoch}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
### Launch Command
|
||||
|
||||
```bash
|
||||
# 8 GPUs, TP=2, PP=2, DP=2
|
||||
accelerate launch --multi_gpu --num_processes 8 train.py
|
||||
|
||||
# Multi-node (2 nodes, 8 GPUs each)
|
||||
# Node 0
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 0 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
|
||||
# Node 1
|
||||
accelerate launch --multi_gpu --num_processes 16 \
|
||||
--num_machines 2 --machine_rank 1 \
|
||||
--main_process_ip $MASTER_ADDR \
|
||||
--main_process_port 29500 \
|
||||
train.py
|
||||
```
|
||||
|
||||
## Activation Checkpointing
|
||||
|
||||
**Reduces memory by recomputing activations**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
recompute_activations=True, # Enable checkpointing
|
||||
checkpoint_num_layers=1, # Checkpoint every N layers
|
||||
distribute_checkpointed_activations=True, # Distribute across TP
|
||||
partition_activations=True, # Partition in PP
|
||||
check_for_nan_in_loss_and_grad=True, # Stability check
|
||||
)
|
||||
```
|
||||
|
||||
**Strategies**:
|
||||
- `SELECTIVE`: Checkpoint transformer blocks only
|
||||
- `FULL`: Checkpoint all layers
|
||||
- `NONE`: No checkpointing
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
## Distributed Optimizer
|
||||
|
||||
**Shards optimizer state across DP ranks**:
|
||||
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
use_distributed_optimizer=True, # Enable sharded optimizer
|
||||
)
|
||||
```
|
||||
|
||||
**Benefits**:
|
||||
- Reduces optimizer memory by DP degree
|
||||
- Example: DP=4 → 4× less optimizer memory per GPU
|
||||
|
||||
**Compatible with**:
|
||||
- AdamW, Adam, SGD
|
||||
- Mixed precision training
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Micro-Batch Size
|
||||
|
||||
```python
|
||||
# Pipeline parallelism requires micro-batching
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # 16 micro-batches per pipeline
|
||||
)
|
||||
|
||||
# Effective batch = num_micro_batches × micro_batch_size × DP
|
||||
# Example: 16 × 2 × 4 = 128
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- More micro-batches → less pipeline bubble
|
||||
- Typical: 4-16 micro-batches
|
||||
|
||||
### Sequence Length
|
||||
|
||||
```python
|
||||
# For long sequences, enable sequence parallelism
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
tp_degree=4,
|
||||
sequence_parallelism=True, # Required: TP > 1
|
||||
)
|
||||
|
||||
# Enables sequences up to TP × normal limit
|
||||
# Example: TP=4, 8K normal → 32K with sequence parallel
|
||||
```
|
||||
|
||||
### GPU Topology
|
||||
|
||||
**NVLink required for TP**:
|
||||
```bash
|
||||
# Check NVLink topology
|
||||
nvidia-smi topo -m
|
||||
|
||||
# Good topology (NVLink between all GPUs)
|
||||
# GPU0 - GPU1: NV12 (fast)
|
||||
# GPU0 - GPU2: NV12 (fast)
|
||||
|
||||
# Bad topology (PCIe only)
|
||||
# GPU0 - GPU4: PHB (slow, avoid TP across these)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- **TP**: Within same node (NVLink)
|
||||
- **PP**: Across nodes (slower interconnect OK)
|
||||
- **DP**: Any topology
|
||||
|
||||
## Model Size Guidelines
|
||||
|
||||
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|
||||
|------------|------|----|----|----|--------------|
|
||||
| 7B | 8 | 1 | 1 | 8 | 1 |
|
||||
| 13B | 8 | 2 | 1 | 4 | 1 |
|
||||
| 20B | 16 | 4 | 1 | 4 | 1 |
|
||||
| 40B | 32 | 4 | 2 | 4 | 4 |
|
||||
| 70B | 64 | 8 | 2 | 4 | 8 |
|
||||
| 175B | 128 | 8 | 4 | 4 | 16 |
|
||||
|
||||
**Assumptions**: BF16, 2K sequence length, A100 80GB
|
||||
|
||||
## Checkpointing
|
||||
|
||||
### Save Checkpoint
|
||||
|
||||
```python
|
||||
# Save full model state
|
||||
accelerator.save_state('checkpoint-1000')
|
||||
|
||||
# Megatron saves separate files per rank
|
||||
# checkpoint-1000/
|
||||
# pytorch_model_tp_0_pp_0.bin
|
||||
# pytorch_model_tp_0_pp_1.bin
|
||||
# pytorch_model_tp_1_pp_0.bin
|
||||
# pytorch_model_tp_1_pp_1.bin
|
||||
# optimizer_tp_0_pp_0.bin
|
||||
# ...
|
||||
```
|
||||
|
||||
### Load Checkpoint
|
||||
|
||||
```python
|
||||
# Resume training
|
||||
accelerator.load_state('checkpoint-1000')
|
||||
|
||||
# Automatically loads correct shard per rank
|
||||
```
|
||||
|
||||
### Convert to Standard PyTorch
|
||||
|
||||
```bash
|
||||
# Merge Megatron checkpoint to single file
|
||||
python merge_megatron_checkpoint.py \
|
||||
--checkpoint-dir checkpoint-1000 \
|
||||
--output pytorch_model.bin
|
||||
```
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: OOM with Pipeline Parallelism
|
||||
|
||||
**Solution**: Increase micro-batches
|
||||
```python
|
||||
megatron_plugin = MegatronLMPlugin(
|
||||
pp_degree=4,
|
||||
num_micro_batches=16, # Increase from 4
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slow Training
|
||||
|
||||
**Check 1**: Pipeline bubbles (PP too high)
|
||||
```python
|
||||
# Reduce PP, increase TP
|
||||
tp_degree=4 # Increase
|
||||
pp_degree=2 # Decrease
|
||||
```
|
||||
|
||||
**Check 2**: Micro-batch size too small
|
||||
```python
|
||||
num_micro_batches=8 # Increase
|
||||
```
|
||||
|
||||
### Issue: NVLink Not Detected
|
||||
|
||||
```bash
|
||||
# Verify NVLink
|
||||
nvidia-smi nvlink -s
|
||||
|
||||
# If no NVLink, avoid TP > 1
|
||||
# Use PP or DP instead
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
|
||||
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
|
||||
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
|
||||
- NVIDIA Apex: https://github.com/NVIDIA/apex
|
||||
525
skills/mlops/accelerate/references/performance.md
Normal file
525
skills/mlops/accelerate/references/performance.md
Normal file
@@ -0,0 +1,525 @@
|
||||
# Accelerate Performance Tuning
|
||||
|
||||
## Profiling
|
||||
|
||||
### Basic Profiling
|
||||
|
||||
```python
|
||||
from accelerate import Accelerator
|
||||
import time
|
||||
|
||||
accelerator = Accelerator()
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
batch = next(iter(dataloader))
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Profile training loop
|
||||
start = time.time()
|
||||
total_batches = 100
|
||||
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= total_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone() # Sync all processes
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
batches_per_sec = total_batches / elapsed
|
||||
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
|
||||
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
|
||||
print(f"Batches/sec: {batches_per_sec:.2f}")
|
||||
```
|
||||
|
||||
### PyTorch Profiler Integration
|
||||
|
||||
```python
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True
|
||||
) as prof:
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10: # Profile first 10 batches
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Print profiling results
|
||||
print(prof.key_averages().table(
|
||||
sort_by="cuda_time_total", row_limit=20
|
||||
))
|
||||
|
||||
# Export to Chrome tracing
|
||||
prof.export_chrome_trace("trace.json")
|
||||
# View at chrome://tracing
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### 1. Gradient Accumulation
|
||||
|
||||
**Problem**: Large batch size causes OOM
|
||||
|
||||
**Solution**: Accumulate gradients across micro-batches
|
||||
|
||||
```python
|
||||
accelerator = Accelerator(gradient_accumulation_steps=8)
|
||||
|
||||
# Effective batch = batch_size × accumulation_steps × num_gpus
|
||||
# Example: 4 × 8 × 8 = 256
|
||||
|
||||
for batch in dataloader:
|
||||
with accelerator.accumulate(model): # Handles accumulation logic
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
|
||||
|
||||
### 2. Gradient Checkpointing
|
||||
|
||||
**Enable in model**:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
use_cache=False # Required for gradient checkpointing
|
||||
)
|
||||
|
||||
# Enable checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Prepare with Accelerate
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 30-50% with 10-15% slowdown
|
||||
|
||||
### 3. Mixed Precision
|
||||
|
||||
**BF16 (A100/H100)**:
|
||||
```python
|
||||
accelerator = Accelerator(mixed_precision='bf16')
|
||||
|
||||
# Automatic mixed precision
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch) # Forward in BF16
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss) # Backward in FP32
|
||||
optimizer.step()
|
||||
```
|
||||
|
||||
**FP16 (V100, older GPUs)**:
|
||||
```python
|
||||
from accelerate.utils import GradScalerKwargs
|
||||
|
||||
scaler_kwargs = GradScalerKwargs(
|
||||
init_scale=2.**16,
|
||||
growth_interval=2000
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision='fp16',
|
||||
kwargs_handlers=[scaler_kwargs]
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% compared to FP32
|
||||
|
||||
### 4. CPU Offloading (DeepSpeed)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DeepSpeedPlugin
|
||||
|
||||
ds_plugin = DeepSpeedPlugin(
|
||||
zero_stage=3,
|
||||
offload_optimizer_device="cpu", # Offload optimizer to CPU
|
||||
offload_param_device="cpu", # Offload parameters to CPU
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
deepspeed_plugin=ds_plugin,
|
||||
mixed_precision='bf16'
|
||||
)
|
||||
```
|
||||
|
||||
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
|
||||
|
||||
**Trade-off**: 20-30% slower due to CPU-GPU transfers
|
||||
|
||||
### 5. Flash Attention
|
||||
|
||||
```python
|
||||
# Install flash-attn
|
||||
# pip install flash-attn
|
||||
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"gpt2",
|
||||
attn_implementation="flash_attention_2" # Enable Flash Attention 2
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Memory savings**: 50% for attention, 2× faster
|
||||
|
||||
**Requirements**: A100/H100, sequence length must be multiple of 128
|
||||
|
||||
## Communication Optimization
|
||||
|
||||
### 1. Gradient Bucketing (DDP)
|
||||
|
||||
```python
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
bucket_cap_mb=25, # Bucket size for gradient reduction
|
||||
gradient_as_bucket_view=True, # Reduce memory copies
|
||||
static_graph=False # Set True if model doesn't change
|
||||
)
|
||||
|
||||
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
|
||||
```
|
||||
|
||||
**Recommended bucket sizes**:
|
||||
- Small models (<1B): 25 MB
|
||||
- Medium models (1-10B): 50-100 MB
|
||||
- Large models (>10B): 100-200 MB
|
||||
|
||||
### 2. Find Unused Parameters
|
||||
|
||||
```python
|
||||
# Only enable if model has unused parameters (slower!)
|
||||
ddp_kwargs = DistributedDataParallelKwargs(
|
||||
find_unused_parameters=True
|
||||
)
|
||||
```
|
||||
|
||||
**Use case**: Models with conditional branches (e.g., mixture of experts)
|
||||
|
||||
**Cost**: 10-20% slower
|
||||
|
||||
### 3. NCCL Tuning
|
||||
|
||||
```bash
|
||||
# Set environment variables before launch
|
||||
export NCCL_DEBUG=INFO # Debug info
|
||||
export NCCL_IB_DISABLE=0 # Enable InfiniBand
|
||||
export NCCL_SOCKET_IFNAME=eth0 # Network interface
|
||||
export NCCL_P2P_LEVEL=NVL # Use NVLink
|
||||
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
**NCCL_P2P_LEVEL options**:
|
||||
- `NVL`: NVLink (fastest, within node)
|
||||
- `PIX`: PCIe (fast, within node)
|
||||
- `PHB`: PCIe host bridge (slow, cross-node)
|
||||
|
||||
## Data Loading Optimization
|
||||
|
||||
### 1. DataLoader Workers
|
||||
|
||||
```python
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
train_loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=32,
|
||||
num_workers=4, # Parallel data loading
|
||||
pin_memory=True, # Pin memory for faster GPU transfer
|
||||
prefetch_factor=2, # Prefetch batches per worker
|
||||
persistent_workers=True # Keep workers alive between epochs
|
||||
)
|
||||
|
||||
train_loader = accelerator.prepare(train_loader)
|
||||
```
|
||||
|
||||
**Recommendations**:
|
||||
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
|
||||
- `pin_memory`: Always True for GPU training
|
||||
- `prefetch_factor`: 2-4 (higher for slow data loading)
|
||||
|
||||
### 2. Data Preprocessing
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
# Bad: Preprocess during training (slow)
|
||||
dataset = load_dataset("openwebtext")
|
||||
|
||||
for batch in dataset:
|
||||
tokens = tokenizer(batch['text']) # Slow!
|
||||
...
|
||||
|
||||
# Good: Preprocess once, save
|
||||
dataset = load_dataset("openwebtext")
|
||||
tokenized = dataset.map(
|
||||
lambda x: tokenizer(x['text']),
|
||||
batched=True,
|
||||
num_proc=8, # Parallel preprocessing
|
||||
remove_columns=['text']
|
||||
)
|
||||
tokenized.save_to_disk("preprocessed_data")
|
||||
|
||||
# Load preprocessed
|
||||
dataset = load_from_disk("preprocessed_data")
|
||||
```
|
||||
|
||||
### 3. Faster Tokenization
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# Enable Rust-based tokenizers (10× faster)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"gpt2",
|
||||
use_fast=True # Use fast Rust tokenizer
|
||||
)
|
||||
```
|
||||
|
||||
## Compilation (PyTorch 2.0+)
|
||||
|
||||
### Compile Model
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Compile model for faster execution
|
||||
model = torch.compile(
|
||||
model,
|
||||
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
|
||||
fullgraph=False, # Compile entire graph (stricter)
|
||||
dynamic=True # Support dynamic shapes
|
||||
)
|
||||
|
||||
model = accelerator.prepare(model)
|
||||
```
|
||||
|
||||
**Speedup**: 10-50% depending on model
|
||||
|
||||
**Compilation modes**:
|
||||
- `default`: Balanced (best for most cases)
|
||||
- `reduce-overhead`: Min overhead (best for small batches)
|
||||
- `max-autotune`: Max performance (slow compile, best for production)
|
||||
|
||||
### Compilation Best Practices
|
||||
|
||||
```python
|
||||
# Bad: Compile after prepare (won't work)
|
||||
model = accelerator.prepare(model)
|
||||
model = torch.compile(model) # Error!
|
||||
|
||||
# Good: Compile before prepare
|
||||
model = torch.compile(model)
|
||||
model = accelerator.prepare(model)
|
||||
|
||||
# Training loop
|
||||
for batch in dataloader:
|
||||
# First iteration: slow (compilation)
|
||||
# Subsequent iterations: fast (compiled)
|
||||
outputs = model(**batch)
|
||||
...
|
||||
```
|
||||
|
||||
## Benchmarking Different Strategies
|
||||
|
||||
### Script Template
|
||||
|
||||
```python
|
||||
import time
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
|
||||
def benchmark_strategy(strategy_name, accelerator_kwargs):
|
||||
"""Benchmark a specific training strategy."""
|
||||
accelerator = Accelerator(**accelerator_kwargs)
|
||||
|
||||
# Setup
|
||||
model = create_model()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
|
||||
dataloader = create_dataloader()
|
||||
|
||||
model, optimizer, dataloader = accelerator.prepare(
|
||||
model, optimizer, dataloader
|
||||
)
|
||||
|
||||
# Warmup
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= 10:
|
||||
break
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Benchmark
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
|
||||
num_batches = 100
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i >= num_batches:
|
||||
break
|
||||
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.time() - start
|
||||
|
||||
# Metrics
|
||||
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
|
||||
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(f"\n{strategy_name}:")
|
||||
print(f" Throughput: {throughput:.2f} samples/sec")
|
||||
print(f" Memory: {memory_used:.2f} GB")
|
||||
print(f" Time: {elapsed:.2f} sec")
|
||||
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Benchmark different strategies
|
||||
strategies = [
|
||||
("DDP + FP32", {}),
|
||||
("DDP + BF16", {"mixed_precision": "bf16"}),
|
||||
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
|
||||
("FSDP", {"fsdp_plugin": fsdp_plugin}),
|
||||
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
|
||||
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
|
||||
]
|
||||
|
||||
for name, kwargs in strategies:
|
||||
benchmark_strategy(name, kwargs)
|
||||
```
|
||||
|
||||
## Performance Checklist
|
||||
|
||||
**Before training**:
|
||||
- [ ] Use BF16/FP16 mixed precision
|
||||
- [ ] Enable gradient checkpointing (if OOM)
|
||||
- [ ] Set appropriate `num_workers` (2-4 per GPU)
|
||||
- [ ] Enable `pin_memory=True`
|
||||
- [ ] Preprocess data once, not during training
|
||||
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
|
||||
|
||||
**For large models**:
|
||||
- [ ] Use FSDP or DeepSpeed ZeRO-3
|
||||
- [ ] Enable CPU offloading (if still OOM)
|
||||
- [ ] Use Flash Attention
|
||||
- [ ] Increase gradient accumulation
|
||||
|
||||
**For multi-node**:
|
||||
- [ ] Check network topology (InfiniBand > Ethernet)
|
||||
- [ ] Tune NCCL settings
|
||||
- [ ] Use larger bucket sizes for DDP
|
||||
- [ ] Verify NVLink for tensor parallelism
|
||||
|
||||
**Profiling**:
|
||||
- [ ] Profile first 10-100 batches
|
||||
- [ ] Check GPU utilization (`nvidia-smi dmon`)
|
||||
- [ ] Check data loading time (should be <5% of iteration)
|
||||
- [ ] Identify communication bottlenecks
|
||||
|
||||
## Common Performance Issues
|
||||
|
||||
### Issue: Low GPU Utilization (<80%)
|
||||
|
||||
**Cause 1**: Data loading bottleneck
|
||||
```python
|
||||
# Solution: Increase workers and prefetch
|
||||
num_workers=8
|
||||
prefetch_factor=4
|
||||
```
|
||||
|
||||
**Cause 2**: Small batch size
|
||||
```python
|
||||
# Solution: Increase batch size or use gradient accumulation
|
||||
batch_size=32 # Increase
|
||||
gradient_accumulation_steps=4 # Or accumulate
|
||||
```
|
||||
|
||||
### Issue: High Memory Usage
|
||||
|
||||
**Solution 1**: Gradient checkpointing
|
||||
```python
|
||||
model.gradient_checkpointing_enable()
|
||||
```
|
||||
|
||||
**Solution 2**: Reduce batch size, increase accumulation
|
||||
```python
|
||||
batch_size=8 # Reduce from 32
|
||||
gradient_accumulation_steps=16 # Maintain effective batch
|
||||
```
|
||||
|
||||
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
|
||||
```python
|
||||
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
|
||||
```
|
||||
|
||||
### Issue: Slow Multi-GPU Training
|
||||
|
||||
**Cause**: Communication bottleneck
|
||||
|
||||
**Check 1**: Gradient bucket size
|
||||
```python
|
||||
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
|
||||
```
|
||||
|
||||
**Check 2**: NCCL settings
|
||||
```bash
|
||||
export NCCL_DEBUG=INFO
|
||||
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
|
||||
```
|
||||
|
||||
**Check 3**: Network bandwidth
|
||||
```bash
|
||||
# Test inter-GPU bandwidth
|
||||
nvidia-smi nvlink -s
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
|
||||
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
|
||||
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
|
||||
- Flash Attention: https://github.com/Dao-AILab/flash-attention
|
||||
567
skills/mlops/audiocraft/SKILL.md
Normal file
567
skills/mlops/audiocraft/SKILL.md
Normal file
@@ -0,0 +1,567 @@
|
||||
---
|
||||
name: audiocraft-audio-generation
|
||||
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
|
||||
|
||||
---
|
||||
|
||||
# AudioCraft: Audio Generation
|
||||
|
||||
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
|
||||
|
||||
## When to use AudioCraft
|
||||
|
||||
**Use AudioCraft when:**
|
||||
- Need to generate music from text descriptions
|
||||
- Creating sound effects and environmental audio
|
||||
- Building music generation applications
|
||||
- Need melody-conditioned music generation
|
||||
- Want stereo audio output
|
||||
- Require controllable music generation with style transfer
|
||||
|
||||
**Key features:**
|
||||
- **MusicGen**: Text-to-music generation with melody conditioning
|
||||
- **AudioGen**: Text-to-sound effects generation
|
||||
- **EnCodec**: High-fidelity neural audio codec
|
||||
- **Multiple model sizes**: Small (300M) to Large (3.3B)
|
||||
- **Stereo support**: Full stereo audio generation
|
||||
- **Style conditioning**: MusicGen-Style for reference-based generation
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Stable Audio**: For longer commercial music generation
|
||||
- **Bark**: For text-to-speech with music/sound effects
|
||||
- **Riffusion**: For spectogram-based music generation
|
||||
- **OpenAI Jukebox**: For raw audio generation with lyrics
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# From PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# From GitHub (latest)
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Or use HuggingFace Transformers
|
||||
pip install transformers torch torchaudio
|
||||
```
|
||||
|
||||
### Basic text-to-music (AudioCraft)
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Set generation parameters
|
||||
model.set_generation_params(
|
||||
duration=8, # seconds
|
||||
top_k=250,
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Generate from text
|
||||
descriptions = ["happy upbeat electronic dance music with synths"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save audio
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Using HuggingFace Transformers
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
import scipy
|
||||
|
||||
# Load model and processor
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
model.to("cuda")
|
||||
|
||||
# Generate music
|
||||
inputs = processor(
|
||||
text=["80s pop track with bassy drums and synth"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True,
|
||||
guidance_scale=3,
|
||||
max_new_tokens=256
|
||||
)
|
||||
|
||||
# Save
|
||||
sampling_rate = model.config.audio_encoder.sampling_rate
|
||||
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
|
||||
```
|
||||
|
||||
### Text-to-sound with AudioGen
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
|
||||
# Load AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
model.set_generation_params(duration=5)
|
||||
|
||||
# Generate sound effects
|
||||
descriptions = ["dog barking in a park with birds chirping"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## Core concepts
|
||||
|
||||
### Architecture overview
|
||||
|
||||
```
|
||||
AudioCraft Architecture:
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Text Encoder (T5) │
|
||||
│ │ │
|
||||
│ Text Embeddings │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ Transformer Decoder (LM) │
|
||||
│ Auto-regressively generates audio tokens │
|
||||
│ Using efficient token interleaving patterns │
|
||||
└────────────────────────┬─────────────────────────────────────┘
|
||||
│
|
||||
┌────────────────────────▼─────────────────────────────────────┐
|
||||
│ EnCodec Audio Decoder │
|
||||
│ Converts tokens back to audio waveform │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Model variants
|
||||
|
||||
| Model | Size | Description | Use Case |
|
||||
|-------|------|-------------|----------|
|
||||
| `musicgen-small` | 300M | Text-to-music | Quick generation |
|
||||
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
|
||||
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
|
||||
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
|
||||
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
|
||||
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
|
||||
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
|
||||
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
|
||||
|
||||
### Generation parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| `duration` | 8.0 | Length in seconds (1-120) |
|
||||
| `top_k` | 250 | Top-k sampling |
|
||||
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
|
||||
| `temperature` | 1.0 | Sampling temperature |
|
||||
| `cfg_coef` | 3.0 | Classifier-free guidance |
|
||||
|
||||
## MusicGen usage
|
||||
|
||||
### Text-to-music generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Configure generation
|
||||
model.set_generation_params(
|
||||
duration=30, # Up to 30 seconds
|
||||
top_k=250, # Sampling diversity
|
||||
top_p=0.0, # 0 = use top_k only
|
||||
temperature=1.0, # Creativity (higher = more varied)
|
||||
cfg_coef=3.0 # Text adherence (higher = stricter)
|
||||
)
|
||||
|
||||
# Generate multiple samples
|
||||
descriptions = [
|
||||
"epic orchestral soundtrack with strings and brass",
|
||||
"chill lo-fi hip hop beat with jazzy piano",
|
||||
"energetic rock song with electric guitar"
|
||||
]
|
||||
|
||||
# Generate (returns [batch, channels, samples])
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# Save each
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Melody-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
import torchaudio
|
||||
|
||||
# Load melody model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
model.set_generation_params(duration=30)
|
||||
|
||||
# Load melody audio
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Generate with melody conditioning
|
||||
descriptions = ["acoustic guitar folk song"]
|
||||
wav = model.generate_with_chroma(descriptions, melody, sr)
|
||||
|
||||
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Stereo generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load stereo model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
model.set_generation_params(duration=15)
|
||||
|
||||
descriptions = ["ambient electronic music with wide stereo panning"]
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
# wav shape: [batch, 2, samples] for stereo
|
||||
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
|
||||
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
### Audio continuation
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
|
||||
|
||||
# Load audio to continue
|
||||
import torchaudio
|
||||
audio, sr = torchaudio.load("intro.wav")
|
||||
|
||||
# Process with text and audio
|
||||
inputs = processor(
|
||||
audio=audio.squeeze().numpy(),
|
||||
sampling_rate=sr,
|
||||
text=["continue with a epic chorus"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Generate continuation
|
||||
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
|
||||
```
|
||||
|
||||
## MusicGen-Style usage
|
||||
|
||||
### Style-conditioned generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load style model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-style')
|
||||
|
||||
# Configure generation with style
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=5.0 # Style influence
|
||||
)
|
||||
|
||||
# Configure style conditioner
|
||||
model.set_style_conditioner_params(
|
||||
eval_q=3, # RVQ quantizers (1-6)
|
||||
excerpt_length=3.0 # Style excerpt length
|
||||
)
|
||||
|
||||
# Load style reference
|
||||
style_audio, sr = torchaudio.load("reference_style.wav")
|
||||
|
||||
# Generate with text + style
|
||||
descriptions = ["upbeat dance track"]
|
||||
wav = model.generate_with_style(descriptions, style_audio, sr)
|
||||
```
|
||||
|
||||
### Style-only generation (no text)
|
||||
|
||||
```python
|
||||
# Generate matching style without text prompt
|
||||
model.set_generation_params(
|
||||
duration=30,
|
||||
cfg_coef=3.0,
|
||||
cfg_coef_beta=None # Disable double CFG for style-only
|
||||
)
|
||||
|
||||
wav = model.generate_with_style([None], style_audio, sr)
|
||||
```
|
||||
|
||||
## AudioGen usage
|
||||
|
||||
### Sound effect generation
|
||||
|
||||
```python
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate various sounds
|
||||
descriptions = [
|
||||
"thunderstorm with heavy rain and lightning",
|
||||
"busy city traffic with car horns",
|
||||
"ocean waves crashing on rocks",
|
||||
"crackling campfire in forest"
|
||||
]
|
||||
|
||||
wav = model.generate(descriptions)
|
||||
|
||||
for i, audio in enumerate(wav):
|
||||
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
|
||||
```
|
||||
|
||||
## EnCodec usage
|
||||
|
||||
### Audio compression
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# Load EnCodec
|
||||
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Load audio
|
||||
wav, sr = torchaudio.load("audio.wav")
|
||||
|
||||
# Ensure correct sample rate
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Encode to tokens
|
||||
with torch.no_grad():
|
||||
encoded = model.encode(wav.unsqueeze(0))
|
||||
codes = encoded[0] # Audio codes
|
||||
|
||||
# Decode back to audio
|
||||
with torch.no_grad():
|
||||
decoded = model.decode(codes)
|
||||
|
||||
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Music generation pipeline
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenerator:
|
||||
def __init__(self, model_name="facebook/musicgen-medium"):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.sample_rate = 32000
|
||||
|
||||
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
|
||||
self.model.set_generation_params(
|
||||
duration=duration,
|
||||
top_k=250,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([prompt])
|
||||
|
||||
return wav[0].cpu()
|
||||
|
||||
def generate_batch(self, prompts, duration=30):
|
||||
self.model.set_generation_params(duration=duration)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate(prompts)
|
||||
|
||||
return wav.cpu()
|
||||
|
||||
def save(self, audio, path):
|
||||
torchaudio.save(path, audio, sample_rate=self.sample_rate)
|
||||
|
||||
# Usage
|
||||
generator = MusicGenerator()
|
||||
audio = generator.generate(
|
||||
"epic cinematic orchestral music",
|
||||
duration=30,
|
||||
temperature=1.0
|
||||
)
|
||||
generator.save(audio, "epic_music.wav")
|
||||
```
|
||||
|
||||
### Workflow 2: Sound design batch processing
|
||||
|
||||
```python
|
||||
import json
|
||||
from pathlib import Path
|
||||
from audiocraft.models import AudioGen
|
||||
import torchaudio
|
||||
|
||||
def batch_generate_sounds(sound_specs, output_dir):
|
||||
"""
|
||||
Generate multiple sounds from specifications.
|
||||
|
||||
Args:
|
||||
sound_specs: list of {"name": str, "description": str, "duration": float}
|
||||
output_dir: output directory path
|
||||
"""
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
output_dir = Path(output_dir)
|
||||
output_dir.mkdir(exist_ok=True)
|
||||
|
||||
results = []
|
||||
|
||||
for spec in sound_specs:
|
||||
model.set_generation_params(duration=spec.get("duration", 5))
|
||||
|
||||
wav = model.generate([spec["description"]])
|
||||
|
||||
output_path = output_dir / f"{spec['name']}.wav"
|
||||
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
|
||||
|
||||
results.append({
|
||||
"name": spec["name"],
|
||||
"path": str(output_path),
|
||||
"description": spec["description"]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
# Usage
|
||||
sounds = [
|
||||
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
|
||||
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
|
||||
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
|
||||
]
|
||||
|
||||
results = batch_generate_sounds(sounds, "sound_effects/")
|
||||
```
|
||||
|
||||
### Workflow 3: Gradio demo
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
def generate_music(prompt, duration, temperature, cfg_coef):
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save to temp file
|
||||
path = "temp_output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate_music,
|
||||
inputs=[
|
||||
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
|
||||
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
|
||||
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Demo"
|
||||
)
|
||||
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Performance optimization
|
||||
|
||||
### Memory optimization
|
||||
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Clear cache between generations
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Generate shorter durations
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Use half precision
|
||||
model = model.half()
|
||||
```
|
||||
|
||||
### Batch processing efficiency
|
||||
|
||||
```python
|
||||
# Process multiple prompts at once (more efficient)
|
||||
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
|
||||
wav = model.generate(descriptions) # Single batch
|
||||
|
||||
# Instead of
|
||||
for desc in descriptions:
|
||||
wav = model.generate([desc]) # Multiple batches (slower)
|
||||
```
|
||||
|
||||
### GPU memory requirements
|
||||
|
||||
| Model | FP32 VRAM | FP16 VRAM |
|
||||
|-------|-----------|-----------|
|
||||
| musicgen-small | ~4GB | ~2GB |
|
||||
| musicgen-medium | ~8GB | ~4GB |
|
||||
| musicgen-large | ~16GB | ~8GB |
|
||||
|
||||
## Common issues
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| CUDA OOM | Use smaller model, reduce duration |
|
||||
| Poor quality | Increase cfg_coef, better prompts |
|
||||
| Generation too short | Check max duration setting |
|
||||
| Audio artifacts | Try different temperature |
|
||||
| Stereo not working | Use stereo model variant |
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/audiocraft
|
||||
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
|
||||
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
|
||||
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
|
||||
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen
|
||||
666
skills/mlops/audiocraft/references/advanced-usage.md
Normal file
666
skills/mlops/audiocraft/references/advanced-usage.md
Normal file
@@ -0,0 +1,666 @@
|
||||
# AudioCraft Advanced Usage Guide
|
||||
|
||||
## Fine-tuning MusicGen
|
||||
|
||||
### Custom dataset preparation
|
||||
|
||||
```python
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
import torchaudio
|
||||
|
||||
def prepare_dataset(audio_dir, output_dir, metadata_file):
|
||||
"""
|
||||
Prepare dataset for MusicGen fine-tuning.
|
||||
|
||||
Directory structure:
|
||||
output_dir/
|
||||
├── audio/
|
||||
│ ├── 0001.wav
|
||||
│ ├── 0002.wav
|
||||
│ └── ...
|
||||
└── metadata.json
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
audio_output = output_dir / "audio"
|
||||
audio_output.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load metadata (format: {"path": "...", "description": "..."})
|
||||
with open(metadata_file) as f:
|
||||
metadata = json.load(f)
|
||||
|
||||
processed = []
|
||||
|
||||
for idx, item in enumerate(metadata):
|
||||
audio_path = Path(audio_dir) / item["path"]
|
||||
|
||||
# Load and resample to 32kHz
|
||||
wav, sr = torchaudio.load(str(audio_path))
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
wav = resampler(wav)
|
||||
|
||||
# Convert to mono if stereo
|
||||
if wav.shape[0] > 1:
|
||||
wav = wav.mean(dim=0, keepdim=True)
|
||||
|
||||
# Save processed audio
|
||||
output_path = audio_output / f"{idx:04d}.wav"
|
||||
torchaudio.save(str(output_path), wav, sample_rate=32000)
|
||||
|
||||
processed.append({
|
||||
"path": str(output_path.relative_to(output_dir)),
|
||||
"description": item["description"],
|
||||
"duration": wav.shape[1] / 32000
|
||||
})
|
||||
|
||||
# Save processed metadata
|
||||
with open(output_dir / "metadata.json", "w") as f:
|
||||
json.dump(processed, f, indent=2)
|
||||
|
||||
print(f"Processed {len(processed)} samples")
|
||||
return processed
|
||||
```
|
||||
|
||||
### Fine-tuning with dora
|
||||
|
||||
```bash
|
||||
# AudioCraft uses dora for experiment management
|
||||
# Install dora
|
||||
pip install dora-search
|
||||
|
||||
# Clone AudioCraft
|
||||
git clone https://github.com/facebookresearch/audiocraft.git
|
||||
cd audiocraft
|
||||
|
||||
# Create config for fine-tuning
|
||||
cat > config/solver/musicgen/finetune.yaml << 'EOF'
|
||||
defaults:
|
||||
- musicgen/musicgen_base
|
||||
- /model: lm/musicgen_lm
|
||||
- /conditioner: cond_base
|
||||
|
||||
solver: musicgen
|
||||
autocast: true
|
||||
autocast_dtype: float16
|
||||
|
||||
optim:
|
||||
epochs: 100
|
||||
batch_size: 4
|
||||
lr: 1e-4
|
||||
ema: 0.999
|
||||
optimizer: adamw
|
||||
|
||||
dataset:
|
||||
batch_size: 4
|
||||
num_workers: 4
|
||||
train:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
valid:
|
||||
- dset: your_dataset
|
||||
root: /path/to/dataset
|
||||
|
||||
checkpoint:
|
||||
save_every: 10
|
||||
keep_every_states: null
|
||||
EOF
|
||||
|
||||
# Run fine-tuning
|
||||
dora run solver=musicgen/finetune
|
||||
```
|
||||
|
||||
### LoRA fine-tuning
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from audiocraft.models import MusicGen
|
||||
import torch
|
||||
|
||||
# Load base model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Get the language model component
|
||||
lm = model.lm
|
||||
|
||||
# Configure LoRA
|
||||
lora_config = LoraConfig(
|
||||
r=8,
|
||||
lora_alpha=16,
|
||||
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
|
||||
lora_dropout=0.05,
|
||||
bias="none"
|
||||
)
|
||||
|
||||
# Apply LoRA
|
||||
lm = get_peft_model(lm, lora_config)
|
||||
lm.print_trainable_parameters()
|
||||
```
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
### DataParallel
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Wrap LM with DataParallel
|
||||
if torch.cuda.device_count() > 1:
|
||||
model.lm = nn.DataParallel(model.lm)
|
||||
|
||||
model.to("cuda")
|
||||
```
|
||||
|
||||
### DistributedDataParallel
|
||||
|
||||
```python
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
def setup(rank, world_size):
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
def train(rank, world_size):
|
||||
setup(rank, world_size)
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.lm = model.lm.to(rank)
|
||||
model.lm = DDP(model.lm, device_ids=[rank])
|
||||
|
||||
# Training loop
|
||||
# ...
|
||||
|
||||
dist.destroy_process_group()
|
||||
```
|
||||
|
||||
## Custom Conditioning
|
||||
|
||||
### Adding new conditioners
|
||||
|
||||
```python
|
||||
from audiocraft.modules.conditioners import BaseConditioner
|
||||
import torch
|
||||
|
||||
class CustomConditioner(BaseConditioner):
|
||||
"""Custom conditioner for additional control signals."""
|
||||
|
||||
def __init__(self, dim, output_dim):
|
||||
super().__init__(dim, output_dim)
|
||||
self.embed = torch.nn.Linear(dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.embed(x)
|
||||
|
||||
def tokenize(self, x):
|
||||
# Tokenize input for conditioning
|
||||
return x
|
||||
|
||||
# Use with MusicGen
|
||||
from audiocraft.models.builders import get_lm_model
|
||||
|
||||
# Modify model config to include custom conditioner
|
||||
# This requires editing the model configuration
|
||||
```
|
||||
|
||||
### Melody conditioning internals
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen
|
||||
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
|
||||
import torch
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Access chroma extractor
|
||||
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
|
||||
|
||||
# Manual chroma extraction
|
||||
def extract_chroma(audio, sr):
|
||||
"""Extract chroma features from audio."""
|
||||
import librosa
|
||||
|
||||
# Compute chroma
|
||||
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
|
||||
|
||||
return torch.from_numpy(chroma).float()
|
||||
|
||||
# Use extracted chroma for conditioning
|
||||
chroma = extract_chroma(melody_audio, sample_rate)
|
||||
```
|
||||
|
||||
## EnCodec Deep Dive
|
||||
|
||||
### Custom compression settings
|
||||
|
||||
```python
|
||||
from audiocraft.models import CompressionModel
|
||||
import torch
|
||||
|
||||
# Load EnCodec
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
# Access codec parameters
|
||||
print(f"Sample rate: {encodec.sample_rate}")
|
||||
print(f"Channels: {encodec.channels}")
|
||||
print(f"Cardinality: {encodec.cardinality}") # Codebook size
|
||||
print(f"Num codebooks: {encodec.num_codebooks}")
|
||||
print(f"Frame rate: {encodec.frame_rate}")
|
||||
|
||||
# Encode with specific bandwidth
|
||||
# Lower bandwidth = more compression, lower quality
|
||||
encodec.set_target_bandwidth(6.0) # 6 kbps
|
||||
|
||||
audio = torch.randn(1, 1, 32000) # 1 second
|
||||
encoded = encodec.encode(audio)
|
||||
decoded = encodec.decode(encoded[0])
|
||||
```
|
||||
|
||||
### Streaming encoding
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.models import CompressionModel
|
||||
|
||||
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
|
||||
|
||||
def encode_streaming(audio_stream, chunk_size=32000):
|
||||
"""Encode audio in streaming fashion."""
|
||||
all_codes = []
|
||||
|
||||
for chunk in audio_stream:
|
||||
# Ensure chunk is right shape
|
||||
if chunk.dim() == 1:
|
||||
chunk = chunk.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
codes = encodec.encode(chunk)[0]
|
||||
all_codes.append(codes)
|
||||
|
||||
return torch.cat(all_codes, dim=-1)
|
||||
|
||||
def decode_streaming(codes_stream, output_stream):
|
||||
"""Decode codes in streaming fashion."""
|
||||
for codes in codes_stream:
|
||||
with torch.no_grad():
|
||||
audio = encodec.decode(codes)
|
||||
output_stream.write(audio.cpu().numpy())
|
||||
```
|
||||
|
||||
## MultiBand Diffusion
|
||||
|
||||
### Using MBD for enhanced quality
|
||||
|
||||
```python
|
||||
from audiocraft.models import MusicGen, MultiBandDiffusion
|
||||
|
||||
# Load MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# Load MultiBand Diffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Generate with standard decoder
|
||||
descriptions = ["epic orchestral music"]
|
||||
wav_standard = model.generate(descriptions)
|
||||
|
||||
# Generate tokens and use MBD decoder
|
||||
with torch.no_grad():
|
||||
# Get tokens
|
||||
gen_tokens = model.generate_tokens(descriptions)
|
||||
|
||||
# Decode with MBD
|
||||
wav_mbd = mbd.tokens_to_wav(gen_tokens)
|
||||
|
||||
# Compare quality
|
||||
print(f"Standard shape: {wav_standard.shape}")
|
||||
print(f"MBD shape: {wav_mbd.shape}")
|
||||
```
|
||||
|
||||
## API Server Deployment
|
||||
|
||||
### FastAPI server
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import io
|
||||
import base64
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Load model at startup
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
duration: float = 10.0
|
||||
temperature: float = 1.0
|
||||
cfg_coef: float = 3.0
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
audio_base64: str
|
||||
sample_rate: int
|
||||
duration: float
|
||||
|
||||
@app.post("/generate", response_model=GenerateResponse)
|
||||
async def generate(request: GenerateRequest):
|
||||
if model is None:
|
||||
raise HTTPException(status_code=500, detail="Model not loaded")
|
||||
|
||||
try:
|
||||
model.set_generation_params(
|
||||
duration=min(request.duration, 30),
|
||||
temperature=request.temperature,
|
||||
cfg_coef=request.cfg_coef
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([request.prompt])
|
||||
|
||||
# Convert to bytes
|
||||
buffer = io.BytesIO()
|
||||
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
|
||||
buffer.seek(0)
|
||||
|
||||
audio_base64 = base64.b64encode(buffer.read()).decode()
|
||||
|
||||
return GenerateResponse(
|
||||
audio_base64=audio_base64,
|
||||
sample_rate=32000,
|
||||
duration=wav.shape[-1] / 32000
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "model_loaded": model is not None}
|
||||
|
||||
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
### Batch processing service
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import torch
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
class MusicGenService:
|
||||
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
|
||||
self.model = MusicGen.get_pretrained(model_name)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def generate_async(self, prompt, duration=10):
|
||||
"""Async generation with thread pool."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _generate():
|
||||
with torch.no_grad():
|
||||
self.model.set_generation_params(duration=duration)
|
||||
return self.model.generate([prompt])
|
||||
|
||||
# Run in thread pool
|
||||
wav = await loop.run_in_executor(self.executor, _generate)
|
||||
return wav[0].cpu()
|
||||
|
||||
async def generate_batch_async(self, prompts, duration=10):
|
||||
"""Process multiple prompts concurrently."""
|
||||
tasks = [self.generate_async(p, duration) for p in prompts]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
# Usage
|
||||
service = MusicGenService()
|
||||
|
||||
async def main():
|
||||
prompts = ["jazz piano", "rock guitar", "electronic beats"]
|
||||
results = await service.generate_batch_async(prompts)
|
||||
return results
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### LangChain tool
|
||||
|
||||
```python
|
||||
from langchain.tools import BaseTool
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
import tempfile
|
||||
|
||||
class MusicGeneratorTool(BaseTool):
|
||||
name = "music_generator"
|
||||
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
self.model.set_generation_params(duration=15)
|
||||
|
||||
def _run(self, description: str) -> str:
|
||||
with torch.no_grad():
|
||||
wav = self.model.generate([description])
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
||||
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
|
||||
return f"Generated music saved to: {f.name}"
|
||||
|
||||
async def _arun(self, description: str) -> str:
|
||||
return self._run(description)
|
||||
```
|
||||
|
||||
### Gradio with advanced controls
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
import torch
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
models = {}
|
||||
|
||||
def load_model(model_size):
|
||||
if model_size not in models:
|
||||
model_name = f"facebook/musicgen-{model_size}"
|
||||
models[model_size] = MusicGen.get_pretrained(model_name)
|
||||
return models[model_size]
|
||||
|
||||
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
|
||||
model = load_model(model_size)
|
||||
|
||||
model.set_generation_params(
|
||||
duration=duration,
|
||||
temperature=temperature,
|
||||
cfg_coef=cfg_coef,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
|
||||
# Save
|
||||
path = "output.wav"
|
||||
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
|
||||
return path
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs=[
|
||||
gr.Textbox(label="Prompt", lines=3),
|
||||
gr.Slider(1, 30, value=10, label="Duration (s)"),
|
||||
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
|
||||
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
|
||||
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
|
||||
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
|
||||
],
|
||||
outputs=gr.Audio(label="Generated Music"),
|
||||
title="MusicGen Advanced",
|
||||
allow_flagging="never"
|
||||
)
|
||||
|
||||
demo.launch(share=True)
|
||||
```
|
||||
|
||||
## Audio Processing Pipeline
|
||||
|
||||
### Post-processing chain
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.transforms as T
|
||||
import numpy as np
|
||||
|
||||
class AudioPostProcessor:
|
||||
def __init__(self, sample_rate=32000):
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
def normalize(self, audio, target_db=-14.0):
|
||||
"""Normalize audio to target loudness."""
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
def fade_in_out(self, audio, fade_duration=0.1):
|
||||
"""Apply fade in/out."""
|
||||
fade_samples = int(fade_duration * self.sample_rate)
|
||||
|
||||
# Create fade curves
|
||||
fade_in = torch.linspace(0, 1, fade_samples)
|
||||
fade_out = torch.linspace(1, 0, fade_samples)
|
||||
|
||||
# Apply fades
|
||||
audio[..., :fade_samples] *= fade_in
|
||||
audio[..., -fade_samples:] *= fade_out
|
||||
|
||||
return audio
|
||||
|
||||
def apply_reverb(self, audio, decay=0.5):
|
||||
"""Apply simple reverb effect."""
|
||||
impulse = torch.zeros(int(self.sample_rate * 0.5))
|
||||
impulse[0] = 1.0
|
||||
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
|
||||
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
|
||||
|
||||
# Convolve
|
||||
audio = torch.nn.functional.conv1d(
|
||||
audio.unsqueeze(0),
|
||||
impulse.unsqueeze(0).unsqueeze(0),
|
||||
padding=len(impulse) // 2
|
||||
).squeeze(0)
|
||||
|
||||
return audio
|
||||
|
||||
def process(self, audio):
|
||||
"""Full processing pipeline."""
|
||||
audio = self.normalize(audio)
|
||||
audio = self.fade_in_out(audio)
|
||||
return audio
|
||||
|
||||
# Usage with MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
wav = model.generate(["chill ambient music"])
|
||||
processor = AudioPostProcessor()
|
||||
wav_processed = processor.process(wav[0].cpu())
|
||||
|
||||
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Audio quality metrics
|
||||
|
||||
```python
|
||||
import torch
|
||||
from audiocraft.metrics import CLAPTextConsistencyMetric
|
||||
from audiocraft.data.audio import audio_read
|
||||
|
||||
def evaluate_generation(audio_path, text_prompt):
|
||||
"""Evaluate generated audio quality."""
|
||||
# Load audio
|
||||
wav, sr = audio_read(audio_path)
|
||||
|
||||
# CLAP consistency (text-audio alignment)
|
||||
clap_metric = CLAPTextConsistencyMetric()
|
||||
clap_score = clap_metric.compute(wav, [text_prompt])
|
||||
|
||||
return {
|
||||
"clap_score": clap_score,
|
||||
"duration": wav.shape[-1] / sr
|
||||
}
|
||||
|
||||
# Batch evaluation
|
||||
def evaluate_batch(generations):
|
||||
"""Evaluate multiple generations."""
|
||||
results = []
|
||||
for gen in generations:
|
||||
result = evaluate_generation(gen["path"], gen["prompt"])
|
||||
result["prompt"] = gen["prompt"]
|
||||
results.append(result)
|
||||
|
||||
# Aggregate
|
||||
avg_clap = sum(r["clap_score"] for r in results) / len(results)
|
||||
return {
|
||||
"individual": results,
|
||||
"average_clap": avg_clap
|
||||
}
|
||||
```
|
||||
|
||||
## Model Comparison
|
||||
|
||||
### MusicGen variants benchmark
|
||||
|
||||
| Model | CLAP Score | Generation Time (10s) | VRAM |
|
||||
|-------|------------|----------------------|------|
|
||||
| musicgen-small | 0.35 | ~5s | 2GB |
|
||||
| musicgen-medium | 0.42 | ~15s | 4GB |
|
||||
| musicgen-large | 0.48 | ~30s | 8GB |
|
||||
| musicgen-melody | 0.45 | ~15s | 4GB |
|
||||
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
|
||||
|
||||
### Prompt engineering tips
|
||||
|
||||
```python
|
||||
# Good prompts - specific and descriptive
|
||||
good_prompts = [
|
||||
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
|
||||
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
|
||||
"funky disco groove with slap bass, brass section, and rhythmic guitar"
|
||||
]
|
||||
|
||||
# Bad prompts - too vague
|
||||
bad_prompts = [
|
||||
"nice music",
|
||||
"song",
|
||||
"good beat"
|
||||
]
|
||||
|
||||
# Structure: [mood] [genre] with [instruments] at [tempo/style]
|
||||
```
|
||||
504
skills/mlops/audiocraft/references/troubleshooting.md
Normal file
504
skills/mlops/audiocraft/references/troubleshooting.md
Normal file
@@ -0,0 +1,504 @@
|
||||
# AudioCraft Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Import errors
|
||||
|
||||
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install from PyPI
|
||||
pip install audiocraft
|
||||
|
||||
# Or from GitHub
|
||||
pip install git+https://github.com/facebookresearch/audiocraft.git
|
||||
|
||||
# Verify installation
|
||||
python -c "from audiocraft.models import MusicGen; print('OK')"
|
||||
```
|
||||
|
||||
### FFmpeg not found
|
||||
|
||||
**Error**: `RuntimeError: ffmpeg not found`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Ubuntu/Debian
|
||||
sudo apt-get install ffmpeg
|
||||
|
||||
# macOS
|
||||
brew install ffmpeg
|
||||
|
||||
# Windows (using conda)
|
||||
conda install -c conda-forge ffmpeg
|
||||
|
||||
# Verify
|
||||
ffmpeg -version
|
||||
```
|
||||
|
||||
### PyTorch CUDA mismatch
|
||||
|
||||
**Error**: `RuntimeError: CUDA error: no kernel image is available`
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Check CUDA version
|
||||
nvcc --version
|
||||
python -c "import torch; print(torch.version.cuda)"
|
||||
|
||||
# Install matching PyTorch
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
# For CUDA 11.8
|
||||
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
### xformers issues
|
||||
|
||||
**Error**: `ImportError: xformers` related errors
|
||||
|
||||
**Solutions**:
|
||||
```bash
|
||||
# Install xformers for memory efficiency
|
||||
pip install xformers
|
||||
|
||||
# Or disable xformers
|
||||
export AUDIOCRAFT_USE_XFORMERS=0
|
||||
|
||||
# In Python
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
|
||||
from audiocraft.models import MusicGen
|
||||
```
|
||||
|
||||
## Model Loading Issues
|
||||
|
||||
### Out of memory during load
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during model loading
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Force CPU loading first
|
||||
import torch
|
||||
device = "cpu"
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
|
||||
model = model.to("cuda")
|
||||
|
||||
# Use HuggingFace with device_map
|
||||
from transformers import MusicgenForConditionalGeneration
|
||||
model = MusicgenForConditionalGeneration.from_pretrained(
|
||||
"facebook/musicgen-small",
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
### Download failures
|
||||
|
||||
**Error**: Connection errors or incomplete downloads
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Set cache directory
|
||||
import os
|
||||
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
|
||||
|
||||
# Or for HuggingFace
|
||||
os.environ["HF_HOME"] = "/path/to/hf_cache"
|
||||
|
||||
# Resume download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("facebook/musicgen-small", resume_download=True)
|
||||
|
||||
# Use local files
|
||||
model = MusicGen.get_pretrained('/local/path/to/model')
|
||||
```
|
||||
|
||||
### Wrong model type
|
||||
|
||||
**Error**: Loading wrong model for task
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# For text-to-music: use MusicGen
|
||||
from audiocraft.models import MusicGen
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-medium')
|
||||
|
||||
# For text-to-sound: use AudioGen
|
||||
from audiocraft.models import AudioGen
|
||||
model = AudioGen.get_pretrained('facebook/audiogen-medium')
|
||||
|
||||
# For melody conditioning: use melody variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# For stereo: use stereo variant
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
## Generation Issues
|
||||
|
||||
### Empty or silent output
|
||||
|
||||
**Problem**: Generated audio is silent or very quiet
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check output
|
||||
wav = model.generate(["upbeat music"])
|
||||
print(f"Shape: {wav.shape}")
|
||||
print(f"Max amplitude: {wav.abs().max().item()}")
|
||||
print(f"Mean amplitude: {wav.abs().mean().item()}")
|
||||
|
||||
# If too quiet, normalize
|
||||
def normalize_audio(audio, target_db=-14.0):
|
||||
rms = torch.sqrt(torch.mean(audio ** 2))
|
||||
target_rms = 10 ** (target_db / 20)
|
||||
gain = target_rms / (rms + 1e-8)
|
||||
return audio * gain
|
||||
|
||||
wav_normalized = normalize_audio(wav)
|
||||
```
|
||||
|
||||
### Poor quality output
|
||||
|
||||
**Problem**: Generated music sounds bad or noisy
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use larger model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-large')
|
||||
|
||||
# Adjust generation parameters
|
||||
model.set_generation_params(
|
||||
duration=15,
|
||||
top_k=250, # Increase for more diversity
|
||||
temperature=0.8, # Lower for more focused output
|
||||
cfg_coef=4.0 # Increase for better text adherence
|
||||
)
|
||||
|
||||
# Use better prompts
|
||||
# Bad: "music"
|
||||
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
|
||||
|
||||
# Try MultiBand Diffusion
|
||||
from audiocraft.models import MultiBandDiffusion
|
||||
mbd = MultiBandDiffusion.get_mbd_musicgen()
|
||||
tokens = model.generate_tokens(["prompt"])
|
||||
wav = mbd.tokens_to_wav(tokens)
|
||||
```
|
||||
|
||||
### Generation too short
|
||||
|
||||
**Problem**: Audio shorter than expected
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check duration setting
|
||||
model.set_generation_params(duration=30) # Set before generate
|
||||
|
||||
# Verify in generation
|
||||
print(f"Duration setting: {model.generation_params}")
|
||||
|
||||
# Check output shape
|
||||
wav = model.generate(["prompt"])
|
||||
actual_duration = wav.shape[-1] / 32000
|
||||
print(f"Actual duration: {actual_duration}s")
|
||||
|
||||
# Note: max duration is typically 30s
|
||||
```
|
||||
|
||||
### Melody conditioning fails
|
||||
|
||||
**Error**: Issues with melody-conditioned generation
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
from audiocraft.models import MusicGen
|
||||
|
||||
# Load melody model (not base model)
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-melody')
|
||||
|
||||
# Load and prepare melody
|
||||
melody, sr = torchaudio.load("melody.wav")
|
||||
|
||||
# Resample to model sample rate if needed
|
||||
if sr != 32000:
|
||||
resampler = torchaudio.transforms.Resample(sr, 32000)
|
||||
melody = resampler(melody)
|
||||
|
||||
# Ensure correct shape [batch, channels, samples]
|
||||
if melody.dim() == 1:
|
||||
melody = melody.unsqueeze(0).unsqueeze(0)
|
||||
elif melody.dim() == 2:
|
||||
melody = melody.unsqueeze(0)
|
||||
|
||||
# Convert stereo to mono
|
||||
if melody.shape[1] > 1:
|
||||
melody = melody.mean(dim=1, keepdim=True)
|
||||
|
||||
# Generate with melody
|
||||
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
|
||||
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
|
||||
```
|
||||
|
||||
## Memory Issues
|
||||
|
||||
### CUDA out of memory
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Clear cache before generation
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10) # Instead of 30
|
||||
|
||||
# Generate one at a time
|
||||
for prompt in prompts:
|
||||
wav = model.generate([prompt])
|
||||
save_audio(wav)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Use CPU for very large generations
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
|
||||
```
|
||||
|
||||
### Memory leak during batch processing
|
||||
|
||||
**Problem**: Memory grows over time
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import gc
|
||||
import torch
|
||||
|
||||
def generate_with_cleanup(model, prompts):
|
||||
results = []
|
||||
|
||||
for prompt in prompts:
|
||||
with torch.no_grad():
|
||||
wav = model.generate([prompt])
|
||||
results.append(wav.cpu())
|
||||
|
||||
# Cleanup
|
||||
del wav
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return results
|
||||
|
||||
# Use context manager
|
||||
with torch.inference_mode():
|
||||
wav = model.generate(["prompt"])
|
||||
```
|
||||
|
||||
## Audio Format Issues
|
||||
|
||||
### Wrong sample rate
|
||||
|
||||
**Problem**: Audio plays at wrong speed
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torchaudio
|
||||
|
||||
# MusicGen outputs at 32kHz
|
||||
sample_rate = 32000
|
||||
|
||||
# AudioGen outputs at 16kHz
|
||||
sample_rate = 16000
|
||||
|
||||
# Always use correct rate when saving
|
||||
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
|
||||
|
||||
# Resample if needed
|
||||
resampler = torchaudio.transforms.Resample(32000, 44100)
|
||||
wav_resampled = resampler(wav)
|
||||
```
|
||||
|
||||
### Stereo/mono mismatch
|
||||
|
||||
**Problem**: Wrong number of channels
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Check model type
|
||||
print(f"Audio channels: {wav.shape}")
|
||||
# Mono: [batch, 1, samples]
|
||||
# Stereo: [batch, 2, samples]
|
||||
|
||||
# Convert mono to stereo
|
||||
if wav.shape[1] == 1:
|
||||
wav_stereo = wav.repeat(1, 2, 1)
|
||||
|
||||
# Convert stereo to mono
|
||||
if wav.shape[1] == 2:
|
||||
wav_mono = wav.mean(dim=1, keepdim=True)
|
||||
|
||||
# Use stereo model for stereo output
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
|
||||
```
|
||||
|
||||
### Clipping and distortion
|
||||
|
||||
**Problem**: Audio has clipping or distortion
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check for clipping
|
||||
max_val = wav.abs().max().item()
|
||||
print(f"Max amplitude: {max_val}")
|
||||
|
||||
# Normalize to prevent clipping
|
||||
if max_val > 1.0:
|
||||
wav = wav / max_val
|
||||
|
||||
# Apply soft clipping
|
||||
def soft_clip(x, threshold=0.9):
|
||||
return torch.tanh(x / threshold) * threshold
|
||||
|
||||
wav_clipped = soft_clip(wav)
|
||||
|
||||
# Lower temperature during generation
|
||||
model.set_generation_params(temperature=0.7) # More controlled
|
||||
```
|
||||
|
||||
## HuggingFace Transformers Issues
|
||||
|
||||
### Processor errors
|
||||
|
||||
**Error**: Issues with MusicgenProcessor
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
from transformers import AutoProcessor, MusicgenForConditionalGeneration
|
||||
|
||||
# Load matching processor and model
|
||||
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
||||
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
||||
|
||||
# Ensure inputs are on same device
|
||||
inputs = processor(
|
||||
text=["prompt"],
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
).to("cuda")
|
||||
|
||||
# Check processor configuration
|
||||
print(processor.tokenizer)
|
||||
print(processor.feature_extractor)
|
||||
```
|
||||
|
||||
### Generation parameter errors
|
||||
|
||||
**Error**: Invalid generation parameters
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# HuggingFace uses different parameter names
|
||||
audio_values = model.generate(
|
||||
**inputs,
|
||||
do_sample=True, # Enable sampling
|
||||
guidance_scale=3.0, # CFG (not cfg_coef)
|
||||
max_new_tokens=256, # Token limit (not duration)
|
||||
temperature=1.0
|
||||
)
|
||||
|
||||
# Calculate tokens from duration
|
||||
# ~50 tokens per second
|
||||
duration_seconds = 10
|
||||
max_tokens = duration_seconds * 50
|
||||
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
|
||||
```
|
||||
|
||||
## Performance Issues
|
||||
|
||||
### Slow generation
|
||||
|
||||
**Problem**: Generation takes too long
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
# Use smaller model
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
|
||||
# Reduce duration
|
||||
model.set_generation_params(duration=10)
|
||||
|
||||
# Use GPU
|
||||
model.to("cuda")
|
||||
|
||||
# Enable flash attention if available
|
||||
# (requires compatible hardware)
|
||||
|
||||
# Batch multiple prompts
|
||||
prompts = ["prompt1", "prompt2", "prompt3"]
|
||||
wav = model.generate(prompts) # Single batch is faster than loop
|
||||
|
||||
# Use compile (PyTorch 2.0+)
|
||||
model.lm = torch.compile(model.lm)
|
||||
```
|
||||
|
||||
### CPU fallback
|
||||
|
||||
**Problem**: Generation running on CPU instead of GPU
|
||||
|
||||
**Solutions**:
|
||||
```python
|
||||
import torch
|
||||
|
||||
# Check CUDA availability
|
||||
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# Explicitly move to GPU
|
||||
model = MusicGen.get_pretrained('facebook/musicgen-small')
|
||||
model.to("cuda")
|
||||
|
||||
# Verify model device
|
||||
print(f"Model device: {next(model.lm.parameters()).device}")
|
||||
```
|
||||
|
||||
## Common Error Messages
|
||||
|
||||
| Error | Cause | Solution |
|
||||
|-------|-------|----------|
|
||||
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
|
||||
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
|
||||
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
|
||||
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
|
||||
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
|
||||
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
|
||||
2. **HuggingFace Forums**: https://discuss.huggingface.co
|
||||
3. **Paper**: https://arxiv.org/abs/2306.05284
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- Python version
|
||||
- PyTorch version
|
||||
- CUDA version
|
||||
- AudioCraft version: `pip show audiocraft`
|
||||
- Full error traceback
|
||||
- Minimal reproducible code
|
||||
- Hardware (GPU model, VRAM)
|
||||
81
skills/mlops/code-review/SKILL.md
Normal file
81
skills/mlops/code-review/SKILL.md
Normal file
@@ -0,0 +1,81 @@
|
||||
---
|
||||
name: code-review
|
||||
description: Guidelines for performing thorough code reviews with security and quality focus
|
||||
---
|
||||
|
||||
# Code Review Skill
|
||||
|
||||
Use this skill when reviewing code changes, pull requests, or auditing existing code.
|
||||
|
||||
## Review Checklist
|
||||
|
||||
### 1. Security First
|
||||
- [ ] No hardcoded secrets, API keys, or credentials
|
||||
- [ ] Input validation on all user-provided data
|
||||
- [ ] SQL queries use parameterized statements (no string concatenation)
|
||||
- [ ] File operations validate paths (no path traversal)
|
||||
- [ ] Authentication/authorization checks present where needed
|
||||
|
||||
### 2. Error Handling
|
||||
- [ ] All external calls (API, DB, file) have try/catch
|
||||
- [ ] Errors are logged with context (but no sensitive data)
|
||||
- [ ] User-facing errors are helpful but don't leak internals
|
||||
- [ ] Resources are cleaned up in finally blocks or context managers
|
||||
|
||||
### 3. Code Quality
|
||||
- [ ] Functions do one thing and are reasonably sized (<50 lines ideal)
|
||||
- [ ] Variable names are descriptive (no single letters except loops)
|
||||
- [ ] No commented-out code left behind
|
||||
- [ ] Complex logic has explanatory comments
|
||||
- [ ] No duplicate code (DRY principle)
|
||||
|
||||
### 4. Testing Considerations
|
||||
- [ ] Edge cases handled (empty inputs, nulls, boundaries)
|
||||
- [ ] Happy path and error paths both work
|
||||
- [ ] New code has corresponding tests (if test suite exists)
|
||||
|
||||
## Review Response Format
|
||||
|
||||
When providing review feedback, structure it as:
|
||||
|
||||
```
|
||||
## Summary
|
||||
[1-2 sentence overall assessment]
|
||||
|
||||
## Critical Issues (Must Fix)
|
||||
- Issue 1: [description + suggested fix]
|
||||
- Issue 2: ...
|
||||
|
||||
## Suggestions (Nice to Have)
|
||||
- Suggestion 1: [description]
|
||||
|
||||
## Questions
|
||||
- [Any clarifying questions about intent]
|
||||
```
|
||||
|
||||
## Common Patterns to Flag
|
||||
|
||||
### Python
|
||||
```python
|
||||
# Bad: SQL injection risk
|
||||
cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")
|
||||
|
||||
# Good: Parameterized query
|
||||
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
```
|
||||
|
||||
### JavaScript
|
||||
```javascript
|
||||
// Bad: XSS risk
|
||||
element.innerHTML = userInput;
|
||||
|
||||
// Good: Safe text content
|
||||
element.textContent = userInput;
|
||||
```
|
||||
|
||||
## Tone Guidelines
|
||||
|
||||
- Be constructive, not critical
|
||||
- Explain *why* something is an issue, not just *what*
|
||||
- Offer solutions, not just problems
|
||||
- Acknowledge good patterns you see
|
||||
224
skills/mlops/faiss/SKILL.md
Normal file
224
skills/mlops/faiss/SKILL.md
Normal file
@@ -0,0 +1,224 @@
|
||||
---
|
||||
name: faiss
|
||||
description: Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or when you need pure similarity search without metadata. Best for high-performance applications.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [faiss-cpu, faiss-gpu, numpy]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [RAG, FAISS, Similarity Search, Vector Search, Facebook AI, GPU Acceleration, Billion-Scale, K-NN, HNSW, High Performance, Large Scale]
|
||||
|
||||
---
|
||||
|
||||
# FAISS - Efficient Similarity Search
|
||||
|
||||
Facebook AI's library for billion-scale vector similarity search.
|
||||
|
||||
## When to use FAISS
|
||||
|
||||
**Use FAISS when:**
|
||||
- Need fast similarity search on large vector datasets (millions/billions)
|
||||
- GPU acceleration required
|
||||
- Pure vector similarity (no metadata filtering needed)
|
||||
- High throughput, low latency critical
|
||||
- Offline/batch processing of embeddings
|
||||
|
||||
**Metrics**:
|
||||
- **31,700+ GitHub stars**
|
||||
- Meta/Facebook AI Research
|
||||
- **Handles billions of vectors**
|
||||
- **C++** with Python bindings
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **Chroma/Pinecone**: Need metadata filtering
|
||||
- **Weaviate**: Need full database features
|
||||
- **Annoy**: Simpler, fewer features
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# CPU only
|
||||
pip install faiss-cpu
|
||||
|
||||
# GPU support
|
||||
pip install faiss-gpu
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
# Create sample data (1000 vectors, 128 dimensions)
|
||||
d = 128
|
||||
nb = 1000
|
||||
vectors = np.random.random((nb, d)).astype('float32')
|
||||
|
||||
# Create index
|
||||
index = faiss.IndexFlatL2(d) # L2 distance
|
||||
index.add(vectors) # Add vectors
|
||||
|
||||
# Search
|
||||
k = 5 # Find 5 nearest neighbors
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
|
||||
print(f"Nearest neighbors: {indices}")
|
||||
print(f"Distances: {distances}")
|
||||
```
|
||||
|
||||
## Index types
|
||||
|
||||
### 1. Flat (exact search)
|
||||
|
||||
```python
|
||||
# L2 (Euclidean) distance
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Inner product (cosine similarity if normalized)
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Slowest, most accurate
|
||||
```
|
||||
|
||||
### 2. IVF (inverted file) - Fast approximate
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# IVF index with 100 clusters
|
||||
nlist = 100
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 3. HNSW (Hierarchical NSW) - Best quality/speed
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# No training needed
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
### 4. Product Quantization - Memory efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers
|
||||
nbits = 8
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train and add
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
## Save and load
|
||||
|
||||
```python
|
||||
# Save index
|
||||
faiss.write_index(index, "large.index")
|
||||
|
||||
# Load index
|
||||
index = faiss.read_index("large.index")
|
||||
|
||||
# Continue using
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
## GPU acceleration
|
||||
|
||||
```python
|
||||
# Single GPU
|
||||
res = faiss.StandardGpuResources()
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Multi-GPU
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# 10-100× faster than CPU
|
||||
```
|
||||
|
||||
## LangChain integration
|
||||
|
||||
```python
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
# Create FAISS vector store
|
||||
vectorstore = FAISS.from_documents(docs, OpenAIEmbeddings())
|
||||
|
||||
# Save
|
||||
vectorstore.save_local("faiss_index")
|
||||
|
||||
# Load
|
||||
vectorstore = FAISS.load_local(
|
||||
"faiss_index",
|
||||
OpenAIEmbeddings(),
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
|
||||
# Search
|
||||
results = vectorstore.similarity_search("query", k=5)
|
||||
```
|
||||
|
||||
## LlamaIndex integration
|
||||
|
||||
```python
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
import faiss
|
||||
|
||||
# Create FAISS index
|
||||
d = 1536
|
||||
faiss_index = faiss.IndexFlatL2(d)
|
||||
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Choose right index type** - Flat for <10K, IVF for 10K-1M, HNSW for quality
|
||||
2. **Normalize for cosine** - Use IndexFlatIP with normalized vectors
|
||||
3. **Use GPU for large datasets** - 10-100× faster
|
||||
4. **Save trained indices** - Training is expensive
|
||||
5. **Tune nprobe/ef_search** - Balance speed/accuracy
|
||||
6. **Monitor memory** - PQ for large datasets
|
||||
7. **Batch queries** - Better GPU utilization
|
||||
|
||||
## Performance
|
||||
|
||||
| Index Type | Build Time | Search Time | Memory | Accuracy |
|
||||
|------------|------------|-------------|--------|----------|
|
||||
| Flat | Fast | Slow | High | 100% |
|
||||
| IVF | Medium | Fast | Medium | 95-99% |
|
||||
| HNSW | Slow | Fastest | High | 99% |
|
||||
| PQ | Medium | Fast | Low | 90-95% |
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/facebookresearch/faiss ⭐ 31,700+
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **License**: MIT
|
||||
|
||||
|
||||
280
skills/mlops/faiss/references/index_types.md
Normal file
280
skills/mlops/faiss/references/index_types.md
Normal file
@@ -0,0 +1,280 @@
|
||||
# FAISS Index Types Guide
|
||||
|
||||
Complete guide to choosing and using FAISS index types.
|
||||
|
||||
## Index selection guide
|
||||
|
||||
| Dataset Size | Index Type | Training | Accuracy | Speed |
|
||||
|--------------|------------|----------|----------|-------|
|
||||
| < 10K | Flat | No | 100% | Slow |
|
||||
| 10K-1M | IVF | Yes | 95-99% | Fast |
|
||||
| 1M-10M | HNSW | No | 99% | Fastest |
|
||||
| > 10M | IVF+PQ | Yes | 90-95% | Fast, low memory |
|
||||
|
||||
## Flat indices (exact search)
|
||||
|
||||
### IndexFlatL2 - L2 (Euclidean) distance
|
||||
|
||||
```python
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
d = 128 # Dimension
|
||||
index = faiss.IndexFlatL2(d)
|
||||
|
||||
# Add vectors
|
||||
vectors = np.random.random((1000, d)).astype('float32')
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
k = 5
|
||||
query = np.random.random((1, d)).astype('float32')
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset < 10,000 vectors
|
||||
- Need 100% accuracy
|
||||
- Serving as baseline
|
||||
|
||||
### IndexFlatIP - Inner product (cosine similarity)
|
||||
|
||||
```python
|
||||
# For cosine similarity, normalize vectors first
|
||||
import faiss
|
||||
|
||||
d = 128
|
||||
index = faiss.IndexFlatIP(d)
|
||||
|
||||
# Normalize vectors (required for cosine similarity)
|
||||
faiss.normalize_L2(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
faiss.normalize_L2(query)
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Need cosine similarity
|
||||
- Recommendation systems
|
||||
- Text embeddings
|
||||
|
||||
## IVF indices (inverted file)
|
||||
|
||||
### IndexIVFFlat - Cluster-based search
|
||||
|
||||
```python
|
||||
# Create quantizer
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
|
||||
# Create IVF index with 100 clusters
|
||||
nlist = 100 # Number of clusters
|
||||
index = faiss.IndexIVFFlat(quantizer, d, nlist)
|
||||
|
||||
# Train on data (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search (nprobe = clusters to search)
|
||||
index.nprobe = 10 # Search 10 closest clusters
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `nlist`: Number of clusters (√N to 4√N recommended)
|
||||
- `nprobe`: Clusters to search (1-nlist, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Dataset 10K-1M vectors
|
||||
- Need fast approximate search
|
||||
- Can afford training time
|
||||
|
||||
### Tuning nprobe
|
||||
|
||||
```python
|
||||
# Test different nprobe values
|
||||
for nprobe in [1, 5, 10, 20, 50]:
|
||||
index.nprobe = nprobe
|
||||
distances, indices = index.search(query, k)
|
||||
# Measure recall/speed trade-off
|
||||
```
|
||||
|
||||
**Guidelines:**
|
||||
- `nprobe=1`: Fastest, ~50% recall
|
||||
- `nprobe=10`: Good balance, ~95% recall
|
||||
- `nprobe=nlist`: Exact search (same as Flat)
|
||||
|
||||
## HNSW indices (graph-based)
|
||||
|
||||
### IndexHNSWFlat - Hierarchical NSW
|
||||
|
||||
```python
|
||||
# HNSW index
|
||||
M = 32 # Number of connections per layer (16-64)
|
||||
index = faiss.IndexHNSWFlat(d, M)
|
||||
|
||||
# Optional: Set ef_construction (build time parameter)
|
||||
index.hnsw.efConstruction = 40 # Higher = better quality, slower build
|
||||
|
||||
# Add vectors (no training needed!)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.hnsw.efSearch = 16 # Search time parameter
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `M`: Connections per layer (16-64, default 32)
|
||||
- `efConstruction`: Build quality (40-200, higher = better)
|
||||
- `efSearch`: Search quality (16-512, higher = more accurate)
|
||||
|
||||
**Use when:**
|
||||
- Need best quality approximate search
|
||||
- Can afford higher memory (more connections)
|
||||
- Dataset 1M-10M vectors
|
||||
|
||||
## PQ indices (product quantization)
|
||||
|
||||
### IndexPQ - Memory-efficient
|
||||
|
||||
```python
|
||||
# PQ reduces memory by 16-32×
|
||||
m = 8 # Number of subquantizers (divides d)
|
||||
nbits = 8 # Bits per subquantizer
|
||||
|
||||
index = faiss.IndexPQ(d, m, nbits)
|
||||
|
||||
# Train (required!)
|
||||
index.train(vectors)
|
||||
|
||||
# Add vectors
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
- `m`: Subquantizers (d must be divisible by m)
|
||||
- `nbits`: Bits per code (8 or 16)
|
||||
|
||||
**Memory savings:**
|
||||
- Original: d × 4 bytes (float32)
|
||||
- PQ: m bytes
|
||||
- Compression ratio: 4d/m
|
||||
|
||||
**Use when:**
|
||||
- Limited memory
|
||||
- Large datasets (> 10M vectors)
|
||||
- Can accept ~90-95% accuracy
|
||||
|
||||
### IndexIVFPQ - IVF + PQ combined
|
||||
|
||||
```python
|
||||
# Best for very large datasets
|
||||
nlist = 4096
|
||||
m = 8
|
||||
nbits = 8
|
||||
|
||||
quantizer = faiss.IndexFlatL2(d)
|
||||
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)
|
||||
|
||||
# Train
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
|
||||
# Search
|
||||
index.nprobe = 32
|
||||
distances, indices = index.search(query, k)
|
||||
```
|
||||
|
||||
**Use when:**
|
||||
- Dataset > 10M vectors
|
||||
- Need fast search + low memory
|
||||
- Can accept 90-95% accuracy
|
||||
|
||||
## GPU indices
|
||||
|
||||
### Single GPU
|
||||
|
||||
```python
|
||||
import faiss
|
||||
|
||||
# Create CPU index
|
||||
index_cpu = faiss.IndexFlatL2(d)
|
||||
|
||||
# Move to GPU
|
||||
res = faiss.StandardGpuResources() # GPU resources
|
||||
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
|
||||
|
||||
# Use normally
|
||||
index_gpu.add(vectors)
|
||||
distances, indices = index_gpu.search(query, k)
|
||||
```
|
||||
|
||||
### Multi-GPU
|
||||
|
||||
```python
|
||||
# Use all available GPUs
|
||||
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
|
||||
|
||||
# Or specific GPUs
|
||||
gpus = [0, 1, 2, 3] # Use GPUs 0-3
|
||||
index_gpu = faiss.index_cpu_to_gpus_list(index_cpu, gpus)
|
||||
```
|
||||
|
||||
**Speedup:**
|
||||
- Single GPU: 10-50× faster than CPU
|
||||
- Multi-GPU: Near-linear scaling
|
||||
|
||||
## Index factory
|
||||
|
||||
```python
|
||||
# Easy index creation with string descriptors
|
||||
index = faiss.index_factory(d, "IVF100,Flat")
|
||||
index = faiss.index_factory(d, "HNSW32")
|
||||
index = faiss.index_factory(d, "IVF4096,PQ8")
|
||||
|
||||
# Train and use
|
||||
index.train(vectors)
|
||||
index.add(vectors)
|
||||
```
|
||||
|
||||
**Common descriptors:**
|
||||
- `"Flat"`: Exact search
|
||||
- `"IVF100,Flat"`: IVF with 100 clusters
|
||||
- `"HNSW32"`: HNSW with M=32
|
||||
- `"IVF4096,PQ8"`: IVF + PQ compression
|
||||
|
||||
## Performance comparison
|
||||
|
||||
### Search speed (1M vectors, k=10)
|
||||
|
||||
| Index | Build Time | Search Time | Memory | Recall |
|
||||
|-------|------------|-------------|--------|--------|
|
||||
| Flat | 0s | 50ms | 512 MB | 100% |
|
||||
| IVF100 | 5s | 2ms | 512 MB | 95% |
|
||||
| HNSW32 | 60s | 1ms | 1GB | 99% |
|
||||
| IVF4096+PQ8 | 30s | 3ms | 32 MB | 90% |
|
||||
|
||||
*CPU (16 cores), 128-dim vectors*
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with Flat** - Baseline for comparison
|
||||
2. **Use IVF for medium datasets** - Good balance
|
||||
3. **Use HNSW for best quality** - If memory allows
|
||||
4. **Add PQ for memory savings** - Large datasets
|
||||
5. **GPU for > 100K vectors** - 10-50× speedup
|
||||
6. **Tune nprobe/efSearch** - Trade-off speed/accuracy
|
||||
7. **Train on representative data** - Better clustering
|
||||
8. **Save trained indices** - Avoid retraining
|
||||
|
||||
## Resources
|
||||
|
||||
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
|
||||
- **Paper**: https://arxiv.org/abs/1702.08734
|
||||
370
skills/mlops/flash-attention/SKILL.md
Normal file
370
skills/mlops/flash-attention/SKILL.md
Normal file
@@ -0,0 +1,370 @@
|
||||
---
|
||||
name: optimizing-attention-flash
|
||||
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [flash-attn, torch, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
|
||||
|
||||
---
|
||||
|
||||
# Flash Attention - Fast Memory-Efficient Attention
|
||||
|
||||
## Quick start
|
||||
|
||||
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
|
||||
|
||||
**PyTorch native (easiest, PyTorch 2.2+)**:
|
||||
```python
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
|
||||
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Automatically uses Flash Attention if available
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**flash-attn library (more features)**:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q, k, v: [batch, seqlen, nheads, headdim]
|
||||
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
|
||||
```
|
||||
|
||||
## Common workflows
|
||||
|
||||
### Workflow 1: Enable in existing PyTorch model
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
Flash Attention Integration:
|
||||
- [ ] Step 1: Check PyTorch version (≥2.2)
|
||||
- [ ] Step 2: Enable Flash Attention backend
|
||||
- [ ] Step 3: Verify speedup with profiling
|
||||
- [ ] Step 4: Test accuracy matches baseline
|
||||
```
|
||||
|
||||
**Step 1: Check PyTorch version**
|
||||
|
||||
```bash
|
||||
python -c "import torch; print(torch.__version__)"
|
||||
# Should be ≥2.2.0
|
||||
```
|
||||
|
||||
If <2.2, upgrade:
|
||||
```bash
|
||||
pip install --upgrade torch
|
||||
```
|
||||
|
||||
**Step 2: Enable Flash Attention backend**
|
||||
|
||||
Replace standard attention:
|
||||
```python
|
||||
# Before (standard attention)
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
|
||||
out = attn_weights @ v
|
||||
|
||||
# After (Flash Attention)
|
||||
import torch.nn.functional as F
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
||||
```
|
||||
|
||||
Force Flash Attention backend:
|
||||
```python
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True,
|
||||
enable_math=False,
|
||||
enable_mem_efficient=False
|
||||
):
|
||||
out = F.scaled_dot_product_attention(q, k, v)
|
||||
```
|
||||
|
||||
**Step 3: Verify speedup with profiling**
|
||||
|
||||
```python
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
def test_attention(use_flash):
|
||||
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
if use_flash:
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=True):
|
||||
return F.scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
|
||||
return attn @ v
|
||||
|
||||
# Benchmark
|
||||
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
|
||||
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
|
||||
|
||||
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
|
||||
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
|
||||
```
|
||||
|
||||
Expected: 2-4x speedup for sequences >512 tokens.
|
||||
|
||||
**Step 4: Test accuracy matches baseline**
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Flash Attention
|
||||
out_flash = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
# Standard attention
|
||||
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
|
||||
out_standard = attn_weights @ v
|
||||
|
||||
# Check difference
|
||||
diff = (out_flash - out_standard).abs().max()
|
||||
print(f"Max difference: {diff:.6f}")
|
||||
# Should be <1e-3 for float16
|
||||
```
|
||||
|
||||
### Workflow 2: Use flash-attn library for advanced features
|
||||
|
||||
For multi-query attention, sliding window, or H100 FP8.
|
||||
|
||||
Copy this checklist:
|
||||
|
||||
```
|
||||
flash-attn Library Setup:
|
||||
- [ ] Step 1: Install flash-attn library
|
||||
- [ ] Step 2: Modify attention code
|
||||
- [ ] Step 3: Enable advanced features
|
||||
- [ ] Step 4: Benchmark performance
|
||||
```
|
||||
|
||||
**Step 1: Install flash-attn library**
|
||||
|
||||
```bash
|
||||
# NVIDIA GPUs (CUDA 12.0+)
|
||||
pip install flash-attn --no-build-isolation
|
||||
|
||||
# Verify installation
|
||||
python -c "from flash_attn import flash_attn_func; print('Success')"
|
||||
```
|
||||
|
||||
**Step 2: Modify attention code**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# Input: [batch_size, seq_len, num_heads, head_dim]
|
||||
# Transpose from [batch, heads, seq, dim] if needed
|
||||
q = q.transpose(1, 2) # [batch, seq, heads, dim]
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
dropout_p=0.1,
|
||||
causal=True, # For autoregressive models
|
||||
window_size=(-1, -1), # No sliding window
|
||||
softmax_scale=None # Auto-scale
|
||||
)
|
||||
|
||||
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
|
||||
```
|
||||
|
||||
**Step 3: Enable advanced features**
|
||||
|
||||
Multi-query attention (shared K/V across heads):
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# q: [batch, seq, num_q_heads, dim]
|
||||
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
|
||||
out = flash_attn_func(q, k, v) # Automatically handles MQA
|
||||
```
|
||||
|
||||
Sliding window attention (local attention):
|
||||
```python
|
||||
# Only attend to window of 256 tokens before/after
|
||||
out = flash_attn_func(
|
||||
q, k, v,
|
||||
window_size=(256, 256), # (left, right) window
|
||||
causal=True
|
||||
)
|
||||
```
|
||||
|
||||
**Step 4: Benchmark performance**
|
||||
|
||||
```python
|
||||
import torch
|
||||
from flash_attn import flash_attn_func
|
||||
import time
|
||||
|
||||
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
_ = flash_attn_func(q, k, v)
|
||||
|
||||
# Benchmark
|
||||
torch.cuda.synchronize()
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
out = flash_attn_func(q, k, v)
|
||||
torch.cuda.synchronize()
|
||||
end = time.time()
|
||||
|
||||
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
|
||||
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
|
||||
```
|
||||
|
||||
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
|
||||
|
||||
For maximum performance on H100 GPUs.
|
||||
|
||||
```
|
||||
FP8 Setup:
|
||||
- [ ] Step 1: Verify H100 GPU available
|
||||
- [ ] Step 2: Install flash-attn with FP8 support
|
||||
- [ ] Step 3: Convert inputs to FP8
|
||||
- [ ] Step 4: Run with FP8 attention
|
||||
```
|
||||
|
||||
**Step 1: Verify H100 GPU**
|
||||
|
||||
```bash
|
||||
nvidia-smi --query-gpu=name --format=csv
|
||||
# Should show "H100" or "H800"
|
||||
```
|
||||
|
||||
**Step 2: Install flash-attn with FP8 support**
|
||||
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
# FP8 support included for H100
|
||||
```
|
||||
|
||||
**Step 3: Convert inputs to FP8**
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
|
||||
|
||||
# Convert to float8_e4m3 (FP8)
|
||||
q_fp8 = q.to(torch.float8_e4m3fn)
|
||||
k_fp8 = k.to(torch.float8_e4m3fn)
|
||||
v_fp8 = v.to(torch.float8_e4m3fn)
|
||||
```
|
||||
|
||||
**Step 4: Run with FP8 attention**
|
||||
|
||||
```python
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
# FlashAttention-3 automatically uses FP8 kernels on H100
|
||||
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
|
||||
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
|
||||
```
|
||||
|
||||
## When to use vs alternatives
|
||||
|
||||
**Use Flash Attention when:**
|
||||
- Training transformers with sequences >512 tokens
|
||||
- Running inference with long context (>2K tokens)
|
||||
- GPU memory constrained (OOM with standard attention)
|
||||
- Need 2-4x speedup without accuracy loss
|
||||
- Using PyTorch 2.2+ or can install flash-attn
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
|
||||
- **xFormers**: Need more attention variants (not just speed)
|
||||
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
|
||||
|
||||
## Common issues
|
||||
|
||||
**Issue: ImportError: cannot import flash_attn**
|
||||
|
||||
Install with no-build-isolation flag:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or install CUDA toolkit first:
|
||||
```bash
|
||||
conda install cuda -c nvidia
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
**Issue: Slower than expected (no speedup)**
|
||||
|
||||
Flash Attention benefits increase with sequence length:
|
||||
- <512 tokens: Minimal speedup (10-20%)
|
||||
- 512-2K tokens: 2-3x speedup
|
||||
- >2K tokens: 3-4x speedup
|
||||
|
||||
Check sequence length is sufficient.
|
||||
|
||||
**Issue: RuntimeError: CUDA error**
|
||||
|
||||
Verify GPU supports Flash Attention:
|
||||
```python
|
||||
import torch
|
||||
print(torch.cuda.get_device_capability())
|
||||
# Should be ≥(7, 5) for Turing+
|
||||
```
|
||||
|
||||
Flash Attention requires:
|
||||
- Ampere (A100, A10): ✅ Full support
|
||||
- Turing (T4): ✅ Supported
|
||||
- Volta (V100): ❌ Not supported
|
||||
|
||||
**Issue: Accuracy degradation**
|
||||
|
||||
Check dtype is float16 or bfloat16 (not float32):
|
||||
```python
|
||||
q = q.to(torch.float16) # Or torch.bfloat16
|
||||
```
|
||||
|
||||
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
|
||||
|
||||
## Advanced topics
|
||||
|
||||
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
|
||||
|
||||
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
|
||||
|
||||
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
|
||||
|
||||
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
|
||||
|
||||
## Hardware requirements
|
||||
|
||||
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
|
||||
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
|
||||
- **CUDA**: 12.0+ (11.8 minimum)
|
||||
- **PyTorch**: 2.2+ for native support
|
||||
|
||||
**Not supported**: V100 (Volta), CPU inference
|
||||
|
||||
## Resources
|
||||
|
||||
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
|
||||
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
|
||||
- Blog: https://tridao.me/blog/2024/flash3/
|
||||
- GitHub: https://github.com/Dao-AILab/flash-attention
|
||||
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
|
||||
|
||||
|
||||
215
skills/mlops/flash-attention/references/benchmarks.md
Normal file
215
skills/mlops/flash-attention/references/benchmarks.md
Normal file
@@ -0,0 +1,215 @@
|
||||
# Performance Benchmarks
|
||||
|
||||
## Contents
|
||||
- Speed comparisons across GPUs
|
||||
- Memory usage analysis
|
||||
- Scaling with sequence length
|
||||
- Training vs inference performance
|
||||
- Flash Attention versions comparison
|
||||
|
||||
## Speed comparisons across GPUs
|
||||
|
||||
### A100 80GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|
||||
|------------|----------|--------------|--------------|---------------|
|
||||
| 512 | 1.2 | 0.9 | N/A | 1.3x |
|
||||
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
|
||||
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
|
||||
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
|
||||
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
|
||||
|
||||
### H100 80GB (Hopper)
|
||||
|
||||
**Forward pass time** (milliseconds, same config):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|
||||
|------------|----------|--------------|---------------------|--------------------|--------------|
|
||||
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
|
||||
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
|
||||
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
|
||||
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
|
||||
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
|
||||
|
||||
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
|
||||
|
||||
### A10G 24GB (Ampere)
|
||||
|
||||
**Forward pass time** (milliseconds, batch=4):
|
||||
|
||||
| Seq Length | Standard | Flash Attn 2 | Speedup |
|
||||
|------------|----------|--------------|---------|
|
||||
| 512 | 2.1 | 1.6 | 1.3x |
|
||||
| 1024 | 6.8 | 2.8 | 2.4x |
|
||||
| 2048 | 25.9 | 9.4 | 2.8x |
|
||||
| 4096 | 102.1 | 35.2 | 2.9x |
|
||||
|
||||
## Memory usage analysis
|
||||
|
||||
### GPU memory consumption (batch=8, heads=32, dim=64)
|
||||
|
||||
**Standard attention memory**:
|
||||
|
||||
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|
||||
|------------|------------------|----------|-------|-------|
|
||||
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
|
||||
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
|
||||
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
|
||||
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
|
||||
|
||||
**Flash Attention 2 memory**:
|
||||
|
||||
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|
||||
|------------|---------------------|----------|-------|-----------|
|
||||
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
|
||||
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
|
||||
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
|
||||
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
|
||||
|
||||
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
|
||||
|
||||
### Memory scaling comparison
|
||||
|
||||
**Llama 2 7B model memory** (float16, batch=1):
|
||||
|
||||
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|
||||
|----------------|-------------------|-------------------|-------------------|
|
||||
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
|
||||
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
|
||||
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
|
||||
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
|
||||
| 32K | OOM | 14.2 GB | Only Flash: Yes |
|
||||
|
||||
### Training memory (Llama 2 7B, batch=4)
|
||||
|
||||
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|
||||
|---------|---------------|-----------------|-----------|
|
||||
| 2K | 18.2 | 12.4 | 32% |
|
||||
| 4K | 34.8 | 16.8 | 52% |
|
||||
| 8K | OOM (>40GB) | 26.2 | Fits! |
|
||||
|
||||
## Scaling with sequence length
|
||||
|
||||
### Computational complexity
|
||||
|
||||
**Standard attention**:
|
||||
- Time: O(N² × d)
|
||||
- Memory: O(N² + N × d)
|
||||
|
||||
**Flash Attention**:
|
||||
- Time: O(N² × d) (same, but with better constants)
|
||||
- Memory: O(N × d) (linear!)
|
||||
|
||||
### Empirical scaling (A100, batch=1, heads=32, dim=64)
|
||||
|
||||
**Time per token (milliseconds)**:
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
|
||||
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
|
||||
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
|
||||
|
||||
**Observation**: Speedup increases quadratically with sequence length!
|
||||
|
||||
### Memory per token (MB)
|
||||
|
||||
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|
||||
|----------|-----|-----|-----|-----|-----|------|
|
||||
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
|
||||
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
|
||||
|
||||
**Observation**: Flash Attention memory per token is constant!
|
||||
|
||||
## Training vs inference performance
|
||||
|
||||
### Training (forward + backward, Llama 2 7B, A100)
|
||||
|
||||
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------------|------------------------|--------------------------|---------|
|
||||
| 4 × 2K | 1.2 | 3.1 | 2.6x |
|
||||
| 8 × 2K | 2.1 | 5.8 | 2.8x |
|
||||
| 4 × 4K | 0.4 | 1.3 | 3.3x |
|
||||
| 8 × 4K | OOM | 2.4 | Enabled |
|
||||
| 2 × 8K | 0.1 | 0.4 | 4.0x |
|
||||
|
||||
### Inference (generation, Llama 2 7B, A100)
|
||||
|
||||
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|----------------|----------------------|-------------------------|---------|
|
||||
| 512 | 48 | 52 | 1.1x |
|
||||
| 2K | 42 | 62 | 1.5x |
|
||||
| 4K | 31 | 58 | 1.9x |
|
||||
| 8K | 18 | 51 | 2.8x |
|
||||
| 16K | OOM | 42 | Enabled |
|
||||
|
||||
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
|
||||
|
||||
## Flash Attention versions comparison
|
||||
|
||||
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
|
||||
|
||||
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|
||||
|--------|-----|-----|------------|-----------|
|
||||
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
|
||||
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
|
||||
| TFLOPS | 180 | 420 | 740 | 1150 |
|
||||
| GPU util % | 35% | 55% | 75% | 82% |
|
||||
|
||||
**Key improvements**:
|
||||
- FA2: 2.3x faster than FA1 (better parallelism)
|
||||
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
|
||||
- FA3 (FP8): 2.6x faster than FA2 (low precision)
|
||||
|
||||
### Features by version
|
||||
|
||||
| Feature | FA1 | FA2 | FA3 |
|
||||
|---------|-----|-----|-----|
|
||||
| Basic attention | ✅ | ✅ | ✅ |
|
||||
| Causal masking | ✅ | ✅ | ✅ |
|
||||
| Multi-query attention | ❌ | ✅ | ✅ |
|
||||
| Sliding window | ❌ | ✅ | ✅ |
|
||||
| Paged KV cache | ❌ | ✅ | ✅ |
|
||||
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
|
||||
| Work partitioning | Basic | Advanced | Optimal |
|
||||
|
||||
## Real-world model benchmarks
|
||||
|
||||
### Llama 2 models (A100 80GB, batch=4, seq=2048)
|
||||
|
||||
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|
||||
|-------|--------|------------------------|--------------------------|---------|
|
||||
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
|
||||
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
|
||||
|
||||
### GPT-style models (seq=1024)
|
||||
|
||||
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|
||||
|-------|----------------------|-------------------------|---------|
|
||||
| GPT-2 (124M) | 520 | 680 | 1.3x |
|
||||
| GPT-J (6B) | 42 | 98 | 2.3x |
|
||||
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
|
||||
|
||||
## Recommendations by use case
|
||||
|
||||
**Training large models (>7B parameters)**:
|
||||
- Use Flash Attention 2 on A100
|
||||
- Use Flash Attention 3 FP8 on H100 for maximum speed
|
||||
- Expected: 2.5-3x speedup
|
||||
|
||||
**Long context inference (>4K tokens)**:
|
||||
- Flash Attention essential (enables contexts standard attention can't handle)
|
||||
- Expected: 2-4x speedup, 5-10x memory reduction
|
||||
|
||||
**Short sequences (<512 tokens)**:
|
||||
- Flash Attention provides 1.2-1.5x speedup
|
||||
- Minimal memory benefit
|
||||
- Still worth enabling (no downside)
|
||||
|
||||
**Multi-user serving**:
|
||||
- Flash Attention reduces per-request memory
|
||||
- Allows higher concurrent batch sizes
|
||||
- Can serve 2-3x more users on same hardware
|
||||
@@ -0,0 +1,293 @@
|
||||
# HuggingFace Transformers Integration
|
||||
|
||||
## Contents
|
||||
- Enabling Flash Attention in Transformers
|
||||
- Supported model architectures
|
||||
- Configuration examples
|
||||
- Performance comparisons
|
||||
- Troubleshooting model-specific issues
|
||||
|
||||
## Enabling Flash Attention in Transformers
|
||||
|
||||
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
|
||||
|
||||
**Simple enable for any supported model**:
|
||||
```python
|
||||
from transformers import AutoModel
|
||||
|
||||
model = AutoModel.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
```
|
||||
|
||||
**Install requirements**:
|
||||
```bash
|
||||
pip install transformers>=4.36
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
## Supported model architectures
|
||||
|
||||
As of Transformers 4.40:
|
||||
|
||||
**Fully supported**:
|
||||
- Llama / Llama 2 / Llama 3
|
||||
- Mistral / Mixtral
|
||||
- Falcon
|
||||
- GPT-NeoX
|
||||
- Phi / Phi-2 / Phi-3
|
||||
- Qwen / Qwen2
|
||||
- Gemma
|
||||
- Starcoder2
|
||||
- GPT-J
|
||||
- OPT
|
||||
- BLOOM
|
||||
|
||||
**Partially supported** (encoder-decoder):
|
||||
- BART
|
||||
- T5 / Flan-T5
|
||||
- Whisper
|
||||
|
||||
**Check support**:
|
||||
```python
|
||||
from transformers import AutoConfig
|
||||
|
||||
config = AutoConfig.from_pretrained("model-name")
|
||||
print(config._attn_implementation_internal)
|
||||
# 'flash_attention_2' if supported
|
||||
```
|
||||
|
||||
## Configuration examples
|
||||
|
||||
### Llama 2 with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model_id = "meta-llama/Llama-2-7b-hf"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Generate
|
||||
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
|
||||
outputs = model.generate(**inputs, max_length=100)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
### Mistral with Flash Attention for long context
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for long context
|
||||
device_map="auto",
|
||||
max_position_embeddings=32768 # Extended context
|
||||
)
|
||||
|
||||
# Process long document (32K tokens)
|
||||
long_text = "..." * 10000
|
||||
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
|
||||
outputs = model.generate(**inputs, max_new_tokens=512)
|
||||
```
|
||||
|
||||
### Fine-tuning with Flash Attention
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=3,
|
||||
fp16=True, # Must match model dtype
|
||||
optim="adamw_torch_fused" # Fast optimizer
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Multi-GPU training
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
# Model parallelism with Flash Attention
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-13b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto", # Automatic multi-GPU placement
|
||||
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
|
||||
)
|
||||
```
|
||||
|
||||
## Performance comparisons
|
||||
|
||||
### Memory usage (Llama 2 7B, batch=1)
|
||||
|
||||
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|
||||
|-----------------|-------------------|-------------------|-----------|
|
||||
| 512 | 1.2 GB | 0.9 GB | 25% |
|
||||
| 2048 | 3.8 GB | 1.4 GB | 63% |
|
||||
| 8192 | 14.2 GB | 3.2 GB | 77% |
|
||||
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
|
||||
|
||||
### Speed (tokens/sec, A100 80GB)
|
||||
|
||||
| Model | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|----------|--------------|---------|
|
||||
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
|
||||
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
|
||||
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
|
||||
|
||||
### Training throughput (samples/sec)
|
||||
|
||||
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|
||||
|-------|------------|----------|--------------|---------|
|
||||
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
|
||||
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
|
||||
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
|
||||
|
||||
## Troubleshooting model-specific issues
|
||||
|
||||
### Issue: Model doesn't support Flash Attention
|
||||
|
||||
Check support list above. If not supported, use PyTorch SDPA as fallback:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="sdpa", # PyTorch native (still faster)
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: CUDA out of memory during loading
|
||||
|
||||
Reduce memory footprint:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
max_memory={0: "18GB"}, # Reserve memory for KV cache
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
```
|
||||
|
||||
### Issue: Slower inference than expected
|
||||
|
||||
Ensure dtype matches:
|
||||
|
||||
```python
|
||||
# Model and inputs must both be float16/bfloat16
|
||||
model = model.to(torch.float16)
|
||||
inputs = tokenizer(..., return_tensors="pt").to("cuda")
|
||||
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
|
||||
for k, v in inputs.items()}
|
||||
```
|
||||
|
||||
### Issue: Different outputs vs standard attention
|
||||
|
||||
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
|
||||
|
||||
```python
|
||||
# Compare outputs
|
||||
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
|
||||
model_flash = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
out_standard = model_standard(**inputs).logits
|
||||
out_flash = model_flash(**inputs).logits
|
||||
|
||||
diff = (out_standard - out_flash).abs().max()
|
||||
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
|
||||
```
|
||||
|
||||
### Issue: ImportError during model loading
|
||||
|
||||
Install flash-attn:
|
||||
```bash
|
||||
pip install flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
Or disable Flash Attention:
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"model-name",
|
||||
attn_implementation="eager", # Standard PyTorch
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Always use float16/bfloat16** with Flash Attention (not float32)
|
||||
2. **Set device_map="auto"** for automatic memory management
|
||||
3. **Use bfloat16 for long context** (better numerical stability)
|
||||
4. **Enable gradient checkpointing** for training large models
|
||||
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
|
||||
|
||||
**Example with all best practices**:
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16, # Better for training
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True
|
||||
)
|
||||
|
||||
# Enable gradient checkpointing for memory
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Training with optimizations
|
||||
training_args = TrainingArguments(
|
||||
output_dir="./results",
|
||||
per_device_train_batch_size=8,
|
||||
gradient_accumulation_steps=2,
|
||||
bf16=True, # Match model dtype
|
||||
optim="adamw_torch_fused",
|
||||
gradient_checkpointing=True
|
||||
)
|
||||
```
|
||||
430
skills/mlops/gguf/SKILL.md
Normal file
430
skills/mlops/gguf/SKILL.md
Normal file
@@ -0,0 +1,430 @@
|
||||
---
|
||||
name: gguf-quantization
|
||||
description: GGUF format and llama.cpp quantization for efficient CPU/GPU inference. Use when deploying models on consumer hardware, Apple Silicon, or when needing flexible quantization from 2-8 bit without GPU requirements.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [llama-cpp-python>=0.2.0]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [GGUF, Quantization, llama.cpp, CPU Inference, Apple Silicon, Model Compression, Optimization]
|
||||
|
||||
---
|
||||
|
||||
# GGUF - Quantization Format for llama.cpp
|
||||
|
||||
The GGUF (GPT-Generated Unified Format) is the standard file format for llama.cpp, enabling efficient inference on CPUs, Apple Silicon, and GPUs with flexible quantization options.
|
||||
|
||||
## When to use GGUF
|
||||
|
||||
**Use GGUF when:**
|
||||
- Deploying on consumer hardware (laptops, desktops)
|
||||
- Running on Apple Silicon (M1/M2/M3) with Metal acceleration
|
||||
- Need CPU inference without GPU requirements
|
||||
- Want flexible quantization (Q2_K to Q8_0)
|
||||
- Using local AI tools (LM Studio, Ollama, text-generation-webui)
|
||||
|
||||
**Key advantages:**
|
||||
- **Universal hardware**: CPU, Apple Silicon, NVIDIA, AMD support
|
||||
- **No Python runtime**: Pure C/C++ inference
|
||||
- **Flexible quantization**: 2-8 bit with various methods (K-quants)
|
||||
- **Ecosystem support**: LM Studio, Ollama, koboldcpp, and more
|
||||
- **imatrix**: Importance matrix for better low-bit quality
|
||||
|
||||
**Use alternatives instead:**
|
||||
- **AWQ/GPTQ**: Maximum accuracy with calibration on NVIDIA GPUs
|
||||
- **HQQ**: Fast calibration-free quantization for HuggingFace
|
||||
- **bitsandbytes**: Simple integration with transformers library
|
||||
- **TensorRT-LLM**: Production NVIDIA deployment with maximum speed
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone llama.cpp
|
||||
git clone https://github.com/ggml-org/llama.cpp
|
||||
cd llama.cpp
|
||||
|
||||
# Build (CPU)
|
||||
make
|
||||
|
||||
# Build with CUDA (NVIDIA)
|
||||
make GGML_CUDA=1
|
||||
|
||||
# Build with Metal (Apple Silicon)
|
||||
make GGML_METAL=1
|
||||
|
||||
# Install Python bindings (optional)
|
||||
pip install llama-cpp-python
|
||||
```
|
||||
|
||||
### Convert model to GGUF
|
||||
|
||||
```bash
|
||||
# Install requirements
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Convert HuggingFace model to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./path/to/model --outfile model-f16.gguf
|
||||
|
||||
# Or specify output type
|
||||
python convert_hf_to_gguf.py ./path/to/model \
|
||||
--outfile model-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
### Quantize model
|
||||
|
||||
```bash
|
||||
# Basic quantization to Q4_K_M
|
||||
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# Quantize with importance matrix (better quality)
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Run inference
|
||||
|
||||
```bash
|
||||
# CLI inference
|
||||
./llama-cli -m model-q4_k_m.gguf -p "Hello, how are you?"
|
||||
|
||||
# Interactive mode
|
||||
./llama-cli -m model-q4_k_m.gguf --interactive
|
||||
|
||||
# With GPU offload
|
||||
./llama-cli -m model-q4_k_m.gguf -ngl 35 -p "Hello!"
|
||||
```
|
||||
|
||||
## Quantization types
|
||||
|
||||
### K-quant methods (recommended)
|
||||
|
||||
| Type | Bits | Size (7B) | Quality | Use Case |
|
||||
|------|------|-----------|---------|----------|
|
||||
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression |
|
||||
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
|
||||
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Balance |
|
||||
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Good balance |
|
||||
| Q4_K_M | 4.5 | ~4.1 GB | High | **Recommended default** |
|
||||
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
|
||||
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
|
||||
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
|
||||
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality |
|
||||
|
||||
### Legacy methods
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| Q4_0 | 4-bit, basic |
|
||||
| Q4_1 | 4-bit with delta |
|
||||
| Q5_0 | 5-bit, basic |
|
||||
| Q5_1 | 5-bit with delta |
|
||||
|
||||
**Recommendation**: Use K-quant methods (Q4_K_M, Q5_K_M) for best quality/size ratio.
|
||||
|
||||
## Conversion workflows
|
||||
|
||||
### Workflow 1: HuggingFace to GGUF
|
||||
|
||||
```bash
|
||||
# 1. Download model
|
||||
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
|
||||
|
||||
# 2. Convert to GGUF (FP16)
|
||||
python convert_hf_to_gguf.py ./llama-3.1-8b \
|
||||
--outfile llama-3.1-8b-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# 3. Quantize
|
||||
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
|
||||
|
||||
# 4. Test
|
||||
./llama-cli -m llama-3.1-8b-q4_k_m.gguf -p "Hello!" -n 50
|
||||
```
|
||||
|
||||
### Workflow 2: With importance matrix (better quality)
|
||||
|
||||
```bash
|
||||
# 1. Convert to GGUF
|
||||
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
|
||||
|
||||
# 2. Create calibration text (diverse samples)
|
||||
cat > calibration.txt << 'EOF'
|
||||
The quick brown fox jumps over the lazy dog.
|
||||
Machine learning is a subset of artificial intelligence.
|
||||
Python is a popular programming language.
|
||||
# Add more diverse text samples...
|
||||
EOF
|
||||
|
||||
# 3. Generate importance matrix
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f calibration.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix \
|
||||
-ngl 35 # GPU layers if available
|
||||
|
||||
# 4. Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf \
|
||||
model-q4_k_m.gguf \
|
||||
Q4_K_M
|
||||
```
|
||||
|
||||
### Workflow 3: Multiple quantizations
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
MODEL="llama-3.1-8b-f16.gguf"
|
||||
IMATRIX="llama-3.1-8b.imatrix"
|
||||
|
||||
# Generate imatrix once
|
||||
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
|
||||
|
||||
# Create multiple quantizations
|
||||
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
|
||||
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
|
||||
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
|
||||
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
|
||||
done
|
||||
```
|
||||
|
||||
## Python usage
|
||||
|
||||
### llama-cpp-python
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Load model
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096, # Context window
|
||||
n_gpu_layers=35, # GPU offload (0 for CPU only)
|
||||
n_threads=8 # CPU threads
|
||||
)
|
||||
|
||||
# Generate
|
||||
output = llm(
|
||||
"What is machine learning?",
|
||||
max_tokens=256,
|
||||
temperature=0.7,
|
||||
stop=["</s>", "\n\n"]
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Chat completion
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="./model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
chat_format="llama-3" # Or "chatml", "mistral", etc.
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is Python?"}
|
||||
]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=messages,
|
||||
max_tokens=256,
|
||||
temperature=0.7
|
||||
)
|
||||
print(response["choices"][0]["message"]["content"])
|
||||
```
|
||||
|
||||
### Streaming
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(model_path="./model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
# Stream tokens
|
||||
for chunk in llm(
|
||||
"Explain quantum computing:",
|
||||
max_tokens=256,
|
||||
stream=True
|
||||
):
|
||||
print(chunk["choices"][0]["text"], end="", flush=True)
|
||||
```
|
||||
|
||||
## Server mode
|
||||
|
||||
### Start OpenAI-compatible server
|
||||
|
||||
```bash
|
||||
# Start server
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096
|
||||
|
||||
# Or with Python bindings
|
||||
python -m llama_cpp.server \
|
||||
--model model-q4_k_m.gguf \
|
||||
--n_gpu_layers 35 \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080
|
||||
```
|
||||
|
||||
### Use with OpenAI client
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1",
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="local-model",
|
||||
messages=[{"role": "user", "content": "Hello!"}],
|
||||
max_tokens=256
|
||||
)
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Hardware optimization
|
||||
|
||||
### Apple Silicon (Metal)
|
||||
|
||||
```bash
|
||||
# Build with Metal
|
||||
make clean && make GGML_METAL=1
|
||||
|
||||
# Run with Metal acceleration
|
||||
./llama-cli -m model.gguf -ngl 99 -p "Hello"
|
||||
|
||||
# Python with Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload all layers
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
### NVIDIA CUDA
|
||||
|
||||
```bash
|
||||
# Build with CUDA
|
||||
make clean && make GGML_CUDA=1
|
||||
|
||||
# Run with CUDA
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
|
||||
# Specify GPU
|
||||
CUDA_VISIBLE_DEVICES=0 ./llama-cli -m model.gguf -ngl 35
|
||||
```
|
||||
|
||||
### CPU optimization
|
||||
|
||||
```bash
|
||||
# Build with AVX2/AVX512
|
||||
make clean && make
|
||||
|
||||
# Run with optimal threads
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
|
||||
# Python CPU config
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=0, # CPU only
|
||||
n_threads=8, # Match physical cores
|
||||
n_batch=512 # Batch size for prompt processing
|
||||
)
|
||||
```
|
||||
|
||||
## Integration with tools
|
||||
|
||||
### Ollama
|
||||
|
||||
```bash
|
||||
# Create Modelfile
|
||||
cat > Modelfile << 'EOF'
|
||||
FROM ./model-q4_k_m.gguf
|
||||
TEMPLATE """{{ .System }}
|
||||
{{ .Prompt }}"""
|
||||
PARAMETER temperature 0.7
|
||||
PARAMETER num_ctx 4096
|
||||
EOF
|
||||
|
||||
# Create Ollama model
|
||||
ollama create mymodel -f Modelfile
|
||||
|
||||
# Run
|
||||
ollama run mymodel "Hello!"
|
||||
```
|
||||
|
||||
### LM Studio
|
||||
|
||||
1. Place GGUF file in `~/.cache/lm-studio/models/`
|
||||
2. Open LM Studio and select the model
|
||||
3. Configure context length and GPU offload
|
||||
4. Start inference
|
||||
|
||||
### text-generation-webui
|
||||
|
||||
```bash
|
||||
# Place in models folder
|
||||
cp model-q4_k_m.gguf text-generation-webui/models/
|
||||
|
||||
# Start with llama.cpp loader
|
||||
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Use K-quants**: Q4_K_M offers best quality/size balance
|
||||
2. **Use imatrix**: Always use importance matrix for Q4 and below
|
||||
3. **GPU offload**: Offload as many layers as VRAM allows
|
||||
4. **Context length**: Start with 4096, increase if needed
|
||||
5. **Thread count**: Match physical CPU cores, not logical
|
||||
6. **Batch size**: Increase n_batch for faster prompt processing
|
||||
|
||||
## Common issues
|
||||
|
||||
**Model loads slowly:**
|
||||
```bash
|
||||
# Use mmap for faster loading
|
||||
./llama-cli -m model.gguf --mmap
|
||||
```
|
||||
|
||||
**Out of memory:**
|
||||
```bash
|
||||
# Reduce GPU layers
|
||||
./llama-cli -m model.gguf -ngl 20 # Reduce from 35
|
||||
|
||||
# Or use smaller quantization
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
**Poor quality at low bits:**
|
||||
```bash
|
||||
# Always use imatrix for Q4 and below
|
||||
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
|
||||
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
- **[Advanced Usage](references/advanced-usage.md)** - Batching, speculative decoding, custom builds
|
||||
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, benchmarks
|
||||
|
||||
## Resources
|
||||
|
||||
- **Repository**: https://github.com/ggml-org/llama.cpp
|
||||
- **Python Bindings**: https://github.com/abetlen/llama-cpp-python
|
||||
- **Pre-quantized Models**: https://huggingface.co/TheBloke
|
||||
- **GGUF Converter**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
|
||||
- **License**: MIT
|
||||
504
skills/mlops/gguf/references/advanced-usage.md
Normal file
504
skills/mlops/gguf/references/advanced-usage.md
Normal file
@@ -0,0 +1,504 @@
|
||||
# GGUF Advanced Usage Guide
|
||||
|
||||
## Speculative Decoding
|
||||
|
||||
### Draft Model Approach
|
||||
|
||||
```bash
|
||||
# Use smaller model as draft for faster generation
|
||||
./llama-speculative \
|
||||
-m large-model-q4_k_m.gguf \
|
||||
-md draft-model-q4_k_m.gguf \
|
||||
-p "Write a story about AI" \
|
||||
-n 500 \
|
||||
--draft 8 # Draft tokens before verification
|
||||
```
|
||||
|
||||
### Self-Speculative Decoding
|
||||
|
||||
```bash
|
||||
# Use same model with different context for speculation
|
||||
./llama-cli -m model-q4_k_m.gguf \
|
||||
--lookup-cache-static lookup.bin \
|
||||
--lookup-cache-dynamic lookup-dynamic.bin \
|
||||
-p "Hello world"
|
||||
```
|
||||
|
||||
## Batched Inference
|
||||
|
||||
### Process Multiple Prompts
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35,
|
||||
n_batch=512 # Larger batch for parallel processing
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"What is Python?",
|
||||
"Explain machine learning.",
|
||||
"Describe neural networks."
|
||||
]
|
||||
|
||||
# Process in batch (each prompt gets separate context)
|
||||
for prompt in prompts:
|
||||
output = llm(prompt, max_tokens=100)
|
||||
print(f"Q: {prompt}")
|
||||
print(f"A: {output['choices'][0]['text']}\n")
|
||||
```
|
||||
|
||||
### Server Batching
|
||||
|
||||
```bash
|
||||
# Start server with batching
|
||||
./llama-server -m model-q4_k_m.gguf \
|
||||
--host 0.0.0.0 \
|
||||
--port 8080 \
|
||||
-ngl 35 \
|
||||
-c 4096 \
|
||||
--parallel 4 # Concurrent requests
|
||||
--cont-batching # Continuous batching
|
||||
```
|
||||
|
||||
## Custom Model Conversion
|
||||
|
||||
### Convert with Vocabulary Modifications
|
||||
|
||||
```python
|
||||
# custom_convert.py
|
||||
import sys
|
||||
sys.path.insert(0, './llama.cpp')
|
||||
|
||||
from convert_hf_to_gguf import main
|
||||
from gguf import GGUFWriter
|
||||
|
||||
# Custom conversion with modified vocab
|
||||
def convert_with_custom_vocab(model_path, output_path):
|
||||
# Load and modify tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
# Add special tokens if needed
|
||||
special_tokens = {"additional_special_tokens": ["<|custom|>"]}
|
||||
tokenizer.add_special_tokens(special_tokens)
|
||||
tokenizer.save_pretrained(model_path)
|
||||
|
||||
# Then run standard conversion
|
||||
main([model_path, "--outfile", output_path])
|
||||
```
|
||||
|
||||
### Convert Specific Architecture
|
||||
|
||||
```bash
|
||||
# For Mistral-style models
|
||||
python convert_hf_to_gguf.py ./mistral-model \
|
||||
--outfile mistral-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# For Qwen models
|
||||
python convert_hf_to_gguf.py ./qwen-model \
|
||||
--outfile qwen-f16.gguf \
|
||||
--outtype f16
|
||||
|
||||
# For Phi models
|
||||
python convert_hf_to_gguf.py ./phi-model \
|
||||
--outfile phi-f16.gguf \
|
||||
--outtype f16
|
||||
```
|
||||
|
||||
## Advanced Quantization
|
||||
|
||||
### Mixed Quantization
|
||||
|
||||
```bash
|
||||
# Quantize different layer types differently
|
||||
./llama-quantize model-f16.gguf model-mixed.gguf Q4_K_M \
|
||||
--allow-requantize \
|
||||
--leave-output-tensor
|
||||
```
|
||||
|
||||
### Quantization with Token Embeddings
|
||||
|
||||
```bash
|
||||
# Keep embeddings at higher precision
|
||||
./llama-quantize model-f16.gguf model-q4.gguf Q4_K_M \
|
||||
--token-embedding-type f16
|
||||
```
|
||||
|
||||
### IQ Quantization (Importance-aware)
|
||||
|
||||
```bash
|
||||
# Ultra-low bit quantization with importance
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-iq2_xxs.gguf IQ2_XXS
|
||||
|
||||
# Available IQ types: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS
|
||||
```
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
### Memory Mapping
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Use memory mapping for large models
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
use_mmap=True, # Memory map the model
|
||||
use_mlock=False, # Don't lock in RAM
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
### Partial GPU Offload
|
||||
|
||||
```python
|
||||
# Calculate layers to offload based on VRAM
|
||||
import subprocess
|
||||
|
||||
def get_free_vram_gb():
|
||||
result = subprocess.run(
|
||||
['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return int(result.stdout.strip()) / 1024
|
||||
|
||||
# Estimate layers based on VRAM (rough: 0.5GB per layer for 7B Q4)
|
||||
free_vram = get_free_vram_gb()
|
||||
layers_to_offload = int(free_vram / 0.5)
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_gpu_layers=min(layers_to_offload, 35) # Cap at total layers
|
||||
)
|
||||
```
|
||||
|
||||
### KV Cache Optimization
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
# Optimize KV cache for long contexts
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=8192, # Large context
|
||||
n_gpu_layers=35,
|
||||
type_k=1, # Q8_0 for K cache (1)
|
||||
type_v=1, # Q8_0 for V cache (1)
|
||||
# Or use Q4_0 (2) for more compression
|
||||
)
|
||||
```
|
||||
|
||||
## Context Management
|
||||
|
||||
### Context Shifting
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
# Handle long conversations with context shifting
|
||||
conversation = []
|
||||
max_history = 10
|
||||
|
||||
def chat(user_message):
|
||||
conversation.append({"role": "user", "content": user_message})
|
||||
|
||||
# Keep only recent history
|
||||
if len(conversation) > max_history * 2:
|
||||
conversation = conversation[-max_history * 2:]
|
||||
|
||||
response = llm.create_chat_completion(
|
||||
messages=conversation,
|
||||
max_tokens=256
|
||||
)
|
||||
|
||||
assistant_message = response["choices"][0]["message"]["content"]
|
||||
conversation.append({"role": "assistant", "content": assistant_message})
|
||||
return assistant_message
|
||||
```
|
||||
|
||||
### Save and Load State
|
||||
|
||||
```bash
|
||||
# Save state to file
|
||||
./llama-cli -m model.gguf \
|
||||
-p "Once upon a time" \
|
||||
--save-session session.bin \
|
||||
-n 100
|
||||
|
||||
# Load and continue
|
||||
./llama-cli -m model.gguf \
|
||||
--load-session session.bin \
|
||||
-p " and they lived" \
|
||||
-n 100
|
||||
```
|
||||
|
||||
## Grammar Constrained Generation
|
||||
|
||||
### JSON Output
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama, LlamaGrammar
|
||||
|
||||
# Define JSON grammar
|
||||
json_grammar = LlamaGrammar.from_string('''
|
||||
root ::= object
|
||||
object ::= "{" ws pair ("," ws pair)* "}" ws
|
||||
pair ::= string ":" ws value
|
||||
value ::= string | number | object | array | "true" | "false" | "null"
|
||||
array ::= "[" ws value ("," ws value)* "]" ws
|
||||
string ::= "\\"" [^"\\\\]* "\\""
|
||||
number ::= [0-9]+
|
||||
ws ::= [ \\t\\n]*
|
||||
''')
|
||||
|
||||
llm = Llama(model_path="model-q4_k_m.gguf", n_gpu_layers=35)
|
||||
|
||||
output = llm(
|
||||
"Output a JSON object with name and age:",
|
||||
grammar=json_grammar,
|
||||
max_tokens=100
|
||||
)
|
||||
print(output["choices"][0]["text"])
|
||||
```
|
||||
|
||||
### Custom Grammar
|
||||
|
||||
```python
|
||||
# Grammar for specific format
|
||||
answer_grammar = LlamaGrammar.from_string('''
|
||||
root ::= "Answer: " letter "\\n" "Explanation: " explanation
|
||||
letter ::= [A-D]
|
||||
explanation ::= [a-zA-Z0-9 .,!?]+
|
||||
''')
|
||||
|
||||
output = llm(
|
||||
"Q: What is 2+2? A) 3 B) 4 C) 5 D) 6",
|
||||
grammar=answer_grammar,
|
||||
max_tokens=100
|
||||
)
|
||||
```
|
||||
|
||||
## LoRA Integration
|
||||
|
||||
### Load LoRA Adapter
|
||||
|
||||
```bash
|
||||
# Apply LoRA at runtime
|
||||
./llama-cli -m base-model-q4_k_m.gguf \
|
||||
--lora lora-adapter.gguf \
|
||||
--lora-scale 1.0 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Multiple LoRA Adapters
|
||||
|
||||
```bash
|
||||
# Stack multiple adapters
|
||||
./llama-cli -m base-model.gguf \
|
||||
--lora adapter1.gguf --lora-scale 0.5 \
|
||||
--lora adapter2.gguf --lora-scale 0.5 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Python LoRA Usage
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="base-model-q4_k_m.gguf",
|
||||
lora_path="lora-adapter.gguf",
|
||||
lora_scale=1.0,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
## Embedding Generation
|
||||
|
||||
### Extract Embeddings
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="model-q4_k_m.gguf",
|
||||
embedding=True, # Enable embedding mode
|
||||
n_gpu_layers=35
|
||||
)
|
||||
|
||||
# Get embeddings
|
||||
embeddings = llm.embed("This is a test sentence.")
|
||||
print(f"Embedding dimension: {len(embeddings)}")
|
||||
```
|
||||
|
||||
### Batch Embeddings
|
||||
|
||||
```python
|
||||
texts = [
|
||||
"Machine learning is fascinating.",
|
||||
"Deep learning uses neural networks.",
|
||||
"Python is a programming language."
|
||||
]
|
||||
|
||||
embeddings = [llm.embed(text) for text in texts]
|
||||
|
||||
# Calculate similarity
|
||||
import numpy as np
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
sim = cosine_similarity(embeddings[0], embeddings[1])
|
||||
print(f"Similarity: {sim:.4f}")
|
||||
```
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### Benchmark Script
|
||||
|
||||
```python
|
||||
import time
|
||||
from llama_cpp import Llama
|
||||
|
||||
def benchmark(model_path, prompt, n_tokens=100, n_runs=5):
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=35,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Warmup
|
||||
llm(prompt, max_tokens=10)
|
||||
|
||||
# Benchmark
|
||||
times = []
|
||||
for _ in range(n_runs):
|
||||
start = time.time()
|
||||
output = llm(prompt, max_tokens=n_tokens)
|
||||
elapsed = time.time() - start
|
||||
times.append(elapsed)
|
||||
|
||||
avg_time = sum(times) / len(times)
|
||||
tokens_per_sec = n_tokens / avg_time
|
||||
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Avg time: {avg_time:.2f}s")
|
||||
print(f"Tokens/sec: {tokens_per_sec:.1f}")
|
||||
|
||||
return tokens_per_sec
|
||||
|
||||
# Compare quantizations
|
||||
for quant in ["q4_k_m", "q5_k_m", "q8_0"]:
|
||||
benchmark(f"model-{quant}.gguf", "Explain quantum computing:", 100)
|
||||
```
|
||||
|
||||
### Optimal Configuration Finder
|
||||
|
||||
```python
|
||||
def find_optimal_config(model_path, target_vram_gb=8):
|
||||
"""Find optimal n_gpu_layers and n_batch for target VRAM."""
|
||||
from llama_cpp import Llama
|
||||
import gc
|
||||
|
||||
best_config = None
|
||||
best_speed = 0
|
||||
|
||||
for n_gpu_layers in range(0, 50, 5):
|
||||
for n_batch in [128, 256, 512, 1024]:
|
||||
try:
|
||||
gc.collect()
|
||||
llm = Llama(
|
||||
model_path=model_path,
|
||||
n_gpu_layers=n_gpu_layers,
|
||||
n_batch=n_batch,
|
||||
n_ctx=2048,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
# Quick benchmark
|
||||
start = time.time()
|
||||
llm("Hello", max_tokens=50)
|
||||
speed = 50 / (time.time() - start)
|
||||
|
||||
if speed > best_speed:
|
||||
best_speed = speed
|
||||
best_config = {
|
||||
"n_gpu_layers": n_gpu_layers,
|
||||
"n_batch": n_batch,
|
||||
"speed": speed
|
||||
}
|
||||
|
||||
del llm
|
||||
gc.collect()
|
||||
|
||||
except Exception as e:
|
||||
print(f"OOM at layers={n_gpu_layers}, batch={n_batch}")
|
||||
break
|
||||
|
||||
return best_config
|
||||
```
|
||||
|
||||
## Multi-GPU Setup
|
||||
|
||||
### Distribute Across GPUs
|
||||
|
||||
```bash
|
||||
# Split model across multiple GPUs
|
||||
./llama-cli -m large-model.gguf \
|
||||
--tensor-split 0.5,0.5 \
|
||||
-ngl 60 \
|
||||
-p "Hello!"
|
||||
```
|
||||
|
||||
### Python Multi-GPU
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
from llama_cpp import Llama
|
||||
|
||||
llm = Llama(
|
||||
model_path="large-model-q4_k_m.gguf",
|
||||
n_gpu_layers=60,
|
||||
tensor_split=[0.5, 0.5] # Split evenly across 2 GPUs
|
||||
)
|
||||
```
|
||||
|
||||
## Custom Builds
|
||||
|
||||
### Build with All Optimizations
|
||||
|
||||
```bash
|
||||
# Clean build with all CPU optimizations
|
||||
make clean
|
||||
LLAMA_OPENBLAS=1 LLAMA_BLAS_VENDOR=OpenBLAS make -j
|
||||
|
||||
# With CUDA and cuBLAS
|
||||
make clean
|
||||
GGML_CUDA=1 LLAMA_CUBLAS=1 make -j
|
||||
|
||||
# With specific CUDA architecture
|
||||
GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_86 make -j
|
||||
```
|
||||
|
||||
### CMake Build
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
cmake .. -DGGML_CUDA=ON -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build . --config Release -j
|
||||
```
|
||||
442
skills/mlops/gguf/references/troubleshooting.md
Normal file
442
skills/mlops/gguf/references/troubleshooting.md
Normal file
@@ -0,0 +1,442 @@
|
||||
# GGUF Troubleshooting Guide
|
||||
|
||||
## Installation Issues
|
||||
|
||||
### Build Fails
|
||||
|
||||
**Error**: `make: *** No targets specified and no makefile found`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Ensure you're in llama.cpp directory
|
||||
cd llama.cpp
|
||||
make
|
||||
```
|
||||
|
||||
**Error**: `fatal error: cuda_runtime.h: No such file or directory`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install CUDA toolkit
|
||||
# Ubuntu
|
||||
sudo apt install nvidia-cuda-toolkit
|
||||
|
||||
# Or set CUDA path
|
||||
export CUDA_PATH=/usr/local/cuda
|
||||
export PATH=$CUDA_PATH/bin:$PATH
|
||||
make GGML_CUDA=1
|
||||
```
|
||||
|
||||
### Python Bindings Issues
|
||||
|
||||
**Error**: `ERROR: Failed building wheel for llama-cpp-python`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Install build dependencies
|
||||
pip install cmake scikit-build-core
|
||||
|
||||
# For CUDA support
|
||||
CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
|
||||
# For Metal (macOS)
|
||||
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
|
||||
```
|
||||
|
||||
**Error**: `ImportError: libcudart.so.XX: cannot open shared object file`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Add CUDA libraries to path
|
||||
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
|
||||
|
||||
# Or reinstall with correct CUDA version
|
||||
pip uninstall llama-cpp-python
|
||||
CUDACXX=/usr/local/cuda/bin/nvcc CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python
|
||||
```
|
||||
|
||||
## Conversion Issues
|
||||
|
||||
### Model Not Supported
|
||||
|
||||
**Error**: `KeyError: 'model.embed_tokens.weight'`
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Check model architecture
|
||||
python -c "from transformers import AutoConfig; print(AutoConfig.from_pretrained('./model').architectures)"
|
||||
|
||||
# Use appropriate conversion script
|
||||
# For most models:
|
||||
python convert_hf_to_gguf.py ./model --outfile model.gguf
|
||||
|
||||
# For older models, check if legacy script needed
|
||||
```
|
||||
|
||||
### Vocabulary Mismatch
|
||||
|
||||
**Error**: `RuntimeError: Vocabulary size mismatch`
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure tokenizer matches model
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("./model")
|
||||
model = AutoModelForCausalLM.from_pretrained("./model")
|
||||
|
||||
print(f"Tokenizer vocab size: {len(tokenizer)}")
|
||||
print(f"Model vocab size: {model.config.vocab_size}")
|
||||
|
||||
# If mismatch, resize embeddings before conversion
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.save_pretrained("./model-fixed")
|
||||
```
|
||||
|
||||
### Out of Memory During Conversion
|
||||
|
||||
**Error**: `torch.cuda.OutOfMemoryError` during conversion
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Use CPU for conversion
|
||||
CUDA_VISIBLE_DEVICES="" python convert_hf_to_gguf.py ./model --outfile model.gguf
|
||||
|
||||
# Or use low memory mode
|
||||
python convert_hf_to_gguf.py ./model --outfile model.gguf --outtype f16
|
||||
```
|
||||
|
||||
## Quantization Issues
|
||||
|
||||
### Wrong Output File Size
|
||||
|
||||
**Problem**: Quantized file is larger than expected
|
||||
|
||||
**Check**:
|
||||
```bash
|
||||
# Verify quantization type
|
||||
./llama-cli -m model.gguf --verbose
|
||||
|
||||
# Expected sizes for 7B model:
|
||||
# Q4_K_M: ~4.1 GB
|
||||
# Q5_K_M: ~4.8 GB
|
||||
# Q8_0: ~7.2 GB
|
||||
# F16: ~13.5 GB
|
||||
```
|
||||
|
||||
### Quantization Crashes
|
||||
|
||||
**Error**: `Segmentation fault` during quantization
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Increase stack size
|
||||
ulimit -s unlimited
|
||||
|
||||
# Or use less threads
|
||||
./llama-quantize -t 4 model-f16.gguf model-q4.gguf Q4_K_M
|
||||
```
|
||||
|
||||
### Poor Quality After Quantization
|
||||
|
||||
**Problem**: Model outputs gibberish after quantization
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Use importance matrix**:
|
||||
```bash
|
||||
# Generate imatrix with good calibration data
|
||||
./llama-imatrix -m model-f16.gguf \
|
||||
-f wiki_sample.txt \
|
||||
--chunk 512 \
|
||||
-o model.imatrix
|
||||
|
||||
# Quantize with imatrix
|
||||
./llama-quantize --imatrix model.imatrix \
|
||||
model-f16.gguf model-q4_k_m.gguf Q4_K_M
|
||||
```
|
||||
|
||||
2. **Try higher precision**:
|
||||
```bash
|
||||
# Use Q5_K_M or Q6_K instead of Q4
|
||||
./llama-quantize model-f16.gguf model-q5_k_m.gguf Q5_K_M
|
||||
```
|
||||
|
||||
3. **Check original model**:
|
||||
```bash
|
||||
# Test FP16 version first
|
||||
./llama-cli -m model-f16.gguf -p "Hello, how are you?" -n 50
|
||||
```
|
||||
|
||||
## Inference Issues
|
||||
|
||||
### Slow Generation
|
||||
|
||||
**Problem**: Generation is slower than expected
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Enable GPU offload**:
|
||||
```bash
|
||||
./llama-cli -m model.gguf -ngl 35 -p "Hello"
|
||||
```
|
||||
|
||||
2. **Optimize batch size**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_batch=512, # Increase for faster prompt processing
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
3. **Use appropriate threads**:
|
||||
```bash
|
||||
# Match physical cores, not logical
|
||||
./llama-cli -m model.gguf -t 8 -p "Hello"
|
||||
```
|
||||
|
||||
4. **Enable Flash Attention** (if supported):
|
||||
```bash
|
||||
./llama-cli -m model.gguf -ngl 35 --flash-attn -p "Hello"
|
||||
```
|
||||
|
||||
### Out of Memory
|
||||
|
||||
**Error**: `CUDA out of memory` or system freeze
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Reduce GPU layers**:
|
||||
```python
|
||||
# Start low and increase
|
||||
llm = Llama(model_path="model.gguf", n_gpu_layers=10)
|
||||
```
|
||||
|
||||
2. **Use smaller quantization**:
|
||||
```bash
|
||||
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
|
||||
```
|
||||
|
||||
3. **Reduce context length**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_ctx=2048, # Reduce from 4096
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
4. **Quantize KV cache**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
type_k=2, # Q4_0 for K cache
|
||||
type_v=2, # Q4_0 for V cache
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
### Garbage Output
|
||||
|
||||
**Problem**: Model outputs random characters or nonsense
|
||||
|
||||
**Diagnose**:
|
||||
```python
|
||||
# Check model loading
|
||||
llm = Llama(model_path="model.gguf", verbose=True)
|
||||
|
||||
# Test with simple prompt
|
||||
output = llm("1+1=", max_tokens=5, temperature=0)
|
||||
print(output)
|
||||
```
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Check model integrity**:
|
||||
```bash
|
||||
# Verify GGUF file
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | head -50
|
||||
```
|
||||
|
||||
2. **Use correct chat format**:
|
||||
```python
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
chat_format="llama-3" # Match your model: chatml, mistral, etc.
|
||||
)
|
||||
```
|
||||
|
||||
3. **Check temperature**:
|
||||
```python
|
||||
# Use lower temperature for deterministic output
|
||||
output = llm("Hello", max_tokens=50, temperature=0.1)
|
||||
```
|
||||
|
||||
### Token Issues
|
||||
|
||||
**Error**: `RuntimeError: unknown token` or encoding errors
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Ensure UTF-8 encoding
|
||||
prompt = "Hello, world!".encode('utf-8').decode('utf-8')
|
||||
output = llm(prompt, max_tokens=50)
|
||||
```
|
||||
|
||||
## Server Issues
|
||||
|
||||
### Connection Refused
|
||||
|
||||
**Error**: `Connection refused` when accessing server
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Bind to all interfaces
|
||||
./llama-server -m model.gguf --host 0.0.0.0 --port 8080
|
||||
|
||||
# Check if port is in use
|
||||
lsof -i :8080
|
||||
```
|
||||
|
||||
### Server Crashes Under Load
|
||||
|
||||
**Problem**: Server crashes with multiple concurrent requests
|
||||
|
||||
**Solutions**:
|
||||
|
||||
1. **Limit parallelism**:
|
||||
```bash
|
||||
./llama-server -m model.gguf \
|
||||
--parallel 2 \
|
||||
-c 4096 \
|
||||
--cont-batching
|
||||
```
|
||||
|
||||
2. **Add request timeout**:
|
||||
```bash
|
||||
./llama-server -m model.gguf --timeout 300
|
||||
```
|
||||
|
||||
3. **Monitor memory**:
|
||||
```bash
|
||||
watch -n 1 nvidia-smi # For GPU
|
||||
watch -n 1 free -h # For RAM
|
||||
```
|
||||
|
||||
### API Compatibility Issues
|
||||
|
||||
**Problem**: OpenAI client not working with server
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
# Use correct base URL format
|
||||
client = OpenAI(
|
||||
base_url="http://localhost:8080/v1", # Include /v1
|
||||
api_key="not-needed"
|
||||
)
|
||||
|
||||
# Use correct model name
|
||||
response = client.chat.completions.create(
|
||||
model="local", # Or the actual model name
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
```
|
||||
|
||||
## Apple Silicon Issues
|
||||
|
||||
### Metal Not Working
|
||||
|
||||
**Problem**: Metal acceleration not enabled
|
||||
|
||||
**Check**:
|
||||
```bash
|
||||
# Verify Metal support
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | grep -i metal
|
||||
```
|
||||
|
||||
**Fix**:
|
||||
```bash
|
||||
# Rebuild with Metal
|
||||
make clean
|
||||
make GGML_METAL=1
|
||||
|
||||
# Python bindings
|
||||
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall
|
||||
```
|
||||
|
||||
### Incorrect Memory Usage on M1/M2
|
||||
|
||||
**Problem**: Model uses too much unified memory
|
||||
|
||||
**Fix**:
|
||||
```python
|
||||
# Offload all layers for Metal
|
||||
llm = Llama(
|
||||
model_path="model.gguf",
|
||||
n_gpu_layers=99, # Offload everything
|
||||
n_threads=1 # Metal handles parallelism
|
||||
)
|
||||
```
|
||||
|
||||
## Debugging
|
||||
|
||||
### Enable Verbose Output
|
||||
|
||||
```bash
|
||||
# CLI verbose mode
|
||||
./llama-cli -m model.gguf --verbose -p "Hello" -n 50
|
||||
|
||||
# Python verbose
|
||||
llm = Llama(model_path="model.gguf", verbose=True)
|
||||
```
|
||||
|
||||
### Check Model Metadata
|
||||
|
||||
```bash
|
||||
# View GGUF metadata
|
||||
./llama-cli -m model.gguf --verbose 2>&1 | head -100
|
||||
```
|
||||
|
||||
### Validate GGUF File
|
||||
|
||||
```python
|
||||
import struct
|
||||
|
||||
def validate_gguf(filepath):
|
||||
with open(filepath, 'rb') as f:
|
||||
magic = f.read(4)
|
||||
if magic != b'GGUF':
|
||||
print(f"Invalid magic: {magic}")
|
||||
return False
|
||||
|
||||
version = struct.unpack('<I', f.read(4))[0]
|
||||
print(f"GGUF version: {version}")
|
||||
|
||||
tensor_count = struct.unpack('<Q', f.read(8))[0]
|
||||
metadata_count = struct.unpack('<Q', f.read(8))[0]
|
||||
print(f"Tensors: {tensor_count}, Metadata: {metadata_count}")
|
||||
|
||||
return True
|
||||
|
||||
validate_gguf("model.gguf")
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
1. **GitHub Issues**: https://github.com/ggml-org/llama.cpp/issues
|
||||
2. **Discussions**: https://github.com/ggml-org/llama.cpp/discussions
|
||||
3. **Reddit**: r/LocalLLaMA
|
||||
|
||||
### Reporting Issues
|
||||
|
||||
Include:
|
||||
- llama.cpp version/commit hash
|
||||
- Build command used
|
||||
- Model name and quantization
|
||||
- Full error message/stack trace
|
||||
- Hardware: CPU/GPU model, RAM, VRAM
|
||||
- OS version
|
||||
- Minimal reproduction steps
|
||||
97
skills/mlops/grpo-rl-training/README.md
Normal file
97
skills/mlops/grpo-rl-training/README.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# GRPO/RL Training Skill
|
||||
|
||||
**Expert-level guidance for Group Relative Policy Optimization with TRL**
|
||||
|
||||
## 📁 Skill Structure
|
||||
|
||||
```
|
||||
grpo-rl-training/
|
||||
├── SKILL.md # Main skill documentation (READ THIS FIRST)
|
||||
├── README.md # This file
|
||||
├── templates/
|
||||
│ └── basic_grpo_training.py # Production-ready training template
|
||||
└── examples/
|
||||
└── reward_functions_library.py # 20+ reward function examples
|
||||
```
|
||||
|
||||
## 🚀 Quick Start
|
||||
|
||||
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
|
||||
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
|
||||
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
|
||||
4. **Modify for your use case** - Adapt dataset, rewards, and config
|
||||
|
||||
## 💡 What's Inside
|
||||
|
||||
### SKILL.md (Main Documentation)
|
||||
- Core GRPO concepts and algorithm fundamentals
|
||||
- Complete implementation workflow (dataset → rewards → training → deployment)
|
||||
- 10+ reward function examples with code
|
||||
- Hyperparameter tuning guide
|
||||
- Training insights (loss behavior, metrics, debugging)
|
||||
- Troubleshooting guide
|
||||
- Production best practices
|
||||
|
||||
### Templates
|
||||
- **basic_grpo_training.py**: Minimal, production-ready training script
|
||||
- Uses Qwen 2.5 1.5B Instruct
|
||||
- 3 reward functions (format + correctness)
|
||||
- LoRA for efficient training
|
||||
- Fully documented and ready to run
|
||||
|
||||
### Examples
|
||||
- **reward_functions_library.py**: 20+ battle-tested reward functions
|
||||
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
|
||||
- Format rewards (XML, JSON, strict/soft)
|
||||
- Length rewards (ideal length, min/max)
|
||||
- Style rewards (reasoning quality, citations, repetition penalty)
|
||||
- Combined rewards (multi-objective optimization)
|
||||
- Preset collections for common tasks
|
||||
|
||||
## 📖 Usage for Agents
|
||||
|
||||
When this skill is loaded in your agent's context:
|
||||
|
||||
1. **Always read SKILL.md first** before implementing
|
||||
2. **Start simple** - Use length-based reward to validate setup
|
||||
3. **Build incrementally** - Add one reward function at a time
|
||||
4. **Reference examples** - Copy patterns from reward_functions_library.py
|
||||
5. **Monitor training** - Watch reward metrics (not loss!)
|
||||
|
||||
## 🎯 Common Use Cases
|
||||
|
||||
| Task Type | Recommended Rewards | Template |
|
||||
|-----------|---------------------|----------|
|
||||
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
|
||||
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
|
||||
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
|
||||
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
|
||||
|
||||
## ⚠️ Critical Reminders
|
||||
|
||||
- **Loss goes UP during training** - This is normal (it's KL divergence)
|
||||
- **Use 3-5 reward functions** - Single rewards often fail
|
||||
- **Test rewards before training** - Debug each function independently
|
||||
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
|
||||
- **Start with num_generations=4-8** - Scale up if GPU allows
|
||||
|
||||
## 🔗 External Resources
|
||||
|
||||
- [TRL Documentation](https://huggingface.co/docs/trl)
|
||||
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
|
||||
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
|
||||
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
|
||||
|
||||
## 📝 Version
|
||||
|
||||
**v1.0.0** - Initial release (January 2025)
|
||||
|
||||
## 👨💻 Maintained By
|
||||
|
||||
Orchestra Research
|
||||
For questions or improvements, see https://orchestra.com
|
||||
|
||||
---
|
||||
|
||||
**License:** MIT
|
||||
**Last Updated:** January 2025
|
||||
575
skills/mlops/grpo-rl-training/SKILL.md
Normal file
575
skills/mlops/grpo-rl-training/SKILL.md
Normal file
@@ -0,0 +1,575 @@
|
||||
---
|
||||
name: grpo-rl-training
|
||||
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
|
||||
|
||||
---
|
||||
|
||||
# GRPO/RL Training with TRL
|
||||
|
||||
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use GRPO training when you need to:
|
||||
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
|
||||
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
|
||||
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
|
||||
- **Align models to domain-specific behaviors** without labeled preference data
|
||||
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
|
||||
|
||||
**Do NOT use GRPO for:**
|
||||
- Simple supervised fine-tuning tasks (use SFT instead)
|
||||
- Tasks without clear reward signals
|
||||
- When you already have high-quality preference pairs (use DPO/PPO instead)
|
||||
|
||||
---
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. GRPO Algorithm Fundamentals
|
||||
|
||||
**Key Mechanism:**
|
||||
- Generates **multiple completions** for each prompt (group size: 4-16)
|
||||
- Compares completions within each group using reward functions
|
||||
- Updates policy to favor higher-rewarded responses relative to the group
|
||||
|
||||
**Critical Difference from PPO:**
|
||||
- No separate reward model needed
|
||||
- More sample-efficient (learns from within-group comparisons)
|
||||
- Simpler to implement and debug
|
||||
|
||||
**Mathematical Intuition:**
|
||||
```
|
||||
For each prompt p:
|
||||
1. Generate N completions: {c₁, c₂, ..., cₙ}
|
||||
2. Compute rewards: {r₁, r₂, ..., rₙ}
|
||||
3. Learn to increase probability of high-reward completions
|
||||
relative to low-reward ones in the same group
|
||||
```
|
||||
|
||||
### 2. Reward Function Design Philosophy
|
||||
|
||||
**Golden Rules:**
|
||||
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
|
||||
2. **Scale rewards appropriately** - Higher weight = stronger signal
|
||||
3. **Use incremental rewards** - Partial credit for partial compliance
|
||||
4. **Test rewards independently** - Debug each reward function in isolation
|
||||
|
||||
**Reward Function Types:**
|
||||
|
||||
| Type | Use Case | Example Weight |
|
||||
|------|----------|----------------|
|
||||
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
|
||||
| **Format** | Strict structure enforcement | 0.5-1.0 |
|
||||
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
|
||||
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Workflow
|
||||
|
||||
### Step 1: Dataset Preparation
|
||||
|
||||
**Critical Requirements:**
|
||||
- Prompts in chat format (list of dicts with 'role' and 'content')
|
||||
- Include system prompts to set expectations
|
||||
- For verifiable tasks, include ground truth answers as additional columns
|
||||
|
||||
**Example Structure:**
|
||||
```python
|
||||
from datasets import load_dataset, Dataset
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
def prepare_dataset(raw_data):
|
||||
"""
|
||||
Transform raw data into GRPO-compatible format.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content (system + user messages)
|
||||
- 'answer': str (ground truth, optional but recommended)
|
||||
"""
|
||||
return raw_data.map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': extract_answer(x['raw_answer'])
|
||||
})
|
||||
```
|
||||
|
||||
**Pro Tips:**
|
||||
- Use one-shot or few-shot examples in system prompt for complex formats
|
||||
- Keep prompts concise (max_prompt_length: 256-512 tokens)
|
||||
- Validate data quality before training (garbage in = garbage out)
|
||||
|
||||
### Step 2: Reward Function Implementation
|
||||
|
||||
**Template Structure:**
|
||||
```python
|
||||
def reward_function_name(
|
||||
prompts, # List[List[Dict]]: Original prompts
|
||||
completions, # List[List[Dict]]: Model generations
|
||||
answer=None, # Optional: Ground truth from dataset
|
||||
**kwargs # Additional dataset columns
|
||||
) -> list[float]:
|
||||
"""
|
||||
Evaluate completions and return rewards.
|
||||
|
||||
Returns: List of floats (one per completion)
|
||||
"""
|
||||
# Extract completion text
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
|
||||
# Compute rewards
|
||||
rewards = []
|
||||
for response in responses:
|
||||
score = compute_score(response)
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Example 1: Correctness Reward (Math/Coding)**
|
||||
```python
|
||||
def correctness_reward(prompts, completions, answer, **kwargs):
|
||||
"""Reward correct answers with high score."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_final_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0
|
||||
for ans, gt in zip(extracted, answer)]
|
||||
```
|
||||
|
||||
**Example 2: Format Reward (Structured Output)**
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward(completions, **kwargs):
|
||||
"""Reward XML-like structured format."""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
|
||||
for r in responses]
|
||||
```
|
||||
|
||||
**Example 3: Incremental Format Reward (Partial Credit)**
|
||||
```python
|
||||
def incremental_format_reward(completions, **kwargs):
|
||||
"""Award partial credit for format compliance."""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.25
|
||||
if '</reasoning>' in r:
|
||||
score += 0.25
|
||||
if '<answer>' in r:
|
||||
score += 0.25
|
||||
if '</answer>' in r:
|
||||
score += 0.25
|
||||
# Penalize extra text after closing tag
|
||||
if r.count('</answer>') == 1:
|
||||
extra_text = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra_text) * 0.001
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
**Critical Insight:**
|
||||
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
|
||||
|
||||
### Step 3: Training Configuration
|
||||
|
||||
**Memory-Optimized Config (Small GPU)**
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6, # Lower = more stable
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4, # Effective batch = 4
|
||||
|
||||
# GRPO-specific
|
||||
num_generations=8, # Group size: 8-16 recommended
|
||||
max_prompt_length=256,
|
||||
max_completion_length=512,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
max_steps=None, # Or set fixed steps (e.g., 500)
|
||||
|
||||
# Optimization
|
||||
bf16=True, # Faster on A100/H100
|
||||
optim="adamw_8bit", # Memory-efficient optimizer
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Or "none" for no logging
|
||||
)
|
||||
```
|
||||
|
||||
**High-Performance Config (Large GPU)**
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
output_dir="outputs/grpo-model",
|
||||
learning_rate=1e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=2,
|
||||
num_generations=16, # Larger groups = better signal
|
||||
max_prompt_length=512,
|
||||
max_completion_length=1024,
|
||||
num_train_epochs=1,
|
||||
bf16=True,
|
||||
use_vllm=True, # Fast generation with vLLM
|
||||
logging_steps=10,
|
||||
)
|
||||
```
|
||||
|
||||
**Critical Hyperparameters:**
|
||||
|
||||
| Parameter | Impact | Tuning Advice |
|
||||
|-----------|--------|---------------|
|
||||
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
|
||||
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
|
||||
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
|
||||
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
|
||||
|
||||
### Step 4: Model Setup and Training
|
||||
|
||||
**Standard Setup (Transformers)**
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Load model
|
||||
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2", # 2-3x faster
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Optional: LoRA for parameter-efficient training
|
||||
peft_config = LoraConfig(
|
||||
r=16, # Rank (higher = more capacity)
|
||||
lora_alpha=32, # Scaling factor (typically 2*r)
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward,
|
||||
format_reward,
|
||||
correctness_reward,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config, # Remove for full fine-tuning
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.train()
|
||||
|
||||
# Save
|
||||
trainer.save_model("final_model")
|
||||
```
|
||||
|
||||
**Unsloth Setup (2-3x Faster)**
|
||||
```python
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="google/gemma-3-1b-it",
|
||||
max_seq_length=1024,
|
||||
load_in_4bit=True,
|
||||
fast_inference=True,
|
||||
max_lora_rank=32,
|
||||
)
|
||||
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=32,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"],
|
||||
lora_alpha=32,
|
||||
use_gradient_checkpointing="unsloth",
|
||||
)
|
||||
|
||||
# Rest is identical to standard setup
|
||||
trainer = GRPOTrainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Critical Training Insights
|
||||
|
||||
### 1. Loss Behavior (EXPECTED PATTERN)
|
||||
- **Loss starts near 0 and INCREASES during training**
|
||||
- This is CORRECT - loss measures KL divergence from initial policy
|
||||
- Model is learning (diverging from original behavior to optimize rewards)
|
||||
- Monitor reward metrics instead of loss for progress
|
||||
|
||||
### 2. Reward Tracking
|
||||
Key metrics to watch:
|
||||
- `reward`: Average across all completions
|
||||
- `reward_std`: Diversity within groups (should remain > 0)
|
||||
- `kl`: KL divergence from reference (should grow moderately)
|
||||
|
||||
**Healthy Training Pattern:**
|
||||
```
|
||||
Step Reward Reward_Std KL
|
||||
100 0.5 0.3 0.02
|
||||
200 0.8 0.25 0.05
|
||||
300 1.2 0.2 0.08 ← Good progression
|
||||
400 1.5 0.15 0.12
|
||||
```
|
||||
|
||||
**Warning Signs:**
|
||||
- Reward std → 0 (model collapsing to single response)
|
||||
- KL exploding (> 0.5) (diverging too much, reduce LR)
|
||||
- Reward stuck (reward functions too harsh or model capacity issue)
|
||||
|
||||
### 3. Common Pitfalls and Solutions
|
||||
|
||||
| Problem | Symptom | Solution |
|
||||
|---------|---------|----------|
|
||||
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
|
||||
| **No learning** | Flat rewards | Check reward function logic, increase LR |
|
||||
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
|
||||
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
|
||||
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
|
||||
|
||||
---
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### 1. Multi-Stage Training
|
||||
For complex tasks, train in stages:
|
||||
|
||||
```python
|
||||
# Stage 1: Format compliance (epochs=1)
|
||||
trainer_stage1 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[incremental_format_reward, format_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage1.train()
|
||||
|
||||
# Stage 2: Correctness (epochs=1)
|
||||
trainer_stage2 = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs=[format_reward, correctness_reward],
|
||||
...
|
||||
)
|
||||
trainer_stage2.train()
|
||||
```
|
||||
|
||||
### 2. Adaptive Reward Scaling
|
||||
```python
|
||||
class AdaptiveReward:
|
||||
def __init__(self, base_reward_func, initial_weight=1.0):
|
||||
self.func = base_reward_func
|
||||
self.weight = initial_weight
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
rewards = self.func(*args, **kwargs)
|
||||
return [r * self.weight for r in rewards]
|
||||
|
||||
def adjust_weight(self, success_rate):
|
||||
"""Increase weight if model struggling, decrease if succeeding."""
|
||||
if success_rate < 0.3:
|
||||
self.weight *= 1.2
|
||||
elif success_rate > 0.8:
|
||||
self.weight *= 0.9
|
||||
```
|
||||
|
||||
### 3. Custom Dataset Integration
|
||||
```python
|
||||
def load_custom_knowledge_base(csv_path):
|
||||
"""Example: School communication platform docs."""
|
||||
import pandas as pd
|
||||
df = pd.read_csv(csv_path)
|
||||
|
||||
dataset = Dataset.from_pandas(df).map(lambda x: {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': x['expert_answer']
|
||||
})
|
||||
return dataset
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deployment and Inference
|
||||
|
||||
### Save and Merge LoRA
|
||||
```python
|
||||
# Merge LoRA adapters into base model
|
||||
if hasattr(trainer.model, 'merge_and_unload'):
|
||||
merged_model = trainer.model.merge_and_unload()
|
||||
merged_model.save_pretrained("production_model")
|
||||
tokenizer.save_pretrained("production_model")
|
||||
```
|
||||
|
||||
### Inference Example
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
generator = pipeline(
|
||||
"text-generation",
|
||||
model="production_model",
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
result = generator(
|
||||
[
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': "What is 15 + 27?"}
|
||||
],
|
||||
max_new_tokens=256,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
top_p=0.9
|
||||
)
|
||||
print(result[0]['generated_text'])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Best Practices Checklist
|
||||
|
||||
**Before Training:**
|
||||
- [ ] Validate dataset format (prompts as List[Dict])
|
||||
- [ ] Test reward functions on sample data
|
||||
- [ ] Calculate expected max_prompt_length from data
|
||||
- [ ] Choose appropriate num_generations based on GPU memory
|
||||
- [ ] Set up logging (wandb recommended)
|
||||
|
||||
**During Training:**
|
||||
- [ ] Monitor reward progression (should increase)
|
||||
- [ ] Check reward_std (should stay > 0.1)
|
||||
- [ ] Watch for OOM errors (reduce batch size if needed)
|
||||
- [ ] Sample generations every 50-100 steps
|
||||
- [ ] Validate format compliance on holdout set
|
||||
|
||||
**After Training:**
|
||||
- [ ] Merge LoRA weights if using PEFT
|
||||
- [ ] Test on diverse prompts
|
||||
- [ ] Compare to baseline model
|
||||
- [ ] Document reward weights and hyperparameters
|
||||
- [ ] Save reproducibility config
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting Guide
|
||||
|
||||
### Debugging Workflow
|
||||
1. **Isolate reward functions** - Test each independently
|
||||
2. **Check data distribution** - Ensure diversity in prompts
|
||||
3. **Reduce complexity** - Start with single reward, add gradually
|
||||
4. **Monitor generations** - Print samples every N steps
|
||||
5. **Validate extraction logic** - Ensure answer parsing works
|
||||
|
||||
### Quick Fixes
|
||||
```python
|
||||
# Debug reward function
|
||||
def debug_reward(completions, **kwargs):
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
for i, r in enumerate(responses[:2]): # Print first 2
|
||||
print(f"Response {i}: {r[:200]}...")
|
||||
return [1.0] * len(responses) # Dummy rewards
|
||||
|
||||
# Test without training
|
||||
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
|
||||
trainer.generate_completions(dataset[:1]) # Generate without updating
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## References and Resources
|
||||
|
||||
**Official Documentation:**
|
||||
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
|
||||
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
|
||||
- Unsloth Docs: https://docs.unsloth.ai/
|
||||
|
||||
**Example Repositories:**
|
||||
- Open R1 Implementation: https://github.com/huggingface/open-r1
|
||||
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
|
||||
|
||||
**Recommended Reading:**
|
||||
- Progressive Disclosure Pattern for agent instructions
|
||||
- Reward shaping in RL (Ng et al.)
|
||||
- LoRA paper (Hu et al., 2021)
|
||||
|
||||
---
|
||||
|
||||
## Usage Instructions for Agents
|
||||
|
||||
When this skill is loaded:
|
||||
|
||||
1. **Read this entire file** before implementing GRPO training
|
||||
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
|
||||
3. **Use the templates** in `templates/` directory as starting points
|
||||
4. **Reference examples** in `examples/` for task-specific implementations
|
||||
5. **Follow the workflow** sequentially (don't skip steps)
|
||||
6. **Debug incrementally** - add one reward function at a time
|
||||
|
||||
**Critical Reminders:**
|
||||
- Always use multiple reward functions (3-5 is optimal)
|
||||
- Monitor reward metrics, not loss
|
||||
- Test reward functions before training
|
||||
- Start small (num_generations=4), scale up gradually
|
||||
- Save checkpoints frequently (every 100 steps)
|
||||
|
||||
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.
|
||||
|
||||
|
||||
|
||||
228
skills/mlops/grpo-rl-training/templates/basic_grpo_training.py
Normal file
228
skills/mlops/grpo-rl-training/templates/basic_grpo_training.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Basic GRPO Training Template
|
||||
=============================
|
||||
|
||||
A minimal, production-ready template for GRPO training with TRL.
|
||||
Adapt this for your specific task by modifying:
|
||||
1. Dataset loading (get_dataset function)
|
||||
2. Reward functions (reward_*_func)
|
||||
3. System prompt (SYSTEM_PROMPT)
|
||||
4. Hyperparameters (GRPOConfig)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import re
|
||||
from datasets import load_dataset, Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
# ==================== CONFIGURATION ====================
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
OUTPUT_DIR = "outputs/grpo-model"
|
||||
MAX_PROMPT_LENGTH = 256
|
||||
MAX_COMPLETION_LENGTH = 512
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
Respond in the following format:
|
||||
<reasoning>
|
||||
[Your step-by-step thinking]
|
||||
</reasoning>
|
||||
<answer>
|
||||
[Final answer]
|
||||
</answer>
|
||||
"""
|
||||
|
||||
# ==================== DATASET ====================
|
||||
|
||||
def get_dataset(split="train"):
|
||||
"""
|
||||
Load and prepare your dataset.
|
||||
|
||||
Returns: Dataset with columns:
|
||||
- 'prompt': List[Dict] with role/content
|
||||
- 'answer': str (ground truth, optional)
|
||||
"""
|
||||
# Example: GSM8K math dataset
|
||||
data = load_dataset('openai/gsm8k', 'main')[split]
|
||||
|
||||
def process_example(x):
|
||||
# Extract ground truth answer
|
||||
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
|
||||
|
||||
return {
|
||||
'prompt': [
|
||||
{'role': 'system', 'content': SYSTEM_PROMPT},
|
||||
{'role': 'user', 'content': x['question']}
|
||||
],
|
||||
'answer': answer
|
||||
}
|
||||
|
||||
return data.map(process_example)
|
||||
|
||||
# ==================== HELPER FUNCTIONS ====================
|
||||
|
||||
def extract_xml_tag(text: str, tag: str) -> str:
|
||||
"""Extract content between XML tags."""
|
||||
pattern = f'<{tag}>(.*?)</{tag}>'
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
return match.group(1).strip() if match else ""
|
||||
|
||||
def extract_answer(text: str) -> str:
|
||||
"""Extract the final answer from structured output."""
|
||||
return extract_xml_tag(text, 'answer')
|
||||
|
||||
# ==================== REWARD FUNCTIONS ====================
|
||||
|
||||
def correctness_reward_func(prompts, completions, answer, **kwargs):
|
||||
"""
|
||||
Reward correct answers.
|
||||
Weight: 2.0 (highest priority)
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
extracted = [extract_answer(r) for r in responses]
|
||||
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Reward proper XML format.
|
||||
Weight: 0.5
|
||||
"""
|
||||
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
|
||||
|
||||
def incremental_format_reward_func(completions, **kwargs):
|
||||
"""
|
||||
Incremental reward for partial format compliance.
|
||||
Weight: up to 0.5
|
||||
"""
|
||||
responses = [comp[0]['content'] for comp in completions]
|
||||
rewards = []
|
||||
|
||||
for r in responses:
|
||||
score = 0.0
|
||||
if '<reasoning>' in r:
|
||||
score += 0.125
|
||||
if '</reasoning>' in r:
|
||||
score += 0.125
|
||||
if '<answer>' in r:
|
||||
score += 0.125
|
||||
if '</answer>' in r:
|
||||
score += 0.125
|
||||
|
||||
# Penalize extra content after closing tag
|
||||
if '</answer>' in r:
|
||||
extra = r.split('</answer>')[-1].strip()
|
||||
score -= len(extra) * 0.001
|
||||
|
||||
rewards.append(score)
|
||||
|
||||
return rewards
|
||||
|
||||
# ==================== MODEL SETUP ====================
|
||||
|
||||
def setup_model_and_tokenizer():
|
||||
"""Load model and tokenizer with optimizations."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_NAME,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
device_map="auto"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def get_peft_config():
|
||||
"""LoRA configuration for parameter-efficient training."""
|
||||
return LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
target_modules=[
|
||||
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj"
|
||||
],
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=0.05,
|
||||
)
|
||||
|
||||
# ==================== TRAINING ====================
|
||||
|
||||
def main():
|
||||
"""Main training function."""
|
||||
|
||||
# Load data
|
||||
print("Loading dataset...")
|
||||
dataset = get_dataset()
|
||||
print(f"Dataset size: {len(dataset)}")
|
||||
|
||||
# Setup model
|
||||
print("Loading model...")
|
||||
model, tokenizer = setup_model_and_tokenizer()
|
||||
|
||||
# Training configuration
|
||||
training_args = GRPOConfig(
|
||||
output_dir=OUTPUT_DIR,
|
||||
run_name="grpo-training",
|
||||
|
||||
# Learning rate
|
||||
learning_rate=5e-6,
|
||||
adam_beta1=0.9,
|
||||
adam_beta2=0.99,
|
||||
weight_decay=0.1,
|
||||
warmup_ratio=0.1,
|
||||
lr_scheduler_type='cosine',
|
||||
|
||||
# Batch settings
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
|
||||
# GRPO specific
|
||||
num_generations=8,
|
||||
max_prompt_length=MAX_PROMPT_LENGTH,
|
||||
max_completion_length=MAX_COMPLETION_LENGTH,
|
||||
|
||||
# Training duration
|
||||
num_train_epochs=1,
|
||||
|
||||
# Optimization
|
||||
bf16=True,
|
||||
optim="adamw_8bit",
|
||||
max_grad_norm=0.1,
|
||||
|
||||
# Logging
|
||||
logging_steps=1,
|
||||
save_steps=100,
|
||||
report_to="wandb", # Change to "none" to disable logging
|
||||
)
|
||||
|
||||
# Initialize trainer
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
reward_funcs=[
|
||||
incremental_format_reward_func,
|
||||
format_reward_func,
|
||||
correctness_reward_func,
|
||||
],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=get_peft_config(),
|
||||
)
|
||||
|
||||
# Train
|
||||
print("Starting training...")
|
||||
trainer.train()
|
||||
|
||||
# Save final model
|
||||
print(f"Saving model to {OUTPUT_DIR}/final")
|
||||
trainer.save_model(f"{OUTPUT_DIR}/final")
|
||||
|
||||
print("Training complete!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
575
skills/mlops/guidance/SKILL.md
Normal file
575
skills/mlops/guidance/SKILL.md
Normal file
@@ -0,0 +1,575 @@
|
||||
---
|
||||
name: guidance
|
||||
description: Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [guidance, transformers]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [Prompt Engineering, Guidance, Constrained Generation, Structured Output, JSON Validation, Grammar, Microsoft Research, Format Enforcement, Multi-Step Workflows]
|
||||
|
||||
---
|
||||
|
||||
# Guidance: Constrained LLM Generation
|
||||
|
||||
## When to Use This Skill
|
||||
|
||||
Use Guidance when you need to:
|
||||
- **Control LLM output syntax** with regex or grammars
|
||||
- **Guarantee valid JSON/XML/code** generation
|
||||
- **Reduce latency** vs traditional prompting approaches
|
||||
- **Enforce structured formats** (dates, emails, IDs, etc.)
|
||||
- **Build multi-step workflows** with Pythonic control flow
|
||||
- **Prevent invalid outputs** through grammatical constraints
|
||||
|
||||
**GitHub Stars**: 18,000+ | **From**: Microsoft Research
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base installation
|
||||
pip install guidance
|
||||
|
||||
# With specific backends
|
||||
pip install guidance[transformers] # Hugging Face models
|
||||
pip install guidance[llama_cpp] # llama.cpp models
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Example: Structured Generation
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
# Load model (supports OpenAI, Transformers, llama.cpp)
|
||||
lm = models.OpenAI("gpt-4")
|
||||
|
||||
# Generate with constraints
|
||||
result = lm + "The capital of France is " + gen("capital", max_tokens=5)
|
||||
|
||||
print(result["capital"]) # "Paris"
|
||||
```
|
||||
|
||||
### With Anthropic Claude
|
||||
|
||||
```python
|
||||
from guidance import models, gen, system, user, assistant
|
||||
|
||||
# Configure Claude
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use context managers for chat format
|
||||
with system():
|
||||
lm += "You are a helpful assistant."
|
||||
|
||||
with user():
|
||||
lm += "What is the capital of France?"
|
||||
|
||||
with assistant():
|
||||
lm += gen(max_tokens=20)
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Context Managers
|
||||
|
||||
Guidance uses Pythonic context managers for chat-style interactions.
|
||||
|
||||
```python
|
||||
from guidance import system, user, assistant, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# System message
|
||||
with system():
|
||||
lm += "You are a JSON generation expert."
|
||||
|
||||
# User message
|
||||
with user():
|
||||
lm += "Generate a person object with name and age."
|
||||
|
||||
# Assistant response
|
||||
with assistant():
|
||||
lm += gen("response", max_tokens=100)
|
||||
|
||||
print(lm["response"])
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Natural chat flow
|
||||
- Clear role separation
|
||||
- Easy to read and maintain
|
||||
|
||||
### 2. Constrained Generation
|
||||
|
||||
Guidance ensures outputs match specified patterns using regex or grammars.
|
||||
|
||||
#### Regex Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Constrain to valid email format
|
||||
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
|
||||
# Constrain to date format (YYYY-MM-DD)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
|
||||
|
||||
# Constrain to phone number
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
|
||||
|
||||
print(lm["email"]) # Guaranteed valid email
|
||||
print(lm["date"]) # Guaranteed YYYY-MM-DD format
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
- Regex converted to grammar at token level
|
||||
- Invalid tokens filtered during generation
|
||||
- Model can only produce matching outputs
|
||||
|
||||
#### Selection Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Constrain to specific choices
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
|
||||
# Multiple-choice selection
|
||||
lm += "Best answer: " + select(
|
||||
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
|
||||
name="answer"
|
||||
)
|
||||
|
||||
print(lm["sentiment"]) # One of: positive, negative, neutral
|
||||
print(lm["answer"]) # One of: A, B, C, or D
|
||||
```
|
||||
|
||||
### 3. Token Healing
|
||||
|
||||
Guidance automatically "heals" token boundaries between prompt and generation.
|
||||
|
||||
**Problem:** Tokenization creates unnatural boundaries.
|
||||
|
||||
```python
|
||||
# Without token healing
|
||||
prompt = "The capital of France is "
|
||||
# Last token: " is "
|
||||
# First generated token might be " Par" (with leading space)
|
||||
# Result: "The capital of France is Paris" (double space!)
|
||||
```
|
||||
|
||||
**Solution:** Guidance backs up one token and regenerates.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Token healing enabled by default
|
||||
lm += "The capital of France is " + gen("capital", max_tokens=5)
|
||||
# Result: "The capital of France is Paris" (correct spacing)
|
||||
```
|
||||
|
||||
**Benefits:**
|
||||
- Natural text boundaries
|
||||
- No awkward spacing issues
|
||||
- Better model performance (sees natural token sequences)
|
||||
|
||||
### 4. Grammar-Based Generation
|
||||
|
||||
Define complex structures using context-free grammars.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# JSON grammar (simplified)
|
||||
json_grammar = """
|
||||
{
|
||||
"name": <gen name regex="[A-Za-z ]+" max_tokens=20>,
|
||||
"age": <gen age regex="[0-9]+" max_tokens=3>,
|
||||
"email": <gen email regex="[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}" max_tokens=50>
|
||||
}
|
||||
"""
|
||||
|
||||
# Generate valid JSON
|
||||
lm += gen("person", grammar=json_grammar)
|
||||
|
||||
print(lm["person"]) # Guaranteed valid JSON structure
|
||||
```
|
||||
|
||||
**Use cases:**
|
||||
- Complex structured outputs
|
||||
- Nested data structures
|
||||
- Programming language syntax
|
||||
- Domain-specific languages
|
||||
|
||||
### 5. Guidance Functions
|
||||
|
||||
Create reusable generation patterns with the `@guidance` decorator.
|
||||
|
||||
```python
|
||||
from guidance import guidance, gen, models
|
||||
|
||||
@guidance
|
||||
def generate_person(lm):
|
||||
"""Generate a person with name and age."""
|
||||
lm += "Name: " + gen("name", max_tokens=20, stop="\n")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+", max_tokens=3)
|
||||
return lm
|
||||
|
||||
# Use the function
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_person(lm)
|
||||
|
||||
print(lm["name"])
|
||||
print(lm["age"])
|
||||
```
|
||||
|
||||
**Stateful Functions:**
|
||||
|
||||
```python
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question, tools, max_rounds=5):
|
||||
"""ReAct agent with tool use."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for i in range(max_rounds):
|
||||
# Thought
|
||||
lm += f"Thought {i+1}: " + gen("thought", stop="\n")
|
||||
|
||||
# Action
|
||||
lm += "\nAction: " + select(list(tools.keys()), name="action")
|
||||
|
||||
# Execute tool
|
||||
tool_result = tools[lm["action"]]()
|
||||
lm += f"\nObservation: {tool_result}\n\n"
|
||||
|
||||
# Check if done
|
||||
lm += "Done? " + select(["Yes", "No"], name="done")
|
||||
if lm["done"] == "Yes":
|
||||
break
|
||||
|
||||
# Final answer
|
||||
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
|
||||
return lm
|
||||
```
|
||||
|
||||
## Backend Configuration
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key" # Or set ANTHROPIC_API_KEY env var
|
||||
)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
```python
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o-mini",
|
||||
api_key="your-api-key" # Or set OPENAI_API_KEY env var
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (Transformers)
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda" # Or "cpu"
|
||||
)
|
||||
```
|
||||
|
||||
### Local Models (llama.cpp)
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35
|
||||
)
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Pattern 1: JSON Generation
|
||||
|
||||
```python
|
||||
from guidance import models, gen, system, user, assistant
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You generate valid JSON."
|
||||
|
||||
with user():
|
||||
lm += "Generate a user profile with name, age, and email."
|
||||
|
||||
with assistant():
|
||||
lm += """{
|
||||
"name": """ + gen("name", regex=r'"[A-Za-z ]+"', max_tokens=30) + """,
|
||||
"age": """ + gen("age", regex=r"[0-9]+", max_tokens=3) + """,
|
||||
"email": """ + gen("email", regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"', max_tokens=50) + """
|
||||
}"""
|
||||
|
||||
print(lm) # Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Pattern 2: Classification
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
text = "This product is amazing! I love it."
|
||||
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]+", max_tokens=3) + "%"
|
||||
|
||||
print(f"Sentiment: {lm['sentiment']}")
|
||||
print(f"Confidence: {lm['confidence']}%")
|
||||
```
|
||||
|
||||
### Pattern 3: Multi-Step Reasoning
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def chain_of_thought(lm, question):
|
||||
"""Generate answer with step-by-step reasoning."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
# Generate multiple reasoning steps
|
||||
for i in range(3):
|
||||
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Final answer
|
||||
lm += "\nTherefore, the answer is: " + gen("answer", max_tokens=50)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = chain_of_thought(lm, "What is 15% of 200?")
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Pattern 4: ReAct Agent
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select, guidance
|
||||
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question):
|
||||
"""ReAct agent with tool use."""
|
||||
tools = {
|
||||
"calculator": lambda expr: eval(expr),
|
||||
"search": lambda query: f"Search results for: {query}",
|
||||
}
|
||||
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for round in range(5):
|
||||
# Thought
|
||||
lm += f"Thought: " + gen("thought", stop="\n") + "\n"
|
||||
|
||||
# Action selection
|
||||
lm += "Action: " + select(["calculator", "search", "answer"], name="action")
|
||||
|
||||
if lm["action"] == "answer":
|
||||
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
|
||||
break
|
||||
|
||||
# Action input
|
||||
lm += "\nAction Input: " + gen("action_input", stop="\n") + "\n"
|
||||
|
||||
# Execute tool
|
||||
if lm["action"] in tools:
|
||||
result = tools[lm["action"]](lm["action_input"])
|
||||
lm += f"Observation: {result}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = react_agent(lm, "What is 25 * 4 + 10?")
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Pattern 5: Data Extraction
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract structured entities from text."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
# Extract person
|
||||
lm += "Person: " + gen("person", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
# Extract organization
|
||||
lm += "Organization: " + gen("organization", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
# Extract date
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}", max_tokens=10) + "\n"
|
||||
|
||||
# Extract location
|
||||
lm += "Location: " + gen("location", stop="\n", max_tokens=30) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = "Tim Cook announced at Apple Park on 2024-09-15 in Cupertino."
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = extract_entities(lm, text)
|
||||
|
||||
print(f"Person: {lm['person']}")
|
||||
print(f"Organization: {lm['organization']}")
|
||||
print(f"Date: {lm['date']}")
|
||||
print(f"Location: {lm['location']}")
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Regex for Format Validation
|
||||
|
||||
```python
|
||||
# ✅ Good: Regex ensures valid format
|
||||
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
|
||||
# ❌ Bad: Free generation may produce invalid emails
|
||||
lm += "Email: " + gen("email", max_tokens=50)
|
||||
```
|
||||
|
||||
### 2. Use select() for Fixed Categories
|
||||
|
||||
```python
|
||||
# ✅ Good: Guaranteed valid category
|
||||
lm += "Status: " + select(["pending", "approved", "rejected"], name="status")
|
||||
|
||||
# ❌ Bad: May generate typos or invalid values
|
||||
lm += "Status: " + gen("status", max_tokens=20)
|
||||
```
|
||||
|
||||
### 3. Leverage Token Healing
|
||||
|
||||
```python
|
||||
# Token healing is enabled by default
|
||||
# No special action needed - just concatenate naturally
|
||||
lm += "The capital is " + gen("capital") # Automatic healing
|
||||
```
|
||||
|
||||
### 4. Use stop Sequences
|
||||
|
||||
```python
|
||||
# ✅ Good: Stop at newline for single-line outputs
|
||||
lm += "Name: " + gen("name", stop="\n")
|
||||
|
||||
# ❌ Bad: May generate multiple lines
|
||||
lm += "Name: " + gen("name", max_tokens=50)
|
||||
```
|
||||
|
||||
### 5. Create Reusable Functions
|
||||
|
||||
```python
|
||||
# ✅ Good: Reusable pattern
|
||||
@guidance
|
||||
def generate_person(lm):
|
||||
lm += "Name: " + gen("name", stop="\n")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
|
||||
return lm
|
||||
|
||||
# Use multiple times
|
||||
lm = generate_person(lm)
|
||||
lm += "\n\n"
|
||||
lm = generate_person(lm)
|
||||
```
|
||||
|
||||
### 6. Balance Constraints
|
||||
|
||||
```python
|
||||
# ✅ Good: Reasonable constraints
|
||||
lm += gen("name", regex=r"[A-Za-z ]+", max_tokens=30)
|
||||
|
||||
# ❌ Too strict: May fail or be very slow
|
||||
lm += gen("name", regex=r"^(John|Jane)$", max_tokens=10)
|
||||
```
|
||||
|
||||
## Comparison to Alternatives
|
||||
|
||||
| Feature | Guidance | Instructor | Outlines | LMQL |
|
||||
|---------|----------|------------|----------|------|
|
||||
| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes |
|
||||
| Grammar Support | ✅ CFG | ❌ No | ✅ CFG | ✅ CFG |
|
||||
| Pydantic Validation | ❌ No | ✅ Yes | ✅ Yes | ❌ No |
|
||||
| Token Healing | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
|
||||
| Local Models | ✅ Yes | ⚠️ Limited | ✅ Yes | ✅ Yes |
|
||||
| API Models | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes |
|
||||
| Pythonic Syntax | ✅ Yes | ✅ Yes | ✅ Yes | ❌ SQL-like |
|
||||
| Learning Curve | Low | Low | Medium | High |
|
||||
|
||||
**When to choose Guidance:**
|
||||
- Need regex/grammar constraints
|
||||
- Want token healing
|
||||
- Building complex workflows with control flow
|
||||
- Using local models (Transformers, llama.cpp)
|
||||
- Prefer Pythonic syntax
|
||||
|
||||
**When to choose alternatives:**
|
||||
- Instructor: Need Pydantic validation with automatic retrying
|
||||
- Outlines: Need JSON schema validation
|
||||
- LMQL: Prefer declarative query syntax
|
||||
|
||||
## Performance Characteristics
|
||||
|
||||
**Latency Reduction:**
|
||||
- 30-50% faster than traditional prompting for constrained outputs
|
||||
- Token healing reduces unnecessary regeneration
|
||||
- Grammar constraints prevent invalid token generation
|
||||
|
||||
**Memory Usage:**
|
||||
- Minimal overhead vs unconstrained generation
|
||||
- Grammar compilation cached after first use
|
||||
- Efficient token filtering at inference time
|
||||
|
||||
**Token Efficiency:**
|
||||
- Prevents wasted tokens on invalid outputs
|
||||
- No need for retry loops
|
||||
- Direct path to valid outputs
|
||||
|
||||
## Resources
|
||||
|
||||
- **Documentation**: https://guidance.readthedocs.io
|
||||
- **GitHub**: https://github.com/guidance-ai/guidance (18k+ stars)
|
||||
- **Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
|
||||
- **Discord**: Community support available
|
||||
|
||||
## See Also
|
||||
|
||||
- `references/constraints.md` - Comprehensive regex and grammar patterns
|
||||
- `references/backends.md` - Backend-specific configuration
|
||||
- `references/examples.md` - Production-ready examples
|
||||
|
||||
|
||||
554
skills/mlops/guidance/references/backends.md
Normal file
554
skills/mlops/guidance/references/backends.md
Normal file
@@ -0,0 +1,554 @@
|
||||
# Backend Configuration Guide
|
||||
|
||||
Complete guide to configuring Guidance with different LLM backends.
|
||||
|
||||
## Table of Contents
|
||||
- API-Based Models (Anthropic, OpenAI)
|
||||
- Local Models (Transformers, llama.cpp)
|
||||
- Backend Comparison
|
||||
- Performance Tuning
|
||||
- Advanced Configuration
|
||||
|
||||
## API-Based Models
|
||||
|
||||
### Anthropic Claude
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
# Using environment variable
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
# Reads ANTHROPIC_API_KEY from environment
|
||||
|
||||
# Explicit API key
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key-here"
|
||||
)
|
||||
```
|
||||
|
||||
#### Available Models
|
||||
|
||||
```python
|
||||
# Claude 3.5 Sonnet (Latest, recommended)
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Claude 3.7 Sonnet (Fast, cost-effective)
|
||||
lm = models.Anthropic("claude-sonnet-3.7-20250219")
|
||||
|
||||
# Claude 3 Opus (Most capable)
|
||||
lm = models.Anthropic("claude-3-opus-20240229")
|
||||
|
||||
# Claude 3.5 Haiku (Fastest, cheapest)
|
||||
lm = models.Anthropic("claude-3-5-haiku-20241022")
|
||||
```
|
||||
|
||||
#### Configuration Options
|
||||
|
||||
```python
|
||||
lm = models.Anthropic(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
api_key="your-api-key",
|
||||
max_tokens=4096, # Max tokens to generate
|
||||
temperature=0.7, # Sampling temperature (0-1)
|
||||
top_p=0.9, # Nucleus sampling
|
||||
timeout=30, # Request timeout (seconds)
|
||||
max_retries=3 # Retry failed requests
|
||||
)
|
||||
```
|
||||
|
||||
#### With Context Managers
|
||||
|
||||
```python
|
||||
from guidance import models, system, user, assistant, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You are a helpful assistant."
|
||||
|
||||
with user():
|
||||
lm += "What is the capital of France?"
|
||||
|
||||
with assistant():
|
||||
lm += gen(max_tokens=50)
|
||||
|
||||
print(lm)
|
||||
```
|
||||
|
||||
### OpenAI
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
# Using environment variable
|
||||
lm = models.OpenAI("gpt-4o")
|
||||
# Reads OPENAI_API_KEY from environment
|
||||
|
||||
# Explicit API key
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o",
|
||||
api_key="your-api-key-here"
|
||||
)
|
||||
```
|
||||
|
||||
#### Available Models
|
||||
|
||||
```python
|
||||
# GPT-4o (Latest, multimodal)
|
||||
lm = models.OpenAI("gpt-4o")
|
||||
|
||||
# GPT-4o Mini (Fast, cost-effective)
|
||||
lm = models.OpenAI("gpt-4o-mini")
|
||||
|
||||
# GPT-4 Turbo
|
||||
lm = models.OpenAI("gpt-4-turbo")
|
||||
|
||||
# GPT-3.5 Turbo (Cheapest)
|
||||
lm = models.OpenAI("gpt-3.5-turbo")
|
||||
```
|
||||
|
||||
#### Configuration Options
|
||||
|
||||
```python
|
||||
lm = models.OpenAI(
|
||||
model="gpt-4o-mini",
|
||||
api_key="your-api-key",
|
||||
max_tokens=2048,
|
||||
temperature=0.7,
|
||||
top_p=1.0,
|
||||
frequency_penalty=0.0,
|
||||
presence_penalty=0.0,
|
||||
timeout=30
|
||||
)
|
||||
```
|
||||
|
||||
#### Chat Format
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.OpenAI("gpt-4o-mini")
|
||||
|
||||
# OpenAI uses chat format
|
||||
lm += [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "What is 2+2?"}
|
||||
]
|
||||
|
||||
# Generate response
|
||||
lm += gen(max_tokens=50)
|
||||
```
|
||||
|
||||
### Azure OpenAI
|
||||
|
||||
```python
|
||||
from guidance import models
|
||||
|
||||
lm = models.AzureOpenAI(
|
||||
model="gpt-4o",
|
||||
azure_endpoint="https://your-resource.openai.azure.com/",
|
||||
api_key="your-azure-api-key",
|
||||
api_version="2024-02-15-preview",
|
||||
deployment_name="your-deployment-name"
|
||||
)
|
||||
```
|
||||
|
||||
## Local Models
|
||||
|
||||
### Transformers (Hugging Face)
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Load model from Hugging Face
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
```
|
||||
|
||||
#### GPU Configuration
|
||||
|
||||
```python
|
||||
# Use GPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Use specific GPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda:0" # GPU 0
|
||||
)
|
||||
|
||||
# Use CPU
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cpu"
|
||||
)
|
||||
```
|
||||
|
||||
#### Advanced Configuration
|
||||
|
||||
```python
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda",
|
||||
torch_dtype="float16", # Use FP16 (faster, less memory)
|
||||
load_in_8bit=True, # 8-bit quantization
|
||||
max_memory={0: "20GB"}, # GPU memory limit
|
||||
offload_folder="./offload" # Offload to disk if needed
|
||||
)
|
||||
```
|
||||
|
||||
#### Popular Models
|
||||
|
||||
```python
|
||||
# Phi-4 (Microsoft)
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
lm = Transformers("microsoft/Phi-3-medium-4k-instruct")
|
||||
|
||||
# Llama 3 (Meta)
|
||||
lm = Transformers("meta-llama/Llama-3.1-8B-Instruct")
|
||||
lm = Transformers("meta-llama/Llama-3.1-70B-Instruct")
|
||||
|
||||
# Mistral (Mistral AI)
|
||||
lm = Transformers("mistralai/Mistral-7B-Instruct-v0.3")
|
||||
lm = Transformers("mistralai/Mixtral-8x7B-Instruct-v0.1")
|
||||
|
||||
# Qwen (Alibaba)
|
||||
lm = Transformers("Qwen/Qwen2.5-7B-Instruct")
|
||||
|
||||
# Gemma (Google)
|
||||
lm = Transformers("google/gemma-2-9b-it")
|
||||
```
|
||||
|
||||
#### Generation Configuration
|
||||
|
||||
```python
|
||||
lm = Transformers(
|
||||
"microsoft/Phi-4-mini-instruct",
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Configure generation
|
||||
from guidance import gen
|
||||
|
||||
result = lm + gen(
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
top_k=50,
|
||||
repetition_penalty=1.1
|
||||
)
|
||||
```
|
||||
|
||||
### llama.cpp
|
||||
|
||||
#### Basic Setup
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
# Load GGUF model
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096 # Context window
|
||||
)
|
||||
```
|
||||
|
||||
#### GPU Configuration
|
||||
|
||||
```python
|
||||
# Use GPU acceleration
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=35, # Offload 35 layers to GPU
|
||||
n_threads=8 # CPU threads for remaining layers
|
||||
)
|
||||
|
||||
# Full GPU offload
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.gguf",
|
||||
n_ctx=4096,
|
||||
n_gpu_layers=-1 # Offload all layers
|
||||
)
|
||||
```
|
||||
|
||||
#### Advanced Configuration
|
||||
|
||||
```python
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/llama-3.1-8b-instruct.Q4_K_M.gguf",
|
||||
n_ctx=8192, # Context window (tokens)
|
||||
n_gpu_layers=35, # GPU layers
|
||||
n_threads=8, # CPU threads
|
||||
n_batch=512, # Batch size for prompt processing
|
||||
use_mmap=True, # Memory-map the model file
|
||||
use_mlock=False, # Lock model in RAM
|
||||
seed=42, # Random seed
|
||||
verbose=False # Suppress verbose output
|
||||
)
|
||||
```
|
||||
|
||||
#### Quantized Models
|
||||
|
||||
```python
|
||||
# Q4_K_M (4-bit, recommended for most cases)
|
||||
lm = LlamaCpp("/path/to/model.Q4_K_M.gguf")
|
||||
|
||||
# Q5_K_M (5-bit, better quality)
|
||||
lm = LlamaCpp("/path/to/model.Q5_K_M.gguf")
|
||||
|
||||
# Q8_0 (8-bit, high quality)
|
||||
lm = LlamaCpp("/path/to/model.Q8_0.gguf")
|
||||
|
||||
# F16 (16-bit float, highest quality)
|
||||
lm = LlamaCpp("/path/to/model.F16.gguf")
|
||||
```
|
||||
|
||||
#### Popular GGUF Models
|
||||
|
||||
```python
|
||||
# Llama 3.1
|
||||
lm = LlamaCpp("llama-3.1-8b-instruct.Q4_K_M.gguf")
|
||||
|
||||
# Mistral
|
||||
lm = LlamaCpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf")
|
||||
|
||||
# Phi-4
|
||||
lm = LlamaCpp("phi-4-mini-instruct.Q4_K_M.gguf")
|
||||
```
|
||||
|
||||
## Backend Comparison
|
||||
|
||||
### Feature Matrix
|
||||
|
||||
| Feature | Anthropic | OpenAI | Transformers | llama.cpp |
|
||||
|---------|-----------|--------|--------------|-----------|
|
||||
| Constrained Generation | ✅ Full | ✅ Full | ✅ Full | ✅ Full |
|
||||
| Token Healing | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
|
||||
| Streaming | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
|
||||
| GPU Support | N/A | N/A | ✅ Yes | ✅ Yes |
|
||||
| Quantization | N/A | N/A | ✅ Yes | ✅ Yes |
|
||||
| Cost | $$$ | $$$ | Free | Free |
|
||||
| Latency | Low | Low | Medium | Low |
|
||||
| Setup Difficulty | Easy | Easy | Medium | Medium |
|
||||
|
||||
### Performance Characteristics
|
||||
|
||||
**Anthropic Claude:**
|
||||
- **Latency**: 200-500ms (API call)
|
||||
- **Throughput**: Limited by API rate limits
|
||||
- **Cost**: $3-15 per 1M input tokens
|
||||
- **Best for**: Production systems, high-quality outputs
|
||||
|
||||
**OpenAI:**
|
||||
- **Latency**: 200-400ms (API call)
|
||||
- **Throughput**: Limited by API rate limits
|
||||
- **Cost**: $0.15-30 per 1M input tokens
|
||||
- **Best for**: Cost-sensitive production, gpt-4o-mini
|
||||
|
||||
**Transformers:**
|
||||
- **Latency**: 50-200ms (local inference)
|
||||
- **Throughput**: GPU-dependent (10-100 tokens/sec)
|
||||
- **Cost**: Hardware cost only
|
||||
- **Best for**: Privacy-sensitive, high-volume, experimentation
|
||||
|
||||
**llama.cpp:**
|
||||
- **Latency**: 30-150ms (local inference)
|
||||
- **Throughput**: Hardware-dependent (20-150 tokens/sec)
|
||||
- **Cost**: Hardware cost only
|
||||
- **Best for**: Edge deployment, Apple Silicon, CPU inference
|
||||
|
||||
### Memory Requirements
|
||||
|
||||
**Transformers (FP16):**
|
||||
- 7B model: ~14GB GPU VRAM
|
||||
- 13B model: ~26GB GPU VRAM
|
||||
- 70B model: ~140GB GPU VRAM (multi-GPU)
|
||||
|
||||
**llama.cpp (Q4_K_M):**
|
||||
- 7B model: ~4.5GB RAM
|
||||
- 13B model: ~8GB RAM
|
||||
- 70B model: ~40GB RAM
|
||||
|
||||
**Optimization Tips:**
|
||||
- Use quantized models (Q4_K_M) for lower memory
|
||||
- Use GPU offloading for faster inference
|
||||
- Use CPU inference for smaller models (<7B)
|
||||
|
||||
## Performance Tuning
|
||||
|
||||
### API Models (Anthropic, OpenAI)
|
||||
|
||||
#### Reduce Latency
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Use lower max_tokens (faster response)
|
||||
lm += gen(max_tokens=100) # Instead of 1000
|
||||
|
||||
# Use streaming (perceived latency reduction)
|
||||
for chunk in lm.stream(gen(max_tokens=500)):
|
||||
print(chunk, end="", flush=True)
|
||||
```
|
||||
|
||||
#### Reduce Cost
|
||||
|
||||
```python
|
||||
# Use cheaper models
|
||||
lm = models.Anthropic("claude-3-5-haiku-20241022") # vs Sonnet
|
||||
lm = models.OpenAI("gpt-4o-mini") # vs gpt-4o
|
||||
|
||||
# Reduce context size
|
||||
# - Keep prompts concise
|
||||
# - Avoid large few-shot examples
|
||||
# - Use max_tokens limits
|
||||
```
|
||||
|
||||
### Local Models (Transformers, llama.cpp)
|
||||
|
||||
#### Optimize GPU Usage
|
||||
|
||||
```python
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Use FP16 for 2x speedup
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
torch_dtype="float16"
|
||||
)
|
||||
|
||||
# Use 8-bit quantization for 4x memory reduction
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
load_in_8bit=True
|
||||
)
|
||||
|
||||
# Use flash attention (requires flash-attn package)
|
||||
lm = Transformers(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
device="cuda",
|
||||
use_flash_attention_2=True
|
||||
)
|
||||
```
|
||||
|
||||
#### Optimize llama.cpp
|
||||
|
||||
```python
|
||||
from guidance.models import LlamaCpp
|
||||
|
||||
# Maximize GPU layers
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_gpu_layers=-1 # All layers on GPU
|
||||
)
|
||||
|
||||
# Optimize batch size
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_batch=512, # Larger batch = faster prompt processing
|
||||
n_gpu_layers=-1
|
||||
)
|
||||
|
||||
# Use Metal (Apple Silicon)
|
||||
lm = LlamaCpp(
|
||||
model_path="/path/to/model.Q4_K_M.gguf",
|
||||
n_gpu_layers=-1, # Use Metal GPU acceleration
|
||||
use_mmap=True
|
||||
)
|
||||
```
|
||||
|
||||
#### Batch Processing
|
||||
|
||||
```python
|
||||
# Process multiple requests efficiently
|
||||
requests = [
|
||||
"What is 2+2?",
|
||||
"What is the capital of France?",
|
||||
"What is photosynthesis?"
|
||||
]
|
||||
|
||||
# Bad: Sequential processing
|
||||
for req in requests:
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
lm += req + gen(max_tokens=50)
|
||||
|
||||
# Good: Reuse loaded model
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct")
|
||||
for req in requests:
|
||||
lm += req + gen(max_tokens=50)
|
||||
```
|
||||
|
||||
## Advanced Configuration
|
||||
|
||||
### Custom Model Configurations
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from guidance.models import Transformers
|
||||
|
||||
# Load custom model
|
||||
tokenizer = AutoTokenizer.from_pretrained("your-model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"your-model",
|
||||
device_map="auto",
|
||||
torch_dtype="float16"
|
||||
)
|
||||
|
||||
# Use with Guidance
|
||||
lm = Transformers(model=model, tokenizer=tokenizer)
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# API keys
|
||||
export ANTHROPIC_API_KEY="sk-ant-..."
|
||||
export OPENAI_API_KEY="sk-..."
|
||||
|
||||
# Transformers cache
|
||||
export HF_HOME="/path/to/cache"
|
||||
export TRANSFORMERS_CACHE="/path/to/cache"
|
||||
|
||||
# GPU selection
|
||||
export CUDA_VISIBLE_DEVICES=0,1 # Use GPU 0 and 1
|
||||
```
|
||||
|
||||
### Debugging
|
||||
|
||||
```python
|
||||
# Enable verbose logging
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
# Check backend info
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
print(f"Model: {lm.model_name}")
|
||||
print(f"Backend: {lm.backend}")
|
||||
|
||||
# Check GPU usage (Transformers)
|
||||
lm = Transformers("microsoft/Phi-4-mini-instruct", device="cuda")
|
||||
print(f"Device: {lm.device}")
|
||||
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Anthropic Docs**: https://docs.anthropic.com
|
||||
- **OpenAI Docs**: https://platform.openai.com/docs
|
||||
- **Hugging Face Models**: https://huggingface.co/models
|
||||
- **llama.cpp**: https://github.com/ggerganov/llama.cpp
|
||||
- **GGUF Models**: https://huggingface.co/models?library=gguf
|
||||
674
skills/mlops/guidance/references/constraints.md
Normal file
674
skills/mlops/guidance/references/constraints.md
Normal file
@@ -0,0 +1,674 @@
|
||||
# Comprehensive Constraint Patterns
|
||||
|
||||
Guide to regex constraints, grammar-based generation, and token healing in Guidance.
|
||||
|
||||
## Table of Contents
|
||||
- Regex Constraints
|
||||
- Grammar-Based Generation
|
||||
- Token Healing
|
||||
- Selection Constraints
|
||||
- Complex Patterns
|
||||
- Performance Optimization
|
||||
|
||||
## Regex Constraints
|
||||
|
||||
### Basic Patterns
|
||||
|
||||
#### Numeric Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Integer (positive)
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+")
|
||||
|
||||
# Integer (with negatives)
|
||||
lm += "Temperature: " + gen("temp", regex=r"-?[0-9]+")
|
||||
|
||||
# Float (positive)
|
||||
lm += "Price: $" + gen("price", regex=r"[0-9]+\.[0-9]{2}")
|
||||
|
||||
# Float (with negatives and optional decimals)
|
||||
lm += "Value: " + gen("value", regex=r"-?[0-9]+(\.[0-9]+)?")
|
||||
|
||||
# Percentage (0-100)
|
||||
lm += "Progress: " + gen("progress", regex=r"(100|[0-9]{1,2})")
|
||||
|
||||
# Range (1-5 stars)
|
||||
lm += "Rating: " + gen("rating", regex=r"[1-5]") + " stars"
|
||||
```
|
||||
|
||||
#### Text Constraints
|
||||
|
||||
```python
|
||||
# Alphabetic only
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z]+")
|
||||
|
||||
# Alphabetic with spaces
|
||||
lm += "Full Name: " + gen("full_name", regex=r"[A-Za-z ]+")
|
||||
|
||||
# Alphanumeric
|
||||
lm += "Username: " + gen("username", regex=r"[A-Za-z0-9_]+")
|
||||
|
||||
# Capitalized words
|
||||
lm += "Title: " + gen("title", regex=r"[A-Z][a-z]+( [A-Z][a-z]+)*")
|
||||
|
||||
# Lowercase only
|
||||
lm += "Code: " + gen("code", regex=r"[a-z0-9-]+")
|
||||
|
||||
# Specific length
|
||||
lm += "ID: " + gen("id", regex=r"[A-Z]{3}-[0-9]{6}") # e.g., "ABC-123456"
|
||||
```
|
||||
|
||||
#### Date and Time Constraints
|
||||
|
||||
```python
|
||||
# Date (YYYY-MM-DD)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
|
||||
|
||||
# Date (MM/DD/YYYY)
|
||||
lm += "Date: " + gen("date_us", regex=r"\d{2}/\d{2}/\d{4}")
|
||||
|
||||
# Time (HH:MM)
|
||||
lm += "Time: " + gen("time", regex=r"\d{2}:\d{2}")
|
||||
|
||||
# Time (HH:MM:SS)
|
||||
lm += "Time: " + gen("time_full", regex=r"\d{2}:\d{2}:\d{2}")
|
||||
|
||||
# ISO 8601 datetime
|
||||
lm += "Timestamp: " + gen(
|
||||
"timestamp",
|
||||
regex=r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z"
|
||||
)
|
||||
|
||||
# Year (YYYY)
|
||||
lm += "Year: " + gen("year", regex=r"(19|20)\d{2}")
|
||||
|
||||
# Month name
|
||||
lm += "Month: " + gen(
|
||||
"month",
|
||||
regex=r"(January|February|March|April|May|June|July|August|September|October|November|December)"
|
||||
)
|
||||
```
|
||||
|
||||
#### Contact Information
|
||||
|
||||
```python
|
||||
# Email
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
|
||||
)
|
||||
|
||||
# Phone (US format)
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
|
||||
|
||||
# Phone (international format)
|
||||
lm += "Phone: " + gen("phone_intl", regex=r"\+[0-9]{1,3}-[0-9]{1,14}")
|
||||
|
||||
# ZIP code (US)
|
||||
lm += "ZIP: " + gen("zip", regex=r"\d{5}(-\d{4})?")
|
||||
|
||||
# Postal code (Canada)
|
||||
lm += "Postal: " + gen("postal", regex=r"[A-Z]\d[A-Z] \d[A-Z]\d")
|
||||
|
||||
# URL
|
||||
lm += "URL: " + gen(
|
||||
"url",
|
||||
regex=r"https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/[a-zA-Z0-9._~:/?#\[\]@!$&'()*+,;=-]*)?"
|
||||
)
|
||||
```
|
||||
|
||||
### Advanced Patterns
|
||||
|
||||
#### JSON Field Constraints
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# String field with quotes
|
||||
lm += '"name": ' + gen("name", regex=r'"[A-Za-z ]+"')
|
||||
|
||||
# Numeric field (no quotes)
|
||||
lm += '"age": ' + gen("age", regex=r"[0-9]+")
|
||||
|
||||
# Boolean field
|
||||
lm += '"active": ' + gen("active", regex=r"(true|false)")
|
||||
|
||||
# Null field
|
||||
lm += '"optional": ' + gen("optional", regex=r"(null|[0-9]+)")
|
||||
|
||||
# Array of strings
|
||||
lm += '"tags": [' + gen(
|
||||
"tags",
|
||||
regex=r'"[a-z]+"(, "[a-z]+")*'
|
||||
) + ']'
|
||||
|
||||
# Complete JSON object
|
||||
lm += """{
|
||||
"name": """ + gen("name", regex=r'"[A-Za-z ]+"') + """,
|
||||
"age": """ + gen("age", regex=r"[0-9]+") + """,
|
||||
"email": """ + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + """
|
||||
}"""
|
||||
```
|
||||
|
||||
#### Code Patterns
|
||||
|
||||
```python
|
||||
# Python variable name
|
||||
lm += "Variable: " + gen("var", regex=r"[a-z_][a-z0-9_]*")
|
||||
|
||||
# Python function name
|
||||
lm += "Function: " + gen("func", regex=r"[a-z_][a-z0-9_]*")
|
||||
|
||||
# Hex color code
|
||||
lm += "Color: #" + gen("color", regex=r"[0-9A-Fa-f]{6}")
|
||||
|
||||
# UUID
|
||||
lm += "UUID: " + gen(
|
||||
"uuid",
|
||||
regex=r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
|
||||
)
|
||||
|
||||
# Git commit hash (short)
|
||||
lm += "Commit: " + gen("commit", regex=r"[0-9a-f]{7}")
|
||||
|
||||
# Semantic version
|
||||
lm += "Version: " + gen("version", regex=r"[0-9]+\.[0-9]+\.[0-9]+")
|
||||
|
||||
# IP address (IPv4)
|
||||
lm += "IP: " + gen(
|
||||
"ip",
|
||||
regex=r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
|
||||
)
|
||||
```
|
||||
|
||||
#### Domain-Specific Patterns
|
||||
|
||||
```python
|
||||
# Credit card number
|
||||
lm += "Card: " + gen("card", regex=r"\d{4}-\d{4}-\d{4}-\d{4}")
|
||||
|
||||
# Social Security Number (US)
|
||||
lm += "SSN: " + gen("ssn", regex=r"\d{3}-\d{2}-\d{4}")
|
||||
|
||||
# ISBN-13
|
||||
lm += "ISBN: " + gen("isbn", regex=r"978-\d{1,5}-\d{1,7}-\d{1,7}-\d")
|
||||
|
||||
# License plate (US)
|
||||
lm += "Plate: " + gen("plate", regex=r"[A-Z]{3}-\d{4}")
|
||||
|
||||
# Currency amount
|
||||
lm += "Amount: $" + gen("amount", regex=r"[0-9]{1,3}(,[0-9]{3})*\.[0-9]{2}")
|
||||
|
||||
# Percentage with decimal
|
||||
lm += "Rate: " + gen("rate", regex=r"[0-9]+\.[0-9]{1,2}%")
|
||||
```
|
||||
|
||||
## Grammar-Based Generation
|
||||
|
||||
### JSON Grammar
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def json_object(lm):
|
||||
"""Generate valid JSON object."""
|
||||
lm += "{\n"
|
||||
|
||||
# Name field (required)
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
|
||||
# Age field (required)
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
|
||||
|
||||
# Email field (required)
|
||||
lm += ' "email": ' + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + ",\n"
|
||||
|
||||
# Active field (required, boolean)
|
||||
lm += ' "active": ' + gen("active", regex=r"(true|false)") + "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = json_object(lm)
|
||||
print(lm) # Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Nested JSON Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def nested_json(lm):
|
||||
"""Generate nested JSON structure."""
|
||||
lm += "{\n"
|
||||
|
||||
# User object
|
||||
lm += ' "user": {\n'
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Address object
|
||||
lm += ' "address": {\n'
|
||||
lm += ' "street": ' + gen("street", regex=r'"[A-Za-z0-9 ]+"') + ",\n"
|
||||
lm += ' "city": ' + gen("city", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "zip": ' + gen("zip", regex=r'"\d{5}"') + "\n"
|
||||
lm += " }\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
```
|
||||
|
||||
### Array Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def json_array(lm, count=3):
|
||||
"""Generate JSON array with fixed count."""
|
||||
lm += "[\n"
|
||||
|
||||
for i in range(count):
|
||||
lm += " {\n"
|
||||
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + "\n"
|
||||
lm += " }"
|
||||
if i < count - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "]"
|
||||
return lm
|
||||
```
|
||||
|
||||
### XML Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def xml_document(lm):
|
||||
"""Generate valid XML document."""
|
||||
lm += '<?xml version="1.0"?>\n'
|
||||
lm += "<person>\n"
|
||||
|
||||
# Name element
|
||||
lm += " <name>" + gen("name", regex=r"[A-Za-z ]+") + "</name>\n"
|
||||
|
||||
# Age element
|
||||
lm += " <age>" + gen("age", regex=r"[0-9]+") + "</age>\n"
|
||||
|
||||
# Email element
|
||||
lm += " <email>" + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
|
||||
) + "</email>\n"
|
||||
|
||||
lm += "</person>"
|
||||
return lm
|
||||
```
|
||||
|
||||
### CSV Grammar
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def csv_row(lm):
|
||||
"""Generate CSV row."""
|
||||
lm += gen("name", regex=r"[A-Za-z ]+") + ","
|
||||
lm += gen("age", regex=r"[0-9]+") + ","
|
||||
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def csv_document(lm, rows=5):
|
||||
"""Generate complete CSV."""
|
||||
# Header
|
||||
lm += "Name,Age,Email\n"
|
||||
|
||||
# Rows
|
||||
for i in range(rows):
|
||||
lm = csv_row(lm)
|
||||
if i < rows - 1:
|
||||
lm += "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Token Healing
|
||||
|
||||
### How Token Healing Works
|
||||
|
||||
**Problem:** Tokenization creates unnatural boundaries.
|
||||
|
||||
```python
|
||||
# Example without token healing
|
||||
prompt = "The capital of France is "
|
||||
# Tokenization: ["The", " capital", " of", " France", " is", " "]
|
||||
# Model sees last token: " "
|
||||
# First generated token might include leading space: " Paris"
|
||||
# Result: "The capital of France is Paris" (double space)
|
||||
```
|
||||
|
||||
**Solution:** Guidance backs up and regenerates the last token.
|
||||
|
||||
```python
|
||||
from guidance import models, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Token healing enabled by default
|
||||
lm += "The capital of France is " + gen("capital", max_tokens=5)
|
||||
|
||||
# Process:
|
||||
# 1. Back up to token before " is "
|
||||
# 2. Regenerate " is" + "capital" together
|
||||
# 3. Result: "The capital of France is Paris" (correct)
|
||||
```
|
||||
|
||||
### Token Healing Examples
|
||||
|
||||
#### Natural Continuations
|
||||
|
||||
```python
|
||||
# Before token healing
|
||||
lm += "The function name is get" + gen("rest")
|
||||
# Might generate: "The function name is get User" (space before User)
|
||||
|
||||
# With token healing
|
||||
lm += "The function name is get" + gen("rest")
|
||||
# Generates: "The function name is getUser" (correct camelCase)
|
||||
```
|
||||
|
||||
#### Code Generation
|
||||
|
||||
```python
|
||||
# Function name completion
|
||||
lm += "def calculate_" + gen("rest", stop="(")
|
||||
# Token healing ensures smooth connection: "calculate_total"
|
||||
|
||||
# Variable name completion
|
||||
lm += "my_" + gen("var_name", regex=r"[a-z_]+")
|
||||
# Token healing ensures: "my_variable_name" (not "my_ variable_name")
|
||||
```
|
||||
|
||||
#### Domain-Specific Terms
|
||||
|
||||
```python
|
||||
# Medical terms
|
||||
lm += "The patient has hyper" + gen("condition")
|
||||
# Token healing helps: "hypertension" (not "hyper tension")
|
||||
|
||||
# Technical terms
|
||||
lm += "Using micro" + gen("tech")
|
||||
# Token healing helps: "microservices" (not "micro services")
|
||||
```
|
||||
|
||||
### Disabling Token Healing
|
||||
|
||||
```python
|
||||
# Disable token healing if needed (rare)
|
||||
lm += gen("text", token_healing=False)
|
||||
```
|
||||
|
||||
## Selection Constraints
|
||||
|
||||
### Basic Selection
|
||||
|
||||
```python
|
||||
from guidance import models, select
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
# Simple selection
|
||||
lm += "Status: " + select(["active", "inactive", "pending"], name="status")
|
||||
|
||||
# Boolean selection
|
||||
lm += "Approved: " + select(["Yes", "No"], name="approved")
|
||||
|
||||
# Multiple choice
|
||||
lm += "Answer: " + select(
|
||||
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
|
||||
name="answer"
|
||||
)
|
||||
```
|
||||
|
||||
### Conditional Selection
|
||||
|
||||
```python
|
||||
from guidance import models, select, gen, guidance
|
||||
|
||||
@guidance
|
||||
def conditional_fields(lm):
|
||||
"""Generate fields conditionally based on type."""
|
||||
lm += "Type: " + select(["person", "company"], name="type")
|
||||
|
||||
if lm["type"] == "person":
|
||||
lm += "\nName: " + gen("name", regex=r"[A-Za-z ]+")
|
||||
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
|
||||
else:
|
||||
lm += "\nCompany Name: " + gen("company", regex=r"[A-Za-z ]+")
|
||||
lm += "\nEmployees: " + gen("employees", regex=r"[0-9]+")
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Repeated Selection
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def multiple_selections(lm):
|
||||
"""Select multiple items."""
|
||||
lm += "Select 3 colors:\n"
|
||||
|
||||
colors = ["red", "blue", "green", "yellow", "purple"]
|
||||
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + select(colors, name=f"color_{i}") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Complex Patterns
|
||||
|
||||
### Pattern 1: Structured Forms
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def user_form(lm):
|
||||
"""Generate structured user form."""
|
||||
lm += "=== User Registration ===\n\n"
|
||||
|
||||
# Name (alphabetic only)
|
||||
lm += "Full Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Age (numeric)
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
|
||||
|
||||
# Email (validated format)
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
# Phone (US format)
|
||||
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}") + "\n"
|
||||
|
||||
# Account type (selection)
|
||||
lm += "Account Type: " + select(
|
||||
["Standard", "Premium", "Enterprise"],
|
||||
name="account_type"
|
||||
) + "\n"
|
||||
|
||||
# Active status (boolean)
|
||||
lm += "Active: " + select(["Yes", "No"], name="active") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 2: Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract multiple entities with constraints."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
# Person name (alphabetic)
|
||||
lm += "Person: " + gen("person", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Organization (alphanumeric with spaces)
|
||||
lm += "Organization: " + gen(
|
||||
"organization",
|
||||
regex=r"[A-Za-z0-9 ]+",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
# Date (YYYY-MM-DD format)
|
||||
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}") + "\n"
|
||||
|
||||
# Location (alphabetic with spaces)
|
||||
lm += "Location: " + gen("location", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Amount (currency)
|
||||
lm += "Amount: $" + gen("amount", regex=r"[0-9,]+\.[0-9]{2}") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 3: Code Generation
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_python_function(lm):
|
||||
"""Generate Python function with constraints."""
|
||||
# Function name (valid Python identifier)
|
||||
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
|
||||
|
||||
# Parameter name
|
||||
lm += gen("param", regex=r"[a-z_][a-z0-9_]*") + "):\n"
|
||||
|
||||
# Docstring
|
||||
lm += ' """' + gen("docstring", stop='"""', max_tokens=50) + '"""\n'
|
||||
|
||||
# Function body (constrained to valid Python)
|
||||
lm += " return " + gen("return_value", stop="\n") + "\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Pattern 4: Hierarchical Data
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def org_chart(lm):
|
||||
"""Generate organizational chart."""
|
||||
lm += "Company: " + gen("company", regex=r"[A-Za-z ]+") + "\n\n"
|
||||
|
||||
# CEO
|
||||
lm += "CEO: " + gen("ceo", regex=r"[A-Za-z ]+") + "\n"
|
||||
|
||||
# Departments
|
||||
for dept in ["Engineering", "Sales", "Marketing"]:
|
||||
lm += f"\n{dept} Department:\n"
|
||||
lm += " Head: " + gen(f"{dept.lower()}_head", regex=r"[A-Za-z ]+") + "\n"
|
||||
lm += " Size: " + gen(f"{dept.lower()}_size", regex=r"[0-9]+") + " employees\n"
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Best Practices
|
||||
|
||||
#### 1. Use Specific Patterns
|
||||
|
||||
```python
|
||||
# ✅ Good: Specific pattern
|
||||
lm += gen("age", regex=r"[0-9]{1,3}") # Fast
|
||||
|
||||
# ❌ Bad: Overly broad pattern
|
||||
lm += gen("age", regex=r"[0-9]+") # Slower
|
||||
```
|
||||
|
||||
#### 2. Limit Max Tokens
|
||||
|
||||
```python
|
||||
# ✅ Good: Reasonable limit
|
||||
lm += gen("name", max_tokens=30)
|
||||
|
||||
# ❌ Bad: No limit
|
||||
lm += gen("name") # May generate forever
|
||||
```
|
||||
|
||||
#### 3. Use stop Sequences
|
||||
|
||||
```python
|
||||
# ✅ Good: Stop at newline
|
||||
lm += gen("line", stop="\n")
|
||||
|
||||
# ❌ Bad: Rely on max_tokens
|
||||
lm += gen("line", max_tokens=100)
|
||||
```
|
||||
|
||||
#### 4. Cache Compiled Grammars
|
||||
|
||||
```python
|
||||
# Grammars are cached automatically after first use
|
||||
# No manual caching needed
|
||||
@guidance
|
||||
def reusable_pattern(lm):
|
||||
"""This grammar is compiled once and cached."""
|
||||
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
|
||||
return lm
|
||||
|
||||
# First call: compiles grammar
|
||||
lm = reusable_pattern(lm)
|
||||
|
||||
# Subsequent calls: uses cached grammar (fast)
|
||||
lm = reusable_pattern(lm)
|
||||
```
|
||||
|
||||
#### 5. Avoid Overlapping Constraints
|
||||
|
||||
```python
|
||||
# ✅ Good: Clear constraints
|
||||
lm += gen("age", regex=r"[0-9]+", max_tokens=3)
|
||||
|
||||
# ❌ Bad: Conflicting constraints
|
||||
lm += gen("age", regex=r"[0-9]{2}", max_tokens=10) # max_tokens unnecessary
|
||||
```
|
||||
|
||||
### Performance Benchmarks
|
||||
|
||||
**Regex vs Free Generation:**
|
||||
- Simple regex (digits): ~1.2x slower than free gen
|
||||
- Complex regex (email): ~1.5x slower than free gen
|
||||
- Grammar-based: ~2x slower than free gen
|
||||
|
||||
**But:**
|
||||
- 100% valid outputs (vs ~70% with free gen + validation)
|
||||
- No retry loops needed
|
||||
- Overall faster end-to-end for structured outputs
|
||||
|
||||
**Optimization Tips:**
|
||||
- Use regex for critical fields only
|
||||
- Use `select()` for small fixed sets (fastest)
|
||||
- Use `stop` sequences when possible (faster than max_tokens)
|
||||
- Cache compiled grammars by reusing functions
|
||||
|
||||
## Resources
|
||||
|
||||
- **Token Healing Paper**: https://arxiv.org/abs/2306.17648
|
||||
- **Guidance Docs**: https://guidance.readthedocs.io
|
||||
- **GitHub**: https://github.com/guidance-ai/guidance
|
||||
767
skills/mlops/guidance/references/examples.md
Normal file
767
skills/mlops/guidance/references/examples.md
Normal file
@@ -0,0 +1,767 @@
|
||||
# Production-Ready Examples
|
||||
|
||||
Real-world examples of using Guidance for structured generation, agents, and workflows.
|
||||
|
||||
## Table of Contents
|
||||
- JSON Generation
|
||||
- Data Extraction
|
||||
- Classification Systems
|
||||
- Agent Systems
|
||||
- Multi-Step Workflows
|
||||
- Code Generation
|
||||
- Production Tips
|
||||
|
||||
## JSON Generation
|
||||
|
||||
### Basic JSON
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def generate_user(lm):
|
||||
"""Generate valid user JSON."""
|
||||
lm += "{\n"
|
||||
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "email": ' + gen(
|
||||
"email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + "\n"
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
# Use it
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm += "Generate a user profile:\n"
|
||||
lm = generate_user(lm)
|
||||
|
||||
print(lm)
|
||||
# Output: Valid JSON guaranteed
|
||||
```
|
||||
|
||||
### Nested JSON
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_order(lm):
|
||||
"""Generate nested order JSON."""
|
||||
lm += "{\n"
|
||||
|
||||
# Customer info
|
||||
lm += ' "customer": {\n'
|
||||
lm += ' "name": ' + gen("customer_name", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "email": ' + gen(
|
||||
"customer_email",
|
||||
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
|
||||
) + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Order details
|
||||
lm += ' "order": {\n'
|
||||
lm += ' "id": ' + gen("order_id", regex=r'"ORD-[0-9]{6}"') + ",\n"
|
||||
lm += ' "date": ' + gen("order_date", regex=r'"\d{4}-\d{2}-\d{2}"') + ",\n"
|
||||
lm += ' "total": ' + gen("order_total", regex=r"[0-9]+\.[0-9]{2}") + "\n"
|
||||
lm += " },\n"
|
||||
|
||||
# Status
|
||||
lm += ' "status": ' + gen(
|
||||
"status",
|
||||
regex=r'"(pending|processing|shipped|delivered)"'
|
||||
) + "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_order(lm)
|
||||
```
|
||||
|
||||
### JSON Array
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_user_list(lm, count=3):
|
||||
"""Generate JSON array of users."""
|
||||
lm += "[\n"
|
||||
|
||||
for i in range(count):
|
||||
lm += " {\n"
|
||||
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
|
||||
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + ",\n"
|
||||
lm += ' "active": ' + gen(f"active_{i}", regex=r"(true|false)") + "\n"
|
||||
lm += " }"
|
||||
if i < count - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "]"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_user_list(lm, count=5)
|
||||
```
|
||||
|
||||
### Dynamic JSON Schema
|
||||
|
||||
```python
|
||||
import json
|
||||
from guidance import models, gen, guidance
|
||||
|
||||
@guidance
|
||||
def json_from_schema(lm, schema):
|
||||
"""Generate JSON matching a schema."""
|
||||
lm += "{\n"
|
||||
|
||||
fields = list(schema["properties"].items())
|
||||
for i, (field_name, field_schema) in enumerate(fields):
|
||||
lm += f' "{field_name}": '
|
||||
|
||||
# Handle different types
|
||||
if field_schema["type"] == "string":
|
||||
if "pattern" in field_schema:
|
||||
lm += gen(field_name, regex=f'"{field_schema["pattern"]}"')
|
||||
else:
|
||||
lm += gen(field_name, regex=r'"[^"]+"')
|
||||
elif field_schema["type"] == "number":
|
||||
lm += gen(field_name, regex=r"[0-9]+(\.[0-9]+)?")
|
||||
elif field_schema["type"] == "integer":
|
||||
lm += gen(field_name, regex=r"[0-9]+")
|
||||
elif field_schema["type"] == "boolean":
|
||||
lm += gen(field_name, regex=r"(true|false)")
|
||||
|
||||
if i < len(fields) - 1:
|
||||
lm += ","
|
||||
lm += "\n"
|
||||
|
||||
lm += "}"
|
||||
return lm
|
||||
|
||||
# Define schema
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"score": {"type": "number"},
|
||||
"active": {"type": "boolean"}
|
||||
}
|
||||
}
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = json_from_schema(lm, schema)
|
||||
```
|
||||
|
||||
## Data Extraction
|
||||
|
||||
### Extract from Text
|
||||
|
||||
```python
|
||||
from guidance import models, gen, guidance, system, user, assistant
|
||||
|
||||
@guidance
|
||||
def extract_person_info(lm, text):
|
||||
"""Extract structured info from text."""
|
||||
lm += f"Text: {text}\n\n"
|
||||
|
||||
with assistant():
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
|
||||
lm += "Occupation: " + gen("occupation", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Email: " + gen(
|
||||
"email",
|
||||
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
stop="\n"
|
||||
) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = "John Smith is a 35-year-old software engineer. Contact: john@example.com"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
with system():
|
||||
lm += "You extract structured information from text."
|
||||
|
||||
with user():
|
||||
lm = extract_person_info(lm, text)
|
||||
|
||||
print(f"Name: {lm['name']}")
|
||||
print(f"Age: {lm['age']}")
|
||||
print(f"Occupation: {lm['occupation']}")
|
||||
print(f"Email: {lm['email']}")
|
||||
```
|
||||
|
||||
### Multi-Entity Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def extract_entities(lm, text):
|
||||
"""Extract multiple entity types."""
|
||||
lm += f"Analyze: {text}\n\n"
|
||||
|
||||
# Person entities
|
||||
lm += "People:\n"
|
||||
for i in range(3): # Up to 3 people
|
||||
lm += f"- " + gen(f"person_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
# Organization entities
|
||||
lm += "\nOrganizations:\n"
|
||||
for i in range(2): # Up to 2 orgs
|
||||
lm += f"- " + gen(f"org_{i}", regex=r"[A-Za-z0-9 ]+", stop="\n") + "\n"
|
||||
|
||||
# Dates
|
||||
lm += "\nDates:\n"
|
||||
for i in range(2): # Up to 2 dates
|
||||
lm += f"- " + gen(f"date_{i}", regex=r"\d{4}-\d{2}-\d{2}", stop="\n") + "\n"
|
||||
|
||||
# Locations
|
||||
lm += "\nLocations:\n"
|
||||
for i in range(2): # Up to 2 locations
|
||||
lm += f"- " + gen(f"location_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
text = """
|
||||
Tim Cook and Satya Nadella met at Microsoft headquarters in Redmond on 2024-09-15
|
||||
to discuss the collaboration between Apple and Microsoft. The meeting continued
|
||||
in Cupertino on 2024-09-20.
|
||||
"""
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = extract_entities(lm, text)
|
||||
```
|
||||
|
||||
### Batch Extraction
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def batch_extract(lm, texts):
|
||||
"""Extract from multiple texts."""
|
||||
lm += "Batch Extraction Results:\n\n"
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
lm += f"=== Item {i+1} ===\n"
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen(f"name_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
|
||||
lm += "Sentiment: " + gen(
|
||||
f"sentiment_{i}",
|
||||
regex=r"(positive|negative|neutral)",
|
||||
stop="\n"
|
||||
) + "\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
texts = [
|
||||
"Alice is happy with the product",
|
||||
"Bob is disappointed with the service",
|
||||
"Carol has no strong feelings either way"
|
||||
]
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = batch_extract(lm, texts)
|
||||
```
|
||||
|
||||
## Classification Systems
|
||||
|
||||
### Sentiment Analysis
|
||||
|
||||
```python
|
||||
from guidance import models, select, gen
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
|
||||
text = "This product is absolutely amazing! Best purchase ever."
|
||||
|
||||
lm += f"Text: {text}\n\n"
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "negative", "neutral"],
|
||||
name="sentiment"
|
||||
)
|
||||
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]{1,3}") + "%\n"
|
||||
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=50)
|
||||
|
||||
print(f"Sentiment: {lm['sentiment']}")
|
||||
print(f"Confidence: {lm['confidence']}%")
|
||||
print(f"Reasoning: {lm['reasoning']}")
|
||||
```
|
||||
|
||||
### Multi-Label Classification
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def classify_article(lm, text):
|
||||
"""Classify article with multiple labels."""
|
||||
lm += f"Article: {text}\n\n"
|
||||
|
||||
# Primary category
|
||||
lm += "Primary Category: " + select(
|
||||
["Technology", "Business", "Science", "Politics", "Entertainment"],
|
||||
name="primary_category"
|
||||
) + "\n"
|
||||
|
||||
# Secondary categories (up to 3)
|
||||
lm += "\nSecondary Categories:\n"
|
||||
categories = ["Technology", "Business", "Science", "Politics", "Entertainment"]
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + select(categories, name=f"secondary_{i}") + "\n"
|
||||
|
||||
# Tags
|
||||
lm += "\nTags: " + gen("tags", stop="\n", max_tokens=50) + "\n"
|
||||
|
||||
# Target audience
|
||||
lm += "Target Audience: " + select(
|
||||
["General", "Expert", "Beginner"],
|
||||
name="audience"
|
||||
)
|
||||
|
||||
return lm
|
||||
|
||||
article = """
|
||||
Apple announced new AI features in iOS 18, leveraging machine learning to improve
|
||||
battery life and performance. The company's stock rose 5% following the announcement.
|
||||
"""
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = classify_article(lm, article)
|
||||
```
|
||||
|
||||
### Intent Classification
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def classify_intent(lm, message):
|
||||
"""Classify user intent."""
|
||||
lm += f"User Message: {message}\n\n"
|
||||
|
||||
# Intent
|
||||
lm += "Intent: " + select(
|
||||
["question", "complaint", "request", "feedback", "other"],
|
||||
name="intent"
|
||||
) + "\n"
|
||||
|
||||
# Urgency
|
||||
lm += "Urgency: " + select(
|
||||
["low", "medium", "high", "critical"],
|
||||
name="urgency"
|
||||
) + "\n"
|
||||
|
||||
# Department
|
||||
lm += "Route To: " + select(
|
||||
["support", "sales", "billing", "technical"],
|
||||
name="department"
|
||||
) + "\n"
|
||||
|
||||
# Sentiment
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "neutral", "negative"],
|
||||
name="sentiment"
|
||||
)
|
||||
|
||||
return lm
|
||||
|
||||
message = "My account was charged twice for the same order. Need help ASAP!"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = classify_intent(lm, message)
|
||||
|
||||
print(f"Intent: {lm['intent']}")
|
||||
print(f"Urgency: {lm['urgency']}")
|
||||
print(f"Department: {lm['department']}")
|
||||
```
|
||||
|
||||
## Agent Systems
|
||||
|
||||
### ReAct Agent
|
||||
|
||||
```python
|
||||
from guidance import models, gen, select, guidance
|
||||
|
||||
@guidance(stateless=False)
|
||||
def react_agent(lm, question, tools, max_rounds=5):
|
||||
"""ReAct agent with tool use."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for round in range(max_rounds):
|
||||
# Thought
|
||||
lm += f"Thought {round+1}: " + gen("thought", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Action selection
|
||||
lm += "Action: " + select(
|
||||
list(tools.keys()) + ["answer"],
|
||||
name="action"
|
||||
)
|
||||
|
||||
if lm["action"] == "answer":
|
||||
lm += "\n\nFinal Answer: " + gen("answer", max_tokens=200)
|
||||
break
|
||||
|
||||
# Action input
|
||||
lm += "\nAction Input: " + gen("action_input", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Execute tool
|
||||
if lm["action"] in tools:
|
||||
try:
|
||||
result = tools[lm["action"]](lm["action_input"])
|
||||
lm += f"Observation: {result}\n\n"
|
||||
except Exception as e:
|
||||
lm += f"Observation: Error - {str(e)}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
# Define tools
|
||||
tools = {
|
||||
"calculator": lambda expr: eval(expr),
|
||||
"search": lambda query: f"Search results for '{query}': [Mock results]",
|
||||
"weather": lambda city: f"Weather in {city}: Sunny, 72°F"
|
||||
}
|
||||
|
||||
# Use agent
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = react_agent(lm, "What is (25 * 4) + 10?", tools)
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Multi-Agent System
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def coordinator_agent(lm, task):
|
||||
"""Coordinator that delegates to specialists."""
|
||||
lm += f"Task: {task}\n\n"
|
||||
|
||||
# Determine which specialist to use
|
||||
lm += "Specialist: " + select(
|
||||
["researcher", "writer", "coder", "analyst"],
|
||||
name="specialist"
|
||||
) + "\n"
|
||||
|
||||
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def researcher_agent(lm, query):
|
||||
"""Research specialist."""
|
||||
lm += f"Research Query: {query}\n\n"
|
||||
lm += "Findings:\n"
|
||||
for i in range(3):
|
||||
lm += f"{i+1}. " + gen(f"finding_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
return lm
|
||||
|
||||
@guidance
|
||||
def writer_agent(lm, topic):
|
||||
"""Writing specialist."""
|
||||
lm += f"Topic: {topic}\n\n"
|
||||
lm += "Title: " + gen("title", stop="\n", max_tokens=50) + "\n"
|
||||
lm += "Content:\n" + gen("content", max_tokens=500)
|
||||
return lm
|
||||
|
||||
# Coordination workflow
|
||||
task = "Write an article about AI safety"
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = coordinator_agent(lm, task)
|
||||
|
||||
specialist = lm["specialist"]
|
||||
if specialist == "researcher":
|
||||
lm = researcher_agent(lm, task)
|
||||
elif specialist == "writer":
|
||||
lm = writer_agent(lm, task)
|
||||
```
|
||||
|
||||
### Tool Use with Validation
|
||||
|
||||
```python
|
||||
@guidance(stateless=False)
|
||||
def validated_tool_agent(lm, question):
|
||||
"""Agent with validated tool calls."""
|
||||
tools = {
|
||||
"add": lambda a, b: float(a) + float(b),
|
||||
"multiply": lambda a, b: float(a) * float(b),
|
||||
"divide": lambda a, b: float(a) / float(b) if float(b) != 0 else "Error: Division by zero"
|
||||
}
|
||||
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
for i in range(5):
|
||||
# Select tool
|
||||
lm += "Tool: " + select(list(tools.keys()) + ["done"], name="tool")
|
||||
|
||||
if lm["tool"] == "done":
|
||||
lm += "\nAnswer: " + gen("answer", max_tokens=100)
|
||||
break
|
||||
|
||||
# Get validated numeric arguments
|
||||
lm += "\nArg1: " + gen("arg1", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
|
||||
lm += "Arg2: " + gen("arg2", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
|
||||
|
||||
# Execute
|
||||
result = tools[lm["tool"]](lm["arg1"], lm["arg2"])
|
||||
lm += f"Result: {result}\n\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = validated_tool_agent(lm, "What is (10 + 5) * 3?")
|
||||
```
|
||||
|
||||
## Multi-Step Workflows
|
||||
|
||||
### Chain of Thought
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def chain_of_thought(lm, question):
|
||||
"""Multi-step reasoning with CoT."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
# Generate reasoning steps
|
||||
lm += "Let me think step by step:\n\n"
|
||||
for i in range(4):
|
||||
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Final answer
|
||||
lm += "\nTherefore, the answer is: " + gen("answer", stop="\n", max_tokens=50)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = chain_of_thought(lm, "If a train travels 60 mph for 2.5 hours, how far does it go?")
|
||||
|
||||
print(lm["answer"])
|
||||
```
|
||||
|
||||
### Self-Consistency
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def self_consistency(lm, question, num_samples=3):
|
||||
"""Generate multiple reasoning paths and aggregate."""
|
||||
lm += f"Question: {question}\n\n"
|
||||
|
||||
answers = []
|
||||
for i in range(num_samples):
|
||||
lm += f"=== Attempt {i+1} ===\n"
|
||||
lm += "Reasoning: " + gen(f"reasoning_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
lm += "Answer: " + gen(f"answer_{i}", stop="\n", max_tokens=50) + "\n\n"
|
||||
answers.append(lm[f"answer_{i}"])
|
||||
|
||||
# Aggregate (simple majority vote)
|
||||
from collections import Counter
|
||||
most_common = Counter(answers).most_common(1)[0][0]
|
||||
|
||||
lm += f"Final Answer (by majority): {most_common}\n"
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = self_consistency(lm, "What is 15% of 200?")
|
||||
```
|
||||
|
||||
### Planning and Execution
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def plan_and_execute(lm, goal):
|
||||
"""Plan tasks then execute them."""
|
||||
lm += f"Goal: {goal}\n\n"
|
||||
|
||||
# Planning phase
|
||||
lm += "Plan:\n"
|
||||
num_steps = 4
|
||||
for i in range(num_steps):
|
||||
lm += f"{i+1}. " + gen(f"plan_step_{i}", stop="\n", max_tokens=100) + "\n"
|
||||
|
||||
# Execution phase
|
||||
lm += "\nExecution:\n\n"
|
||||
for i in range(num_steps):
|
||||
lm += f"Step {i+1}: {lm[f'plan_step_{i}']}\n"
|
||||
lm += "Status: " + select(["completed", "in-progress", "blocked"], name=f"status_{i}") + "\n"
|
||||
lm += "Result: " + gen(f"result_{i}", stop="\n", max_tokens=150) + "\n\n"
|
||||
|
||||
# Summary
|
||||
lm += "Summary: " + gen("summary", max_tokens=200)
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = plan_and_execute(lm, "Build a REST API for a blog platform")
|
||||
```
|
||||
|
||||
## Code Generation
|
||||
|
||||
### Python Function
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_python_function(lm, description):
|
||||
"""Generate Python function from description."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
|
||||
# Function signature
|
||||
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
|
||||
lm += gen("params", regex=r"[a-z_][a-z0-9_]*(, [a-z_][a-z0-9_]*)*") + "):\n"
|
||||
|
||||
# Docstring
|
||||
lm += ' """' + gen("docstring", stop='"""', max_tokens=100) + '"""\n'
|
||||
|
||||
# Function body
|
||||
lm += " " + gen("body", stop="\n", max_tokens=200) + "\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_python_function(lm, "Check if a number is prime")
|
||||
|
||||
print(lm)
|
||||
```
|
||||
|
||||
### SQL Query
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_sql(lm, description):
|
||||
"""Generate SQL query from description."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
lm += "SQL Query:\n"
|
||||
|
||||
# SELECT clause
|
||||
lm += "SELECT " + gen("select_clause", stop=" FROM", max_tokens=100)
|
||||
|
||||
# FROM clause
|
||||
lm += " FROM " + gen("from_clause", stop=" WHERE", max_tokens=50)
|
||||
|
||||
# WHERE clause (optional)
|
||||
lm += " WHERE " + gen("where_clause", stop=";", max_tokens=100) + ";"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_sql(lm, "Get all users who signed up in the last 30 days")
|
||||
```
|
||||
|
||||
### API Endpoint
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def generate_api_endpoint(lm, description):
|
||||
"""Generate REST API endpoint."""
|
||||
lm += f"Description: {description}\n\n"
|
||||
|
||||
# HTTP method
|
||||
lm += "Method: " + select(["GET", "POST", "PUT", "DELETE"], name="method") + "\n"
|
||||
|
||||
# Path
|
||||
lm += "Path: /" + gen("path", regex=r"[a-z0-9/-]+", stop="\n") + "\n"
|
||||
|
||||
# Request body (if POST/PUT)
|
||||
if lm["method"] in ["POST", "PUT"]:
|
||||
lm += "\nRequest Body:\n"
|
||||
lm += "{\n"
|
||||
lm += ' "field1": ' + gen("field1", regex=r'"[a-z_]+"') + ",\n"
|
||||
lm += ' "field2": ' + gen("field2", regex=r'"[a-z_]+"') + "\n"
|
||||
lm += "}\n"
|
||||
|
||||
# Response
|
||||
lm += "\nResponse (200 OK):\n"
|
||||
lm += "{\n"
|
||||
lm += ' "status": "success",\n'
|
||||
lm += ' "data": ' + gen("response_data", max_tokens=100) + "\n"
|
||||
lm += "}\n"
|
||||
|
||||
return lm
|
||||
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm = generate_api_endpoint(lm, "Create a new blog post")
|
||||
```
|
||||
|
||||
## Production Tips
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
@guidance
|
||||
def safe_extraction(lm, text):
|
||||
"""Extract with fallback handling."""
|
||||
try:
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n", max_tokens=30)
|
||||
return lm
|
||||
except Exception as e:
|
||||
# Fallback to less strict extraction
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Name: " + gen("name", stop="\n", max_tokens=30)
|
||||
return lm
|
||||
```
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def cached_generation(text):
|
||||
"""Cache LLM generations."""
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
lm += f"Analyze: {text}\n"
|
||||
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
|
||||
return lm["sentiment"]
|
||||
|
||||
# First call: hits LLM
|
||||
result1 = cached_generation("This is great!")
|
||||
|
||||
# Second call: returns cached result
|
||||
result2 = cached_generation("This is great!") # Instant!
|
||||
```
|
||||
|
||||
### Monitoring
|
||||
|
||||
```python
|
||||
import time
|
||||
|
||||
@guidance
|
||||
def monitored_generation(lm, text):
|
||||
"""Track generation metrics."""
|
||||
start_time = time.time()
|
||||
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Analysis: " + gen("analysis", max_tokens=100)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Log metrics
|
||||
print(f"Generation time: {elapsed:.2f}s")
|
||||
print(f"Output length: {len(lm['analysis'])} chars")
|
||||
|
||||
return lm
|
||||
```
|
||||
|
||||
### Batch Processing
|
||||
|
||||
```python
|
||||
def batch_process(texts, batch_size=10):
|
||||
"""Process texts in batches."""
|
||||
lm = models.Anthropic("claude-sonnet-4-5-20250929")
|
||||
results = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i+batch_size]
|
||||
|
||||
for text in batch:
|
||||
lm += f"Text: {text}\n"
|
||||
lm += "Sentiment: " + select(
|
||||
["positive", "negative", "neutral"],
|
||||
name=f"sentiment_{i}"
|
||||
) + "\n\n"
|
||||
|
||||
results.extend([lm[f"sentiment_{i}"] for i in range(len(batch))])
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **Guidance Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
|
||||
- **Guidance Docs**: https://guidance.readthedocs.io
|
||||
- **Community Examples**: https://github.com/guidance-ai/guidance/discussions
|
||||
307
skills/mlops/llava/SKILL.md
Normal file
307
skills/mlops/llava/SKILL.md
Normal file
@@ -0,0 +1,307 @@
|
||||
---
|
||||
name: llava
|
||||
description: Large Language and Vision Assistant. Enables visual instruction tuning and image-based conversations. Combines CLIP vision encoder with Vicuna/LLaMA language models. Supports multi-turn image chat, visual question answering, and instruction following. Use for vision-language chatbots or image understanding tasks. Best for conversational image analysis.
|
||||
version: 1.0.0
|
||||
author: Orchestra Research
|
||||
license: MIT
|
||||
dependencies: [transformers, torch, pillow]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [LLaVA, Vision-Language, Multimodal, Visual Question Answering, Image Chat, CLIP, Vicuna, Conversational AI, Instruction Tuning, VQA]
|
||||
|
||||
---
|
||||
|
||||
# LLaVA - Large Language and Vision Assistant
|
||||
|
||||
Open-source vision-language model for conversational image understanding.
|
||||
|
||||
## When to use LLaVA
|
||||
|
||||
**Use when:**
|
||||
- Building vision-language chatbots
|
||||
- Visual question answering (VQA)
|
||||
- Image description and captioning
|
||||
- Multi-turn image conversations
|
||||
- Visual instruction following
|
||||
- Document understanding with images
|
||||
|
||||
**Metrics**:
|
||||
- **23,000+ GitHub stars**
|
||||
- GPT-4V level capabilities (targeted)
|
||||
- Apache 2.0 License
|
||||
- Multiple model sizes (7B-34B params)
|
||||
|
||||
**Use alternatives instead**:
|
||||
- **GPT-4V**: Highest quality, API-based
|
||||
- **CLIP**: Simple zero-shot classification
|
||||
- **BLIP-2**: Better for captioning only
|
||||
- **Flamingo**: Research, not open-source
|
||||
|
||||
## Quick start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/haotian-liu/LLaVA
|
||||
cd LLaVA
|
||||
|
||||
# Install
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Basic usage
|
||||
|
||||
```python
|
||||
from llava.model.builder import load_pretrained_model
|
||||
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
|
||||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
||||
from llava.conversation import conv_templates
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
# Load model
|
||||
model_path = "liuhaotian/llava-v1.5-7b"
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path=model_path,
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path(model_path)
|
||||
)
|
||||
|
||||
# Load image
|
||||
image = Image.open("image.jpg")
|
||||
image_tensor = process_images([image], image_processor, model.config)
|
||||
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
||||
|
||||
# Create conversation
|
||||
conv = conv_templates["llava_v1"].copy()
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
# Generate response
|
||||
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
images=image_tensor,
|
||||
do_sample=True,
|
||||
temperature=0.2,
|
||||
max_new_tokens=512
|
||||
)
|
||||
|
||||
response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Available models
|
||||
|
||||
| Model | Parameters | VRAM | Quality |
|
||||
|-------|------------|------|---------|
|
||||
| LLaVA-v1.5-7B | 7B | ~14 GB | Good |
|
||||
| LLaVA-v1.5-13B | 13B | ~28 GB | Better |
|
||||
| LLaVA-v1.6-34B | 34B | ~70 GB | Best |
|
||||
|
||||
```python
|
||||
# Load different models
|
||||
model_7b = "liuhaotian/llava-v1.5-7b"
|
||||
model_13b = "liuhaotian/llava-v1.5-13b"
|
||||
model_34b = "liuhaotian/llava-v1.6-34b"
|
||||
|
||||
# 4-bit quantization for lower VRAM
|
||||
load_4bit = True # Reduces VRAM by ~4×
|
||||
```
|
||||
|
||||
## CLI usage
|
||||
|
||||
```bash
|
||||
# Single image query
|
||||
python -m llava.serve.cli \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--image-file image.jpg \
|
||||
--query "What is in this image?"
|
||||
|
||||
# Multi-turn conversation
|
||||
python -m llava.serve.cli \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--image-file image.jpg
|
||||
# Then type questions interactively
|
||||
```
|
||||
|
||||
## Web UI (Gradio)
|
||||
|
||||
```bash
|
||||
# Launch Gradio interface
|
||||
python -m llava.serve.gradio_web_server \
|
||||
--model-path liuhaotian/llava-v1.5-7b \
|
||||
--load-4bit # Optional: reduce VRAM
|
||||
|
||||
# Access at http://localhost:7860
|
||||
```
|
||||
|
||||
## Multi-turn conversations
|
||||
|
||||
```python
|
||||
# Initialize conversation
|
||||
conv = conv_templates["llava_v1"].copy()
|
||||
|
||||
# Turn 1
|
||||
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response1 = generate(conv, model, image) # "A dog playing in a park"
|
||||
|
||||
# Turn 2
|
||||
conv.messages[-1][1] = response1 # Add previous response
|
||||
conv.append_message(conv.roles[0], "What breed is the dog?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response2 = generate(conv, model, image) # "Golden Retriever"
|
||||
|
||||
# Turn 3
|
||||
conv.messages[-1][1] = response2
|
||||
conv.append_message(conv.roles[0], "What time of day is it?")
|
||||
conv.append_message(conv.roles[1], None)
|
||||
response3 = generate(conv, model, image)
|
||||
```
|
||||
|
||||
## Common tasks
|
||||
|
||||
### Image captioning
|
||||
|
||||
```python
|
||||
question = "Describe this image in detail."
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Visual question answering
|
||||
|
||||
```python
|
||||
question = "How many people are in the image?"
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Object detection (textual)
|
||||
|
||||
```python
|
||||
question = "List all the objects you can see in this image."
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Scene understanding
|
||||
|
||||
```python
|
||||
question = "What is happening in this scene?"
|
||||
response = ask(model, image, question)
|
||||
```
|
||||
|
||||
### Document understanding
|
||||
|
||||
```python
|
||||
question = "What is the main topic of this document?"
|
||||
response = ask(model, document_image, question)
|
||||
```
|
||||
|
||||
## Training custom model
|
||||
|
||||
```bash
|
||||
# Stage 1: Feature alignment (558K image-caption pairs)
|
||||
bash scripts/v1_5/pretrain.sh
|
||||
|
||||
# Stage 2: Visual instruction tuning (150K instruction data)
|
||||
bash scripts/v1_5/finetune.sh
|
||||
```
|
||||
|
||||
## Quantization (reduce VRAM)
|
||||
|
||||
```python
|
||||
# 4-bit quantization
|
||||
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
||||
model_path="liuhaotian/llava-v1.5-13b",
|
||||
model_base=None,
|
||||
model_name=get_model_name_from_path("liuhaotian/llava-v1.5-13b"),
|
||||
load_4bit=True # Reduces VRAM ~4×
|
||||
)
|
||||
|
||||
# 8-bit quantization
|
||||
load_8bit=True # Reduces VRAM ~2×
|
||||
```
|
||||
|
||||
## Best practices
|
||||
|
||||
1. **Start with 7B model** - Good quality, manageable VRAM
|
||||
2. **Use 4-bit quantization** - Reduces VRAM significantly
|
||||
3. **GPU required** - CPU inference extremely slow
|
||||
4. **Clear prompts** - Specific questions get better answers
|
||||
5. **Multi-turn conversations** - Maintain conversation context
|
||||
6. **Temperature 0.2-0.7** - Balance creativity/consistency
|
||||
7. **max_new_tokens 512-1024** - For detailed responses
|
||||
8. **Batch processing** - Process multiple images sequentially
|
||||
|
||||
## Performance
|
||||
|
||||
| Model | VRAM (FP16) | VRAM (4-bit) | Speed (tokens/s) |
|
||||
|-------|-------------|--------------|------------------|
|
||||
| 7B | ~14 GB | ~4 GB | ~20 |
|
||||
| 13B | ~28 GB | ~8 GB | ~12 |
|
||||
| 34B | ~70 GB | ~18 GB | ~5 |
|
||||
|
||||
*On A100 GPU*
|
||||
|
||||
## Benchmarks
|
||||
|
||||
LLaVA achieves competitive scores on:
|
||||
- **VQAv2**: 78.5%
|
||||
- **GQA**: 62.0%
|
||||
- **MM-Vet**: 35.4%
|
||||
- **MMBench**: 64.3%
|
||||
|
||||
## Limitations
|
||||
|
||||
1. **Hallucinations** - May describe things not in image
|
||||
2. **Spatial reasoning** - Struggles with precise locations
|
||||
3. **Small text** - Difficulty reading fine print
|
||||
4. **Object counting** - Imprecise for many objects
|
||||
5. **VRAM requirements** - Need powerful GPU
|
||||
6. **Inference speed** - Slower than CLIP
|
||||
|
||||
## Integration with frameworks
|
||||
|
||||
### LangChain
|
||||
|
||||
```python
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
class LLaVALLM(LLM):
|
||||
def _call(self, prompt, stop=None):
|
||||
# Custom LLaVA inference
|
||||
return response
|
||||
|
||||
llm = LLaVALLM()
|
||||
```
|
||||
|
||||
### Gradio App
|
||||
|
||||
```python
|
||||
import gradio as gr
|
||||
|
||||
def chat(image, text, history):
|
||||
response = ask_llava(model, image, text)
|
||||
return response
|
||||
|
||||
demo = gr.ChatInterface(
|
||||
chat,
|
||||
additional_inputs=[gr.Image(type="pil")],
|
||||
title="LLaVA Chat"
|
||||
)
|
||||
demo.launch()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
- **GitHub**: https://github.com/haotian-liu/LLaVA ⭐ 23,000+
|
||||
- **Paper**: https://arxiv.org/abs/2304.08485
|
||||
- **Demo**: https://llava.hliu.cc
|
||||
- **Models**: https://huggingface.co/liuhaotian
|
||||
- **License**: Apache 2.0
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user