Compare commits

..

122 Commits

Author SHA1 Message Date
teknium1
f984cc335b feat: enhance auxiliary model configuration and environment variable handling
- Added support for auxiliary model overrides in the configuration, allowing users to specify providers and models for vision and web extraction tasks.
- Updated the CLI configuration example to include new auxiliary model settings.
- Enhanced the environment variable mapping in the CLI to accommodate auxiliary model configurations.
- Improved the resolution logic for auxiliary clients to support task-specific provider overrides.
- Updated relevant documentation and comments for clarity on the new features and their usage.
2026-03-07 08:52:06 -08:00
teknium1
0a82396718 feat: shared iteration budget across parent + subagents
Subagent tool calls now count toward the same session-wide iteration
limit as the parent agent. Previously, each subagent had its own
independent counter, so a parent with max_iterations=60 could spawn
3 subagents each doing 50 calls = 150 total tool calls unmetered.

Changes:
- IterationBudget: thread-safe shared counter (run_agent.py)
  - consume(): try to use one iteration, returns False if exhausted
  - refund(): give back one iteration (for execute_code turns)
  - Thread-safe via Lock (subagents run in ThreadPoolExecutor)
- Parent creates the budget, children inherit it via delegate_tool.py
- execute_code turns are refunded (don't count against budget)
- Default raised from 60 → 90 to account for shared consumption
- Per-child cap (50) still applies as a safety valve

The per-child max_iterations (default 50) remains as a per-child
ceiling, but the shared budget is the hard session-wide limit.
A child stops at whichever comes first.
2026-03-07 08:16:37 -08:00
teknium1
5da55ea1e3 fix: sanitize orphaned tool-call/result pairs in message compression
Enhance message compression by adding a method to clean up orphaned tool-call and tool-result pairs. This ensures that the API receives well-formed messages, preventing errors related to mismatched IDs. The new functionality includes removing orphaned results and adding stub results for missing calls, improving overall message integrity during compression.
2026-03-07 08:08:00 -08:00
teknium1
064c009deb feat: show update-available notice in CLI banner
Check how many commits behind origin/main the local repo is and
display a warning in the welcome banner:

  ⚠ 12 commits behind — run hermes update to update

- git fetch cached for 6 hours (avoids repeated network calls)
- Falls back gracefully if offline or not a git repo
- Never breaks the banner — all errors silently caught
2026-03-07 07:35:36 -08:00
teknium1
caab1cf453 fix: update setup/config UI for local browser mode
- tools_config.py: Add 'Local Browser' as first provider option
  (no API keys needed, same npm install for agent-browser)
- setup.py: Show 'Browser Automation (local)' when agent-browser
  CLI is found but no Browserbase key is set
- config.py: Mark BROWSERBASE_* descriptions as optional
- status.py: Note that local browser works without Browserbase
2026-03-07 01:23:27 -08:00
teknium1
55c70f3508 fix: strip MarkdownV2 escapes from Telegram plaintext fallback
When Telegram's MarkdownV2 parser rejects a message, the send() fallback
was sending the already-escaped text as plain text. This caused users to
see raw backslashes before every special character (periods, dashes,
parentheses, etc.) — e.g. 'sentence\.' or '\-\-auto\-approve'.

Changes:
- Add _strip_mdv2() to reverse MarkdownV2 escaping for clean plaintext
- Use stripped text in the send() fallback path instead of raw escaped chunk
- Add logging when the MDV2 fallback is triggered for diagnostics
- Add logger to telegram.py (was missing)

The edit_message() fallback already correctly used the original content;
this brings send() in line with that behavior.
2026-03-07 01:23:18 -08:00
teknium1
d29249b8fa feat: local browser backend — zero-cost headless Chromium via agent-browser
Add local browser mode as an automatic fallback when Browserbase
credentials are not configured. Uses the same agent-browser CLI with
--session (local Chromium) instead of --cdp (cloud Browserbase).

The agent-facing API is completely unchanged — all 10 browser_* tools
produce identical output in both modes. Auto-detection:
  - BROWSERBASE_API_KEY set → cloud mode (existing behavior)
  - No key → local mode (new, free, headless Chromium)

Changes:
- _is_local_mode(): auto-detect based on env vars
- _create_local_session(): lightweight session (no API call)
- _get_session_info(): branches on local vs cloud
- _run_browser_command(): --session in local, --cdp in cloud
- check_browser_requirements(): only needs agent-browser CLI in local mode
- _emergency_cleanup: CLI close in local, API release in cloud
- cleanup_browser/browser_close: skip BB API calls in local mode
- Registry: removed requires_env — check_fn handles both modes

Setup for local mode:
  npm install -g agent-browser
  agent-browser install              # downloads Chromium
  agent-browser install --with-deps  # also installs system libs (Docker/Debian)

Closes #374 (Phase 1)
2026-03-07 01:14:57 -08:00
teknium1
f668e9fc75 feat: platform-conditional skill loading + Apple/macOS skills
Add a 'platforms' field to SKILL.md frontmatter that restricts skills
to specific operating systems. Skills with platforms: [macos] only
appear in the system prompt, skills_list(), and slash commands on macOS.
Skills without the field load everywhere (backward compatible).

Implementation:
- skill_matches_platform() in tools/skills_tool.py — core filter
- Wired into all 3 discovery paths: prompt_builder.py, skills_tool.py,
  skill_commands.py
- 28 new tests across 3 test files

New bundled Apple/macOS skills (all platforms: [macos]):
- imessage — Send/receive iMessages via imsg CLI
- apple-reminders — Manage Reminders via remindctl CLI
- apple-notes — Manage Notes via memo CLI
- findmy — Track devices/AirTags via AppleScript + screen capture

Docs updated: CONTRIBUTING.md, AGENTS.md, creating-skills.md,
skills.md (user guide)
2026-03-07 00:47:54 -08:00
teknium1
74fe1e2254 chore: remove TODO.md — all items tracked as issues
All remaining TODO items have covering issues:
- Local Browser via CDP: #374, #493
- Signal Integration: #405
- Plugin/Extension System: #359
- MCP Client Improvements: #581 (new)
- Filesystem Checkpointing: #452

Completed items (MCP core support) already shipped in PR #301.
2026-03-07 00:07:14 -08:00
teknium1
348936752a fix: simplify timezone migration to use os.getenv directly
The previous 'get_env_value' in dir() check always evaluated to False
(dir() returns local scope, not module scope), making the left branch
dead code. Simplified to just os.getenv() which was the fallback anyway.
2026-03-07 00:05:05 -08:00
teknium1
69a36a3361 Merge PR #309: fix(timezone): timezone-aware now() for prompt, cron, and execute_code
Authored by areu01or00. Adds timezone support via hermes_time.now() helper
with IANA timezone resolution (HERMES_TIMEZONE env → config.yaml → server-local).
Updates system prompt timestamp, cron scheduling, and execute_code sandbox TZ
injection. Includes config migration (v4→v5) and comprehensive test coverage.
2026-03-07 00:04:41 -08:00
Teknium
8712dd6d1c Merge pull request #308 from batuhankocyigit/patch-2
fix: rename misspelled directory 'fouth-edition' to 'fourth-edition'
2026-03-06 23:43:09 -08:00
teknium1
55a21fe37b docs: add Environments, Benchmarks & Data Generation guide
Comprehensive developer guide covering:
- Architecture (BaseEnv → HermesAgentBaseEnv → concrete envs)
- All three benchmarks (TerminalBench2, TBLite, YC-Bench)
- Training environments (TerminalTestEnv, HermesSweEnv)
- Core components (AgentLoop, ToolContext, Tool Call Parsers)
- Two-phase operation (Phase 1 OpenAI, Phase 2 VLLM)
- Running environments (evaluate, process, serve modes)
- Creating new environments (training + eval-only)
- Configuration reference and prerequisites

Also updates environments/README.md directory tree to include
TBLite and YC-Bench benchmarks.
2026-03-06 23:31:45 -08:00
teknium1
f55f625277 chore: reorder terminal backends in setup wizard
Local, Docker, Modal, SSH, Daytona, Singularity (Linux-only, last).
2026-03-06 22:21:57 -08:00
teknium1
9dac85b069 fix: uv pip install fails outside venv in setup wizard
uv pip install requires a virtual environment by default. When hermes
is installed system-wide or via pipx, the setup wizard's SDK installs
(daytona, swe-rex[modal], tinker-atropos) fail with 'No virtual
environment found'. Fix by passing --python sys.executable to uv,
which targets the correct Python regardless of venv state.

Also show the actual error message on install failure so users can
debug.
2026-03-06 21:55:33 -08:00
teknium1
99bd69baa8 Merge feat/modular-setup-wizard: modular setup wizard with section subcommands and tool-first UX
- 5 standalone sections: hermes setup [model|terminal|gateway|tools|agent]
- Returning user menu with section shortcuts
- Tool-first UX: category -> provider -> API key flow
- Unified hermes tools / hermes setup tools
- Fixed dict-format model config display bug

Closes #567
2026-03-06 21:12:30 -08:00
teknium1
a62a137a4f fix: handle dict-format model config in setup wizard display
config['model'] can be a dict (old format: {default, base_url, provider})
or a string (new format). The setup wizard was showing the raw dict in
'Keep current' and 'Model set to' messages. Now extracts the model name
from either format.
2026-03-06 21:11:40 -08:00
teknium1
82b18e8ac2 feat: unify hermes tools and hermes setup tools into single flow
Both 'hermes tools' and 'hermes setup tools' now use the same unified
flow in tools_config.py:

1. Select platform (CLI, Telegram, Discord, etc.)
2. Toggle all 18 toolsets on/off in checklist
3. Newly enabled tools that need API keys → provider-aware config
   (e.g., TTS shows Edge/OpenAI/ElevenLabs picker)
4. Already-configured tools that stay enabled → silent, no prompts
5. Menu option: 'Reconfigure an existing tool' for updating
   providers or API keys on tools that are already set up

Key changes:
- Move TOOL_CATEGORIES, provider config, and post-setup hooks from
  setup.py to tools_config.py
- Replace flat _check_and_prompt_requirements() with provider-aware
  _configure_toolset() that uses TOOL_CATEGORIES
- Add _reconfigure_tool() flow for updating existing configs
- setup.py's setup_tools() now delegates to tools_command()
- tools_command() menu adds 'Reconfigure' option alongside platforms
- Only prompt for API keys on tools that are NEWLY toggled on AND
  don't already have keys configured

No breaking changes. All 2013 tests pass.
2026-03-06 21:02:00 -08:00
teknium1
0111c9848d fix: remove ANSI codes and em dashes from menu labels
simple_term_menu miscalculates string widths when labels contain
ANSI escape codes (from color()) or em dashes, causing duplicated
and garbled lines on arrow key navigation.

Replace color() status indicators with plain text [configured]/[active]
and em dashes with regular dashes in all prompt_choice/prompt_checklist
labels.
2026-03-06 21:02:00 -08:00
teknium1
ab9cadfeee feat: modular setup wizard with section subcommands and tool-first UX
Restructure the monolithic hermes setup wizard into independently-runnable
sections with a category-first tool configuration experience.

Changes:
- Break setup into 5 sections: model, terminal, gateway, tools, agent
- Each section is a standalone function, runnable individually via
  'hermes setup model', 'hermes setup terminal', etc.
- Returning users get a menu: Quick Setup / Full Setup / individual sections
- First-time users get a guided walkthrough of all sections

Tool Configuration UX overhaul:
- Replace flat API key checklist with category-first approach
- Show tool types (TTS, Web Search, Image Gen, etc.) as top-level items
- Within each category, let users pick a provider:
  - TTS: Microsoft Edge (Free), OpenAI, ElevenLabs
  - Web: Firecrawl Cloud, Firecrawl Self-Hosted
  - Image Gen: FAL.ai
  - Browser: Browserbase
  - Smart Home: Home Assistant
  - RL Training: Tinker/Atropos
  - GitHub: Personal Access Token
- Shows configured status on each tool and provider
- Only prompts for API keys after provider selection

Also:
- Add section argument to setup argparse parser in main.py
- Update summary to show new section commands
- Add self-hosted Firecrawl and Home Assistant to tool setup
- All 2013 tests pass
2026-03-06 21:02:00 -08:00
teknium1
ce28f847ce fix: update OpenRouter model names for yc-bench config
Use anthropic/claude-sonnet-4.6 (OpenRouter format) instead of
anthropic/claude-sonnet-4-20250514 (direct API format).
2026-03-06 19:58:56 -08:00
teknium1
b4fbb6fe10 feat: add YC-Bench long-horizon agent benchmark environment
Adds eval-only benchmark for YC-Bench (collinear-ai/yc-bench), a
deterministic long-horizon benchmark where the agent acts as CEO of an
AI startup over a simulated 1-3 year run.

Key design decisions verified against the official yc-bench repo:
- Uses 'sim init' (NOT 'yc-bench run') to avoid starting a competing
  built-in agent loop
- Correct DB table names: 'companies' and 'sim_events'
- Correct 4 domains: research, inference, data_environment, training
- Penalty values are preset-dependent (not hardcoded in system prompt)
- Sequential evaluation (each run is 100-500 turns)
- Follows TerminalBench2 patterns: KeyboardInterrupt handling,
  cleanup_all_environments(), tqdm logging handler, streaming JSONL

yc-bench added as optional dependency: pip install hermes-agent[yc-bench]

Closes #340
2026-03-06 19:25:56 -08:00
teknium1
82d7e9429e chore: add GLM/Kimi/MiniMax models to insights pricing (zero cost)
These direct providers don't return cost in API responses and their
per-token pricing isn't readily available externally. Treat as local
models with zero cost so they appear in /insights without fake estimates.
2026-03-06 19:12:14 -08:00
teknium1
e2821effb5 feat: add direct API-key providers as auxiliary client fallbacks
When the user only has a z.ai/Kimi/MiniMax API key (no OpenRouter key),
auxiliary tasks (context compression, web summarization, session search)
now fall back to the configured direct provider instead of returning None.

Resolution chain: OpenRouter -> Nous -> Custom endpoint -> Codex OAuth
-> direct API-key providers -> None.

Uses cheap/fast models for auxiliary tasks:
- zai: glm-4.5-flash
- kimi-coding: kimi-k2-turbo-preview
- minimax/minimax-cn: MiniMax-M2.5-highspeed

Vision auxiliary intentionally NOT modified — vision needs multimodal
models (Gemini) that these providers don't serve.
2026-03-06 19:08:54 -08:00
teknium1
9742f11fda chore: add context lengths for Kimi and MiniMax models
Adds DEFAULT_CONTEXT_LENGTHS entries for kimi-k2.5 (262144), kimi-k2-thinking
(262144), kimi-k2-turbo-preview (262144), kimi-k2-0905-preview (131072),
MiniMax-M2.5/M2.5-highspeed/M2.1 (204800), and glm-4.5/4.5-flash (131072).

Avoids unnecessary 2M-token probe on first use with direct providers.
2026-03-06 19:01:38 -08:00
teknium1
388dd4789c feat: add z.ai/GLM, Kimi/Moonshot, MiniMax as first-class providers
Adds 4 new direct API-key providers (zai, kimi-coding, minimax, minimax-cn)
to the inference provider system. All use standard OpenAI-compatible
chat/completions endpoints with Bearer token auth.

Core changes:
- auth.py: Extended ProviderConfig with api_key_env_vars and base_url_env_var
  fields. Added providers to PROVIDER_REGISTRY. Added provider aliases
  (glm, z-ai, zhipu, kimi, moonshot). Added auto-detection of API-key
  providers in resolve_provider(). Added resolve_api_key_provider_credentials()
  and get_api_key_provider_status() helpers.
- runtime_provider.py: Added generic API-key provider branch in
  resolve_runtime_provider() — any provider with auth_type='api_key'
  is automatically handled.
- main.py: Added providers to hermes model menu with generic
  _model_flow_api_key_provider() flow. Updated _has_any_provider_configured()
  to check all provider env vars. Updated argparse --provider choices.
- setup.py: Added providers to setup wizard with API key prompts and
  curated model lists.
- config.py: Added env vars (GLM_API_KEY, KIMI_API_KEY, MINIMAX_API_KEY,
  etc.) to OPTIONAL_ENV_VARS.
- status.py: Added API key display and provider status section.
- doctor.py: Added connectivity checks for each provider endpoint.
- cli.py: Updated provider docstrings.

Docs: Updated README.md, .env.example, cli-config.yaml.example,
cli-commands.md, environment-variables.md, configuration.md.

Tests: 50 new tests covering registry, aliases, resolution, auto-detection,
credential resolution, and runtime provider dispatch.

Inspired by PR #33 (numman-ali) which proposed a provider registry approach.
Credit to tars90percent (PR #473) and manuelschipper (PR #420) for related
provider improvements merged earlier in this changeset.
2026-03-06 18:55:18 -08:00
Teknium
fdebca4573 Merge pull request #571 from NousResearch/rewbs/nous-key-remint-attempt-on-401
fix: implement Nous credential refresh on 401 error for retry logic
2026-03-06 18:52:01 -08:00
teknium1
479dfc096a Merge PR #473: Update model id in OpenRouter from minimax-m2.1 to minimax-m2.5
Authored by tars90percent. Updates remaining minimax-m2.1 references to
minimax-m2.5 in rl_training_tool.py and docs.
2026-03-06 18:43:18 -08:00
teknium1
3c6c11b7c9 Merge PR #420: fix: respect OPENAI_BASE_URL when resolving API key priority
Authored by manuelschipper. Adds GLM-4.7 and GLM-5 context lengths (202752)
to model_metadata.py. The key priority fix (prefer OPENAI_API_KEY for
non-OpenRouter endpoints) was already applied in PR #295; merged the Z.ai
mention into the comment.
2026-03-06 18:43:13 -08:00
Robin Fernandes
bc091eb7ef fix: implement Nous credential refresh on 401 error for retry logic 2026-03-07 13:34:23 +11:00
teknium1
f75b1d21b4 fix: execute_code and delegate_task now respect disabled toolsets
When a user disables the web toolset via 'hermes tools', the execute_code
schema description still hardcoded web_search/web_extract as available,
causing the model to keep trying to use them. Similarly, delegate_task
always defaulted to ['terminal', 'file', 'web'] for subagents regardless
of the parent's config.

Changes:
- execute_code schema is now built dynamically via build_execute_code_schema()
  based on which sandbox tools are actually enabled
- model_tools.py rebuilds the execute_code schema at definition time using
  the intersection of sandbox-allowed and session-enabled tools
- delegate_task now inherits the parent agent's enabled_toolsets instead of
  hardcoding DEFAULT_TOOLSETS when no explicit toolsets are specified
- delegate_task description updated to say 'inherits your enabled toolsets'

Reported by kotyKD on Discord.
2026-03-06 17:36:14 -08:00
teknium1
94053d75a6 fix: custom endpoint no longer leaks OPENROUTER_API_KEY (#560)
API key selection is now base_url-aware: when the resolved base_url
targets OpenRouter, OPENROUTER_API_KEY takes priority (preserving the
#289 fix). When hitting any other endpoint (Z.ai, vLLM, custom, etc.),
OPENAI_API_KEY takes priority so the OpenRouter key doesn't leak.

Applied in both the runtime provider resolver (the real code path) and
the CLI initial default (for consistency).

Fixes #560.
2026-03-06 17:16:14 -08:00
teknium1
2a68099675 fix(tests): isolate tests from user ~/.hermes/ config and SOUL.md
_make_cli() now patches CLI_CONFIG with clean defaults so
test_cli_init tests don't depend on the developer's local config.yaml.
test_empty_dir_returns_empty now mocks Path.home() so it doesn't pick
up a global SOUL.md.

Credit to teyrebaz33 for identifying and fixing these in PR #557.
Fixes #555.
2026-03-06 17:10:35 -08:00
teknium1
6cd3bc6640 Merge PR #563: fix: prevent data loss in skills sync on copy/update failure
Authored by 0xbyt4. Two bugs fixed:
1. Failed copytree no longer poisons the manifest (skill gets retried)
2. Failed update no longer destroys user's copy (backup + restore)
2026-03-06 17:01:30 -08:00
0xbyt4
211b55815e fix: prevent data loss in skills sync on copy/update failure
Two bugs in sync_skills():

1. Failed copytree poisons manifest: when shutil.copytree fails (disk
   full, permission error), the skill is still recorded in the manifest.
   On the next sync, the skill appears as "in manifest but not on disk"
   which is interpreted as "user deliberately deleted it" — the skill
   is never retried.  Fix: only write to manifest on successful copy.

2. Failed update destroys user copy: rmtree deletes the existing skill
   directory before copytree runs. If copytree then fails, the user's
   skill is gone with no way to recover.  Fix: move to .bak before
   copying, restore from backup if copytree fails.

Both bugs are proven by new regression tests that fail on the old code
and pass on the fix.
2026-03-07 03:58:32 +03:00
teknium1
8ae4a6f824 fix: improve handling of empty responses after tool calls
- Added fallback mechanism to utilize previous content when the model generates an empty response after tool calls, reducing unnecessary API retries.
- Enhanced logging to indicate when prior content is used as a final response.
- Updated logic to ensure that genuine empty responses are retried appropriately, maintaining user experience.
2026-03-06 16:54:31 -08:00
teknium1
b98301677a docs: add /insights to all help menus and documentation
- website/docs/reference/cli-commands.md: Added 'hermes insights' terminal
  command section with --days and --source flags, plus /insights slash command
  in the Conversation section
- website/docs/user-guide/cli.md: Added /insights to slash commands table
- website/docs/user-guide/messaging/index.md: Added /insights to gateway
  chat commands table
- website/docs/user-guide/sessions.md: Added cross-reference to hermes
  insights from the sessions stats section
2026-03-06 16:48:58 -08:00
teknium1
f2fdde5ba4 fix: show user-modified skills count in hermes update output 2026-03-06 16:14:43 -08:00
teknium1
4f56e31dc7 fix: track origin hashes in skills manifest to preserve user modifications
Upgrade skills_sync manifest to v2 format (name:origin_hash). The origin
hash records the MD5 of the bundled skill at the time it was last synced.

On update, the user's copy is compared against the origin hash:
- User copy == origin hash → unmodified → safe to update from bundled
- User copy != origin hash → user customized → skip (preserve changes)

v1 manifests (plain names) are auto-migrated: the user's current hash
becomes the baseline, so future syncs can detect modifications.

Output now shows user-modified skills:
  ~ whisper (user-modified, skipping)

27 tests covering all scenarios including v1→v2 migration, user
modification detection, update after migration, and origin hash tracking.
2009 tests pass.
2026-03-06 16:13:58 -08:00
Teknium
6d3804770c Merge pull request #552 from NousResearch/feat/insights
feat: /insights command — usage analytics, cost estimation & activity patterns
2026-03-06 16:00:28 -08:00
teknium1
ab0f4126cf fix: restore all removed bundled skills + fix skills sync system
- Restored 21 skills removed in commits 757d012 and 740dd92:
  accelerate, audiocraft, code-review, faiss, flash-attention, gguf,
  grpo-rl-training, guidance, llava, nemo-curator, obliteratus, peft,
  pytorch-fsdp, pytorch-lightning, simpo, slime, stable-diffusion,
  tensorrt-llm, torchtitan, trl-fine-tuning, whisper

- Rewrote sync_skills() with proper update semantics:
  * New skills (not in manifest): copied to user dir
  * Existing skills (in manifest + on disk): updated via hash comparison
  * User-deleted skills (in manifest, not on disk): respected, not re-added
  * Stale manifest entries (removed from bundled): cleaned from manifest

- Added sync_skills() to CLI startup (cmd_chat) and gateway startup
  (start_gateway) — previously only ran during 'hermes update'

- Updated cmd_update output to show new/updated/cleaned counts

- Rewrote tests: 20 tests covering manifest CRUD, dir hashing, fresh
  install, user deletion respect, update detection, stale cleanup, and
  name collision handling

75 bundled skills total. 2002 tests pass.
2026-03-06 15:57:30 -08:00
teknium1
585f8528b2 fix: deep review — prefix matching, tool_calls extraction, query perf, serialization
Issues found and fixed during deep code path review:

1. CRITICAL: Prefix matching returned wrong prices for dated model names
   - 'gpt-4o-mini-2024-07-18' matched gpt-4o ($2.50) instead of gpt-4o-mini ($0.15)
   - Same for o3-mini→o3 (9x), gpt-4.1-mini→gpt-4.1 (5x), gpt-4.1-nano→gpt-4.1 (20x)
   - Fix: use longest-match-wins strategy instead of first-match
   - Removed dangerous key.startswith(bare) reverse matching

2. CRITICAL: Top Tools section was empty for CLI sessions
   - run_agent.py doesn't set tool_name on tool response messages (pre-existing)
   - Insights now also extracts tool names from tool_calls JSON on assistant
     messages, which IS populated for all sessions
   - Uses max() merge strategy to avoid double-counting between sources

3. SELECT * replaced with explicit column list
   - Skips system_prompt and model_config blobs (can be thousands of chars)
   - Reduces memory and I/O for large session counts

4. Sets in overview dict converted to sorted lists
   - models_with_pricing / models_without_pricing were Python sets
   - Sets aren't JSON-serializable — would crash json.dumps()

5. Negative duration guard
   - end > start check prevents negative durations from clock drift

6. Model breakdown sort fallback
   - When all tokens are 0, now sorts by session count instead of arbitrary order

7. Removed unused timedelta import

Added 6 new tests: dated model pricing (4), tool_calls JSON extraction,
JSON serialization safety. Total: 69 tests.
2026-03-06 14:50:57 -08:00
teknium1
75f523f5c0 fix: unknown/custom models get zero cost instead of fake estimates
Custom OAI endpoints, self-hosted models, and local inference should NOT
show fabricated cost estimates. Changed default pricing from $3/$12 per
million tokens to $0/$0 for unrecognized models.

- Added _has_known_pricing() to distinguish commercial vs custom models
- Models with known pricing show $ amounts; unknown models show 'N/A'
- Overview shows asterisk + note when some models lack pricing data
- Gateway format adds '(excludes custom/self-hosted models)' note
- Added 7 new tests for custom model cost handling
2026-03-06 14:18:19 -08:00
teknium1
68fbae5692 docs: add Custom & Self-Hosted LLM Providers guide
Comprehensive guide for using Hermes Agent with alternative LLM backends:
- Ollama (local models, zero config)
- vLLM (high-performance GPU inference)
- SGLang (RadixAttention, prefix caching)
- llama.cpp / llama-server (CPU & Metal inference)
- LiteLLM Proxy (multi-provider gateway)
- ClawRouter (cost-optimized routing with complexity scoring)
- 10+ other compatible providers table (Together, Groq, DeepSeek, etc.)
- Choosing the Right Setup decision table
- General custom endpoint setup instructions

All of these work via the existing OPENAI_BASE_URL + OPENAI_API_KEY
custom endpoint support — no code changes needed.
2026-03-06 14:16:06 -08:00
teknium1
80f1dd8d37 docs: add Custom & Self-Hosted LLM Providers guide
Comprehensive guide for using Hermes Agent with alternative LLM backends:
- Ollama (local models, zero config)
- vLLM (high-performance GPU inference)
- SGLang (RadixAttention, prefix caching)
- llama.cpp / llama-server (CPU & Metal inference)
- LiteLLM Proxy (multi-provider gateway)
- ClawRouter (cost-optimized routing with complexity scoring)
- 10+ other compatible providers table (Together, Groq, DeepSeek, etc.)
- Choosing the Right Setup decision table
- General custom endpoint setup instructions

All of these work via the existing OPENAI_BASE_URL + OPENAI_API_KEY
custom endpoint support — no code changes needed.
2026-03-06 14:15:57 -08:00
teknium1
b52b37ae64 feat: add /insights command with usage analytics and cost estimation
Inspired by Claude Code's /insights, adapted for Hermes Agent's multi-platform
architecture. Analyzes session history from state.db to produce comprehensive
usage insights.

Features:
- Overview stats: sessions, messages, tokens, estimated cost, active time
- Model breakdown: per-model sessions, tokens, and cost estimation
- Platform breakdown: CLI vs Telegram vs Discord etc. (unique to Hermes)
- Tool usage ranking: most-used tools with percentages
- Activity patterns: day-of-week chart, peak hours, streaks
- Notable sessions: longest, most messages, most tokens, most tool calls
- Cost estimation: real pricing data for 25+ models (OpenAI, Anthropic,
  DeepSeek, Google, Meta) with fuzzy model name matching
- Configurable time window: --days flag (default 30)
- Source filtering: --source flag to filter by platform

Three entry points:
- /insights slash command in CLI (supports --days and --source flags)
- /insights slash command in gateway (compact markdown format)
- hermes insights CLI subcommand (standalone)

Includes 56 tests covering pricing helpers, format helpers, empty DB,
populated DB with multi-platform data, filtering, formatting, and edge cases.
2026-03-06 14:04:59 -08:00
teknium1
d63b363cde refactor: extract atomic_json_write helper, add 24 checkpoint tests
Extract the duplicated temp-file + fsync + os.replace pattern from
batch_runner.py (1 instance) and process_registry.py (2 instances) into
a shared utils.atomic_json_write() function.

Add 12 tests for atomic_json_write covering: valid JSON, parent dir
creation, overwrite, crash safety (original preserved on error), no temp
file leaks, string paths, unicode, custom indent, concurrent writes.

Add 12 tests for batch_runner checkpoint behavior covering:
_save_checkpoint (valid JSON, last_updated, overwrite, lock/no-lock,
parent dirs, no temp leaks), _load_checkpoint (missing file, existing
data, corrupt JSON), and resume logic (preserves prior progress,
different run_name starts fresh).
2026-03-06 05:50:12 -08:00
teknium1
c05c60665e Merge PR #298: Make process_registry checkpoint writes atomic
Authored by aydnOktay. Companion to PR #297 (batch_runner). Applies the
same atomic write pattern (temp file + fsync + os.replace) to both
_write_checkpoint() and recover_from_checkpoint() in process_registry.py.
Prevents checkpoint corruption on gateway crashes. Also improves error
handling: bare 'pass' replaced with logger.debug(..., exc_info=True)
for better debugging.
2026-03-06 05:32:35 -08:00
teknium1
b4873a5de7 fix(setup): Escape skips instead of exiting, add control hints to all prompts
Previously pressing Escape in any setup wizard menu called sys.exit(1),
killing the entire wizard with no way to recover. Now:

- prompt_choice: Escape keeps the current default and moves on (prints
  'Skipped (keeping current)'). Shows '↑/↓ Navigate  Enter Select
  Esc Skip  Ctrl+C Exit' hint.
- prompt_checklist: Escape returns pre-selected items instead of empty
  list. Shows 'SPACE Toggle  ENTER Confirm  ESC Skip  Ctrl+C Exit'.
- prompt_yes_no: now catches KeyboardInterrupt/EOFError properly.
- Fallback number prompts also show control hints.

Ctrl+C still exits the wizard cleanly.
2026-03-06 05:27:11 -08:00
teknium1
913f8ce0a5 Merge PR #297: Make batch_runner checkpoint incremental and atomic
Authored by aydnOktay. Three improvements to batch_runner fault tolerance:
1) Atomic checkpoint writes (temp file + fsync + os.replace) to prevent
   corruption on crashes — same pattern as auth.py's _save_auth_store().
2) Incremental checkpoints after each batch result instead of only at end,
   so interrupted runs can resume with minimal progress loss.
3) Resume loads existing checkpoint state instead of initializing empty,
   preventing clobber of prior progress.

Conflict resolved: kept both the incremental checkpoint logic (PR) and
the batch worker error handling (HEAD) in the imap_unordered loop.
2026-03-06 05:16:31 -08:00
teknium1
4a63737227 Merge PR #433: fix(whatsapp): replace Linux-only fuser with cross-platform port cleanup
Authored by Farukest. Fixes #432. Extracts _kill_port_process() helper
that uses netstat+taskkill on Windows and fuser on Linux. Previously,
fuser calls were inline with bare except-pass, so on Windows orphaned
bridge processes were never cleaned up — causing 'address already in use'
errors on reconnect. Includes 5 tests covering both platforms, port
matching edge cases, and exception suppression.
2026-03-06 04:52:25 -08:00
teknium1
3e93db16bd Merge PR #436: fix: use _max_tokens_param in max-iterations retry path
Authored by Farukest. Fixes #435. The retry summary in
_handle_max_iterations() hardcoded max_tokens instead of using
_max_tokens_param(), which returns max_completion_tokens for direct
OpenAI API (required by gpt-4o, o-series). The first attempt already
used _max_tokens_param correctly — only the retry path was wrong.
Includes 4 tests for _max_tokens_param provider detection.
2026-03-06 04:46:24 -08:00
teknium1
f863a42351 Merge PR #441: fix(gateway): return response from /retry handler instead of discarding it
Authored by PercyDikec. Fixes #440. _handle_retry_command called
_handle_message(retry_event) but discarded the return value, returning
None instead. Since only _process_message_background sends the response
via adapter.send(), this meant the agent would run (tool progress was
visible) but the final answer was silently dropped on all platforms.
2026-03-06 04:42:54 -08:00
teknium1
dc55f493be fix: add missing re.DOTALL to DeepSeek V3.1 parser (same bug as V3)
The V3.1 parser had the same issue — .*? without re.DOTALL fails to
match multi-line JSON arguments. Found during review of PR #444.
2026-03-06 04:41:47 -08:00
teknium1
936fda3f9e Merge PR #444: fix: add missing re.DOTALL flag to DeepSeek V3 tool call parser
Authored by PercyDikec. Fixes #443. Without re.DOTALL, the regex .*
doesn't match newlines, so multi-line JSON arguments (the normal case)
silently fail to parse. Every other parser in the codebase that matches
across lines already uses re.DOTALL.
2026-03-06 04:39:53 -08:00
teknium1
ecb8148a9f Merge PR #448: fix(cli): use correct dict key for codex auth file path in status output
Authored by PercyDikec. Fixes #447. The status display used
codex_status.get('auth_file') but get_codex_auth_status() in auth.py
returns the path under 'auth_store' (line 1220). This one-char key
mismatch silently dropped the auth file path from 'hermes status'.
2026-03-06 04:34:46 -08:00
teknium1
2dbbedc05a docs: rebrand messaging — 'the self-improving AI agent'
- Lead with the learning loop: autonomous skill creation, skill
  self-improvement, memory nudges, FTS5 session search, Honcho
  dialectic user modeling
- 'Runs anywhere' angle: 6 backends, serverless persistence with
  Daytona/Modal, not tied to your laptop
- 'Built by model trainers' replaces 'model-agnostic'
- Updated README tagline, feature table, subtitle
- Updated docs landing page hero, description, key features
- Updated docusaurus tagline and pyproject.toml description
2026-03-06 04:34:06 -08:00
teknium1
c30967806c test: add 26 tests for set_config_value secret routing
Verifies explicit allowlist keys, catch-all _API_KEY/_TOKEN patterns,
case insensitivity, TERMINAL_SSH prefix, and config.yaml routing for
non-secret keys. Covers the fix from PR #469.
2026-03-06 04:26:18 -08:00
teknium1
145f719d30 Merge PR #469: fix(config): route API keys and tokens to .env instead of config.yaml
Authored by ygd58. Fixes #465. Adds missing keys to allowlist and
catch-all patterns (_API_KEY, _TOKEN suffixes) for future-proofing.
2026-03-06 04:23:49 -08:00
teknium1
b89eb29174 fix: correct mock tool name 'search' → 'search_files' in test_code_execution
The mock handler checked for function_name == 'search' but the RPC
sends 'search_files'. Any test exercising search_files through the
mock would get 'Unknown tool' instead of the canned response.
2026-03-06 03:53:43 -08:00
teknium1
3670089a42 docs: add Daytona to batch_runner, process_registry, agent_loop, tool_context
Add daytona_image to batch_runner per-prompt container image overrides
so batch processing works with the Daytona backend. Update inline
comments in RL environment files (agent_loop, tool_context) and
process_registry docstrings to include Daytona in backend lists.
2026-03-06 03:49:59 -08:00
teknium1
3982fcf095 fix: sync execute_code sandbox stubs with real tool schemas
The _TOOL_STUBS dict in code_execution_tool.py was out of sync with the
actual tool schemas, causing TypeErrors when the LLM used parameters it
sees in its system prompt but the sandbox stubs didn't accept:

search_files:
  - Added missing params: context, offset, output_mode
  - Fixed target default: 'grep' → 'content' (old value was obsolete)

patch:
  - Added missing params: mode, patch (V4A multi-file patch support)

Also added 4 drift-detection tests (TestStubSchemaDrift) that will
catch future divergence between stubs and real schemas:
  - test_stubs_cover_all_schema_params: every schema param in stub
  - test_stubs_pass_all_params_to_rpc: every stub param sent over RPC
  - test_search_files_target_uses_current_values: no obsolete values
  - test_generated_module_accepts_all_params: generated code compiles

All 28 tests pass.
2026-03-06 03:40:06 -08:00
teknium1
8481fdcf08 docs: complete Daytona backend documentation coverage
Update all remaining files that enumerate terminal backends to include
Daytona. Covers security docs (bypass info, backend comparison table),
environment variables reference (DAYTONA_API_KEY, TERMINAL_DAYTONA_IMAGE,
container resources header), AGENTS.md (architecture tree, config keys),
environments/README.md, hermes_base_env.py field description, and various
module docstrings.

Follow-up to PR #451 merge.
2026-03-06 03:37:05 -08:00
teknium1
39299e2de4 Merge PR #451: feat: Add Daytona environment backend
Authored by rovle. Adds Daytona as the sixth terminal execution backend
with cloud sandboxes, persistent workspaces, and full CLI/gateway integration.
Includes 24 unit tests and 8 integration tests.
2026-03-06 03:32:40 -08:00
teknium1
efec4fcaab feat(execute_code): add json_parse, shell_quote, retry helpers to sandbox
The execute_code sandbox generates a hermes_tools.py stub module for LLM
scripts. Three common failure modes keep tripping up scripts:

1. json.loads(strict=True) rejects control chars in terminal() output
   (e.g., GitHub issue bodies with literal tabs/newlines)
2. Shell backtick/quote interpretation when interpolating dynamic content
   into terminal() commands (markdown with backticks gets eaten by bash)
3. No retry logic for transient network failures (API timeouts, rate limits)

Adds three convenience helpers to the generated hermes_tools module:

- json_parse(text) — json.loads with strict=False for tolerant parsing
- shell_quote(s) — shlex.quote() for safe shell interpolation
- retry(fn, max_attempts=3, delay=2) — exponential backoff wrapper

Also updates the EXECUTE_CODE_SCHEMA description to document these helpers
so LLMs know they're available without importing anything extra.

Includes 7 new tests (unit + integration) covering all three helpers.
2026-03-06 01:52:46 -08:00
teknium1
5ce2c47d60 docs: update all docs for optional-skills and browse command
Update 7 documentation files to reflect:
- optional-skills/ directory in all project structure trees
- 'hermes skills browse' in all CLI command listings
- '/skills browse' in all slash command references
- Three-tier skill placement (bundled → optional → hub)
- 'official' trust level in trust level tables
- Updated /skills description from 'Search, install...' to 'Browse, search...'

Files updated:
- CONTRIBUTING.md (skill classification, project tree, section title)
- AGENTS.md (project tree, Skills Hub description, source adapters list)
- website/docs/reference/cli-commands.md (CLI table, slash command table)
- website/docs/developer-guide/creating-skills.md (structure, classification, trust)
- website/docs/user-guide/features/skills.md (hub commands, trust table, slash commands)
- website/docs/user-guide/cli.md (slash command description)
- website/docs/developer-guide/architecture.md (project tree)
2026-03-06 01:46:34 -08:00
teknium1
f6f3d1de9b fix: review fixes — path traversal guard, trust_style consistency, edge cases
Address code review findings:

Security (Medium):
- Path traversal guard in OptionalSkillSource.fetch() — resolve() and
  validate that the path stays within optional-skills/ before reading

Bug fixes (Medium):
- Add 'builtin' to trust_style dicts in do_inspect() and
  _resolve_short_name() — official skills now show bright_cyan 'official'
  label consistently across all display functions (5/5 dicts fixed)

Edge cases (Low):
- Clamp page_size to [1, 100] in do_browse() to prevent ZeroDivisionError
- Update SkillMeta.source docstring to include 'official'
- Add browse command to optional-skills/DESCRIPTION.md
2026-03-06 01:40:01 -08:00
teknium1
ec0fe3242a feat: 'hermes skills browse' — paginated browsing of all hub skills
Add a browse command that shows all available skills across all registries,
paginated and sorted with official skills first.

Usage:
  hermes skills browse                    # all sources, page 1
  hermes skills browse --source official  # only official optional skills
  hermes skills browse --page 2           # page 2
  hermes skills browse --size 30          # 30 per page
  /skills browse                          # slash command in chat

Features:
- Official optional skills always appear first (★ marker, cyan styling)
- Per-source limits prevent overloading (100 official/github, 50 others)
- Deduplication by name preferring higher trust
- Sorted: official > trusted > community, then alphabetical
- Page navigation hints at bottom
- Source counts summary
- Works in both CLI and /skills chat interface
- Added 'official' as source filter option for search command too
2026-03-06 01:29:45 -08:00
teknium1
f2e24faaca feat: optional skills — official skills shipped but not activated by default
Add 'optional-skills/' directory for official skills that ship with the repo
but are not copied to ~/.hermes/skills/ during setup. They are:
- NOT shown to the model in the system prompt
- NOT copied during hermes setup/update
- Discoverable via 'hermes skills search' labeled as 'official'
- Installable via 'hermes skills install' with builtin trust (no third-party warning)
- Auto-categorized on install based on directory structure

Implementation:
- OptionalSkillSource adapter in tools/skills_hub.py (search/fetch/inspect)
- Added to create_source_router() as first source (highest priority)
- Trust level 'builtin' for official skills in skills_guard.py
- Friendly install message for official skills (no third-party warning)
- 'official' label in cyan in search results and skill list

First optional skill: Blackbox CLI (autonomous-ai-agents/blackbox)
- Multi-model coding agent with built-in judge/Chairman pattern
- Delegates to Claude, Codex, Gemini, and Blackbox models
- Open-source CLI (GPL-3.0, TypeScript, forked from Gemini CLI)
- Requires paid Blackbox AI API key

Refs: #475
2026-03-06 01:24:11 -08:00
teknium1
8c80b96318 chore: update OpenRouter model list
- Remove opus-4.5 and gpt-5.2
- Reorder GPT: 5.4-pro, 5.4, 5.3-codex
- Add qwen/qwen3.5-plus-02-15 and qwen/qwen3.5-35b-a3b
- Update z-ai/glm-4.7 → glm-5
- Update minimax/minimax-m2.1 → minimax-m2.5
2026-03-06 00:52:45 -08:00
teknium1
2387465dcc chore: add openai/gpt-5.4-pro and stepfun/step-3.5-flash to OpenRouter models 2026-03-06 00:49:25 -08:00
tars90percent
32636ecf8a Update MiniMax model ID from m2.1 to m2.5 2026-03-06 16:47:48 +08:00
ygd58
6055adbe1b fix(config): route API keys and tokens to .env instead of config.yaml 2026-03-06 08:55:36 +01:00
teknium1
ffd2f8dc50 docs: add Vision & Image Paste guide with platform compatibility
New docs page covering clipboard image paste across all platforms:
- Platform compatibility table (macOS, Linux X11/Wayland, WSL2, VSCode, SSH)
- Setup instructions per platform (xclip, wl-paste, powershell.exe)
- Explanation of terminal paste limitations and why /paste exists
- SSH workarounds (file upload, URLs, X11 forwarding, messaging)
- Keybinding reference (Alt+V, Ctrl+V, /paste) with when each works

Also updates CLI commands reference with /paste command and
Alt+V keybinding documentation.
2026-03-05 23:51:46 -08:00
teknium1
e93b4d1dcd feat: Alt+V keybinding for clipboard image paste
Alt key combos pass through all terminal emulators (sent as ESC + key),
unlike Ctrl+V which terminals intercept for text paste. This is the
reliable way to attach clipboard images on WSL2, Windows Terminal,
VSCode, and SSH sessions where Ctrl+V never reaches the application
for image-only clipboard content.

Also adds 'Paste image: Alt+V (or /paste)' hint to /help output.
2026-03-05 22:48:39 -08:00
teknium1
014a5b712d fix: prevent duplicate gateway instances from running simultaneously
start_gateway() now checks for an existing running instance via PID file
before starting. If another gateway is already running under the same
HERMES_HOME, it refuses to start with a clear error message directing the
user to 'hermes gateway restart' or 'hermes gateway stop'.

Also fixes gateway/status.py to respect the HERMES_HOME env var instead of
hardcoding ~/.hermes. This scopes the PID file per HERMES_HOME directory,
which lays the groundwork for future multi-profile support where distinct
HERMES_HOME directories can run concurrent gateway instances independently.
2026-03-05 20:35:33 -08:00
teknium1
2317d115cd fix: clipboard image paste on WSL2, Wayland, and VSCode terminal
The original implementation only supported xclip (X11), which silently
fails on WSL2 (can't access Windows clipboard for images), Wayland
desktops (xclip is X11-only), and VSCode terminal on WSL2.

Clipboard backend changes (hermes_cli/clipboard.py):
- WSL2: detect via /proc/version, use powershell.exe with .NET
  System.Windows.Forms.Clipboard to extract images as base64 PNG
- Wayland: use wl-paste with MIME type detection, auto-convert BMP
  to PNG for WSLg environments (via Pillow or ImageMagick)
- Dispatch order: WSL → Wayland → X11 (xclip), with fallthrough
- New has_clipboard_image() for lightweight clipboard checks
- Cache WSL detection result per-process

CLI changes (cli.py):
- /paste command: explicit clipboard image check for terminals where
  BracketedPaste doesn't fire (image-only clipboard in VSCode/WinTerm)
- Ctrl+V keybinding: fallback for Linux terminals where Ctrl+V sends
  raw byte instead of triggering bracketed paste

Tests: 80 tests (up from 37) covering WSL, Wayland, X11 dispatch,
BMP conversion, has_clipboard_image, and /paste command.
2026-03-05 20:22:44 -08:00
teknium1
8253b54be9 test: strengthen assertions in skill_manager + memory_tool (batch 3)
test_skill_manager_tool.py (20 weak → 0):
  - Validation error messages verified against exact strings
  - Name validation: checks specific invalid name echoed in error
  - Frontmatter validation: exact error text for missing fields,
    unclosed markers, empty content, invalid YAML
  - File path validation: traversal, disallowed dirs, root-level

test_memory_tool.py (13 weak → 0):
  - Security scan tests verify both 'Blocked' prefix AND specific
    threat pattern ID (prompt_injection, exfil_curl, etc.)
  - Invisible unicode tests verify exact codepoint strings
  - Snapshot test verifies type, header, content, and isolation
2026-03-05 18:51:43 -08:00
teknium1
5c867fd79f test: strengthen assertions across 3 more test files (batch 2)
test_run_agent.py (2 weak → 0, +13 assertions):
  - Session ID validated against actual YYYYMMDD_HHMMSS_hex format
  - API failure verifies error message propagation
  - Invalid JSON args verifies empty dict fallback + message structure
  - Context compression verifies final_response + completed flag
  - Invalid tool name retry verifies api_calls count
  - Invalid response verifies completed/failed/error structure

test_model_tools.py (3 weak → 0):
  - Unknown tool error includes tool name in message
  - Exception returns dict with 'error' key + non-empty message
  - get_all_tool_names verifies both web_search AND terminal present

test_approval.py (1 weak → 0, assert ratio 1.1 → 2.2):
  - Dangerous commands verify description content (delete, shell, drop, etc.)
  - Safe commands explicitly assert key AND desc are None
  - Pre/post condition checks for state management
2026-03-05 18:46:30 -08:00
teknium1
a44e041acf test: strengthen assertions across 7 test files (batch 1)
Replaced weak 'is not None' / '> 0' / 'len >= 1' assertions with
concrete value checks across the most flagged test files:

gateway/test_pairing.py (11 weak → 0):
  - Code assertions verify isinstance + len == CODE_LENGTH
  - Approval results verify dict structure + specific user_id/user_name
  - Added code2 != code1 check in rate_limit_expires

test_hermes_state.py (6 weak → 0):
  - ended_at verified as float timestamp
  - Search result counts exact (== 2, not >= 1)
  - Context verified as non-empty list
  - Export verified as dict, session ID verified

test_cli_init.py (4 weak → 0):
  - max_turns asserts exact value (60)
  - model asserts string with provider/name format

gateway/test_hooks.py (2 zero-assert tests → fixed):
  - test_no_handlers_for_event: verifies no handler registered
  - test_handler_error_does_not_propagate: verifies handler count + return

gateway/test_platform_base.py (9 weak image tests → fixed):
  - extract_images tests now verify actual URL and alt_text
  - truncate_message verifies content preservation after splitting

cron/test_scheduler.py (1 weak → 0):
  - resolve_origin verifies dict equality, not just existence

cron/test_jobs.py (2 weak → 0 + 4 new tests):
  - Schedule parsing verifies ISO timestamp type
  - Cron expression verifies result is valid datetime string
  - NEW: 4 tests for update_job() (was completely untested)
2026-03-05 18:39:37 -08:00
teknium1
e9f05b3524 test: comprehensive tests for model metadata + firecrawl config
model_metadata tests (61 tests, was 39):
  - Token estimation: concrete value assertions, unicode, tool_call messages,
    vision multimodal content, additive verification
  - Context length resolution: cache-over-API priority, no-base_url skips cache,
    missing context_length key in API response
  - API metadata fetch: canonical_slug aliasing, TTL expiry with time mock,
    stale cache fallback on API failure, malformed JSON resilience
  - Probe tiers: above-max returns 2M, zero returns None
  - Error parsing: Anthropic format ('X > Y maximum'), LM Studio, empty string,
    unreasonably large numbers — also fixed parser to handle Anthropic format
  - Cache: corruption resilience (garbage YAML, wrong structure), value updates,
    special chars in model names

Firecrawl config tests (8 tests, was 4):
  - Singleton caching (core purpose — verified constructor called once)
  - Constructor failure recovery (retry after exception)
  - Return value actually asserted (not just constructor args)
  - Empty string env vars treated as absent
  - Proper setup/teardown for env var isolation
2026-03-05 18:22:39 -08:00
teknium1
e2a834578d refactor: extract clipboard methods + comprehensive tests (37 tests)
Refactored image paste internals for testability:
- Extracted _try_attach_clipboard_image() method (clipboard → state)
- Extracted _build_multimodal_content() method (images → OpenAI format)
- chat() now delegates to these instead of inline logic

Tests organized in 4 levels:
  Level 1 (19 tests): Clipboard module — every platform path with
    realistic subprocess simulation (tools writing files, timeouts,
    empty files, cleanup on failure)
  Level 2 (8 tests): _build_multimodal_content — base64 encoding,
    MIME types (png/jpg/webp/unknown), missing files, multiple images,
    default question for empty text
  Level 3 (5 tests): _try_attach_clipboard_image — state management,
    counter increment/rollback, naming convention, mixed success/failure
  Level 4 (5 tests): Queue routing — tuple unpacking, command detection,
    images-only payloads, text-only payloads
2026-03-05 18:07:53 -08:00
teknium1
ffc752a79e test: improve clipboard tests with realistic scenarios and multimodal coverage
Rewrote clipboard tests from 11 shallow mocks to 21 realistic tests:
- Success paths now simulate tools actually writing files (not pre-created)
- osascript: success with PNG, success with TIFF, extraction-fail cases
- pngpaste: empty file rejection edge case
- Linux: extraction failure cleanup verification
- New TestMultimodalConversion class: base64 encoding, MIME types,
  multiple images, missing file handling, default question fallback
2026-03-05 17:58:06 -08:00
teknium1
399562a7d1 feat: clipboard image paste in CLI (Cmd+V / Ctrl+V)
Copy an image to clipboard (screenshot, browser, etc.) and paste into
the Hermes CLI. The image is saved to ~/.hermes/images/, shown as a
badge above the input ([📎 Image #1]), and sent to the model as a
base64-encoded OpenAI vision multimodal content block.

Implementation:
- hermes_cli/clipboard.py: clean module with platform-specific extraction
  - macOS: pngpaste (if installed) → osascript fallback (always available)
  - Linux: xclip (apt install xclip)
- cli.py: BracketedPaste key handler checks clipboard on every paste,
  image bar widget shows attached images, chat() converts to multimodal
  content format, Ctrl+C clears attachments

Inspired by @m0at's fork (https://github.com/m0at/hermes-agent) which
implemented image paste support for local vision models. Reimplemented
cleanly as a separate module with tests.
2026-03-05 17:55:41 -08:00
teknium1
fec8a0da72 Merge PR #296: fix(cron): close lock_fd on failed flock to prevent fd leak
Authored by alireza78a. When flock() raises on a concurrent tick, the
file descriptor was leaked because the except clause returned without
closing it. Adds lock_fd=None init and close in the except path.
2026-03-05 17:05:06 -08:00
teknium1
9f4542b3db fix: require Python 3.11+ in pyproject.toml
Was incorrectly set to >=3.10. Hermes uses tomllib and other 3.11+
features. CONTRIBUTING.md and README already say 3.11+.
2026-03-05 17:04:08 -08:00
teknium1
363633e2ba fix: allow self-hosted Firecrawl without API key + add self-hosting docs
On top of PR #460: self-hosted Firecrawl instances don't require an API
key (USE_DB_AUTHENTICATION=false), so don't force users to set a dummy
FIRECRAWL_API_KEY when FIRECRAWL_API_URL is set. Also adds a proper
self-hosting section to the configuration docs explaining what you get,
what you lose, and how to set it up (Docker stack, tradeoffs vs cloud).

Added 2 more tests (URL-only without key, neither-set raises).
2026-03-05 16:44:21 -08:00
teknium1
a41ba57a7a Merge PR #460: feat(tools): add support for self-hosted firecrawl
Authored by caentzminger. Adds optional FIRECRAWL_API_URL env var to point
the Firecrawl client at a self-hosted instance instead of the cloud API.
2026-03-05 16:41:30 -08:00
teknium1
884c8ea70a chore: add openai/gpt-5.4 to OpenRouter preferred models list 2026-03-05 16:13:45 -08:00
teknium1
c886333d32 feat: smart context length probing with persistent caching + banner display
Replaces the unsafe 128K fallback for unknown models with a descending
probe strategy (2M → 1M → 512K → 200K → 128K → 64K → 32K). When a
context-length error occurs, the agent steps down tiers and retries.
The discovered limit is cached per model+provider combo in
~/.hermes/context_length_cache.yaml so subsequent sessions skip probing.

Also parses API error messages to extract the actual context limit
(e.g. 'maximum context length is 32768 tokens') for instant resolution.

The CLI banner now displays the context window size next to the model
name (e.g. 'claude-opus-4 · 200K context · Nous Research').

Changes:
- agent/model_metadata.py: CONTEXT_PROBE_TIERS, persistent cache
  (save/load/get), parse_context_limit_from_error(), get_next_probe_tier()
- agent/context_compressor.py: accepts base_url, passes to metadata
- run_agent.py: step-down logic in context error handler, caches on success
- cli.py + hermes_cli/banner.py: context length in welcome banner
- tests: 22 new tests for probing, parsing, and caching

Addresses #132. PR #319's approach (8K default) rejected — too conservative.
2026-03-05 16:09:57 -08:00
teknium1
55b173dd03 refactor: move shutil import to module level
Cleanup on top of PR #305 — replace two inline 'import shutil as _shutil'
with a single module-level import.
2026-03-05 15:57:05 -08:00
dmahan93
9079a27814 fix: prompt box and response box span full terminal width on wide screens
- Replace hardcoded '─' * 200 horizontal rules with Window(char='─')
  so prompt_toolkit fills the entire terminal width automatically
- Use shutil.get_terminal_size().columns instead of Rich Console.width
  for response box, separator line, and input height calculation
  (more reliable inside patch_stdout context)
2026-03-05 15:57:05 -08:00
caentzminger
d7d10b14cd feat(tools): add support for self-hosted firecrawl
Adds optional FIRECRAWL_API_URL environment variable to support
self-hosted Firecrawl deployments alongside the cloud service.

- Add FIRECRAWL_API_URL to optional env vars in hermes_cli/config.py
- Update _get_firecrawl_client() in tools/web_tools.py to accept custom API URL
- Add tests for client initialization with/without URL
- Document new env var in installation and config guides
2026-03-05 16:16:18 -06:00
rovle
a6499b6107 fix(daytona): use shell timeout wrapper instead of broken SDK exec timeout
The Daytona SDK's process.exec(timeout=N) parameter is not enforced —
the server-side timeout never fires and the SDK has no client-side
fallback, causing commands to hang indefinitely.

Fix: wrap commands with timeout N sh -c '...' (coreutils) which
reliably kills the process and returns exit code 124. Added
shlex.quote for proper shell escaping and a secondary deadline (timeout + 10s) that force-stops the sandbox if the shell timeout somehow fails.

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 13:12:41 -08:00
rovle
74a36b0729 docs: add Daytona to backend lists in docs
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:55:41 -08:00
rovle
efc7a7b957 fix(daytona): don't guess /root on cwd probe failure, keep constructor default; update tests to reflect this
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:49:35 -08:00
rovle
4f1464b3af fix(daytona): default disk to 10GB to match platform limit
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:37:30 -08:00
rovle
3a41079fac fix(daytona): add optional dependency group to pyproject.toml
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:13:12 -08:00
rovle
5279540bb4 fix(daytona): add missing config mappings in gateway, CLI defaults, and config display
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:12:50 -08:00
rovle
577da79a47 fix(daytona): make disk cap visible and use SDK enum for sandbox
state

- Replace logger.warning with warnings.warn for the disk cap so users
  actually see it (logger was suppressed by CLI's log level config)
- Use SandboxState enum instead of string literals in
_ensure_sandbox_ready

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 11:03:39 -08:00
rovle
1faa9648d3 chore(daytona): cap the disk size to current maximum on daytona sandboxes
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:43:41 -08:00
PercyDikec
ad57bf1e4b fix(cli): use correct dict key for codex auth file path in status output 2026-03-05 21:27:12 +03:00
rovle
d5efb82c7c test(daytona): add unit and integration tests for Daytona backend
Unit tests cover cwd resolution, sandbox persistence/resume, cleanup,
command execution, resource conversion, interrupt handling, retry
exhaustion, and sandbox readiness checks. Integration tests verify
basic commands, filesystem ops, session persistence, and task
isolation against a live Daytona API.

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:26:22 -08:00
rovle
ea2f7ef2f6 docs(config): add Daytona disk limit hint and fix default cwd in example
Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:02:22 -08:00
rovle
435530018b fix(daytona): resolve cwd by detecting home directory inside the sandbox 2026-03-05 10:02:22 -08:00
rovle
df61054a84 feat(cli): add Daytona to setup wizard, doctor, and status display
Add Daytona as a backend choice in the interactive setup wizard with
SDK installation and API key prompts. Show Daytona image in status
output and validate API key + SDK in doctor checks. Add OPTION 6
example in cli-config.yaml.example.

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:02:22 -08:00
rovle
690b8bb563 feat(cli): add Daytona config mapping and env var sync
Wire TERMINAL_DAYTONA_IMAGE through cli.py env_mappings and
hermes_cli/config.py so `hermes config set` propagates correctly.
2026-03-05 10:02:21 -08:00
rovle
c43451a50b feat(terminal): integrate Daytona backend into tool pipeline
Add Daytona to image selection, container_config guards, environment
factory, requirements check, and diagnostics in terminal_tool.py and
file_tools.py. Also add to sandboxed-backend approval bypass.

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:02:21 -08:00
rovle
1e312c6582 feat(environments): add Daytona cloud sandbox backend
New execution backend using the Daytona Python SDK. Supports persistent
sandboxes via stop/start lifecycle, interrupt handling, and automatic
retry on transient errors.

Signed-off-by: rovle <lovre.pesut@gmail.com>
2026-03-05 10:02:21 -08:00
PercyDikec
e36c8cd49a fix: add missing re.DOTALL flag to DeepSeek V3 tool call parser 2026-03-05 20:32:38 +03:00
PercyDikec
16cb6d1a6e fix(gateway): return response from /retry handler instead of discarding it 2026-03-05 19:59:54 +03:00
Farukest
e25ad79d5d fix: use _max_tokens_param in max-iterations retry path
The retry summary in _handle_max_iterations hardcodes max_tokens instead
of calling _max_tokens_param(). For direct OpenAI API users (gpt-4o,
o-series), the correct parameter name is max_completion_tokens. The first
attempt at line 2697 already uses _max_tokens_param correctly but the
retry path at line 2743 was missed.
2026-03-05 17:49:37 +03:00
Farukest
82cb1752d9 fix(whatsapp): replace Linux-only fuser with cross-platform port cleanup
fuser command does not exist on Windows, causing orphaned bridge processes
to never be cleaned up. On crash recovery, the port stays occupied and the
next connect() fails with address-already-in-use.

Add _kill_port_process() helper that uses netstat+taskkill on Windows and
fuser on Linux/macOS. Replace both call sites in connect() and disconnect().
2026-03-05 17:13:14 +03:00
Dev User
3221818b6e fix: respect OPENAI_BASE_URL when resolving API key priority
When base_url points to a non-OpenRouter endpoint (e.g. Z.ai),
OPENROUTER_API_KEY incorrectly takes priority over OPENAI_API_KEY,
sending the wrong credentials. This causes 401 errors on the main
inference path and forces users to comment out OPENROUTER_API_KEY,
which then breaks auxiliary clients (compression, vision).

Fix: check whether base_url contains "openrouter" and swap the key
priority accordingly. Also adds GLM-4.7 and GLM-5 context lengths
to DEFAULT_CONTEXT_LENGTHS.
2026-03-05 08:25:16 +00:00
areu01or00
a1c25046a9 fix(timezone): add timezone-aware clock across agent, cron, and execute_code 2026-03-03 18:23:40 +05:30
BathreeNode
d10108f8ca fix: rename misspelled directory 'fouth-edition' to 'fourth-edition'
The ECMA schema directory was misspelled as 'fouth-edition'
instead of 'fourth-edition'. Renamed all 4 files within to
correct the path:

- opc-contentTypes.xsd
- opc-coreProperties.xsd
- opc-digSig.xsd
- opc-relationships.xsd
2026-03-03 09:21:28 +03:00
BathreeNode
8b520f9848 fix: rename misspelled directory 'fouth-edition' to 'fourth-edition'
The ECMA schema directory was misspelled as 'fouth-edition'
instead of 'fourth-edition'. Renamed all 4 files within to
correct the path:

- opc-contentTypes.xsd
- opc-coreProperties.xsd
- opc-digSig.xsd
- opc-relationships.xsd
2026-03-03 09:20:47 +03:00
BathreeNode
a718aed1be fix: rename misspelled directory 'fouth-edition' to 'fourth-edition'
The ECMA schema directory was misspelled as 'fouth-edition'
instead of 'fourth-edition'. Renamed all 4 files within to
correct the path:

- opc-contentTypes.xsd
- opc-coreProperties.xsd
- opc-digSig.xsd
- opc-relationships.xsd
2026-03-03 09:20:07 +03:00
BathreeNode
5f29e7b63c fix: rename misspelled directory 'fouth-edition' to 'fourth-edition'
The ECMA schema directory was misspelled as 'fouth-edition'
instead of 'fourth-edition'. Renamed all 4 files within to
correct the path:

- opc-contentTypes.xsd
- opc-coreProperties.xsd
- opc-digSig.xsd
- opc-relationships.xsd
2026-03-03 09:17:13 +03:00
aydnOktay
5fa3e24b76 Make process_registry checkpoint writes atomic 2026-03-03 02:44:01 +03:00
aydnOktay
ac6d747fa6 Make batch_runner checkpoint incremental and atomic 2026-03-03 01:43:07 +03:00
alireza78a
ee541c84f1 fix(cron): close lock_fd on failed flock to prevent fd leak 2026-03-03 02:09:56 +03:30
202 changed files with 41658 additions and 2339 deletions

View File

@@ -13,6 +13,34 @@ OPENROUTER_API_KEY=
# Examples: anthropic/claude-opus-4.6, openai/gpt-4o, google/gemini-3-flash-preview, zhipuai/glm-4-plus
LLM_MODEL=anthropic/claude-opus-4.6
# =============================================================================
# LLM PROVIDER (z.ai / GLM)
# =============================================================================
# z.ai provides access to ZhipuAI GLM models (GLM-4-Plus, etc.)
# Get your key at: https://z.ai or https://open.bigmodel.cn
GLM_API_KEY=
# GLM_BASE_URL=https://api.z.ai/api/paas/v4 # Override default base URL
# =============================================================================
# LLM PROVIDER (Kimi / Moonshot)
# =============================================================================
# Kimi/Moonshot provides access to Moonshot AI coding models
# Get your key at: https://platform.moonshot.ai
KIMI_API_KEY=
# KIMI_BASE_URL=https://api.moonshot.ai/v1 # Override default base URL
# =============================================================================
# LLM PROVIDER (MiniMax)
# =============================================================================
# MiniMax provides access to MiniMax models (global endpoint)
# Get your key at: https://www.minimax.io
MINIMAX_API_KEY=
# MINIMAX_BASE_URL=https://api.minimax.io/v1 # Override default base URL
# MiniMax China endpoint (for users in mainland China)
MINIMAX_CN_API_KEY=
# MINIMAX_CN_BASE_URL=https://api.minimaxi.com/v1 # Override default base URL
# =============================================================================
# TOOL API KEYS
# =============================================================================

View File

@@ -44,7 +44,8 @@ hermes-agent/
│ │ ├── docker.py # Docker container execution
│ │ ├── ssh.py # SSH remote execution
│ │ ├── singularity.py # Singularity/Apptainer + SIF management
│ │ ── modal.py # Modal cloud execution
│ │ ── modal.py # Modal cloud execution
│ │ └── daytona.py # Daytona cloud sandboxes
│ ├── terminal_tool.py # Terminal orchestration (sudo, lifecycle, factory)
│ ├── todo_tool.py # Planning & task management
│ ├── process_registry.py # Background process management
@@ -55,6 +56,7 @@ hermes-agent/
├── cron/ # Scheduler implementation
├── environments/ # RL training environments (Atropos integration)
├── skills/ # Bundled skill sources
├── optional-skills/ # Official optional skills (not activated by default)
├── cli.py # Interactive CLI orchestrator (HermesCLI class)
├── run_agent.py # AIAgent class (core conversation loop)
├── model_tools.py # Tool orchestration (thin layer over tools/registry.py)
@@ -202,7 +204,7 @@ Every installed skill in `~/.hermes/skills/` is automatically registered as a sl
The skill name (from frontmatter or folder name) becomes the command: `axolotl``/axolotl`.
Implementation (`agent/skill_commands.py`, shared between CLI and gateway):
1. `scan_skill_commands()` scans all SKILL.md files at startup
1. `scan_skill_commands()` scans all SKILL.md files at startup, filtering out skills incompatible with the current OS platform (via the `platforms` frontmatter field)
2. `build_skill_invocation_message()` loads the SKILL.md content and builds a user-turn message
3. The message includes the full skill content, a list of supporting files (not loaded), and the user's instruction
4. Supporting files can be loaded on demand via the `skill_view` tool
@@ -421,16 +423,19 @@ The system uses `_config_version` to detect outdated configs:
API keys are loaded from `~/.hermes/.env`:
- `OPENROUTER_API_KEY` - Main LLM API access (primary provider)
- `FIRECRAWL_API_KEY` - Web search/extract tools
- `FIRECRAWL_API_URL` - Self-hosted Firecrawl endpoint (optional)
- `BROWSERBASE_API_KEY` / `BROWSERBASE_PROJECT_ID` - Browser automation
- `FAL_KEY` - Image generation (FLUX model)
- `NOUS_API_KEY` - Vision and Mixture-of-Agents tools
Terminal tool configuration (in `~/.hermes/config.yaml`):
- `terminal.backend` - Backend: local, docker, singularity, modal, or ssh
- `terminal.backend` - Backend: local, docker, singularity, modal, daytona, or ssh
- `terminal.cwd` - Working directory ("." = host CWD for local only; for remote backends set an absolute path inside the target, or omit to use the backend's default)
- `terminal.docker_image` - Image for Docker backend
- `terminal.singularity_image` - Image for Singularity backend
- `terminal.modal_image` - Image for Modal backend
- `terminal.daytona_image` - Image for Daytona backend
- `DAYTONA_API_KEY` - API key for Daytona backend (in .env)
- SSH: `TERMINAL_SSH_HOST`, `TERMINAL_SSH_USER`, `TERMINAL_SSH_KEY` in .env
Agent behavior (in `~/.hermes/.env`):
@@ -494,7 +499,7 @@ terminal(command="pytest -v tests/", background=true)
- `process(action="submit", session_id="proc_abc123", data="yes")` -- send + Enter
**Key behaviors:**
- Background processes execute through the configured terminal backend (local/Docker/Modal/SSH/Singularity) -- never directly on the host unless `TERMINAL_ENV=local`
- Background processes execute through the configured terminal backend (local/Docker/Modal/Daytona/SSH/Singularity) -- never directly on the host unless `TERMINAL_ENV=local`
- The `wait` action blocks the tool call until the process finishes, times out, or is interrupted by a new user message
- PTY mode (`pty=true` on terminal) enables interactive CLI tools (Codex, Claude Code)
- In RL training, background processes are auto-killed when the episode ends (`tool_context.cleanup()`)
@@ -652,6 +657,7 @@ SKILL.md files use YAML frontmatter (agentskills.io format):
name: skill-name
description: Brief description for listing
version: 1.0.0
platforms: [macos] # Optional — restrict to specific OS (macos/linux/windows)
metadata:
hermes:
tags: [tag1, tag2]
@@ -660,12 +666,14 @@ metadata:
# Skill Content...
```
**Skills Hub** — user-driven skill search/install from online registries (GitHub, ClawHub, Claude marketplaces, LobeHub). Not exposed as an agent tool — the model cannot search for or install skills. Users manage skills via `hermes skills ...` CLI commands or the `/skills` slash command in chat.
**Platform filtering** — Skills with a `platforms` field are automatically excluded from the system prompt index, `skills_list()`, and slash commands on incompatible platforms. Skills without the field load everywhere (backward compatible). See `skills/apple/` for macOS-only examples (iMessage, Reminders, Notes, FindMy).
**Skills Hub** — user-driven skill search/install from online registries and official optional skills. Sources: official optional skills (shipped with repo, labeled "official"), GitHub (openai/skills, anthropics/skills, custom taps), ClawHub, Claude marketplace, LobeHub. Not exposed as an agent tool — the model cannot search for or install skills. Users manage skills via `hermes skills browse/search/install` CLI commands or the `/skills` slash command in chat.
Key files:
- `tools/skills_tool.py` — Agent-facing skill list/view (progressive disclosure)
- `tools/skills_guard.py` — Security scanner (regex + LLM audit, trust-aware install policy)
- `tools/skills_hub.py` — Source adapters (GitHub, ClawHub, Claude marketplace, LobeHub), lock file, auth
- `tools/skills_hub.py` — Source adapters (OptionalSkillSource, GitHub, ClawHub, Claude marketplace, LobeHub), lock file, auth
- `hermes_cli/skills_hub.py` — CLI subcommands + `/skills` slash command handler
---

View File

@@ -43,7 +43,9 @@ Bundled skills (in `skills/`) ship with every Hermes install. They should be **b
- Document handling, web research, common dev workflows, system administration
- Used regularly by a wide range of people
If your skill is specialized (a niche engineering tool, a specific SaaS integration, a game), it's better suited for a **Skills Hub**upload it to a skills registry and share it in the [Nous Research Discord](https://discord.gg/NousResearch). Users can install it with `hermes skills install`.
If your skill is official and useful but not universally needed (e.g., a paid service integration, a heavyweight dependency), put it in **`optional-skills/`** — it ships with the repo but isn't activated by default. Users can discover it via `hermes skills browse` (labeled "official") and install it with `hermes skills install` (no third-party warning, builtin trust).
If your skill is specialized, community-contributed, or niche, it's better suited for a **Skills Hub** — upload it to a skills registry and share it in the [Nous Research Discord](https://discord.gg/NousResearch). Users can install it with `hermes skills install`.
---
@@ -153,7 +155,7 @@ hermes-agent/
│ ├── skill_tools.py # Skill search, load, manage
│ └── environments/ # Terminal execution backends
│ ├── base.py # BaseEnvironment ABC
│ ├── local.py, docker.py, ssh.py, singularity.py, modal.py
│ ├── local.py, docker.py, ssh.py, singularity.py, modal.py, daytona.py
├── gateway/ # Messaging gateway
│ ├── run.py # GatewayRunner — platform lifecycle, message routing, cron
@@ -168,6 +170,7 @@ hermes-agent/
│ └── whatsapp-bridge/ # Node.js WhatsApp bridge (Baileys)
├── skills/ # Bundled skills (copied to ~/.hermes/skills/ on install)
├── optional-skills/ # Official optional skills (discoverable via hub, not activated by default)
├── environments/ # RL training environments (Atropos integration)
├── tests/ # Test suite
├── website/ # Documentation site (hermes-agent.nousresearch.com)
@@ -294,9 +297,9 @@ If it's a new toolset, add it to `toolsets.py` and to the relevant platform pres
---
## Adding a Bundled Skill
## Adding a Skill
Bundled skills live in `skills/` organized by category:
Bundled skills live in `skills/` organized by category. Official optional skills use the same structure in `optional-skills/`:
```
skills/
@@ -322,6 +325,9 @@ description: Brief description (shown in skill search results)
version: 1.0.0
author: Your Name
license: MIT
platforms: [macos, linux] # Optional — restrict to specific OS platforms
# Valid: macos, linux, windows
# Omit to load on all platforms (default)
metadata:
hermes:
tags: [Category, Subcategory, Keywords]
@@ -348,6 +354,18 @@ Known failure modes and how to handle them.
How the agent confirms it worked.
```
### Platform-specific skills
Skills can declare which OS platforms they support via the `platforms` frontmatter field. Skills with this field are automatically hidden from the system prompt, `skills_list()`, and slash commands on incompatible platforms.
```yaml
platforms: [macos] # macOS only (e.g., iMessage, Apple Reminders)
platforms: [macos, linux] # macOS and Linux
platforms: [windows] # Windows only
```
If the field is omitted or empty, the skill loads on all platforms (backward compatible). See `skills/apple/` for examples of macOS-only skills.
### Skill guidelines
- **No external dependencies unless absolutely necessary.** Prefer stdlib Python, curl, and existing Hermes tools (`web_extract`, `terminal`, `read_file`).

View File

@@ -11,17 +11,17 @@
<a href="https://nousresearch.com"><img src="https://img.shields.io/badge/Built%20by-Nous%20Research-blueviolet?style=for-the-badge" alt="Built by Nous Research"></a>
</p>
**The fully open-source AI agent that grows with you.** Install it on a machine, give it your messaging accounts, and it becomes a persistent personal agent — learning your projects, building its own skills, running tasks on a schedule, and reaching you wherever you are.
**The self-improving AI agent built by [Nous Research](https://nousresearch.com).** It's the only agent with a built-in learning loop — it creates skills from experience, improves them during use, nudges itself to persist knowledge, searches its own past conversations, and builds a deepening model of who you are across sessions. Run it on a $5 VPS, a GPU cluster, or serverless infrastructure that costs nearly nothing when idle. It's not tied to your laptop — talk to it from Telegram while it works on a cloud VM.
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai), OpenAI Codex, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
Use any model you want — [Nous Portal](https://portal.nousresearch.com), [OpenRouter](https://openrouter.ai) (200+ models), [z.ai/GLM](https://z.ai), [Kimi/Moonshot](https://platform.moonshot.ai), [MiniMax](https://www.minimax.io), OpenAI, or your own endpoint. Switch with `hermes model` — no code changes, no lock-in.
<table>
<tr><td><b>A real terminal interface</b></td><td>Full TUI with multiline editing, slash-command autocomplete, conversation history, interrupt-and-redirect, and streaming tool output.</td></tr>
<tr><td><b>Lives where you do</b></td><td>Telegram, Discord, Slack, WhatsApp, and CLI — all from a single gateway process. Voice memo transcription, cross-platform conversation continuity.</td></tr>
<tr><td><b>Grows the longer it runs</b></td><td>Persistent memory across sessions. When it solves a hard problem, it writes a skill document for next time. Skills are searchable, shareable, and compatible with the <a href="https://agentskills.io">agentskills.io</a> open standard.</td></tr>
<tr><td><b>A closed learning loop</b></td><td>Agent-curated memory with periodic nudges. Autonomous skill creation after complex tasks. Skills self-improve during use. FTS5 session search with LLM summarization for cross-session recall. <a href="https://github.com/plastic-labs/honcho">Honcho</a> dialectic user modeling. Compatible with the <a href="https://agentskills.io">agentskills.io</a> open standard.</td></tr>
<tr><td><b>Scheduled automations</b></td><td>Built-in cron scheduler with delivery to any platform. Daily reports, nightly backups, weekly audits — all in natural language, running unattended.</td></tr>
<tr><td><b>Delegates and parallelizes</b></td><td>Spawn isolated subagents for parallel workstreams. Write Python scripts that call tools via RPC, collapsing multi-step pipelines into zero-context-cost turns.</td></tr>
<tr><td><b>Real sandboxing</b></td><td>Five terminal backends — local, Docker, SSH, Singularity, and Modal — with persistent workspaces and container security hardening.</td></tr>
<tr><td><b>Runs anywhere, not just your laptop</b></td><td>Six terminal backends — local, Docker, SSH, Daytona, Singularity, and Modal. Daytona and Modal offer serverless persistence — your agent's environment hibernates when idle and wakes on demand, costing nearly nothing between sessions. Run it on a $5 VPS or a GPU cluster.</td></tr>
<tr><td><b>Research-ready</b></td><td>Batch trajectory generation, Atropos RL environments, trajectory compression for training the next generation of tool-calling models.</td></tr>
</table>

129
TODO.md
View File

@@ -1,129 +0,0 @@
# Hermes Agent - Future Improvements
---
## 3. Local Browser Control via CDP 🌐
**Status:** Not started (currently Browserbase cloud only)
**Priority:** Medium
Support local Chrome/Chromium via Chrome DevTools Protocol alongside existing Browserbase cloud backend.
**What other agents do:**
- **OpenClaw**: Full CDP-based Chrome control with snapshots, actions, uploads, profiles, file chooser, PDF save, console messages, tab management. Uses local Chrome for persistent login sessions.
- **Cline**: Headless browser with Computer Use (click, type, scroll, screenshot, console logs)
**Our approach:**
- Add a `local` backend option to `browser_tool.py` using Playwright or raw CDP
- Config toggle: `browser.backend: local | browserbase | auto`
- `auto` mode: try local first, fall back to Browserbase
- Local advantages: free, persistent login sessions, no API key needed
- Local disadvantages: no CAPTCHA solving, no stealth mode, requires Chrome installed
- Reuse the same 10-tool interface -- just swap the backend
- Later: Chrome profile management for persistent sessions across restarts
---
## 4. Signal Integration 📡
**Status:** Not started
**Priority:** Low
New platform adapter using signal-cli daemon (JSON-RPC HTTP + SSE). Requires Java runtime and phone number registration.
**Reference:** OpenClaw has Signal support via signal-cli.
---
## 5. Plugin/Extension System 🔌
**Status:** Partially implemented (event hooks exist in `gateway/hooks.py`)
**Priority:** Medium
Full Python plugin interface that goes beyond the current hook system.
**What other agents do:**
- **OpenClaw**: Plugin SDK with tool-send capabilities, lifecycle phase hooks (before-agent-start, after-tool-call, model-override), plugin registry with install/uninstall.
- **Pi**: Extensions are TypeScript modules that can register tools, commands, keyboard shortcuts, custom UI widgets, overlays, status lines, dialogs, compaction hooks, raw terminal input listeners. Extremely comprehensive.
- **OpenCode**: MCP client support (stdio, SSE, StreamableHTTP), OAuth auth for MCP servers. Also has Copilot/Codex plugins.
- **Codex**: Full MCP integration with skill dependencies.
- **Cline**: MCP integration + lifecycle hooks with cancellation support.
**Our approach (phased):**
### Phase 1: Enhanced hooks
- Expand the existing `gateway/hooks.py` to support more events: `before-tool-call`, `after-tool-call`, `before-response`, `context-compress`, `session-end`
- Allow hooks to modify tool results (e.g., filter sensitive output)
### Phase 2: Plugin interface
- `~/.hermes/plugins/<name>/plugin.yaml` + `handler.py`
- Plugins can: register new tools, add CLI commands, subscribe to events, inject system prompt sections
- `hermes plugin list|install|uninstall|create` CLI commands
- Plugin discovery and validation on startup
### Phase 3: MCP support (industry standard) ✅ DONE
- ✅ MCP client that connects to external MCP servers (stdio + HTTP/StreamableHTTP)
- ✅ Config: `mcp_servers` in config.yaml with connection details
- ✅ Each MCP server's tools auto-registered as a dynamic toolset
- Future: Resources, Prompts, Progress notifications, `hermes mcp` CLI command
---
## 6. MCP (Model Context Protocol) Support 🔗 ✅ DONE
**Status:** Implemented (PR #301)
**Priority:** Complete
Native MCP client support with stdio and HTTP/StreamableHTTP transports, auto-discovery, reconnection with exponential backoff, env var filtering, and credential stripping. See `docs/mcp.md` for full documentation.
**Still TODO:**
- `hermes mcp` CLI subcommand (list/test/status)
- `hermes tools` UI integration for MCP toolsets
- MCP Resources and Prompts support
- OAuth authentication for remote servers
- Progress notifications for long-running tools
---
## 8. Filesystem Checkpointing / Rollback 🔄
**Status:** Not started
**Priority:** Low-Medium
Automatic filesystem snapshots after each agent loop iteration so the user can roll back destructive changes to their project.
**What other agents do:**
- **Cline**: Workspace checkpoints at each step with Compare/Restore UI
- **OpenCode**: Git-backed workspace snapshots per step, with weekly gc
- **Codex**: Sandboxed execution with commit-per-step, rollback on failure
**Our approach:**
- After each tool call (or batch of tool calls in a single turn) that modifies files, create a lightweight checkpoint of the affected files
- Git-based when the project is a repo: auto-commit to a detached/temporary branch (`hermes/checkpoints/<session>`) after each agent turn, squash or discard on session end
- Non-git fallback: tar snapshots of changed files in `~/.hermes/checkpoints/<session_id>/`
- `hermes rollback` CLI command to restore to a previous checkpoint
- Agent-accessible via a `checkpoint` tool: `list` (show available restore points), `restore` (roll back to a named point), `diff` (show what changed since a checkpoint)
- Configurable: off by default (opt-in via `config.yaml`), since auto-committing can be surprising
- Cleanup: checkpoints expire after session ends (or configurable retention period)
- Integration with the terminal backend: works with local, SSH, and Docker backends (snapshots happen on the execution host)
---
## Implementation Priority Order
### Tier 1: Next Up
1. ~~MCP Support -- #6~~ ✅ Done (PR #301)
### Tier 2: Quality of Life
3. Local Browser Control via CDP -- #3
4. Plugin/Extension System -- #5
### Tier 3: Nice to Have
5. Session Branching / Checkpoints -- #7
6. Filesystem Checkpointing / Rollback -- #8
7. Signal Integration -- #4

View File

@@ -4,18 +4,20 @@ Provides a single resolution chain so every consumer (context compression,
session search, web extraction, vision analysis, browser vision) picks up
the best available backend without duplicating fallback logic.
Resolution order for text tasks:
Resolution order (same for text and vision tasks):
1. OpenRouter (OPENROUTER_API_KEY)
2. Nous Portal (~/.hermes/auth.json active provider)
3. Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY)
4. Codex OAuth (Responses API via chatgpt.com with gpt-5.3-codex,
wrapped to look like a chat.completions client)
5. None
5. Direct API-key providers (z.ai/GLM, Kimi/Moonshot, MiniMax, MiniMax-CN)
— checked via PROVIDER_REGISTRY entries with auth_type='api_key'
6. None
Resolution order for vision/multimodal tasks:
1. OpenRouter
2. Nous Portal
3. None (custom endpoints can't substitute for Gemini multimodal)
Per-task provider overrides (e.g. AUXILIARY_VISION_PROVIDER,
CONTEXT_COMPRESSION_PROVIDER) can force a specific provider for each task:
"openrouter", "nous", or "main" (= steps 3-5).
Default "auto" follows the full chain above.
"""
import json
@@ -31,6 +33,14 @@ from hermes_constants import OPENROUTER_BASE_URL
logger = logging.getLogger(__name__)
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
"zai": "glm-4.5-flash",
"kimi-coding": "kimi-k2-turbo-preview",
"minimax": "MiniMax-M2.5-highspeed",
"minimax-cn": "MiniMax-M2.5-highspeed",
}
# OpenRouter app attribution headers
_OR_HEADERS = {
"HTTP-Referer": "https://github.com/NousResearch/hermes-agent",
@@ -282,53 +292,159 @@ def _read_codex_access_token() -> Optional[str]:
return None
# ── Public API ──────────────────────────────────────────────────────────────
def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
"""Try each API-key provider in PROVIDER_REGISTRY order.
def get_text_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
"""Return (client, model_slug) for text-only auxiliary tasks.
Falls through OpenRouter -> Nous Portal -> custom endpoint -> Codex OAuth -> (None, None).
Returns (client, model) for the first provider whose env var is set,
or (None, None) if none are configured.
"""
# 1. OpenRouter
or_key = os.getenv("OPENROUTER_API_KEY")
if or_key:
logger.debug("Auxiliary text client: OpenRouter")
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL,
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
try:
from hermes_cli.auth import PROVIDER_REGISTRY
except ImportError:
logger.debug("Could not import PROVIDER_REGISTRY for API-key fallback")
return None, None
# 2. Nous Portal
nous = _read_nous_auth()
if nous:
global auxiliary_is_nous
auxiliary_is_nous = True
logger.debug("Auxiliary text client: Nous Portal")
return (
OpenAI(api_key=_nous_api_key(nous), base_url=_nous_base_url()),
_NOUS_MODEL,
)
for provider_id, pconfig in PROVIDER_REGISTRY.items():
if pconfig.auth_type != "api_key":
continue
# Check if any of the provider's env vars are set
api_key = ""
for env_var in pconfig.api_key_env_vars:
val = os.getenv(env_var, "").strip()
if val:
api_key = val
break
if not api_key:
continue
# Resolve base URL (with optional env-var override)
base_url = pconfig.inference_base_url
if pconfig.base_url_env_var:
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
if env_url:
base_url = env_url.rstrip("/")
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
return OpenAI(api_key=api_key, base_url=base_url), model
# 3. Custom endpoint (both base URL and key must be set)
custom_base = os.getenv("OPENAI_BASE_URL")
custom_key = os.getenv("OPENAI_API_KEY")
if custom_base and custom_key:
model = os.getenv("OPENAI_MODEL") or os.getenv("LLM_MODEL") or "gpt-4o-mini"
logger.debug("Auxiliary text client: custom endpoint (%s)", model)
return OpenAI(api_key=custom_key, base_url=custom_base), model
# 4. Codex OAuth -- uses the Responses API (only endpoint the token
# can access), wrapped to look like a chat.completions client.
codex_token = _read_codex_access_token()
if codex_token:
logger.debug("Auxiliary text client: Codex OAuth (%s via Responses API)", _CODEX_AUX_MODEL)
real_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
# 5. Nothing available
logger.debug("Auxiliary text client: none available")
return None, None
def get_async_text_auxiliary_client():
# ── Provider resolution helpers ─────────────────────────────────────────────
def _get_auxiliary_provider(task: str = "") -> str:
"""Read the provider override for a specific auxiliary task.
Checks AUXILIARY_{TASK}_PROVIDER first (e.g. AUXILIARY_VISION_PROVIDER),
then CONTEXT_{TASK}_PROVIDER (for the compression section's summary_provider),
then falls back to "auto". Returns one of: "auto", "openrouter", "nous", "main".
"""
if task:
for prefix in ("AUXILIARY_", "CONTEXT_"):
val = os.getenv(f"{prefix}{task.upper()}_PROVIDER", "").strip().lower()
if val and val != "auto":
return val
return "auto"
def _try_openrouter() -> Tuple[Optional[OpenAI], Optional[str]]:
or_key = os.getenv("OPENROUTER_API_KEY")
if not or_key:
return None, None
logger.debug("Auxiliary client: OpenRouter")
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL,
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]:
nous = _read_nous_auth()
if not nous:
return None, None
global auxiliary_is_nous
auxiliary_is_nous = True
logger.debug("Auxiliary client: Nous Portal")
return (
OpenAI(api_key=_nous_api_key(nous), base_url=_nous_base_url()),
_NOUS_MODEL,
)
def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
custom_base = os.getenv("OPENAI_BASE_URL")
custom_key = os.getenv("OPENAI_API_KEY")
if not custom_base or not custom_key:
return None, None
model = os.getenv("OPENAI_MODEL") or os.getenv("LLM_MODEL") or "gpt-4o-mini"
logger.debug("Auxiliary client: custom endpoint (%s)", model)
return OpenAI(api_key=custom_key, base_url=custom_base), model
def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
codex_token = _read_codex_access_token()
if not codex_token:
return None, None
logger.debug("Auxiliary client: Codex OAuth (%s via Responses API)", _CODEX_AUX_MODEL)
real_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
"""Resolve a specific forced provider. Returns (None, None) if creds missing."""
if forced == "openrouter":
client, model = _try_openrouter()
if client is None:
logger.warning("auxiliary.provider=openrouter but OPENROUTER_API_KEY not set")
return client, model
if forced == "nous":
client, model = _try_nous()
if client is None:
logger.warning("auxiliary.provider=nous but Nous Portal not configured (run: hermes login)")
return client, model
if forced == "main":
# "main" = skip OpenRouter/Nous, use the main chat model's credentials.
for try_fn in (_try_custom_endpoint, _try_codex, _resolve_api_key_provider):
client, model = try_fn()
if client is not None:
return client, model
logger.warning("auxiliary.provider=main but no main endpoint credentials found")
return None, None
# Unknown provider name — fall through to auto
logger.warning("Unknown auxiliary.provider=%r, falling back to auto", forced)
return None, None
def _resolve_auto() -> Tuple[Optional[OpenAI], Optional[str]]:
"""Full auto-detection chain: OpenRouter → Nous → custom → Codex → API-key → None."""
for try_fn in (_try_openrouter, _try_nous, _try_custom_endpoint,
_try_codex, _resolve_api_key_provider):
client, model = try_fn()
if client is not None:
return client, model
logger.debug("Auxiliary client: none available")
return None, None
# ── Public API ──────────────────────────────────────────────────────────────
def get_text_auxiliary_client(task: str = "") -> Tuple[Optional[OpenAI], Optional[str]]:
"""Return (client, default_model_slug) for text-only auxiliary tasks.
Args:
task: Optional task name ("compression", "web_extract") to check
for a task-specific provider override.
Callers may override the returned model with a per-task env var
(e.g. CONTEXT_COMPRESSION_MODEL, AUXILIARY_WEB_EXTRACT_MODEL).
"""
forced = _get_auxiliary_provider(task)
if forced != "auto":
return _resolve_forced_provider(forced)
return _resolve_auto()
def get_async_text_auxiliary_client(task: str = ""):
"""Return (async_client, model_slug) for async consumers.
For standard providers returns (AsyncOpenAI, model). For Codex returns
@@ -337,7 +453,7 @@ def get_async_text_auxiliary_client():
"""
from openai import AsyncOpenAI
sync_client, model = get_text_auxiliary_client()
sync_client, model = get_text_auxiliary_client(task)
if sync_client is None:
return None, None
@@ -354,30 +470,16 @@ def get_async_text_auxiliary_client():
def get_vision_auxiliary_client() -> Tuple[Optional[OpenAI], Optional[str]]:
"""Return (client, model_slug) for vision/multimodal auxiliary tasks.
"""Return (client, default_model_slug) for vision/multimodal auxiliary tasks.
Only OpenRouter and Nous Portal qualify — custom endpoints cannot
substitute for Gemini multimodal.
Checks AUXILIARY_VISION_PROVIDER for a forced provider, otherwise
auto-detects. Callers may override the returned model with
AUXILIARY_VISION_MODEL.
"""
# 1. OpenRouter
or_key = os.getenv("OPENROUTER_API_KEY")
if or_key:
logger.debug("Auxiliary vision client: OpenRouter")
return OpenAI(api_key=or_key, base_url=OPENROUTER_BASE_URL,
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
# 2. Nous Portal
nous = _read_nous_auth()
if nous:
logger.debug("Auxiliary vision client: Nous Portal")
return (
OpenAI(api_key=_nous_api_key(nous), base_url=_nous_base_url()),
_NOUS_MODEL,
)
# 3. Nothing suitable
logger.debug("Auxiliary vision client: none available")
return None, None
forced = _get_auxiliary_provider("vision")
if forced != "auto":
return _resolve_forced_provider(forced)
return _resolve_auto()
def get_auxiliary_extra_body() -> dict:

View File

@@ -34,23 +34,26 @@ class ContextCompressor:
summary_target_tokens: int = 2500,
quiet_mode: bool = False,
summary_model_override: str = None,
base_url: str = "",
):
self.model = model
self.base_url = base_url
self.threshold_percent = threshold_percent
self.protect_first_n = protect_first_n
self.protect_last_n = protect_last_n
self.summary_target_tokens = summary_target_tokens
self.quiet_mode = quiet_mode
self.context_length = get_model_context_length(model)
self.context_length = get_model_context_length(model, base_url=base_url)
self.threshold_tokens = int(self.context_length * threshold_percent)
self.compression_count = 0
self._context_probed = False # True after a step-down from context error
self.last_prompt_tokens = 0
self.last_completion_tokens = 0
self.last_total_tokens = 0
self.client, default_model = get_text_auxiliary_client()
self.client, default_model = get_text_auxiliary_client("compression")
self.summary_model = summary_model_override or default_model
def update_from_response(self, usage: Dict[str, Any]):
@@ -193,10 +196,111 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
logger.debug("Could not build fallback auxiliary client: %s", exc)
return None, None
# ------------------------------------------------------------------
# Tool-call / tool-result pair integrity helpers
# ------------------------------------------------------------------
@staticmethod
def _get_tool_call_id(tc) -> str:
"""Extract the call ID from a tool_call entry (dict or SimpleNamespace)."""
if isinstance(tc, dict):
return tc.get("id", "")
return getattr(tc, "id", "") or ""
def _sanitize_tool_pairs(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Fix orphaned tool_call / tool_result pairs after compression.
Two failure modes:
1. A tool *result* references a call_id whose assistant tool_call was
removed (summarized/truncated). The API rejects this with
"No tool call found for function call output with call_id ...".
2. An assistant message has tool_calls whose results were dropped.
The API rejects this because every tool_call must be followed by
a tool result with the matching call_id.
This method removes orphaned results and inserts stub results for
orphaned calls so the message list is always well-formed.
"""
surviving_call_ids: set = set()
for msg in messages:
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = self._get_tool_call_id(tc)
if cid:
surviving_call_ids.add(cid)
result_call_ids: set = set()
for msg in messages:
if msg.get("role") == "tool":
cid = msg.get("tool_call_id")
if cid:
result_call_ids.add(cid)
# 1. Remove tool results whose call_id has no matching assistant tool_call
orphaned_results = result_call_ids - surviving_call_ids
if orphaned_results:
messages = [
m for m in messages
if not (m.get("role") == "tool" and m.get("tool_call_id") in orphaned_results)
]
if not self.quiet_mode:
logger.info("Compression sanitizer: removed %d orphaned tool result(s)", len(orphaned_results))
# 2. Add stub results for assistant tool_calls whose results were dropped
missing_results = surviving_call_ids - result_call_ids
if missing_results:
patched: List[Dict[str, Any]] = []
for msg in messages:
patched.append(msg)
if msg.get("role") == "assistant":
for tc in msg.get("tool_calls") or []:
cid = self._get_tool_call_id(tc)
if cid in missing_results:
patched.append({
"role": "tool",
"content": "[Result from earlier conversation — see context summary above]",
"tool_call_id": cid,
})
messages = patched
if not self.quiet_mode:
logger.info("Compression sanitizer: added %d stub tool result(s)", len(missing_results))
return messages
def _align_boundary_forward(self, messages: List[Dict[str, Any]], idx: int) -> int:
"""Push a compress-start boundary forward past any orphan tool results.
If ``messages[idx]`` is a tool result, slide forward until we hit a
non-tool message so we don't start the summarised region mid-group.
"""
while idx < len(messages) and messages[idx].get("role") == "tool":
idx += 1
return idx
def _align_boundary_backward(self, messages: List[Dict[str, Any]], idx: int) -> int:
"""Pull a compress-end boundary backward to avoid splitting a
tool_call / result group.
If the message just before ``idx`` is an assistant message with
tool_calls, those tool results will start at ``idx`` and would be
separated from their parent. Move backwards to include the whole
group in the summarised region.
"""
if idx <= 0 or idx >= len(messages):
return idx
prev = messages[idx - 1]
if prev.get("role") == "assistant" and prev.get("tool_calls"):
# The results for this assistant turn sit at idx..idx+k.
# Include the assistant message in the summarised region too.
idx -= 1
return idx
def compress(self, messages: List[Dict[str, Any]], current_tokens: int = None) -> List[Dict[str, Any]]:
"""Compress conversation messages by summarizing middle turns.
Keeps first N + last N turns, summarizes everything in between.
After compression, orphaned tool_call / tool_result pairs are cleaned
up so the API never receives mismatched IDs.
"""
n_messages = len(messages)
if n_messages <= self.protect_first_n + self.protect_last_n + 1:
@@ -209,6 +313,12 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
if compress_start >= compress_end:
return messages
# Adjust boundaries to avoid splitting tool_call/result groups.
compress_start = self._align_boundary_forward(messages, compress_start)
compress_end = self._align_boundary_backward(messages, compress_end)
if compress_start >= compress_end:
return messages
turns_to_summarize = messages[compress_start:compress_end]
display_tokens = current_tokens if current_tokens else self.last_prompt_tokens or estimate_messages_tokens_rough(messages)
@@ -230,6 +340,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
tail = messages[-self.protect_last_n:]
kept.extend(m.copy() for m in tail)
self.compression_count += 1
kept = self._sanitize_tool_pairs(kept)
if not self.quiet_mode:
print(f" ✂️ Truncated: {len(messages)}{len(kept)} messages (dropped middle turns)")
return kept
@@ -253,6 +364,8 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
self.compression_count += 1
compressed = self._sanitize_tool_pairs(compressed)
if not self.quiet_mode:
new_estimate = estimate_messages_tokens_rough(compressed)
saved_estimate = display_tokens - new_estimate

818
agent/insights.py Normal file
View File

@@ -0,0 +1,818 @@
"""
Session Insights Engine for Hermes Agent.
Analyzes historical session data from the SQLite state database to produce
comprehensive usage insights — token consumption, cost estimates, tool usage
patterns, activity trends, model/platform breakdowns, and session metrics.
Inspired by Claude Code's /insights command, adapted for Hermes Agent's
multi-platform architecture with additional cost estimation and platform
breakdown capabilities.
Usage:
from agent.insights import InsightsEngine
engine = InsightsEngine(db)
report = engine.generate(days=30)
print(engine.format_terminal(report))
"""
import json
import time
from collections import Counter, defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional
# =========================================================================
# Model pricing (USD per million tokens) — approximate as of early 2026
# =========================================================================
MODEL_PRICING = {
# OpenAI
"gpt-4o": {"input": 2.50, "output": 10.00},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"gpt-4.1": {"input": 2.00, "output": 8.00},
"gpt-4.1-mini": {"input": 0.40, "output": 1.60},
"gpt-4.1-nano": {"input": 0.10, "output": 0.40},
"gpt-4.5-preview": {"input": 75.00, "output": 150.00},
"gpt-5": {"input": 10.00, "output": 30.00},
"gpt-5.4": {"input": 10.00, "output": 30.00},
"o3": {"input": 10.00, "output": 40.00},
"o3-mini": {"input": 1.10, "output": 4.40},
"o4-mini": {"input": 1.10, "output": 4.40},
# Anthropic
"claude-opus-4-20250514": {"input": 15.00, "output": 75.00},
"claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00},
"claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
"claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00},
"claude-3-opus-20240229": {"input": 15.00, "output": 75.00},
"claude-3-haiku-20240307": {"input": 0.25, "output": 1.25},
# DeepSeek
"deepseek-chat": {"input": 0.14, "output": 0.28},
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
# Google
"gemini-2.5-pro": {"input": 1.25, "output": 10.00},
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
# Meta (via providers)
"llama-4-maverick": {"input": 0.50, "output": 0.70},
"llama-4-scout": {"input": 0.20, "output": 0.30},
# Z.AI / GLM (direct provider — pricing not published externally, treat as local)
"glm-5": {"input": 0.0, "output": 0.0},
"glm-4.7": {"input": 0.0, "output": 0.0},
"glm-4.5": {"input": 0.0, "output": 0.0},
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
# Kimi / Moonshot (direct provider — pricing not published externally, treat as local)
"kimi-k2.5": {"input": 0.0, "output": 0.0},
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
# MiniMax (direct provider — pricing not published externally, treat as local)
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
}
# Fallback: unknown/custom models get zero cost (we can't assume pricing
# for self-hosted models, custom OAI endpoints, local inference, etc.)
_DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
def _has_known_pricing(model_name: str) -> bool:
"""Check if a model has known pricing (vs unknown/custom endpoint)."""
return _get_pricing(model_name) is not _DEFAULT_PRICING
def _get_pricing(model_name: str) -> Dict[str, float]:
"""Look up pricing for a model. Uses fuzzy matching on model name.
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
we can't assume costs for self-hosted endpoints, local inference, etc.
"""
if not model_name:
return _DEFAULT_PRICING
# Strip provider prefix (e.g., "anthropic/claude-..." -> "claude-...")
bare = model_name.split("/")[-1].lower()
# Exact match first
if bare in MODEL_PRICING:
return MODEL_PRICING[bare]
# Fuzzy prefix match — prefer the LONGEST matching key to avoid
# e.g. "gpt-4o" matching before "gpt-4o-mini" for "gpt-4o-mini-2024-07-18"
best_match = None
best_len = 0
for key, price in MODEL_PRICING.items():
if bare.startswith(key) and len(key) > best_len:
best_match = price
best_len = len(key)
if best_match:
return best_match
# Keyword heuristics (checked in most-specific-first order)
if "opus" in bare:
return {"input": 15.00, "output": 75.00}
if "sonnet" in bare:
return {"input": 3.00, "output": 15.00}
if "haiku" in bare:
return {"input": 0.80, "output": 4.00}
if "gpt-4o-mini" in bare:
return {"input": 0.15, "output": 0.60}
if "gpt-4o" in bare:
return {"input": 2.50, "output": 10.00}
if "gpt-5" in bare:
return {"input": 10.00, "output": 30.00}
if "deepseek" in bare:
return {"input": 0.14, "output": 0.28}
if "gemini" in bare:
return {"input": 0.15, "output": 0.60}
return _DEFAULT_PRICING
def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
"""Estimate the USD cost for a given model and token counts."""
pricing = _get_pricing(model)
return (input_tokens * pricing["input"] + output_tokens * pricing["output"]) / 1_000_000
def _format_duration(seconds: float) -> str:
"""Format seconds into a human-readable duration string."""
if seconds < 60:
return f"{seconds:.0f}s"
minutes = seconds / 60
if minutes < 60:
return f"{minutes:.0f}m"
hours = minutes / 60
if hours < 24:
remaining_min = int(minutes % 60)
return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h"
days = hours / 24
return f"{days:.1f}d"
def _bar_chart(values: List[int], max_width: int = 20) -> List[str]:
"""Create simple horizontal bar chart strings from values."""
peak = max(values) if values else 1
if peak == 0:
return ["" for _ in values]
return ["" * max(1, int(v / peak * max_width)) if v > 0 else "" for v in values]
class InsightsEngine:
"""
Analyzes session history and produces usage insights.
Works directly with a SessionDB instance (or raw sqlite3 connection)
to query session and message data.
"""
def __init__(self, db):
"""
Initialize with a SessionDB instance.
Args:
db: A SessionDB instance (from hermes_state.py)
"""
self.db = db
self._conn = db._conn
def generate(self, days: int = 30, source: str = None) -> Dict[str, Any]:
"""
Generate a complete insights report.
Args:
days: Number of days to look back (default: 30)
source: Optional filter by source platform
Returns:
Dict with all computed insights
"""
cutoff = time.time() - (days * 86400)
# Gather raw data
sessions = self._get_sessions(cutoff, source)
tool_usage = self._get_tool_usage(cutoff, source)
message_stats = self._get_message_stats(cutoff, source)
if not sessions:
return {
"days": days,
"source_filter": source,
"empty": True,
"overview": {},
"models": [],
"platforms": [],
"tools": [],
"activity": {},
"top_sessions": [],
}
# Compute insights
overview = self._compute_overview(sessions, message_stats)
models = self._compute_model_breakdown(sessions)
platforms = self._compute_platform_breakdown(sessions)
tools = self._compute_tool_breakdown(tool_usage)
activity = self._compute_activity_patterns(sessions)
top_sessions = self._compute_top_sessions(sessions)
return {
"days": days,
"source_filter": source,
"empty": False,
"generated_at": time.time(),
"overview": overview,
"models": models,
"platforms": platforms,
"tools": tools,
"activity": activity,
"top_sessions": top_sessions,
}
# =========================================================================
# Data gathering (SQL queries)
# =========================================================================
# Columns we actually need (skip system_prompt, model_config blobs)
_SESSION_COLS = ("id, source, model, started_at, ended_at, "
"message_count, tool_call_count, input_tokens, output_tokens")
def _get_sessions(self, cutoff: float, source: str = None) -> List[Dict]:
"""Fetch sessions within the time window."""
if source:
cursor = self._conn.execute(
f"""SELECT {self._SESSION_COLS} FROM sessions
WHERE started_at >= ? AND source = ?
ORDER BY started_at DESC""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
f"""SELECT {self._SESSION_COLS} FROM sessions
WHERE started_at >= ?
ORDER BY started_at DESC""",
(cutoff,),
)
return [dict(row) for row in cursor.fetchall()]
def _get_tool_usage(self, cutoff: float, source: str = None) -> List[Dict]:
"""Get tool call counts from messages.
Uses two sources:
1. tool_name column on 'tool' role messages (set by gateway)
2. tool_calls JSON on 'assistant' role messages (covers CLI where
tool_name is not populated on tool responses)
"""
tool_counts = Counter()
# Source 1: explicit tool_name on tool response messages
if source:
cursor = self._conn.execute(
"""SELECT m.tool_name, COUNT(*) as count
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?
AND m.role = 'tool' AND m.tool_name IS NOT NULL
GROUP BY m.tool_name
ORDER BY count DESC""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
"""SELECT m.tool_name, COUNT(*) as count
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?
AND m.role = 'tool' AND m.tool_name IS NOT NULL
GROUP BY m.tool_name
ORDER BY count DESC""",
(cutoff,),
)
for row in cursor.fetchall():
tool_counts[row["tool_name"]] += row["count"]
# Source 2: extract from tool_calls JSON on assistant messages
# (covers CLI sessions where tool_name is NULL on tool responses)
if source:
cursor2 = self._conn.execute(
"""SELECT m.tool_calls
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
(cutoff, source),
)
else:
cursor2 = self._conn.execute(
"""SELECT m.tool_calls
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?
AND m.role = 'assistant' AND m.tool_calls IS NOT NULL""",
(cutoff,),
)
tool_calls_counts = Counter()
for row in cursor2.fetchall():
try:
calls = row["tool_calls"]
if isinstance(calls, str):
calls = json.loads(calls)
if isinstance(calls, list):
for call in calls:
func = call.get("function", {}) if isinstance(call, dict) else {}
name = func.get("name")
if name:
tool_calls_counts[name] += 1
except (json.JSONDecodeError, TypeError, AttributeError):
continue
# Merge: prefer tool_name source, supplement with tool_calls source
# for tools not already counted
if not tool_counts and tool_calls_counts:
# No tool_name data at all — use tool_calls exclusively
tool_counts = tool_calls_counts
elif tool_counts and tool_calls_counts:
# Both sources have data — use whichever has the higher count per tool
# (they may overlap, so take the max to avoid double-counting)
all_tools = set(tool_counts) | set(tool_calls_counts)
merged = Counter()
for tool in all_tools:
merged[tool] = max(tool_counts.get(tool, 0), tool_calls_counts.get(tool, 0))
tool_counts = merged
# Convert to the expected format
return [
{"tool_name": name, "count": count}
for name, count in tool_counts.most_common()
]
def _get_message_stats(self, cutoff: float, source: str = None) -> Dict:
"""Get aggregate message statistics."""
if source:
cursor = self._conn.execute(
"""SELECT
COUNT(*) as total_messages,
SUM(CASE WHEN m.role = 'user' THEN 1 ELSE 0 END) as user_messages,
SUM(CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END) as assistant_messages,
SUM(CASE WHEN m.role = 'tool' THEN 1 ELSE 0 END) as tool_messages
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ? AND s.source = ?""",
(cutoff, source),
)
else:
cursor = self._conn.execute(
"""SELECT
COUNT(*) as total_messages,
SUM(CASE WHEN m.role = 'user' THEN 1 ELSE 0 END) as user_messages,
SUM(CASE WHEN m.role = 'assistant' THEN 1 ELSE 0 END) as assistant_messages,
SUM(CASE WHEN m.role = 'tool' THEN 1 ELSE 0 END) as tool_messages
FROM messages m
JOIN sessions s ON s.id = m.session_id
WHERE s.started_at >= ?""",
(cutoff,),
)
row = cursor.fetchone()
return dict(row) if row else {
"total_messages": 0, "user_messages": 0,
"assistant_messages": 0, "tool_messages": 0,
}
# =========================================================================
# Computation
# =========================================================================
def _compute_overview(self, sessions: List[Dict], message_stats: Dict) -> Dict:
"""Compute high-level overview statistics."""
total_input = sum(s.get("input_tokens") or 0 for s in sessions)
total_output = sum(s.get("output_tokens") or 0 for s in sessions)
total_tokens = total_input + total_output
total_tool_calls = sum(s.get("tool_call_count") or 0 for s in sessions)
total_messages = sum(s.get("message_count") or 0 for s in sessions)
# Cost estimation (weighted by model)
total_cost = 0.0
models_with_pricing = set()
models_without_pricing = set()
for s in sessions:
model = s.get("model") or ""
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
total_cost += _estimate_cost(model, inp, out)
display = model.split("/")[-1] if "/" in model else (model or "unknown")
if _has_known_pricing(model):
models_with_pricing.add(display)
else:
models_without_pricing.add(display)
# Session duration stats (guard against negative durations from clock drift)
durations = []
for s in sessions:
start = s.get("started_at")
end = s.get("ended_at")
if start and end and end > start:
durations.append(end - start)
total_hours = sum(durations) / 3600 if durations else 0
avg_duration = sum(durations) / len(durations) if durations else 0
# Earliest and latest session
started_timestamps = [s["started_at"] for s in sessions if s.get("started_at")]
date_range_start = min(started_timestamps) if started_timestamps else None
date_range_end = max(started_timestamps) if started_timestamps else None
return {
"total_sessions": len(sessions),
"total_messages": total_messages,
"total_tool_calls": total_tool_calls,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_tokens": total_tokens,
"estimated_cost": total_cost,
"total_hours": total_hours,
"avg_session_duration": avg_duration,
"avg_messages_per_session": total_messages / len(sessions) if sessions else 0,
"avg_tokens_per_session": total_tokens / len(sessions) if sessions else 0,
"user_messages": message_stats.get("user_messages") or 0,
"assistant_messages": message_stats.get("assistant_messages") or 0,
"tool_messages": message_stats.get("tool_messages") or 0,
"date_range_start": date_range_start,
"date_range_end": date_range_end,
"models_with_pricing": sorted(models_with_pricing),
"models_without_pricing": sorted(models_without_pricing),
}
def _compute_model_breakdown(self, sessions: List[Dict]) -> List[Dict]:
"""Break down usage by model."""
model_data = defaultdict(lambda: {
"sessions": 0, "input_tokens": 0, "output_tokens": 0,
"total_tokens": 0, "tool_calls": 0, "cost": 0.0,
})
for s in sessions:
model = s.get("model") or "unknown"
# Normalize: strip provider prefix for display
display_model = model.split("/")[-1] if "/" in model else model
d = model_data[display_model]
d["sessions"] += 1
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
d["input_tokens"] += inp
d["output_tokens"] += out
d["total_tokens"] += inp + out
d["tool_calls"] += s.get("tool_call_count") or 0
d["cost"] += _estimate_cost(model, inp, out)
d["has_pricing"] = _has_known_pricing(model)
result = [
{"model": model, **data}
for model, data in model_data.items()
]
# Sort by tokens first, fall back to session count when tokens are 0
result.sort(key=lambda x: (x["total_tokens"], x["sessions"]), reverse=True)
return result
def _compute_platform_breakdown(self, sessions: List[Dict]) -> List[Dict]:
"""Break down usage by platform/source."""
platform_data = defaultdict(lambda: {
"sessions": 0, "messages": 0, "input_tokens": 0,
"output_tokens": 0, "total_tokens": 0, "tool_calls": 0,
})
for s in sessions:
source = s.get("source") or "unknown"
d = platform_data[source]
d["sessions"] += 1
d["messages"] += s.get("message_count") or 0
inp = s.get("input_tokens") or 0
out = s.get("output_tokens") or 0
d["input_tokens"] += inp
d["output_tokens"] += out
d["total_tokens"] += inp + out
d["tool_calls"] += s.get("tool_call_count") or 0
result = [
{"platform": platform, **data}
for platform, data in platform_data.items()
]
result.sort(key=lambda x: x["sessions"], reverse=True)
return result
def _compute_tool_breakdown(self, tool_usage: List[Dict]) -> List[Dict]:
"""Process tool usage data into a ranked list with percentages."""
total_calls = sum(t["count"] for t in tool_usage) if tool_usage else 0
result = []
for t in tool_usage:
pct = (t["count"] / total_calls * 100) if total_calls else 0
result.append({
"tool": t["tool_name"],
"count": t["count"],
"percentage": pct,
})
return result
def _compute_activity_patterns(self, sessions: List[Dict]) -> Dict:
"""Analyze activity patterns by day of week and hour."""
day_counts = Counter() # 0=Monday ... 6=Sunday
hour_counts = Counter()
daily_counts = Counter() # date string -> count
for s in sessions:
ts = s.get("started_at")
if not ts:
continue
dt = datetime.fromtimestamp(ts)
day_counts[dt.weekday()] += 1
hour_counts[dt.hour] += 1
daily_counts[dt.strftime("%Y-%m-%d")] += 1
day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
day_breakdown = [
{"day": day_names[i], "count": day_counts.get(i, 0)}
for i in range(7)
]
hour_breakdown = [
{"hour": i, "count": hour_counts.get(i, 0)}
for i in range(24)
]
# Busiest day and hour
busiest_day = max(day_breakdown, key=lambda x: x["count"]) if day_breakdown else None
busiest_hour = max(hour_breakdown, key=lambda x: x["count"]) if hour_breakdown else None
# Active days (days with at least one session)
active_days = len(daily_counts)
# Streak calculation
if daily_counts:
all_dates = sorted(daily_counts.keys())
current_streak = 1
max_streak = 1
for i in range(1, len(all_dates)):
d1 = datetime.strptime(all_dates[i - 1], "%Y-%m-%d")
d2 = datetime.strptime(all_dates[i], "%Y-%m-%d")
if (d2 - d1).days == 1:
current_streak += 1
max_streak = max(max_streak, current_streak)
else:
current_streak = 1
else:
max_streak = 0
return {
"by_day": day_breakdown,
"by_hour": hour_breakdown,
"busiest_day": busiest_day,
"busiest_hour": busiest_hour,
"active_days": active_days,
"max_streak": max_streak,
}
def _compute_top_sessions(self, sessions: List[Dict]) -> List[Dict]:
"""Find notable sessions (longest, most messages, most tokens)."""
top = []
# Longest by duration
sessions_with_duration = [
s for s in sessions
if s.get("started_at") and s.get("ended_at")
]
if sessions_with_duration:
longest = max(
sessions_with_duration,
key=lambda s: (s["ended_at"] - s["started_at"]),
)
dur = longest["ended_at"] - longest["started_at"]
top.append({
"label": "Longest session",
"session_id": longest["id"][:16],
"value": _format_duration(dur),
"date": datetime.fromtimestamp(longest["started_at"]).strftime("%b %d"),
})
# Most messages
most_msgs = max(sessions, key=lambda s: s.get("message_count") or 0)
if (most_msgs.get("message_count") or 0) > 0:
top.append({
"label": "Most messages",
"session_id": most_msgs["id"][:16],
"value": f"{most_msgs['message_count']} msgs",
"date": datetime.fromtimestamp(most_msgs["started_at"]).strftime("%b %d") if most_msgs.get("started_at") else "?",
})
# Most tokens
most_tokens = max(
sessions,
key=lambda s: (s.get("input_tokens") or 0) + (s.get("output_tokens") or 0),
)
token_total = (most_tokens.get("input_tokens") or 0) + (most_tokens.get("output_tokens") or 0)
if token_total > 0:
top.append({
"label": "Most tokens",
"session_id": most_tokens["id"][:16],
"value": f"{token_total:,} tokens",
"date": datetime.fromtimestamp(most_tokens["started_at"]).strftime("%b %d") if most_tokens.get("started_at") else "?",
})
# Most tool calls
most_tools = max(sessions, key=lambda s: s.get("tool_call_count") or 0)
if (most_tools.get("tool_call_count") or 0) > 0:
top.append({
"label": "Most tool calls",
"session_id": most_tools["id"][:16],
"value": f"{most_tools['tool_call_count']} calls",
"date": datetime.fromtimestamp(most_tools["started_at"]).strftime("%b %d") if most_tools.get("started_at") else "?",
})
return top
# =========================================================================
# Formatting
# =========================================================================
def format_terminal(self, report: Dict) -> str:
"""Format the insights report for terminal display (CLI)."""
if report.get("empty"):
days = report.get("days", 30)
src = f" (source: {report['source_filter']})" if report.get("source_filter") else ""
return f" No sessions found in the last {days} days{src}."
lines = []
o = report["overview"]
days = report["days"]
src_filter = report.get("source_filter")
# Header
lines.append("")
lines.append(" ╔══════════════════════════════════════════════════════════╗")
lines.append(" ║ 📊 Hermes Insights ║")
period_label = f"Last {days} days"
if src_filter:
period_label += f" ({src_filter})"
padding = 58 - len(period_label) - 2
left_pad = padding // 2
right_pad = padding - left_pad
lines.append(f"{' ' * left_pad} {period_label} {' ' * right_pad}")
lines.append(" ╚══════════════════════════════════════════════════════════╝")
lines.append("")
# Date range
if o.get("date_range_start") and o.get("date_range_end"):
start_str = datetime.fromtimestamp(o["date_range_start"]).strftime("%b %d, %Y")
end_str = datetime.fromtimestamp(o["date_range_end"]).strftime("%b %d, %Y")
lines.append(f" Period: {start_str}{end_str}")
lines.append("")
# Overview
lines.append(" 📋 Overview")
lines.append(" " + "" * 56)
lines.append(f" Sessions: {o['total_sessions']:<12} Messages: {o['total_messages']:,}")
lines.append(f" Tool calls: {o['total_tool_calls']:<12,} User messages: {o['user_messages']:,}")
lines.append(f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}")
cost_str = f"${o['estimated_cost']:.2f}"
if o.get("models_without_pricing"):
cost_str += " *"
lines.append(f" Total tokens: {o['total_tokens']:<12,} Est. cost: {cost_str}")
if o["total_hours"] > 0:
lines.append(f" Active time: ~{_format_duration(o['total_hours'] * 3600):<11} Avg session: ~{_format_duration(o['avg_session_duration'])}")
lines.append(f" Avg msgs/session: {o['avg_messages_per_session']:.1f}")
lines.append("")
# Model breakdown
if report["models"]:
lines.append(" 🤖 Models Used")
lines.append(" " + "" * 56)
lines.append(f" {'Model':<30} {'Sessions':>8} {'Tokens':>12} {'Cost':>8}")
for m in report["models"]:
model_name = m["model"][:28]
if m.get("has_pricing"):
cost_cell = f"${m['cost']:>6.2f}"
else:
cost_cell = " N/A"
lines.append(f" {model_name:<30} {m['sessions']:>8} {m['total_tokens']:>12,} {cost_cell}")
if o.get("models_without_pricing"):
lines.append(f" * Cost N/A for custom/self-hosted models")
lines.append("")
# Platform breakdown
if len(report["platforms"]) > 1 or (report["platforms"] and report["platforms"][0]["platform"] != "cli"):
lines.append(" 📱 Platforms")
lines.append(" " + "" * 56)
lines.append(f" {'Platform':<14} {'Sessions':>8} {'Messages':>10} {'Tokens':>14}")
for p in report["platforms"]:
lines.append(f" {p['platform']:<14} {p['sessions']:>8} {p['messages']:>10,} {p['total_tokens']:>14,}")
lines.append("")
# Tool usage
if report["tools"]:
lines.append(" 🔧 Top Tools")
lines.append(" " + "" * 56)
lines.append(f" {'Tool':<28} {'Calls':>8} {'%':>8}")
for t in report["tools"][:15]: # Top 15
lines.append(f" {t['tool']:<28} {t['count']:>8,} {t['percentage']:>7.1f}%")
if len(report["tools"]) > 15:
lines.append(f" ... and {len(report['tools']) - 15} more tools")
lines.append("")
# Activity patterns
act = report.get("activity", {})
if act.get("by_day"):
lines.append(" 📅 Activity Patterns")
lines.append(" " + "" * 56)
# Day of week chart
day_values = [d["count"] for d in act["by_day"]]
bars = _bar_chart(day_values, max_width=15)
for i, d in enumerate(act["by_day"]):
bar = bars[i]
lines.append(f" {d['day']} {bar:<15} {d['count']}")
lines.append("")
# Peak hours (show top 5 busiest hours)
busy_hours = sorted(act["by_hour"], key=lambda x: x["count"], reverse=True)
busy_hours = [h for h in busy_hours if h["count"] > 0][:5]
if busy_hours:
hour_strs = []
for h in busy_hours:
hr = h["hour"]
ampm = "AM" if hr < 12 else "PM"
display_hr = hr % 12 or 12
hour_strs.append(f"{display_hr}{ampm} ({h['count']})")
lines.append(f" Peak hours: {', '.join(hour_strs)}")
if act.get("active_days"):
lines.append(f" Active days: {act['active_days']}")
if act.get("max_streak") and act["max_streak"] > 1:
lines.append(f" Best streak: {act['max_streak']} consecutive days")
lines.append("")
# Notable sessions
if report.get("top_sessions"):
lines.append(" 🏆 Notable Sessions")
lines.append(" " + "" * 56)
for ts in report["top_sessions"]:
lines.append(f" {ts['label']:<20} {ts['value']:<18} ({ts['date']}, {ts['session_id']})")
lines.append("")
return "\n".join(lines)
def format_gateway(self, report: Dict) -> str:
"""Format the insights report for gateway/messaging (shorter)."""
if report.get("empty"):
days = report.get("days", 30)
return f"No sessions found in the last {days} days."
lines = []
o = report["overview"]
days = report["days"]
lines.append(f"📊 **Hermes Insights** — Last {days} days\n")
# Overview
lines.append(f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}")
lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})")
cost_note = ""
if o.get("models_without_pricing"):
cost_note = " _(excludes custom/self-hosted models)_"
lines.append(f"**Est. cost:** ${o['estimated_cost']:.2f}{cost_note}")
if o["total_hours"] > 0:
lines.append(f"**Active time:** ~{_format_duration(o['total_hours'] * 3600)} | **Avg session:** ~{_format_duration(o['avg_session_duration'])}")
lines.append("")
# Models (top 5)
if report["models"]:
lines.append("**🤖 Models:**")
for m in report["models"][:5]:
cost_str = f"${m['cost']:.2f}" if m.get("has_pricing") else "N/A"
lines.append(f" {m['model'][:25]}{m['sessions']} sessions, {m['total_tokens']:,} tokens, {cost_str}")
lines.append("")
# Platforms (if multi-platform)
if len(report["platforms"]) > 1:
lines.append("**📱 Platforms:**")
for p in report["platforms"]:
lines.append(f" {p['platform']}{p['sessions']} sessions, {p['messages']:,} msgs")
lines.append("")
# Tools (top 8)
if report["tools"]:
lines.append("**🔧 Top Tools:**")
for t in report["tools"][:8]:
lines.append(f" {t['tool']}{t['count']:,} calls ({t['percentage']:.1f}%)")
lines.append("")
# Activity summary
act = report.get("activity", {})
if act.get("busiest_day") and act.get("busiest_hour"):
hr = act["busiest_hour"]["hour"]
ampm = "AM" if hr < 12 else "PM"
display_hr = hr % 12 or 12
lines.append(f"**📅 Busiest:** {act['busiest_day']['day']}s ({act['busiest_day']['count']} sessions), {display_hr}{ampm} ({act['busiest_hour']['count']} sessions)")
if act.get("active_days"):
lines.append(f"**Active days:** {act['active_days']}", )
if act.get("max_streak", 0) > 1:
lines.append(f"**Best streak:** {act['max_streak']} consecutive days")
return "\n".join(lines)

View File

@@ -5,10 +5,14 @@ and run_agent.py for pre-flight context checks.
"""
import logging
import os
import re
import time
from typing import Any, Dict, List
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
import yaml
from hermes_constants import OPENROUTER_MODELS_URL
@@ -18,6 +22,18 @@ _model_metadata_cache: Dict[str, Dict[str, Any]] = {}
_model_metadata_cache_time: float = 0
_MODEL_CACHE_TTL = 3600
# Descending tiers for context length probing when the model is unknown.
# We start high and step down on context-length errors until one works.
CONTEXT_PROBE_TIERS = [
2_000_000,
1_000_000,
512_000,
200_000,
128_000,
64_000,
32_000,
]
DEFAULT_CONTEXT_LENGTHS = {
"anthropic/claude-opus-4": 200000,
"anthropic/claude-opus-4.5": 200000,
@@ -33,6 +49,17 @@ DEFAULT_CONTEXT_LENGTHS = {
"meta-llama/llama-3.3-70b-instruct": 131072,
"deepseek/deepseek-chat-v3": 65536,
"qwen/qwen-2.5-72b-instruct": 32768,
"glm-4.7": 202752,
"glm-5": 202752,
"glm-4.5": 131072,
"glm-4.5-flash": 131072,
"kimi-k2.5": 262144,
"kimi-k2-thinking": 262144,
"kimi-k2-turbo-preview": 262144,
"kimi-k2-0905-preview": 131072,
"MiniMax-M2.5": 204800,
"MiniMax-M2.5-highspeed": 204800,
"MiniMax-M2.1": 204800,
}
@@ -71,17 +98,117 @@ def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any
return _model_metadata_cache or {}
def get_model_context_length(model: str) -> int:
"""Get the context length for a model (API first, then fallback defaults)."""
def _get_context_cache_path() -> Path:
"""Return path to the persistent context length cache file."""
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
return hermes_home / "context_length_cache.yaml"
def _load_context_cache() -> Dict[str, int]:
"""Load the model+provider → context_length cache from disk."""
path = _get_context_cache_path()
if not path.exists():
return {}
try:
with open(path) as f:
data = yaml.safe_load(f) or {}
return data.get("context_lengths", {})
except Exception as e:
logger.debug("Failed to load context length cache: %s", e)
return {}
def save_context_length(model: str, base_url: str, length: int) -> None:
"""Persist a discovered context length for a model+provider combo.
Cache key is ``model@base_url`` so the same model name served from
different providers can have different limits.
"""
key = f"{model}@{base_url}"
cache = _load_context_cache()
if cache.get(key) == length:
return # already stored
cache[key] = length
path = _get_context_cache_path()
try:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
logger.info("Cached context length %s%s tokens", key, f"{length:,}")
except Exception as e:
logger.debug("Failed to save context length cache: %s", e)
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
"""Look up a previously discovered context length for model+provider."""
key = f"{model}@{base_url}"
cache = _load_context_cache()
return cache.get(key)
def get_next_probe_tier(current_length: int) -> Optional[int]:
"""Return the next lower probe tier, or None if already at minimum."""
for tier in CONTEXT_PROBE_TIERS:
if tier < current_length:
return tier
return None
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
"""Try to extract the actual context limit from an API error message.
Many providers include the limit in their error text, e.g.:
- "maximum context length is 32768 tokens"
- "context_length_exceeded: 131072"
- "Maximum context size 32768 exceeded"
- "model's max context length is 65536"
"""
error_lower = error_msg.lower()
# Pattern: look for numbers near context-related keywords
patterns = [
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
]
for pattern in patterns:
match = re.search(pattern, error_lower)
if match:
limit = int(match.group(1))
# Sanity check: must be a reasonable context length
if 1024 <= limit <= 10_000_000:
return limit
return None
def get_model_context_length(model: str, base_url: str = "") -> int:
"""Get the context length for a model.
Resolution order:
1. Persistent cache (previously discovered via probing)
2. OpenRouter API metadata
3. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match)
4. First probe tier (2M) — will be narrowed on first context error
"""
# 1. Check persistent cache (model+provider)
if base_url:
cached = get_cached_context_length(model, base_url)
if cached is not None:
return cached
# 2. OpenRouter API metadata
metadata = fetch_model_metadata()
if model in metadata:
return metadata[model].get("context_length", 128000)
# 3. Hardcoded defaults (fuzzy match)
for default_model, length in DEFAULT_CONTEXT_LENGTHS.items():
if default_model in model or model in default_model:
return length
return 128000
# 4. Unknown model — start at highest probe tier
return CONTEXT_PROBE_TIERS[0]
def estimate_tokens_rough(text: str) -> int:

View File

@@ -142,12 +142,28 @@ def _read_skill_description(skill_file: Path, max_chars: int = 60) -> str:
return ""
def _skill_is_platform_compatible(skill_file: Path) -> bool:
"""Quick check if a SKILL.md is compatible with the current OS platform.
Reads just enough to parse the ``platforms`` frontmatter field.
Skills without the field (the vast majority) are always compatible.
"""
try:
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
raw = skill_file.read_text(encoding="utf-8")[:2000]
frontmatter, _ = _parse_frontmatter(raw)
return skill_matches_platform(frontmatter)
except Exception:
return True # Err on the side of showing the skill
def build_skills_system_prompt() -> str:
"""Build a compact skill index for the system prompt.
Scans ~/.hermes/skills/ for SKILL.md files grouped by category.
Includes per-skill descriptions from frontmatter so the model can
match skills by meaning, not just name.
Filters out skills incompatible with the current OS platform.
"""
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
skills_dir = hermes_home / "skills"
@@ -159,6 +175,9 @@ def build_skills_system_prompt() -> str:
# Each entry: (skill_name, description)
skills_by_category: dict[str, list[tuple[str, str]]] = {}
for skill_file in skills_dir.rglob("SKILL.md"):
# Skip skills incompatible with the current OS platform
if not _skill_is_platform_compatible(skill_file):
continue
rel_path = skill_file.relative_to(skills_dir)
parts = rel_path.parts
if len(parts) >= 2:

View File

@@ -22,7 +22,7 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
global _skill_commands
_skill_commands = {}
try:
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter
from tools.skills_tool import SKILLS_DIR, _parse_frontmatter, skill_matches_platform
if not SKILLS_DIR.exists():
return _skill_commands
for skill_md in SKILLS_DIR.rglob("SKILL.md"):
@@ -31,6 +31,9 @@ def scan_skill_commands() -> Dict[str, Dict[str, Any]]:
try:
content = skill_md.read_text(encoding='utf-8')
frontmatter, body = _parse_frontmatter(content)
# Skip skills incompatible with the current OS platform
if not skill_matches_platform(frontmatter):
continue
name = frontmatter.get('name', skill_md.parent.name)
description = frontmatter.get('description', '')
if not description:

View File

@@ -29,7 +29,6 @@ from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from multiprocessing import Pool, Lock
import traceback
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
import fire
@@ -250,7 +249,7 @@ def _process_single_prompt(
task_id = f"task_{prompt_index}"
# Per-prompt container image override: if the dataset row has an 'image' field,
# register it for this task's sandbox. Works with Docker, Modal, and Singularity.
# register it for this task's sandbox. Works with Docker, Modal, Singularity, and Daytona.
container_image = prompt_data.get("image") or prompt_data.get("docker_image")
if container_image:
# Verify the image is accessible before spending tokens on the agent loop.
@@ -292,6 +291,7 @@ def _process_single_prompt(
"docker_image": container_image,
"modal_image": container_image,
"singularity_image": f"docker://{container_image}",
"daytona_image": container_image,
}
if prompt_data.get("cwd"):
overrides["cwd"] = prompt_data["cwd"]
@@ -700,14 +700,13 @@ class BatchRunner:
lock (Lock): Optional lock for thread-safe access
"""
checkpoint_data["last_updated"] = datetime.now().isoformat()
from utils import atomic_json_write
if lock:
with lock:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
atomic_json_write(self.checkpoint_file, checkpoint_data)
else:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
atomic_json_write(self.checkpoint_file, checkpoint_data)
def _scan_completed_prompts_by_content(self) -> set:
"""
@@ -832,13 +831,15 @@ class BatchRunner:
print(f" New batches created: {len(batches_to_process)}")
print("=" * 70 + "\n")
# Initialize checkpoint data (needed for saving at the end)
checkpoint_data = {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
# Load existing checkpoint (so resume doesn't clobber prior progress)
checkpoint_data = self._load_checkpoint()
if checkpoint_data.get("run_name") != self.run_name:
checkpoint_data = {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
# Prepare configuration for workers
config = {
@@ -860,7 +861,7 @@ class BatchRunner:
}
# For backward compatibility, still track by index (but this is secondary to content matching)
completed_prompts_set = set()
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
# Aggregate statistics across all batches
total_tool_stats = {}
@@ -869,6 +870,9 @@ class BatchRunner:
print(f"\n🔧 Initializing {self.num_workers} worker processes...")
# Checkpoint writes happen in the parent process; keep a lock for safety.
checkpoint_lock = Lock()
# Process batches in parallel
with Pool(processes=self.num_workers) as pool:
# Create tasks for each batch
@@ -914,6 +918,25 @@ class BatchRunner:
for result in pool.imap_unordered(_process_batch_worker, tasks):
results.append(result)
progress.update(task, advance=1)
# Incremental checkpoint update (so resume works after crash)
try:
batch_num = result.get('batch_num')
completed = result.get('completed_prompts', []) or []
completed_prompts_set.update(completed)
if isinstance(batch_num, int):
checkpoint_data.setdefault('batch_stats', {})[str(batch_num)] = {
'processed': result.get('processed', 0),
'skipped': result.get('skipped', 0),
'discarded_no_reasoning': result.get('discarded_no_reasoning', 0),
}
checkpoint_data['completed_prompts'] = sorted(completed_prompts_set)
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
except Exception as ckpt_err:
# Don't fail the run if checkpoint write fails
print(f"⚠️ Warning: Failed to save incremental checkpoint: {ckpt_err}")
except Exception as e:
logger.error("Batch worker failed: %s", e, exc_info=True)
raise
@@ -945,9 +968,12 @@ class BatchRunner:
for key in total_reasoning_stats:
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
# Save final checkpoint
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data)
# Save final checkpoint (best-effort; incremental writes already happened)
try:
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data, lock=checkpoint_lock)
except Exception as ckpt_err:
print(f"⚠️ Warning: Failed to save final checkpoint: {ckpt_err}")
# Calculate success rates
for tool_name in total_tool_stats:

View File

@@ -13,6 +13,10 @@ model:
# "auto" - Use Nous Portal if logged in, otherwise OpenRouter/env vars (default)
# "openrouter" - Always use OpenRouter API key from OPENROUTER_API_KEY
# "nous" - Always use Nous Portal (requires: hermes login)
# "zai" - Use z.ai / ZhipuAI GLM models (requires: GLM_API_KEY)
# "kimi-coding"- Use Kimi / Moonshot AI models (requires: KIMI_API_KEY)
# "minimax" - Use MiniMax global endpoint (requires: MINIMAX_API_KEY)
# "minimax-cn" - Use MiniMax China endpoint (requires: MINIMAX_CN_API_KEY)
# Can also be overridden with --provider flag or HERMES_INFERENCE_PROVIDER env var.
provider: "auto"
@@ -116,8 +120,23 @@ terminal:
# timeout: 180
# lifetime_seconds: 300
# modal_image: "nikolaik/python-nodejs:python3.11-nodejs20"
# -----------------------------------------------------------------------------
# OPTION 6: Daytona cloud execution
# Commands run in Daytona cloud sandboxes
# Great for: Cloud dev environments, persistent workspaces, team collaboration
# Requires: pip install daytona, DAYTONA_API_KEY env var
# -----------------------------------------------------------------------------
# terminal:
# backend: "daytona"
# cwd: "~"
# timeout: 180
# lifetime_seconds: 300
# daytona_image: "nikolaik/python-nodejs:python3.11-nodejs20"
# container_disk: 10240 # Daytona max is 10GB per sandbox
#
# --- Container resource limits (docker, singularity, modal -- ignored for local/ssh) ---
# --- Container resource limits (docker, singularity, modal, daytona -- ignored for local/ssh) ---
# These settings apply to all container backends. They control the resources
# allocated to the sandbox and whether its filesystem persists across sessions.
container_cpu: 1 # CPU cores
@@ -180,8 +199,58 @@ compression:
threshold: 0.85
# Model to use for generating summaries (fast/cheap recommended)
# This model compresses the middle turns into a concise summary
# This model compresses the middle turns into a concise summary.
# IMPORTANT: it receives the full middle section of the conversation, so it
# MUST support a context length at least as large as your main model's.
summary_model: "google/gemini-3-flash-preview"
# Provider for the summary model (default: "auto")
# Options: "auto", "openrouter", "nous", "main"
# summary_provider: "auto"
# =============================================================================
# Auxiliary Models (Advanced — Experimental)
# =============================================================================
# Hermes uses lightweight "auxiliary" models for side tasks: image analysis,
# browser screenshot analysis, web page summarization, and context compression.
#
# By default these use Gemini Flash via OpenRouter or Nous Portal and are
# auto-detected from your credentials. You do NOT need to change anything
# here for normal usage.
#
# WARNING: Overriding these with providers other than OpenRouter or Nous Portal
# is EXPERIMENTAL and may not work. Not all models/providers support vision,
# produce usable summaries, or accept the same API format. Change at your own
# risk — if things break, reset to "auto" / empty values.
#
# Each task has its own provider + model pair so you can mix providers.
# For example: OpenRouter for vision (needs multimodal), but your main
# local endpoint for compression (just needs text).
#
# Provider options:
# "auto" - Best available: OpenRouter → Nous Portal → main endpoint (default)
# "openrouter" - Force OpenRouter (requires OPENROUTER_API_KEY)
# "nous" - Force Nous Portal (requires: hermes login)
# "main" - Use the same provider & credentials as your main chat model.
# Skips OpenRouter/Nous and uses your custom endpoint
# (OPENAI_BASE_URL), Codex OAuth, or API-key provider directly.
# Useful if you run a local model and want auxiliary tasks to
# use it too.
#
# Model: leave empty to use the provider's default. When empty, OpenRouter
# uses "google/gemini-3-flash-preview" and Nous uses "gemini-3-flash".
# Other providers pick a sensible default automatically.
#
# auxiliary:
# # Image analysis: vision_analyze tool + browser screenshots
# vision:
# provider: "auto"
# model: "" # e.g. "google/gemini-2.5-flash", "openai/gpt-4o"
#
# # Web page scraping / summarization + browser page text extraction
# web_extract:
# provider: "auto"
# model: ""
# =============================================================================
# Persistent Memory

320
cli.py
View File

@@ -14,6 +14,7 @@ Usage:
import logging
import os
import shutil
import sys
import json
import atexit
@@ -157,6 +158,7 @@ def load_cli_config() -> Dict[str, Any]:
"docker_image": "python:3.11",
"singularity_image": "docker://python:3.11",
"modal_image": "python:3.11",
"daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20",
},
"browser": {
"inactivity_timeout": 120, # Auto-cleanup inactive browser sessions after 2 min
@@ -167,7 +169,7 @@ def load_cli_config() -> Dict[str, Any]:
"summary_model": "google/gemini-3-flash-preview", # Fast/cheap model for summaries
},
"agent": {
"max_turns": 60, # Default max tool-calling iterations
"max_turns": 90, # Default max tool-calling iterations (shared with subagents)
"verbose": False,
"system_prompt": "",
"prefill_messages_file": "",
@@ -283,12 +285,13 @@ def load_cli_config() -> Dict[str, Any]:
"docker_image": "TERMINAL_DOCKER_IMAGE",
"singularity_image": "TERMINAL_SINGULARITY_IMAGE",
"modal_image": "TERMINAL_MODAL_IMAGE",
"daytona_image": "TERMINAL_DAYTONA_IMAGE",
# SSH config
"ssh_host": "TERMINAL_SSH_HOST",
"ssh_user": "TERMINAL_SSH_USER",
"ssh_port": "TERMINAL_SSH_PORT",
"ssh_key": "TERMINAL_SSH_KEY",
# Container resource config (docker, singularity, modal -- ignored for local/ssh)
# Container resource config (docker, singularity, modal, daytona -- ignored for local/ssh)
"container_cpu": "TERMINAL_CONTAINER_CPU",
"container_memory": "TERMINAL_CONTAINER_MEMORY",
"container_disk": "TERMINAL_CONTAINER_DISK",
@@ -329,12 +332,36 @@ def load_cli_config() -> Dict[str, Any]:
"enabled": "CONTEXT_COMPRESSION_ENABLED",
"threshold": "CONTEXT_COMPRESSION_THRESHOLD",
"summary_model": "CONTEXT_COMPRESSION_MODEL",
"summary_provider": "CONTEXT_COMPRESSION_PROVIDER",
}
for config_key, env_var in compression_env_mappings.items():
if config_key in compression_config:
os.environ[env_var] = str(compression_config[config_key])
# Apply auxiliary model overrides to environment variables.
# Vision and web_extract each have their own provider + model pair.
# (Compression is handled in the compression section above.)
# Only set env vars for non-empty / non-default values so auto-detection
# still works.
auxiliary_config = defaults.get("auxiliary", {})
auxiliary_task_env = {
# config key → (provider env var, model env var)
"vision": ("AUXILIARY_VISION_PROVIDER", "AUXILIARY_VISION_MODEL"),
"web_extract": ("AUXILIARY_WEB_EXTRACT_PROVIDER", "AUXILIARY_WEB_EXTRACT_MODEL"),
}
for task_key, (prov_env, model_env) in auxiliary_task_env.items():
task_cfg = auxiliary_config.get(task_key, {})
if not isinstance(task_cfg, dict):
continue
prov = str(task_cfg.get("provider", "")).strip()
model = str(task_cfg.get("model", "")).strip()
if prov and prov != "auto":
os.environ[prov_env] = prov
if model:
os.environ[model_env] = model
return defaults
# Load configuration at module startup
@@ -507,7 +534,18 @@ def _get_available_skills() -> Dict[str, List[str]]:
return skills_by_category
def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None):
def _format_context_length(tokens: int) -> str:
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
if tokens >= 1_000_000:
val = tokens / 1_000_000
return f"{val:g}M"
elif tokens >= 1_000:
val = tokens / 1_000
return f"{val:g}K"
return str(tokens)
def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dict] = None, enabled_toolsets: List[str] = None, session_id: str = None, context_length: int = None):
"""
Build and print a Claude Code-style welcome banner with caduceus on left and info on right.
@@ -518,6 +556,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic
tools: List of tool definitions
enabled_toolsets: List of enabled toolset names
session_id: Unique session identifier for logging
context_length: Model's context window size in tokens
"""
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
@@ -543,7 +582,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str, tools: List[dic
if len(model_short) > 28:
model_short = model_short[:25] + "..."
left_lines.append(f"[#FFBF00]{model_short}[/] [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
left_lines.append(f"[dim #B8860B]{cwd}[/]")
# Add session ID if provided
@@ -690,6 +730,7 @@ COMMANDS = {
"/cron": "Manage scheduled tasks (list, add, remove)",
"/skills": "Search, install, inspect, or manage skills from online registries",
"/platforms": "Show gateway/messaging platform status",
"/paste": "Check clipboard for an image and attach it",
"/reload-mcp": "Reload MCP servers from config.yaml",
"/quit": "Exit the CLI (also: /exit, /q)",
}
@@ -816,10 +857,10 @@ class HermesCLI:
Args:
model: Model to use (default: from env or claude-sonnet)
toolsets: List of toolsets to enable (default: all)
provider: Inference provider ("auto", "openrouter", "nous", "openai-codex")
provider: Inference provider ("auto", "openrouter", "nous", "openai-codex", "zai", "kimi-coding", "minimax", "minimax-cn")
api_key: API key (default: from environment)
base_url: API base URL (default: OpenRouter)
max_turns: Maximum tool-calling iterations (default: 60)
max_turns: Maximum tool-calling iterations shared with subagents (default: 90)
verbose: Enable verbose logging
compact: Use compact display mode
resume: Session ID to resume (restores conversation history from SQLite)
@@ -853,7 +894,13 @@ class HermesCLI:
or os.getenv("OPENAI_BASE_URL")
or os.getenv("OPENROUTER_BASE_URL", CLI_CONFIG["model"]["base_url"])
)
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY")
# Match key to resolved base_url: OpenRouter URL → prefer OPENROUTER_API_KEY,
# custom endpoint → prefer OPENAI_API_KEY (issue #560).
# Note: _ensure_runtime_credentials() re-resolves this before first use.
if "openrouter.ai" in self.base_url:
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY")
else:
self.api_key = api_key or os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
self._nous_key_expires_at: Optional[str] = None
self._nous_key_source: Optional[str] = None
# Max turns priority: CLI arg > config file > env var > default
@@ -866,7 +913,7 @@ class HermesCLI:
elif os.getenv("HERMES_MAX_ITERATIONS"):
self.max_turns = int(os.getenv("HERMES_MAX_ITERATIONS"))
else:
self.max_turns = 60
self.max_turns = 90
# Parse and validate toolsets
self.enabled_toolsets = toolsets
@@ -1078,6 +1125,11 @@ class HermesCLI:
# Get terminal working directory (where commands will execute)
cwd = os.getenv("TERMINAL_CWD", os.getcwd())
# Get context length for display
ctx_len = None
if hasattr(self, 'agent') and self.agent and hasattr(self.agent, 'context_compressor'):
ctx_len = self.agent.context_compressor.context_length
# Build and display the banner
build_welcome_banner(
console=self.console,
@@ -1086,6 +1138,7 @@ class HermesCLI:
tools=tools,
enabled_toolsets=self.enabled_toolsets,
session_id=self.session_id,
context_length=ctx_len,
)
# Show tool availability warnings if any tools are disabled
@@ -1093,6 +1146,69 @@ class HermesCLI:
self.console.print()
def _try_attach_clipboard_image(self) -> bool:
"""Check clipboard for an image and attach it if found.
Saves the image to ~/.hermes/images/ and appends the path to
``_attached_images``. Returns True if an image was attached.
"""
from hermes_cli.clipboard import save_clipboard_image
img_dir = Path.home() / ".hermes" / "images"
self._image_counter += 1
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
img_path = img_dir / f"clip_{ts}_{self._image_counter}.png"
if save_clipboard_image(img_path):
self._attached_images.append(img_path)
return True
self._image_counter -= 1
return False
def _handle_paste_command(self):
"""Handle /paste — explicitly check clipboard for an image.
This is the reliable fallback for terminals where BracketedPaste
doesn't fire for image-only clipboard content (e.g., VSCode terminal,
Windows Terminal with WSL2).
"""
from hermes_cli.clipboard import has_clipboard_image
if has_clipboard_image():
if self._try_attach_clipboard_image():
n = len(self._attached_images)
_cprint(f" 📎 Image #{n} attached from clipboard")
else:
_cprint(f" {_DIM}(>_<) Clipboard has an image but extraction failed{_RST}")
else:
_cprint(f" {_DIM}(._.) No image found in clipboard{_RST}")
def _build_multimodal_content(self, text: str, images: list) -> list:
"""Convert text + image paths into OpenAI vision multimodal content.
Returns a list of content parts suitable for the ``content`` field
of a ``user`` message.
"""
import base64 as _b64
content_parts = []
text_part = text if isinstance(text, str) and text else "What do you see in this image?"
content_parts.append({"type": "text", "text": text_part})
_MIME = {
"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg",
"gif": "image/gif", "webp": "image/webp",
}
for img_path in images:
if img_path.exists():
data = _b64.b64encode(img_path.read_bytes()).decode()
ext = img_path.suffix.lower().lstrip(".")
mime = _MIME.get(ext, "image/png")
content_parts.append({
"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{data}"}
})
return content_parts
def _show_tool_availability_warnings(self):
"""Show warnings about disabled tools due to missing API keys."""
try:
@@ -1162,7 +1278,8 @@ class HermesCLI:
_cprint(f" {_GOLD}{cmd:<22}{_RST} {_DIM}-{_RST} {info['description']}")
_cprint(f"\n {_DIM}Tip: Just type your message to chat with Hermes!{_RST}")
_cprint(f" {_DIM}Multi-line: Alt+Enter for a new line{_RST}\n")
_cprint(f" {_DIM}Multi-line: Alt+Enter for a new line{_RST}")
_cprint(f" {_DIM}Paste image: Alt+V (or /paste){_RST}\n")
def show_tools(self):
"""Display available tools with kawaii ASCII art."""
@@ -1771,6 +1888,10 @@ class HermesCLI:
self._manual_compress()
elif cmd_lower == "/usage":
self._show_usage()
elif cmd_lower.startswith("/insights"):
self._show_insights(cmd_original)
elif cmd_lower == "/paste":
self._handle_paste_command()
elif cmd_lower == "/reload-mcp":
self._reload_mcp()
else:
@@ -1894,6 +2015,39 @@ class HermesCLI:
for quiet_logger in ('tools', 'minisweagent', 'run_agent', 'trajectory_compressor', 'cron', 'hermes_cli'):
logging.getLogger(quiet_logger).setLevel(logging.ERROR)
def _show_insights(self, command: str = "/insights"):
"""Show usage insights and analytics from session history."""
# Parse optional --days flag
parts = command.split()
days = 30
source = None
i = 1
while i < len(parts):
if parts[i] == "--days" and i + 1 < len(parts):
try:
days = int(parts[i + 1])
except ValueError:
print(f" Invalid --days value: {parts[i + 1]}")
return
i += 2
elif parts[i] == "--source" and i + 1 < len(parts):
source = parts[i + 1]
i += 2
else:
i += 1
try:
from hermes_state import SessionDB
from agent.insights import InsightsEngine
db = SessionDB()
engine = InsightsEngine(db)
report = engine.generate(days=days, source=source)
print(engine.format_terminal(report))
db.close()
except Exception as e:
print(f" Error generating insights: {e}")
def _reload_mcp(self):
"""Reload MCP servers: disconnect all, re-read config.yaml, reconnect.
@@ -2115,20 +2269,21 @@ class HermesCLI:
self._approval_state = None
self._approval_deadline = 0
self._invalidate()
_cprint(f"\n{_DIM} ⏱ Timeout — denying command{_RST}")
return "deny"
def chat(self, message: str) -> Optional[str]:
def chat(self, message, images: list = None) -> Optional[str]:
"""
Send a message to the agent and get a response.
Handles streaming output, interrupt detection (user typing while agent
is working), and re-queueing of interrupted messages.
Uses a dedicated _interrupt_queue (separate from _pending_input) to avoid
race conditions between the process_loop and interrupt monitoring. Messages
typed while the agent is running go to _interrupt_queue; messages typed while
idle go to _pending_input.
Args:
message: The user's message
message: The user's message (str or multimodal content list)
images: Optional list of Path objects for attached images
Returns:
The agent's response, or None on error
@@ -2141,10 +2296,19 @@ class HermesCLI:
if not self._init_agent():
return None
# Convert attached images to OpenAI vision multimodal content
if images:
message = self._build_multimodal_content(
message if isinstance(message, str) else "", images
)
for img_path in images:
if img_path.exists():
_cprint(f" {_DIM}📎 attached {img_path.name} ({img_path.stat().st_size // 1024}KB){_RST}")
# Add user message to history
self.conversation_history.append({"role": "user", "content": message})
w = self.console.width
w = shutil.get_terminal_size().columns
_cprint(f"{_GOLD}{'' * w}{_RST}")
print(flush=True)
@@ -2220,7 +2384,7 @@ class HermesCLI:
response = response + "\n\n---\n_[Interrupted - processing new message]_"
if response:
w = self.console.width
w = shutil.get_terminal_size().columns
label = " ⚕ Hermes "
fill = w - 2 - len(label) # 2 for ╭ and ╮
top = f"{_GOLD}╭─{label}{'' * max(fill - 1, 0)}{_RST}"
@@ -2305,6 +2469,10 @@ class HermesCLI:
self._approval_state = None # dict with command, description, choices, selected, response_queue
self._approval_deadline = 0
# Clipboard image attachments (paste images into the CLI)
self._attached_images: list[Path] = []
self._image_counter = 0
# Register callbacks so terminal_tool prompts route through our UI
set_sudo_password_callback(self._sudo_password_callback)
set_approval_callback(self._approval_callback)
@@ -2374,11 +2542,18 @@ class HermesCLI:
# --- Normal input routing ---
text = event.app.current_buffer.text.strip()
if text:
if self._agent_running and not text.startswith("/"):
self._interrupt_queue.put(text)
has_images = bool(self._attached_images)
if text or has_images:
# Snapshot and clear attached images
images = list(self._attached_images)
self._attached_images.clear()
event.app.invalidate()
# Bundle text + images as a tuple when images are present
payload = (text, images) if images else text
if self._agent_running and not (text and text.startswith("/")):
self._interrupt_queue.put(payload)
else:
self._pending_input.put(text)
self._pending_input.put(payload)
event.app.current_buffer.reset(append_to_history=True)
@kb.add('escape', 'enter')
@@ -2491,10 +2666,12 @@ class HermesCLI:
print("\n⚡ Interrupting agent... (press Ctrl+C again to force exit)")
self.agent.interrupt()
else:
# If there's text in the input buffer, clear it (like bash).
# If the buffer is already empty, exit.
if event.app.current_buffer.text:
# If there's text or images, clear them (like bash).
# If everything is already empty, exit.
if event.app.current_buffer.text or self._attached_images:
event.app.current_buffer.reset()
self._attached_images.clear()
event.app.invalidate()
else:
self._should_exit = True
event.app.exit()
@@ -2504,7 +2681,53 @@ class HermesCLI:
"""Handle Ctrl+D - exit."""
self._should_exit = True
event.app.exit()
from prompt_toolkit.keys import Keys
@kb.add(Keys.BracketedPaste, eager=True)
def handle_paste(event):
"""Handle terminal paste — detect clipboard images.
When the terminal supports bracketed paste, Ctrl+V / Cmd+V
triggers this with the pasted text. We also check the
clipboard for an image on every paste event.
"""
pasted_text = event.data or ""
if self._try_attach_clipboard_image():
event.app.invalidate()
if pasted_text:
event.current_buffer.insert_text(pasted_text)
@kb.add('c-v')
def handle_ctrl_v(event):
"""Fallback image paste for terminals without bracketed paste.
On Linux terminals (GNOME Terminal, Konsole, etc.), Ctrl+V
sends raw byte 0x16 instead of triggering a paste. This
binding catches that and checks the clipboard for images.
On terminals that DO intercept Ctrl+V for paste (macOS
Terminal, iTerm2, VSCode, Windows Terminal), the bracketed
paste handler fires instead and this binding never triggers.
"""
if self._try_attach_clipboard_image():
event.app.invalidate()
@kb.add('escape', 'v')
def handle_alt_v(event):
"""Alt+V — paste image from clipboard.
Alt key combos pass through all terminal emulators (sent as
ESC + key), unlike Ctrl+V which terminals intercept for text
paste. This is the reliable way to attach clipboard images
on WSL2, VSCode, and any terminal over SSH where Ctrl+V
can't reach the application for image-only clipboard.
"""
if self._try_attach_clipboard_image():
event.app.invalidate()
else:
# No image found — show a hint
pass # silent when no image (avoid noise on accidental press)
# Dynamic prompt: shows Hermes symbol when agent is working,
# or answer prompt when clarify freetext mode is active.
cli_ref = self
@@ -2540,7 +2763,7 @@ class HermesCLI:
def _input_height():
try:
doc = input_area.buffer.document
available_width = (cli_ref.console.width or 80) - 4 # subtract prompt width
available_width = shutil.get_terminal_size().columns - 4 # subtract prompt width
if available_width < 10:
available_width = 40
visual_lines = 0
@@ -2801,13 +3024,35 @@ class HermesCLI:
# Horizontal rules above and below the input (bronze, 1 line each).
# The bottom rule moves down as the TextArea grows with newlines.
# Using char='─' instead of hardcoded repetition so the rule
# always spans the full terminal width on any screen size.
input_rule_top = Window(
content=FormattedTextControl([('class:input-rule', '' * 200)]),
char='',
height=1,
style='class:input-rule',
)
input_rule_bot = Window(
content=FormattedTextControl([('class:input-rule', '' * 200)]),
char='',
height=1,
style='class:input-rule',
)
# Image attachment indicator — shows badges like [📎 Image #1] above input
cli_ref = self
def _get_image_bar():
if not cli_ref._attached_images:
return []
base = cli_ref._image_counter - len(cli_ref._attached_images) + 1
badges = " ".join(
f"[📎 Image #{base + i}]"
for i in range(len(cli_ref._attached_images))
)
return [("class:image-badge", f" {badges} ")]
image_bar = Window(
content=FormattedTextControl(_get_image_bar),
height=Condition(lambda: bool(cli_ref._attached_images)),
)
# Layout: interactive prompt widgets + ruled input at bottom.
@@ -2821,6 +3066,7 @@ class HermesCLI:
clarify_widget,
spacer,
input_rule_top,
image_bar,
input_area,
input_rule_bot,
CompletionsMenu(max_height=12, scroll_offset=1),
@@ -2836,6 +3082,8 @@ class HermesCLI:
'hint': '#555555 italic',
# Bronze horizontal rules around the input area
'input-rule': '#CD7F32',
# Clipboard image attachment badges
'image-badge': '#87CEEB bold',
'completion-menu': 'bg:#1a1a2e #FFF8DC',
'completion-menu.completion': 'bg:#1a1a2e #FFF8DC',
'completion-menu.completion.current': 'bg:#333355 #FFD700',
@@ -2885,9 +3133,14 @@ class HermesCLI:
if not user_input:
continue
# Unpack image payload: (text, [Path, ...]) or plain str
submit_images = []
if isinstance(user_input, tuple):
user_input, submit_images = user_input
# Check for commands
if user_input.startswith("/"):
if isinstance(user_input, str) and user_input.startswith("/"):
print(f"\n⚙️ {user_input}")
if not self.process_command(user_input):
self._should_exit = True
@@ -2898,7 +3151,7 @@ class HermesCLI:
# Expand paste references back to full content
import re as _re
paste_match = _re.match(r'\[Pasted text #\d+: \d+ lines → (.+)\]', user_input)
paste_match = _re.match(r'\[Pasted text #\d+: \d+ lines → (.+)\]', user_input) if isinstance(user_input, str) else None
if paste_match:
paste_path = Path(paste_match.group(1))
if paste_path.exists():
@@ -2920,12 +3173,17 @@ class HermesCLI:
print()
_cprint(f"{_GOLD}{_RST} {_BOLD}{user_input}{_RST}")
# Show image attachment count
if submit_images:
n = len(submit_images)
_cprint(f" {_DIM}📎 {n} image{'s' if n > 1 else ''} attached{_RST}")
# Regular chat - run agent
self._agent_running = True
app.invalidate() # Refresh status line
try:
self.chat(user_input)
self.chat(user_input, images=submit_images or None)
finally:
self._agent_running = False
app.invalidate() # Refresh status line
@@ -2995,7 +3253,7 @@ def main(
q: Shorthand for --query
toolsets: Comma-separated list of toolsets to enable (e.g., "web,terminal")
model: Model to use (default: anthropic/claude-opus-4-20250514)
provider: Inference provider ("auto", "openrouter", "nous")
provider: Inference provider ("auto", "openrouter", "nous", "openai-codex", "zai", "kimi-coding", "minimax", "minimax-cn")
api_key: API key for authentication
base_url: Base URL for the API
max_turns: Maximum tool-calling iterations (default: 60)

View File

@@ -14,6 +14,8 @@ from datetime import datetime, timedelta
from pathlib import Path
from typing import Optional, Dict, List, Any
from hermes_time import now as _hermes_now
try:
from croniter import croniter
HAS_CRONITER = True
@@ -128,7 +130,7 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
# Duration like "30m", "2h", "1d" → one-shot from now
try:
minutes = parse_duration(schedule)
run_at = datetime.now() + timedelta(minutes=minutes)
run_at = _hermes_now() + timedelta(minutes=minutes)
return {
"kind": "once",
"run_at": run_at.isoformat(),
@@ -146,37 +148,50 @@ def parse_schedule(schedule: str) -> Dict[str, Any]:
)
def _ensure_aware(dt: datetime) -> datetime:
"""Make a naive datetime tz-aware using the configured timezone.
Handles backward compatibility: timestamps stored before timezone support
are naive (server-local). We assume they were in the same timezone as
the current configuration so comparisons work without crashing.
"""
if dt.tzinfo is None:
tz = _hermes_now().tzinfo
return dt.replace(tzinfo=tz)
return dt
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
"""
Compute the next run time for a schedule.
Returns ISO timestamp string, or None if no more runs.
"""
now = datetime.now()
now = _hermes_now()
if schedule["kind"] == "once":
run_at = datetime.fromisoformat(schedule["run_at"])
run_at = _ensure_aware(datetime.fromisoformat(schedule["run_at"]))
# If in the future, return it; if in the past, no more runs
return schedule["run_at"] if run_at > now else None
elif schedule["kind"] == "interval":
minutes = schedule["minutes"]
if last_run_at:
# Next run is last_run + interval
last = datetime.fromisoformat(last_run_at)
last = _ensure_aware(datetime.fromisoformat(last_run_at))
next_run = last + timedelta(minutes=minutes)
else:
# First run is now + interval
next_run = now + timedelta(minutes=minutes)
return next_run.isoformat()
elif schedule["kind"] == "cron":
if not HAS_CRONITER:
return None
cron = croniter(schedule["expr"], now)
next_run = cron.get_next(datetime)
return next_run.isoformat()
return None
@@ -204,7 +219,7 @@ def save_jobs(jobs: List[Dict[str, Any]]):
fd, tmp_path = tempfile.mkstemp(dir=str(JOBS_FILE.parent), suffix='.tmp', prefix='.jobs_')
try:
with os.fdopen(fd, 'w', encoding='utf-8') as f:
json.dump({"jobs": jobs, "updated_at": datetime.now().isoformat()}, f, indent=2)
json.dump({"jobs": jobs, "updated_at": _hermes_now().isoformat()}, f, indent=2)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, JOBS_FILE)
@@ -249,7 +264,7 @@ def create_job(
deliver = "origin" if origin else "local"
job_id = uuid.uuid4().hex[:12]
now = datetime.now().isoformat()
now = _hermes_now().isoformat()
job = {
"id": job_id,
@@ -328,7 +343,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
jobs = load_jobs()
for i, job in enumerate(jobs):
if job["id"] == job_id:
now = datetime.now().isoformat()
now = _hermes_now().isoformat()
job["last_run_at"] = now
job["last_status"] = "ok" if success else "error"
job["last_error"] = error if not success else None
@@ -361,7 +376,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
def get_due_jobs() -> List[Dict[str, Any]]:
"""Get all jobs that are due to run now."""
now = datetime.now()
now = _hermes_now()
jobs = load_jobs()
due = []
@@ -373,7 +388,7 @@ def get_due_jobs() -> List[Dict[str, Any]]:
if not next_run:
continue
next_run_dt = datetime.fromisoformat(next_run)
next_run_dt = _ensure_aware(datetime.fromisoformat(next_run))
if next_run_dt <= now:
due.append(job)
@@ -386,7 +401,7 @@ def save_job_output(job_id: str, output: str):
job_output_dir = OUTPUT_DIR / job_id
job_output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
timestamp = _hermes_now().strftime("%Y-%m-%d_%H-%M-%S")
output_file = job_output_dir / f"{timestamp}.md"
with open(output_file, 'w', encoding='utf-8') as f:

View File

@@ -27,6 +27,8 @@ from datetime import datetime
from pathlib import Path
from typing import Optional
from hermes_time import now as _hermes_now
logger = logging.getLogger(__name__)
# Add parent directory to path for imports
@@ -207,7 +209,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
provider=runtime.get("provider"),
api_mode=runtime.get("api_mode"),
quiet_mode=True,
session_id=f"cron_{job_id}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
session_id=f"cron_{job_id}_{_hermes_now().strftime('%Y%m%d_%H%M%S')}"
)
result = agent.run_conversation(prompt)
@@ -219,7 +221,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
output = f"""# Cron Job: {job_name}
**Job ID:** {job_id}
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
**Schedule:** {job.get('schedule_display', 'N/A')}
## Prompt
@@ -241,7 +243,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
output = f"""# Cron Job: {job_name} (FAILED)
**Job ID:** {job_id}
**Run Time:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
**Schedule:** {job.get('schedule_display', 'N/A')}
## Prompt
@@ -280,6 +282,7 @@ def tick(verbose: bool = True) -> int:
_LOCK_DIR.mkdir(parents=True, exist_ok=True)
# Cross-platform file locking: fcntl on Unix, msvcrt on Windows
lock_fd = None
try:
lock_fd = open(_LOCK_FILE, "w")
if fcntl:
@@ -288,17 +291,19 @@ def tick(verbose: bool = True) -> int:
msvcrt.locking(lock_fd.fileno(), msvcrt.LK_NBLCK, 1)
except (OSError, IOError):
logger.debug("Tick skipped — another instance holds the lock")
if lock_fd is not None:
lock_fd.close()
return 0
try:
due_jobs = get_due_jobs()
if verbose and not due_jobs:
logger.info("%s - No jobs due", datetime.now().strftime('%H:%M:%S'))
logger.info("%s - No jobs due", _hermes_now().strftime('%H:%M:%S'))
return 0
if verbose:
logger.info("%s - %s job(s) due", datetime.now().strftime('%H:%M:%S'), len(due_jobs))
logger.info("%s - %s job(s) due", _hermes_now().strftime('%H:%M:%S'), len(due_jobs))
executed = 0
for job in due_jobs:

View File

@@ -0,0 +1,344 @@
# send_file Integration Map — Hermes Agent Codebase Deep Dive
## 1. environments/tool_context.py — Base64 File Transfer Implementation
### upload_file() (lines 153-205)
- Reads local file as raw bytes, base64-encodes to ASCII string
- Creates parent dirs in sandbox via `self.terminal(f"mkdir -p {parent}")`
- **Chunk size:** 60,000 chars (~60KB per shell command)
- **Small files (<=60KB b64):** Single `printf '%s' '{b64}' | base64 -d > {remote_path}`
- **Large files:** Writes chunks to `/tmp/_hermes_upload.b64` via `printf >> append`, then `base64 -d` to target
- **Error handling:** Checks local file exists; returns `{exit_code, output}`
- **Size limits:** No explicit limit, but shell arg limit ~2MB means chunking is necessary for files >~45KB raw
- **No theoretical max** — but very large files would be slow (many terminal round trips)
### download_file() (lines 234-278)
- Runs `base64 {remote_path}` inside sandbox, captures stdout
- Strips output, base64-decodes to raw bytes
- Writes to host filesystem with parent dir creation
- **Error handling:** Checks exit code, empty output, decode errors
- Returns `{success: bool, bytes: int}` or `{success: false, error: str}`
- **Size limit:** Bounded by terminal output buffer (practical limit ~few MB via base64 terminal output)
### Promotion potential:
- These methods work via `self.terminal()` — they're environment-agnostic
- Could be directly lifted into a new tool that operates on the agent's current sandbox
- For send_file, this `download_file()` pattern is the key: it extracts files from sandbox → host
## 2. tools/environments/base.py — BaseEnvironment Interface
### Current methods:
- `execute(command, cwd, timeout, stdin_data)``{output, returncode}`
- `cleanup()` — release resources
- `stop()` — alias for cleanup
- `_prepare_command()` — sudo transformation
- `_build_run_kwargs()` — subprocess kwargs
- `_timeout_result()` — standard timeout dict
### What would need to be added for file transfer:
- **Nothing required at this level.** File transfer can be implemented via `execute()` (base64 over terminal, like ToolContext does) or via environment-specific methods.
- Optional: `upload_file(local_path, remote_path)` and `download_file(remote_path, local_path)` methods could be added to BaseEnvironment for optimized per-backend transfers, but the base64-over-terminal approach already works universally.
## 3. tools/environments/docker.py — Docker Container Details
### Container ID tracking:
- `self._container_id` stored at init from `self._inner.container_id`
- Inner is `minisweagent.environments.docker.DockerEnvironment`
- Container ID is a standard Docker container hash
### docker cp feasibility:
- **YES**, `docker cp` could be used for optimized file transfer:
- `docker cp {container_id}:{remote_path} {local_path}` (download)
- `docker cp {local_path} {container_id}:{remote_path}` (upload)
- Much faster than base64-over-terminal for large files
- Container ID is directly accessible via `env._container_id` or `env._inner.container_id`
### Volumes mounted:
- **Persistent mode:** Bind mounts at `~/.hermes/sandboxes/docker/{task_id}/workspace``/workspace` and `.../home``/root`
- **Ephemeral mode:** tmpfs at `/workspace` (10GB), `/home` (1GB), `/root` (1GB)
- **User volumes:** From `config.yaml docker_volumes` (arbitrary `-v` mounts)
- **Security tmpfs:** `/tmp` (512MB), `/var/tmp` (256MB), `/run` (64MB)
### Direct host access for persistent mode:
- If persistent, files at `/workspace/foo.txt` are just `~/.hermes/sandboxes/docker/{task_id}/workspace/foo.txt` on host — no transfer needed!
## 4. tools/environments/ssh.py — SSH Connection Management
### Connection management:
- Uses SSH ControlMaster for persistent connection
- Control socket at `/tmp/hermes-ssh/{user}@{host}:{port}.sock`
- ControlPersist=300 (5 min keepalive)
- BatchMode=yes (non-interactive)
- Stores: `self.host`, `self.user`, `self.port`, `self.key_path`
### SCP/SFTP feasibility:
- **YES**, SCP can piggyback on the ControlMaster socket:
- `scp -o ControlPath={socket} {user}@{host}:{remote} {local}` (download)
- `scp -o ControlPath={socket} {local} {user}@{host}:{remote}` (upload)
- Same SSH key and connection reuse — zero additional auth
- Would be much faster than base64-over-terminal for large files
## 5. tools/environments/modal.py — Modal Sandbox Filesystem
### Filesystem API exposure:
- **Not directly.** The inner `SwerexModalEnvironment` wraps Modal's sandbox
- The sandbox object is accessible at: `env._inner.deployment._sandbox`
- Modal's Python SDK exposes `sandbox.open()` for file I/O — but only via async API
- Currently only used for `snapshot_filesystem()` during cleanup
- **Could use:** `sandbox.open(path, "rb")` to read files or `sandbox.open(path, "wb")` to write
- **Alternative:** Base64-over-terminal already works via `execute()` — simpler, no SDK dependency
## 6. gateway/platforms/base.py — MEDIA: Tag Flow (Complete)
### extract_media() (lines 587-620):
- **Pattern:** `MEDIA:\S+` — extracts file paths after MEDIA: prefix
- **Voice flag:** `[[audio_as_voice]]` global directive sets `is_voice=True` for all media in message
- Returns `List[Tuple[str, bool]]` (path, is_voice) and cleaned content
### _process_message_background() media routing (lines 752-786):
- After extracting MEDIA tags, routes by file extension:
- `.ogg .opus .mp3 .wav .m4a``send_voice()`
- `.mp4 .mov .avi .mkv .3gp``send_video()`
- `.jpg .jpeg .png .webp .gif``send_image_file()`
- **Everything else** → `send_document()`
- This routing already supports arbitrary files!
### send_* method inventory (base class):
- `send(chat_id, content, reply_to, metadata)` — ABSTRACT, text
- `send_image(chat_id, image_url, caption, reply_to)` — URL-based images
- `send_animation(chat_id, animation_url, caption, reply_to)` — GIF animations
- `send_voice(chat_id, audio_path, caption, reply_to)` — voice messages
- `send_video(chat_id, video_path, caption, reply_to)` — video files
- `send_document(chat_id, file_path, caption, file_name, reply_to)` — generic files
- `send_image_file(chat_id, image_path, caption, reply_to)` — local image files
- `send_typing(chat_id)` — typing indicator
- `edit_message(chat_id, message_id, content)` — edit sent messages
### What's missing:
- **Telegram:** No override for `send_document` or `send_image_file` — falls back to text!
- **Discord:** No override for `send_document` — falls back to text!
- **WhatsApp:** Has `send_document` and `send_image_file` via bridge — COMPLETE.
- The base class defaults just send "📎 File: /path" as text — useless for actual file delivery.
## 7. gateway/platforms/telegram.py — Send Method Analysis
### Implemented send methods:
- `send()` — MarkdownV2 text with fallback to plain
- `send_voice()``.ogg`/`.opus` as `send_voice()`, others as `send_audio()`
- `send_image()` — URL-based via `send_photo()`
- `send_animation()` — GIF via `send_animation()`
- `send_typing()` — "typing" chat action
- `edit_message()` — edit text messages
### MISSING:
- **`send_document()` NOT overridden** — Need to add `self._bot.send_document(chat_id, document=open(file_path, 'rb'), ...)`
- **`send_image_file()` NOT overridden** — Need to add `self._bot.send_photo(chat_id, photo=open(path, 'rb'), ...)`
- **`send_video()` NOT overridden** — Need to add `self._bot.send_video(...)`
## 8. gateway/platforms/discord.py — Send Method Analysis
### Implemented send methods:
- `send()` — text messages with chunking
- `send_voice()` — discord.File attachment
- `send_image()` — downloads URL, creates discord.File attachment
- `send_typing()` — channel.typing()
- `edit_message()` — edit text messages
### MISSING:
- **`send_document()` NOT overridden** — Need to add discord.File attachment
- **`send_image_file()` NOT overridden** — Need to add discord.File from local path
- **`send_video()` NOT overridden** — Need to add discord.File attachment
## 9. gateway/run.py — User File Attachment Handling
### Current attachment flow:
1. **Telegram photos** (line 509-529): Download via `photo.get_file()``cache_image_from_bytes()` → vision auto-analysis
2. **Telegram voice** (line 532-541): Download → `cache_audio_from_bytes()` → STT transcription
3. **Telegram audio** (line 542-551): Same pattern
4. **Telegram documents** (line 553-617): Extension validation against `SUPPORTED_DOCUMENT_TYPES`, 20MB limit, content injection for text files
5. **Discord attachments** (line 717-751): Content-type detection, image/audio caching, URL fallback for other types
6. **Gateway run.py** (lines 818-883): Auto-analyzes images with vision, transcribes audio, enriches document messages with context notes
### Key insight: Files are always cached to host filesystem first, then processed. The agent sees local file paths.
## 10. tools/terminal_tool.py — Terminal Tool & Environment Interaction
### How it manages environments:
- Global dict `_active_environments: Dict[str, Any]` keyed by task_id
- Per-task creation locks prevent duplicate sandbox creation
- Auto-cleanup thread kills idle environments after `TERMINAL_LIFETIME_SECONDS`
- `_get_env_config()` reads all TERMINAL_* env vars for backend selection
- `_create_environment()` factory creates the right backend type
### Could send_file piggyback?
- **YES.** send_file needs access to the same environment to extract files from sandboxes.
- It can reuse `_active_environments[task_id]` to get the environment, then:
- Docker: Use `docker cp` via `env._container_id`
- SSH: Use `scp` via `env.control_socket`
- Local: Just read the file directly
- Modal: Use base64-over-terminal via `env.execute()`
- The file_tools.py module already does this with `ShellFileOperations` — read_file/write_file/search/patch all share the same env instance.
## 11. tools/tts_tool.py — Working Example of File Delivery
### Flow:
1. Generate audio file to `~/.hermes/audio_cache/tts_TIMESTAMP.{ogg,mp3}`
2. Return JSON with `media_tag: "MEDIA:/path/to/file"`
3. For Telegram voice: prepend `[[audio_as_voice]]` directive
4. The LLM includes the MEDIA tag in its response text
5. `BasePlatformAdapter._process_message_background()` calls `extract_media()` to find the tag
6. Routes by extension → `send_voice()` for audio files
7. Platform adapter sends the file natively
### Key pattern: Tool saves file to host → returns MEDIA: path → LLM echoes it → gateway extracts → platform delivers
## 12. tools/image_generation_tool.py — Working Example of Image Delivery
### Flow:
1. Call FAL.ai API → get image URL
2. Return JSON with `image: "https://fal.media/..."` URL
3. The LLM includes the URL in markdown: `![description](URL)`
4. `BasePlatformAdapter.extract_images()` finds `![alt](url)` patterns
5. Routes through `send_image()` (URL) or `send_animation()` (GIF)
6. Platform downloads and sends natively
### Key difference from TTS: Images are URL-based, not local files. The gateway downloads at send time.
---
# INTEGRATION MAP: Where send_file Hooks In
## Architecture Decision: MEDIA: Tag Protocol vs. New Tool
The MEDIA: tag protocol is already the established pattern for file delivery. Two options:
### Option A: Pure MEDIA: Tag (Minimal Change)
- No new tool needed
- Agent downloads file from sandbox to host using terminal (base64)
- Saves to known location (e.g., `~/.hermes/file_cache/`)
- Includes `MEDIA:/path` in response text
- Existing routing in `_process_message_background()` handles delivery
- **Problem:** Agent has to manually do base64 dance + know about MEDIA: convention
### Option B: Dedicated send_file Tool (Recommended)
- New tool that the agent calls with `(file_path, caption?)`
- Tool handles the sandbox → host extraction automatically
- Returns MEDIA: tag that gets routed through existing pipeline
- Much cleaner agent experience
## Implementation Plan for Option B
### Files to CREATE:
1. **`tools/send_file_tool.py`** — The new tool
- Accepts: `file_path` (path in sandbox), `caption` (optional)
- Detects environment backend from `_active_environments`
- Extracts file from sandbox:
- **local:** `shutil.copy()` or direct path
- **docker:** `docker cp {container_id}:{path} {local_cache}/`
- **ssh:** `scp -o ControlPath=... {user}@{host}:{path} {local_cache}/`
- **modal:** base64-over-terminal via `env.execute("base64 {path}")`
- Saves to `~/.hermes/file_cache/{uuid}_{filename}`
- Returns: `MEDIA:/cached/path` in response for gateway to pick up
- Register with `registry.register(name="send_file", toolset="file", ...)`
### Files to MODIFY:
2. **`gateway/platforms/telegram.py`** — Add missing send methods:
```python
async def send_document(self, chat_id, file_path, caption=None, file_name=None, reply_to=None):
with open(file_path, "rb") as f:
msg = await self._bot.send_document(
chat_id=int(chat_id), document=f,
caption=caption, filename=file_name or os.path.basename(file_path))
return SendResult(success=True, message_id=str(msg.message_id))
async def send_image_file(self, chat_id, image_path, caption=None, reply_to=None):
with open(image_path, "rb") as f:
msg = await self._bot.send_photo(chat_id=int(chat_id), photo=f, caption=caption)
return SendResult(success=True, message_id=str(msg.message_id))
async def send_video(self, chat_id, video_path, caption=None, reply_to=None):
with open(video_path, "rb") as f:
msg = await self._bot.send_video(chat_id=int(chat_id), video=f, caption=caption)
return SendResult(success=True, message_id=str(msg.message_id))
```
3. **`gateway/platforms/discord.py`** — Add missing send methods:
```python
async def send_document(self, chat_id, file_path, caption=None, file_name=None, reply_to=None):
channel = self._client.get_channel(int(chat_id)) or await self._client.fetch_channel(int(chat_id))
with open(file_path, "rb") as f:
file = discord.File(io.BytesIO(f.read()), filename=file_name or os.path.basename(file_path))
msg = await channel.send(content=caption, file=file)
return SendResult(success=True, message_id=str(msg.id))
async def send_image_file(self, chat_id, image_path, caption=None, reply_to=None):
# Same pattern as send_document with image filename
async def send_video(self, chat_id, video_path, caption=None, reply_to=None):
# Same pattern, discord renders video attachments inline
```
4. **`toolsets.py`** — Add `"send_file"` to `_HERMES_CORE_TOOLS` list
5. **`agent/prompt_builder.py`** — Update platform hints to mention send_file tool
### Code that can be REUSED (zero rewrite):
- `BasePlatformAdapter.extract_media()` — Already extracts MEDIA: tags
- `BasePlatformAdapter._process_message_background()` — Already routes by extension
- `ToolContext.download_file()` — Base64-over-terminal extraction pattern
- `tools/terminal_tool.py` _active_environments dict — Environment access
- `tools/registry.py` — Tool registration infrastructure
- `gateway/platforms/base.py` send_document/send_image_file/send_video signatures — Already defined
### Code that needs to be WRITTEN from scratch:
1. `tools/send_file_tool.py` (~150 lines):
- File extraction from each environment backend type
- Local file cache management
- Registry registration
2. Telegram `send_document` + `send_image_file` + `send_video` overrides (~40 lines)
3. Discord `send_document` + `send_image_file` + `send_video` overrides (~50 lines)
### Total effort: ~240 lines of new code, ~5 lines of config changes
## Key Environment-Specific Extract Strategies
| Backend | Extract Method | Speed | Complexity |
|------------|-------------------------------|----------|------------|
| local | shutil.copy / direct path | Instant | None |
| docker | `docker cp container:path .` | Fast | Low |
| docker+vol | Direct host path access | Instant | None |
| ssh | `scp -o ControlPath=...` | Fast | Low |
| modal | base64-over-terminal | Moderate | Medium |
| singularity| Direct path (overlay mount) | Fast | Low |
## Data Flow Summary
```
Agent calls send_file(file_path="/workspace/output.pdf", caption="Here's the report")
send_file_tool.py:
1. Get environment from _active_environments[task_id]
2. Detect backend type (docker/ssh/modal/local)
3. Extract file to ~/.hermes/file_cache/{uuid}_{filename}
4. Return: '{"success": true, "media_tag": "MEDIA:/home/user/.hermes/file_cache/abc123_output.pdf"}'
LLM includes MEDIA: tag in its response text
BasePlatformAdapter._process_message_background():
1. extract_media(response) → finds MEDIA:/path
2. Checks extension: .pdf → send_document()
3. Calls platform-specific send_document(chat_id, file_path, caption)
TelegramAdapter.send_document() / DiscordAdapter.send_document():
Opens file, sends via platform API as native document attachment
User receives downloadable file in chat
```

View File

@@ -40,7 +40,7 @@ This directory contains the integration layer between **hermes-agent's** tool-ca
- `evaluate_log()` for saving eval results to JSON + samples.jsonl
**HermesAgentBaseEnv** (`hermes_base_env.py`) extends BaseEnv with hermes-agent specifics:
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, modal, ssh, singularity)
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, modal, daytona, ssh, singularity)
- Resolves hermes-agent toolsets via `_resolve_tools_for_group()` (calls `get_tool_definitions()` which queries `tools/registry.py`)
- Implements `collect_trajectory()` which runs the full agent loop and computes rewards
- Supports two-phase operation (Phase 1: OpenAI server, Phase 2: VLLM ManagedServer)
@@ -195,8 +195,12 @@ environments/
│ └── hermes_swe_env.py
└── benchmarks/ # Evaluation benchmarks
── terminalbench_2/
└── terminalbench2_env.py
── terminalbench_2/ # 89 terminal tasks, Modal sandboxes
└── terminalbench2_env.py
├── tblite/ # 100 calibrated tasks (fast TB2 proxy)
│ └── tblite_env.py
└── yc_bench/ # Long-horizon strategic benchmark
└── yc_bench_env.py
```
## Concrete Environments
@@ -324,7 +328,7 @@ For eval benchmarks, follow the pattern in `terminalbench2_env.py`:
| `distribution` | Probabilistic toolset distribution name | `None` |
| `max_agent_turns` | Max LLM calls per rollout | `30` |
| `agent_temperature` | Sampling temperature | `1.0` |
| `terminal_backend` | `local`, `docker`, `modal`, `ssh`, `singularity` | `local` |
| `terminal_backend` | `local`, `docker`, `modal`, `daytona`, `ssh`, `singularity` | `local` |
| `system_prompt` | System message for the agent | `None` |
| `tool_call_parser` | Parser name for Phase 2 | `hermes` |
| `eval_handling` | `STOP_TRAIN`, `LIMIT_TRAIN`, `NONE` | `STOP_TRAIN` |

View File

@@ -23,7 +23,7 @@ from typing import Any, Dict, List, Optional, Set
from model_tools import handle_function_call
# Thread pool for running sync tool calls that internally use asyncio.run()
# (e.g., mini-swe-agent's modal/docker backends). Running them in a separate
# (e.g., mini-swe-agent's modal/docker/daytona backends). Running them in a separate
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
# Size must be large enough for concurrent eval tasks (e.g., 89 TB2 tasks all
# making tool calls). Too small = thread pool starvation, tasks queue for minutes.
@@ -336,7 +336,7 @@ class HermesAgentLoop:
tool_elapsed = _time.monotonic() - tool_submit_time
else:
# Run tool calls in a thread pool so backends that
# use asyncio.run() internally (modal, docker) get
# use asyncio.run() internally (modal, docker, daytona) get
# a clean event loop instead of deadlocking.
loop = asyncio.get_event_loop()
# Capture current tool_name/args for the lambda

View File

@@ -0,0 +1,115 @@
# YC-Bench: Long-Horizon Agent Benchmark
[YC-Bench](https://github.com/collinear-ai/yc-bench) by [Collinear AI](https://collinear.ai/) is a deterministic, long-horizon benchmark that tests LLM agents' ability to act as a tech startup CEO. The agent manages a simulated company over 1-3 years, making compounding decisions about resource allocation, cash flow, task management, and prestige specialisation across 4 skill domains.
Unlike TerminalBench2 (which evaluates per-task coding ability with binary pass/fail), YC-Bench measures **long-term strategic coherence** — whether an agent can maintain consistent strategy, manage compounding consequences, and adapt plans over hundreds of turns.
## Setup
```bash
# Install yc-bench (optional dependency)
pip install "hermes-agent[yc-bench]"
# Or install from source
git clone https://github.com/collinear-ai/yc-bench
cd yc-bench && pip install -e .
# Verify
yc-bench --help
```
## Running
```bash
# From the repo root:
bash environments/benchmarks/yc_bench/run_eval.sh
# Or directly:
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
--config environments/benchmarks/yc_bench/default.yaml
# Override model:
bash environments/benchmarks/yc_bench/run_eval.sh \
--openai.model_name anthropic/claude-opus-4-20250514
# Quick single-preset test:
bash environments/benchmarks/yc_bench/run_eval.sh \
--env.presets '["fast_test"]' --env.seeds '[1]'
```
## How It Works
### Architecture
```
HermesAgentLoop (our agent)
-> terminal tool -> subprocess("yc-bench company status") -> JSON output
-> terminal tool -> subprocess("yc-bench task accept --task-id X") -> JSON
-> terminal tool -> subprocess("yc-bench sim resume") -> JSON (advance time)
-> ... (100-500 turns per run)
```
The environment initialises the simulation via `yc-bench sim init` (NOT `yc-bench run`, which would start yc-bench's own built-in agent loop). Our `HermesAgentLoop` then drives all interaction through CLI commands.
### Simulation Mechanics
- **4 skill domains**: research, inference, data_environment, training
- **Prestige system** (1.0-10.0): Gates access to higher-paying tasks
- **Employee management**: Junior/Mid/Senior with domain-specific skill rates
- **Throughput splitting**: `effective_rate = base_rate / N` active tasks per employee
- **Financial pressure**: Monthly payroll, bankruptcy = game over
- **Deterministic**: SHA256-based RNG — same seed + preset = same world
### Difficulty Presets
| Preset | Employees | Tasks | Focus |
|-----------|-----------|-------|-------|
| tutorial | 3 | 50 | Basic loop mechanics |
| easy | 5 | 100 | Throughput awareness |
| **medium**| 5 | 150 | Prestige climbing + domain specialisation |
| **hard** | 7 | 200 | Precise ETA reasoning |
| nightmare | 8 | 300 | Sustained perfection under payroll pressure |
| fast_test | (varies) | (varies) | Quick validation (~50 turns) |
Default eval runs **fast_test + medium + hard** × 3 seeds = 9 runs.
### Scoring
```
composite = 0.5 × survival + 0.5 × normalised_funds
```
- **Survival** (binary): Did the company avoid bankruptcy?
- **Normalised funds** (0.0-1.0): Log-scale relative to initial $250K capital
## Configuration
Key fields in `default.yaml`:
| Field | Default | Description |
|-------|---------|-------------|
| `presets` | `["fast_test", "medium", "hard"]` | Which presets to evaluate |
| `seeds` | `[1, 2, 3]` | RNG seeds per preset |
| `max_agent_turns` | 200 | Max LLM calls per run |
| `run_timeout` | 3600 | Wall-clock timeout per run (seconds) |
| `survival_weight` | 0.5 | Weight of survival in composite score |
| `funds_weight` | 0.5 | Weight of normalised funds in composite |
| `horizon_years` | null | Override horizon (null = auto from preset) |
## Cost & Time Estimates
Each run is 100-500 LLM turns. Approximate costs per run at typical API rates:
| Preset | Turns | Time | Est. Cost |
|--------|-------|------|-----------|
| fast_test | ~50 | 5-10 min | $1-5 |
| medium | ~200 | 20-40 min | $5-15 |
| hard | ~300 | 30-60 min | $10-25 |
Full default eval (9 runs): ~3-6 hours, $50-200 depending on model.
## References
- [collinear-ai/yc-bench](https://github.com/collinear-ai/yc-bench) — Official repository
- [Collinear AI](https://collinear.ai/) — Company behind yc-bench
- [TerminalBench2](../terminalbench_2/) — Per-task coding benchmark (complementary)

View File

@@ -0,0 +1,43 @@
# YC-Bench Evaluation -- Default Configuration
#
# Long-horizon agent benchmark: agent plays CEO of an AI startup over
# a simulated 1-3 year run, interacting via yc-bench CLI subcommands.
#
# Requires: pip install "hermes-agent[yc-bench]"
#
# Usage:
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
# --config environments/benchmarks/yc_bench/default.yaml
#
# # Override model:
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
# --config environments/benchmarks/yc_bench/default.yaml \
# --openai.model_name anthropic/claude-opus-4-20250514
env:
enabled_toolsets: ["terminal"]
max_agent_turns: 200
max_token_length: 32000
agent_temperature: 0.0
terminal_backend: "local"
terminal_timeout: 60
presets: ["fast_test", "medium", "hard"]
seeds: [1, 2, 3]
run_timeout: 3600 # 60 min wall-clock per run, auto-FAIL if exceeded
survival_weight: 0.5 # weight of binary survival in composite score
funds_weight: 0.5 # weight of normalised final funds in composite score
db_dir: "/tmp/yc_bench_dbs"
company_name: "BenchCo"
start_date: "01/01/2025" # MM/DD/YYYY (yc-bench convention)
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
use_wandb: true
wandb_name: "yc-bench"
ensure_scores_are_not_same: false
data_dir_to_save_evals: "environments/benchmarks/evals/yc-bench"
openai:
base_url: "https://openrouter.ai/api/v1"
model_name: "anthropic/claude-sonnet-4.6"
server_type: "openai"
health_check: false
# api_key loaded from OPENROUTER_API_KEY in .env

View File

@@ -0,0 +1,34 @@
#!/bin/bash
# YC-Bench Evaluation
#
# Requires: pip install "hermes-agent[yc-bench]"
#
# Run from repo root:
# bash environments/benchmarks/yc_bench/run_eval.sh
#
# Override model:
# bash environments/benchmarks/yc_bench/run_eval.sh \
# --openai.model_name anthropic/claude-opus-4-20250514
#
# Run a single preset:
# bash environments/benchmarks/yc_bench/run_eval.sh \
# --env.presets '["fast_test"]' --env.seeds '[1]'
set -euo pipefail
mkdir -p logs evals/yc-bench
LOG_FILE="logs/yc_bench_$(date +%Y%m%d_%H%M%S).log"
echo "YC-Bench Evaluation"
echo "Log: $LOG_FILE"
echo ""
PYTHONUNBUFFERED=1 LOGLEVEL="${LOGLEVEL:-INFO}" \
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
--config environments/benchmarks/yc_bench/default.yaml \
"$@" \
2>&1 | tee "$LOG_FILE"
echo ""
echo "Log saved to: $LOG_FILE"

View File

@@ -0,0 +1,847 @@
"""
YCBenchEvalEnv -- YC-Bench Long-Horizon Agent Benchmark Environment
Evaluates agentic LLMs on YC-Bench: a deterministic, long-horizon benchmark
where the agent acts as CEO of an AI startup over a simulated 1-3 year run.
The agent manages cash flow, employees, tasks, and prestige across 4 domains,
interacting exclusively via CLI subprocess calls against a SQLite-backed
discrete-event simulation.
Unlike TerminalBench2 (per-task binary pass/fail), YC-Bench measures sustained
multi-turn strategic coherence -- whether an agent can manage compounding
decisions over hundreds of turns without going bankrupt.
This is an eval-only environment. Run via:
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
--config environments/benchmarks/yc_bench/default.yaml
The evaluate flow:
1. setup() -- Verifies yc-bench installed, builds eval matrix (preset x seed)
2. evaluate() -- Iterates over all runs sequentially through:
a. rollout_and_score_eval() -- Per-run agent loop
- Initialises a fresh yc-bench simulation via `sim init` (NOT `run`)
- Runs HermesAgentLoop with terminal tool only
- Reads final SQLite DB to extract score
- Returns survival (0/1) + normalised funds score
b. Aggregates per-preset and overall metrics
c. Logs results via evaluate_log() and wandb
Key features:
- CLI-only interface: agent calls yc-bench subcommands via terminal tool
- Deterministic: same seed + preset = same world (SHA256-based RNG)
- Multi-dimensional scoring: survival + normalised final funds
- Per-preset difficulty breakdown in results
- Isolated SQLite DB per run (no cross-run state leakage)
Requires: pip install hermes-agent[yc-bench]
"""
import asyncio
import datetime
import json
import logging
import math
import os
import sqlite3
import subprocess
import sys
import threading
import time
import uuid
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
from pydantic import Field
from atroposlib.envs.base import EvalHandlingEnum
from atroposlib.envs.server_handling.server_manager import APIServerConfig
from environments.agent_loop import HermesAgentLoop
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
logger = logging.getLogger(__name__)
# =============================================================================
# System prompt
# =============================================================================
YC_BENCH_SYSTEM_PROMPT = """\
You are the autonomous CEO of an early-stage AI startup in a deterministic
business simulation. You manage the company exclusively through the `yc-bench`
CLI tool. Your primary goal is to **survive** until the simulation horizon ends
without going bankrupt, while **maximising final funds**.
## Simulation Mechanics
- **Funds**: You start with $250,000 seed capital. Revenue comes from completing
tasks. Rewards scale with your prestige: `base × (1 + scale × (prestige 1))`.
- **Domains**: There are 4 skill domains: **research**, **inference**,
**data_environment**, and **training**. Each has its own prestige level
(1.0-10.0). Higher prestige unlocks better-paying tasks.
- **Employees**: You have employees (Junior/Mid/Senior) with domain-specific
skill rates. **Throughput splits**: `effective_rate = base_rate / N` where N
is the number of active tasks assigned to that employee. Focus beats breadth.
- **Payroll**: Deducted automatically on the first business day of each month.
Running out of funds = bankruptcy = game over.
- **Time**: The simulation runs on business days (Mon-Fri), 09:00-18:00.
Time only advances when you call `yc-bench sim resume`.
## Task Lifecycle
1. Browse market tasks with `market browse`
2. Accept a task with `task accept` (this sets its deadline)
3. Assign employees with `task assign`
4. Dispatch with `task dispatch` to start work
5. Call `sim resume` to advance time and let employees make progress
6. Tasks complete when all domain requirements are fulfilled
**Penalties for failure vary by difficulty preset.** Completing a task on time
earns full reward + prestige gain. Missing a deadline or cancelling a task
incurs prestige penalties -- cancelling is always more costly than letting a
task fail, so cancel only as a last resort.
## CLI Commands
### Observe
- `yc-bench company status` -- funds, prestige, runway
- `yc-bench employee list` -- skills, salary, active tasks
- `yc-bench market browse [--domain D] [--required-prestige-lte N]` -- available tasks
- `yc-bench task list [--status active|planned]` -- your tasks
- `yc-bench task inspect --task-id UUID` -- progress, deadline, assignments
- `yc-bench finance ledger [--category monthly_payroll|task_reward]` -- transaction history
- `yc-bench report monthly` -- monthly P&L
### Act
- `yc-bench task accept --task-id UUID` -- accept from market
- `yc-bench task assign --task-id UUID --employee-id UUID` -- assign employee
- `yc-bench task dispatch --task-id UUID` -- start work (needs >=1 assignment)
- `yc-bench task cancel --task-id UUID --reason "text"` -- cancel (prestige penalty)
- `yc-bench sim resume` -- advance simulation clock
### Memory (persists across context truncation)
- `yc-bench scratchpad read` -- read your persistent notes
- `yc-bench scratchpad write --content "text"` -- overwrite notes
- `yc-bench scratchpad append --content "text"` -- append to notes
- `yc-bench scratchpad clear` -- clear notes
## Strategy Guidelines
1. **Specialise in 2-3 domains** to climb the prestige ladder faster and unlock
high-reward tasks. Don't spread thin across all 4 domains early on.
2. **Focus employees** -- assigning one employee to many tasks halves their
throughput per additional task. Keep assignments concentrated.
3. **Use the scratchpad** to track your strategy, upcoming deadlines, and
employee assignments. This persists even if conversation context is truncated.
4. **Monitor runway** -- always know how many months of payroll you can cover.
Accept high-reward tasks before payroll dates.
5. **Don't over-accept** -- taking too many tasks and missing deadlines cascades
into prestige loss, locking you out of profitable contracts.
6. Use `finance ledger` and `report monthly` to track revenue trends.
## Your Turn
Each turn:
1. Call `yc-bench company status` and `yc-bench task list` to orient yourself.
2. Check for completed tasks and pending deadlines.
3. Browse market for profitable tasks within your prestige level.
4. Accept, assign, and dispatch tasks strategically.
5. Call `yc-bench sim resume` to advance time.
6. Repeat until the simulation ends.
Think step by step before acting."""
# Starting funds in cents ($250,000)
INITIAL_FUNDS_CENTS = 25_000_000
# Default horizon per preset (years)
_PRESET_HORIZONS = {
"tutorial": 1,
"easy": 1,
"medium": 1,
"hard": 1,
"nightmare": 1,
"fast_test": 1,
"default": 3,
"high_reward": 1,
}
# =============================================================================
# Configuration
# =============================================================================
class YCBenchEvalConfig(HermesAgentEnvConfig):
"""
Configuration for the YC-Bench evaluation environment.
Extends HermesAgentEnvConfig with YC-Bench-specific settings for
preset selection, seed control, scoring, and simulation parameters.
"""
presets: List[str] = Field(
default=["fast_test", "medium", "hard"],
description="YC-Bench preset names to evaluate.",
)
seeds: List[int] = Field(
default=[1, 2, 3],
description="Random seeds -- each preset x seed = one run.",
)
run_timeout: int = Field(
default=3600,
description="Maximum wall-clock seconds per run. Default 60 minutes.",
)
survival_weight: float = Field(
default=0.5,
description="Weight of survival (0/1) in composite score.",
)
funds_weight: float = Field(
default=0.5,
description="Weight of normalised final funds in composite score.",
)
db_dir: str = Field(
default="/tmp/yc_bench_dbs",
description="Directory for per-run SQLite databases.",
)
horizon_years: Optional[int] = Field(
default=None,
description=(
"Simulation horizon in years. If None (default), inferred from "
"preset name (1 year for most, 3 for 'default')."
),
)
company_name: str = Field(
default="BenchCo",
description="Name of the simulated company.",
)
start_date: str = Field(
default="01/01/2025",
description="Simulation start date in MM/DD/YYYY format (yc-bench convention).",
)
# =============================================================================
# Scoring helpers
# =============================================================================
def _read_final_score(db_path: str) -> Dict[str, Any]:
"""
Read final game state from a YC-Bench SQLite database.
Returns dict with final_funds_cents (int), survived (bool),
terminal_reason (str).
Note: yc-bench table names are plural -- 'companies' not 'company',
'sim_events' not 'simulation_log'.
"""
if not os.path.exists(db_path):
logger.warning("DB not found at %s", db_path)
return {
"final_funds_cents": 0,
"survived": False,
"terminal_reason": "db_missing",
}
conn = None
try:
conn = sqlite3.connect(db_path)
cur = conn.cursor()
# Read final funds from the 'companies' table
cur.execute("SELECT funds_cents FROM companies LIMIT 1")
row = cur.fetchone()
funds = row[0] if row else 0
# Determine terminal reason from 'sim_events' table
terminal_reason = "unknown"
try:
cur.execute(
"SELECT event_type FROM sim_events "
"WHERE event_type IN ('bankruptcy', 'horizon_end') "
"ORDER BY scheduled_at DESC LIMIT 1"
)
event_row = cur.fetchone()
if event_row:
terminal_reason = event_row[0]
except sqlite3.OperationalError:
# Table may not exist if simulation didn't progress
pass
survived = funds >= 0 and terminal_reason != "bankruptcy"
return {
"final_funds_cents": funds,
"survived": survived,
"terminal_reason": terminal_reason,
}
except Exception as e:
logger.error("Failed to read DB %s: %s", db_path, e)
return {
"final_funds_cents": 0,
"survived": False,
"terminal_reason": f"db_error: {e}",
}
finally:
if conn:
conn.close()
def _compute_composite_score(
final_funds_cents: int,
survived: bool,
survival_weight: float = 0.5,
funds_weight: float = 0.5,
initial_funds_cents: int = INITIAL_FUNDS_CENTS,
) -> float:
"""
Compute composite score from survival and final funds.
Score = survival_weight * survival_score
+ funds_weight * normalised_funds_score
Normalised funds uses log-scale relative to initial capital:
- funds <= 0: 0.0
- funds == initial: ~0.15
- funds == 10x: ~0.52
- funds == 100x: 1.0
"""
survival_score = 1.0 if survived else 0.0
if final_funds_cents <= 0:
funds_score = 0.0
else:
max_ratio = 100.0
ratio = final_funds_cents / max(initial_funds_cents, 1)
funds_score = min(math.log1p(ratio) / math.log1p(max_ratio), 1.0)
return survival_weight * survival_score + funds_weight * funds_score
# =============================================================================
# Main Environment
# =============================================================================
class YCBenchEvalEnv(HermesAgentBaseEnv):
"""
YC-Bench long-horizon agent benchmark environment (eval-only).
Each eval item is a (preset, seed) pair. The environment initialises the
simulation via ``yc-bench sim init`` (NOT ``yc-bench run`` which would start
a competing built-in agent loop). The HermesAgentLoop then drives the
interaction by calling individual yc-bench CLI commands via the terminal tool.
After the agent loop ends, the SQLite DB is read to extract the final score.
Scoring:
composite = 0.5 * survival + 0.5 * normalised_funds
"""
name = "yc-bench"
env_config_cls = YCBenchEvalConfig
@classmethod
def config_init(cls) -> Tuple[YCBenchEvalConfig, List[APIServerConfig]]:
env_config = YCBenchEvalConfig(
enabled_toolsets=["terminal"],
disabled_toolsets=None,
distribution=None,
max_agent_turns=200,
max_token_length=32000,
agent_temperature=0.0,
system_prompt=YC_BENCH_SYSTEM_PROMPT,
terminal_backend="local",
terminal_timeout=60,
presets=["fast_test", "medium", "hard"],
seeds=[1, 2, 3],
run_timeout=3600,
survival_weight=0.5,
funds_weight=0.5,
db_dir="/tmp/yc_bench_dbs",
eval_handling=EvalHandlingEnum.STOP_TRAIN,
group_size=1,
steps_per_eval=1,
total_steps=1,
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
use_wandb=True,
wandb_name="yc-bench",
ensure_scores_are_not_same=False,
)
server_configs = [
APIServerConfig(
base_url="https://openrouter.ai/api/v1",
model_name="anthropic/claude-sonnet-4.6",
server_type="openai",
api_key=os.getenv("OPENROUTER_API_KEY", ""),
health_check=False,
)
]
return env_config, server_configs
# =========================================================================
# Setup
# =========================================================================
async def setup(self):
"""Verify yc-bench is installed and build the eval matrix."""
# Verify yc-bench CLI is available
try:
result = subprocess.run(
["yc-bench", "--help"], capture_output=True, text=True, timeout=10
)
if result.returncode != 0:
raise FileNotFoundError
except (FileNotFoundError, subprocess.TimeoutExpired):
raise RuntimeError(
"yc-bench CLI not found. Install with:\n"
' pip install "hermes-agent[yc-bench]"\n'
"Or: git clone https://github.com/collinear-ai/yc-bench "
"&& cd yc-bench && pip install -e ."
)
print("yc-bench CLI verified.")
# Build eval matrix: preset x seed
self.all_eval_items = [
{"preset": preset, "seed": seed}
for preset in self.config.presets
for seed in self.config.seeds
]
self.iter = 0
os.makedirs(self.config.db_dir, exist_ok=True)
self.eval_metrics: List[Tuple[str, float]] = []
# Streaming JSONL log for crash-safe result persistence
log_dir = os.path.join(os.path.dirname(__file__), "logs")
os.makedirs(log_dir, exist_ok=True)
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
self._streaming_file = open(self._streaming_path, "w")
self._streaming_lock = threading.Lock()
print(f"\nYC-Bench eval matrix: {len(self.all_eval_items)} runs")
for item in self.all_eval_items:
print(f" preset={item['preset']!r} seed={item['seed']}")
print(f"Streaming results to: {self._streaming_path}\n")
def _save_result(self, result: Dict[str, Any]):
"""Write a single run result to the streaming JSONL file immediately."""
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
return
with self._streaming_lock:
self._streaming_file.write(
json.dumps(result, ensure_ascii=False, default=str) + "\n"
)
self._streaming_file.flush()
# =========================================================================
# Training pipeline stubs (eval-only -- not used)
# =========================================================================
async def get_next_item(self):
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
self.iter += 1
return item
def format_prompt(self, item: Dict[str, Any]) -> str:
preset = item["preset"]
seed = item["seed"]
return (
f"A new YC-Bench simulation has been initialized "
f"(preset='{preset}', seed={seed}).\n"
f"Your company '{self.config.company_name}' is ready.\n\n"
"Begin by calling:\n"
"1. `yc-bench company status` -- see your starting funds and prestige\n"
"2. `yc-bench employee list` -- see your team and their skills\n"
"3. `yc-bench market browse --required-prestige-lte 1` -- find tasks "
"you can take\n\n"
"Then accept 2-3 tasks, assign employees, dispatch them, and call "
"`yc-bench sim resume` to advance time. Repeat this loop until the "
"simulation ends (horizon reached or bankruptcy)."
)
async def compute_reward(self, item, result, ctx) -> float:
return 0.0
async def collect_trajectories(self, item):
return None, []
async def score(self, rollout_group_data):
return None
# =========================================================================
# Per-run evaluation
# =========================================================================
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
"""
Evaluate a single (preset, seed) run.
1. Sets DATABASE_URL and YC_BENCH_EXPERIMENT env vars
2. Initialises the simulation via ``yc-bench sim init`` (NOT ``run``)
3. Runs HermesAgentLoop with terminal tool
4. Reads SQLite DB to compute final score
5. Returns result dict with survival, funds, and composite score
"""
preset = eval_item["preset"]
seed = eval_item["seed"]
run_id = str(uuid.uuid4())[:8]
run_key = f"{preset}_seed{seed}_{run_id}"
from tqdm import tqdm
tqdm.write(f" [START] preset={preset!r} seed={seed} (run_id={run_id})")
run_start = time.time()
# Isolated DB per run -- prevents cross-run state leakage
db_path = os.path.join(self.config.db_dir, f"yc_bench_{run_key}.db")
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
os.environ["YC_BENCH_EXPERIMENT"] = preset
# Determine horizon: explicit config override > preset lookup > default 1
horizon = self.config.horizon_years or _PRESET_HORIZONS.get(preset, 1)
try:
# ----------------------------------------------------------
# Step 1: Initialise the simulation via CLI
# IMPORTANT: We use `sim init`, NOT `yc-bench run`.
# `yc-bench run` starts yc-bench's own LLM agent loop (via
# LiteLLM), which would compete with our HermesAgentLoop.
# `sim init` just sets up the world and returns.
# ----------------------------------------------------------
init_cmd = [
"yc-bench", "sim", "init",
"--seed", str(seed),
"--start-date", self.config.start_date,
"--company-name", self.config.company_name,
"--horizon-years", str(horizon),
]
init_result = subprocess.run(
init_cmd, capture_output=True, text=True, timeout=30,
)
if init_result.returncode != 0:
error_msg = (init_result.stderr or init_result.stdout).strip()
raise RuntimeError(f"yc-bench sim init failed: {error_msg}")
tqdm.write(f" Simulation initialized (horizon={horizon}yr)")
# ----------------------------------------------------------
# Step 2: Run the HermesAgentLoop
# ----------------------------------------------------------
tools, valid_names = self._resolve_tools_for_group()
messages: List[Dict[str, Any]] = [
{"role": "system", "content": YC_BENCH_SYSTEM_PROMPT},
{"role": "user", "content": self.format_prompt(eval_item)},
]
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=run_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
# ----------------------------------------------------------
# Step 3: Read final score from the simulation DB
# ----------------------------------------------------------
score_data = _read_final_score(db_path)
final_funds = score_data["final_funds_cents"]
survived = score_data["survived"]
terminal_reason = score_data["terminal_reason"]
composite = _compute_composite_score(
final_funds_cents=final_funds,
survived=survived,
survival_weight=self.config.survival_weight,
funds_weight=self.config.funds_weight,
)
elapsed = time.time() - run_start
status = "SURVIVED" if survived else "BANKRUPT"
if final_funds >= 0:
funds_str = f"${final_funds / 100:,.0f}"
else:
funds_str = f"-${abs(final_funds) / 100:,.0f}"
tqdm.write(
f" [{status}] preset={preset!r} seed={seed} "
f"funds={funds_str} score={composite:.3f} "
f"turns={result.turns_used} ({elapsed:.0f}s)"
)
out = {
"preset": preset,
"seed": seed,
"survived": survived,
"final_funds_cents": final_funds,
"final_funds_usd": final_funds / 100,
"terminal_reason": terminal_reason,
"composite_score": composite,
"turns_used": result.turns_used,
"finished_naturally": result.finished_naturally,
"elapsed_seconds": elapsed,
"db_path": db_path,
"messages": result.messages,
}
self._save_result(out)
return out
except Exception as e:
elapsed = time.time() - run_start
logger.error("Run %s failed: %s", run_key, e, exc_info=True)
tqdm.write(
f" [ERROR] preset={preset!r} seed={seed}: {e} ({elapsed:.0f}s)"
)
out = {
"preset": preset,
"seed": seed,
"survived": False,
"final_funds_cents": 0,
"final_funds_usd": 0.0,
"terminal_reason": f"error: {e}",
"composite_score": 0.0,
"turns_used": 0,
"error": str(e),
"elapsed_seconds": elapsed,
}
self._save_result(out)
return out
# =========================================================================
# Evaluate
# =========================================================================
async def _run_with_timeout(self, item: Dict[str, Any]) -> Dict:
"""Wrap a single rollout with a wall-clock timeout."""
preset = item["preset"]
seed = item["seed"]
try:
return await asyncio.wait_for(
self.rollout_and_score_eval(item),
timeout=self.config.run_timeout,
)
except asyncio.TimeoutError:
from tqdm import tqdm
tqdm.write(
f" [TIMEOUT] preset={preset!r} seed={seed} "
f"(exceeded {self.config.run_timeout}s)"
)
out = {
"preset": preset,
"seed": seed,
"survived": False,
"final_funds_cents": 0,
"final_funds_usd": 0.0,
"terminal_reason": f"timeout ({self.config.run_timeout}s)",
"composite_score": 0.0,
"turns_used": 0,
"error": "timeout",
}
self._save_result(out)
return out
async def evaluate(self, *args, **kwargs) -> None:
"""
Run YC-Bench evaluation over all (preset, seed) combinations.
Runs sequentially -- each run is 100-500 turns, parallelising would
be prohibitively expensive and cause env var conflicts.
"""
start_time = time.time()
from tqdm import tqdm
# --- tqdm-compatible logging handler (TB2 pattern) ---
class _TqdmHandler(logging.Handler):
def emit(self, record):
try:
tqdm.write(self.format(record))
except Exception:
self.handleError(record)
root = logging.getLogger()
handler = _TqdmHandler()
handler.setFormatter(
logging.Formatter("%(levelname)s %(name)s: %(message)s")
)
root.handlers = [handler]
for noisy in ("httpx", "openai"):
logging.getLogger(noisy).setLevel(logging.WARNING)
# --- Print config summary ---
print(f"\n{'='*60}")
print("Starting YC-Bench Evaluation")
print(f"{'='*60}")
print(f" Presets: {self.config.presets}")
print(f" Seeds: {self.config.seeds}")
print(f" Total runs: {len(self.all_eval_items)}")
print(f" Max turns/run: {self.config.max_agent_turns}")
print(f" Run timeout: {self.config.run_timeout}s")
print(f"{'='*60}\n")
results = []
pbar = tqdm(
total=len(self.all_eval_items), desc="YC-Bench", dynamic_ncols=True
)
try:
for item in self.all_eval_items:
result = await self._run_with_timeout(item)
results.append(result)
survived_count = sum(1 for r in results if r.get("survived"))
pbar.set_postfix_str(
f"survived={survived_count}/{len(results)}"
)
pbar.update(1)
except (KeyboardInterrupt, asyncio.CancelledError):
tqdm.write("\n[INTERRUPTED] Stopping evaluation...")
pbar.close()
try:
from tools.terminal_tool import cleanup_all_environments
cleanup_all_environments()
except Exception:
pass
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
self._streaming_file.close()
return
pbar.close()
end_time = time.time()
# --- Compute metrics ---
valid = [r for r in results if r is not None]
if not valid:
print("Warning: No valid results.")
return
total = len(valid)
survived_total = sum(1 for r in valid if r.get("survived"))
survival_rate = survived_total / total if total else 0.0
avg_score = (
sum(r.get("composite_score", 0) for r in valid) / total
if total
else 0.0
)
preset_results: Dict[str, List[Dict]] = defaultdict(list)
for r in valid:
preset_results[r["preset"]].append(r)
eval_metrics = {
"eval/survival_rate": survival_rate,
"eval/avg_composite_score": avg_score,
"eval/total_runs": total,
"eval/survived_runs": survived_total,
"eval/evaluation_time_seconds": end_time - start_time,
}
for preset, items in sorted(preset_results.items()):
ps = sum(1 for r in items if r.get("survived"))
pt = len(items)
pa = (
sum(r.get("composite_score", 0) for r in items) / pt
if pt
else 0
)
key = preset.replace("-", "_")
eval_metrics[f"eval/survival_rate_{key}"] = ps / pt if pt else 0
eval_metrics[f"eval/avg_score_{key}"] = pa
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
# --- Print summary ---
print(f"\n{'='*60}")
print("YC-Bench Evaluation Results")
print(f"{'='*60}")
print(
f"Overall survival rate: {survival_rate:.1%} "
f"({survived_total}/{total})"
)
print(f"Average composite score: {avg_score:.4f}")
print(f"Evaluation time: {end_time - start_time:.1f}s")
print("\nPer-preset breakdown:")
for preset, items in sorted(preset_results.items()):
ps = sum(1 for r in items if r.get("survived"))
pt = len(items)
pa = (
sum(r.get("composite_score", 0) for r in items) / pt
if pt
else 0
)
print(f" {preset}: {ps}/{pt} survived avg_score={pa:.4f}")
for r in items:
status = "SURVIVED" if r.get("survived") else "BANKRUPT"
funds = r.get("final_funds_usd", 0)
print(
f" seed={r['seed']} [{status}] "
f"${funds:,.0f} "
f"score={r.get('composite_score', 0):.3f}"
)
print(f"{'='*60}\n")
# --- Log results ---
samples = [
{k: v for k, v in r.items() if k != "messages"} for r in valid
]
try:
await self.evaluate_log(
metrics=eval_metrics,
samples=samples,
start_time=start_time,
end_time=end_time,
generation_parameters={
"temperature": self.config.agent_temperature,
"max_tokens": self.config.max_token_length,
"max_agent_turns": self.config.max_agent_turns,
},
)
except Exception as e:
print(f"Error logging results: {e}")
# --- Cleanup (TB2 pattern) ---
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
self._streaming_file.close()
print(f"Results saved to: {self._streaming_path}")
try:
from tools.terminal_tool import cleanup_all_environments
cleanup_all_environments()
except Exception:
pass
try:
from environments.agent_loop import _tool_executor
_tool_executor.shutdown(wait=False, cancel_futures=True)
except Exception:
pass
# =========================================================================
# Wandb logging
# =========================================================================
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log YC-Bench-specific metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
for k, v in self.eval_metrics:
wandb_metrics[k] = v
self.eval_metrics = []
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
YCBenchEvalEnv.cli()

View File

@@ -114,8 +114,8 @@ class HermesAgentEnvConfig(BaseEnvConfig):
# --- Terminal backend ---
terminal_backend: str = Field(
default="local",
description="Terminal backend: 'local', 'docker', 'modal', 'ssh', 'singularity'. "
"Modal recommended for production RL (cloud isolation per rollout).",
description="Terminal backend: 'local', 'docker', 'modal', 'daytona', 'ssh', 'singularity'. "
"Modal or Daytona recommended for production RL (cloud isolation per rollout).",
)
terminal_timeout: int = Field(
default=120,

View File

@@ -35,7 +35,8 @@ class DeepSeekV31ToolCallParser(ToolCallParser):
# Regex captures: function_name, function_arguments
PATTERN = re.compile(
r"<tool▁call▁begin>(?P<function_name>.*?)<tool▁sep>(?P<function_arguments>.*?)<tool▁call▁end>"
r"<tool▁call▁begin>(?P<function_name>.*?)<tool▁sep>(?P<function_arguments>.*?)<tool▁call▁end>",
re.DOTALL,
)
def parse(self, text: str) -> ParseResult:

View File

@@ -38,7 +38,8 @@ class DeepSeekV3ToolCallParser(ToolCallParser):
# Regex captures: type, function_name, function_arguments
PATTERN = re.compile(
r"<tool▁call▁begin>(?P<type>.*)<tool▁sep>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<tool▁call▁end>"
r"<tool▁call▁begin>(?P<type>.*)<tool▁sep>(?P<function_name>.*)\n```json\n(?P<function_arguments>.*)\n```<tool▁call▁end>",
re.DOTALL,
)
def parse(self, text: str) -> ParseResult:

View File

@@ -44,7 +44,7 @@ _tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
"""
Run a tool call in a thread pool executor so backends that use asyncio.run()
internally (modal, docker) get a clean event loop.
internally (modal, docker, daytona) get a clean event loop.
If we're already in an async context, executes handle_function_call() in a
disposable worker thread and blocks for the result.
@@ -95,7 +95,7 @@ class ToolContext:
backend = os.getenv("TERMINAL_ENV", "local")
logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
# Run via thread helper so modal/docker backends' asyncio.run() doesn't deadlock
# Run via thread helper so modal/docker/daytona backends' asyncio.run() doesn't deadlock
result = _run_tool_in_thread(
"terminal",
{"command": command, "timeout": timeout},

View File

@@ -8,10 +8,13 @@ Uses python-telegram-bot library for:
"""
import asyncio
import logging
import os
import re
from typing import Dict, List, Optional, Any
logger = logging.getLogger(__name__)
try:
from telegram import Update, Bot, Message
from telegram.ext import (
@@ -73,6 +76,19 @@ def _escape_mdv2(text: str) -> str:
return _MDV2_ESCAPE_RE.sub(r'\\\1', text)
def _strip_mdv2(text: str) -> str:
"""Strip MarkdownV2 escape backslashes to produce clean plain text.
Also removes MarkdownV2 bold markers (*text* -> text) so the fallback
doesn't show stray asterisks from header/bold conversion.
"""
# Remove escape backslashes before special characters
cleaned = re.sub(r'\\([_*\[\]()~`>#\+\-=|{}.!\\])', r'\1', text)
# Remove MarkdownV2 bold markers that format_message converted from **bold**
cleaned = re.sub(r'\*([^*]+)\*', r'\1', cleaned)
return cleaned
class TelegramAdapter(BasePlatformAdapter):
"""
Telegram bot adapter.
@@ -199,9 +215,13 @@ class TelegramAdapter(BasePlatformAdapter):
except Exception as md_error:
# Markdown parsing failed, try plain text
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
# Strip MDV2 escape backslashes so the user doesn't
# see raw backslashes littered through the message.
plain_chunk = _strip_mdv2(chunk)
msg = await self._bot.send_message(
chat_id=int(chat_id),
text=chunk,
text=plain_chunk,
parse_mode=None, # Plain text
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
message_thread_id=int(thread_id) if thread_id else None,

View File

@@ -28,6 +28,41 @@ from typing import Dict, List, Optional, Any
logger = logging.getLogger(__name__)
def _kill_port_process(port: int) -> None:
"""Kill any process listening on the given TCP port."""
try:
if _IS_WINDOWS:
# Use netstat to find the PID bound to this port, then taskkill
result = subprocess.run(
["netstat", "-ano", "-p", "TCP"],
capture_output=True, text=True, timeout=5,
)
for line in result.stdout.splitlines():
parts = line.split()
if len(parts) >= 5 and parts[3] == "LISTENING":
local_addr = parts[1]
if local_addr.endswith(f":{port}"):
try:
subprocess.run(
["taskkill", "/PID", parts[4], "/F"],
capture_output=True, timeout=5,
)
except subprocess.SubprocessError:
pass
else:
result = subprocess.run(
["fuser", f"{port}/tcp"],
capture_output=True, timeout=5,
)
if result.returncode == 0:
subprocess.run(
["fuser", "-k", f"{port}/tcp"],
capture_output=True, timeout=5,
)
except Exception:
pass
import sys
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
@@ -145,21 +180,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
self._session_path.mkdir(parents=True, exist_ok=True)
# Kill any orphaned bridge from a previous gateway run
try:
result = subprocess.run(
["fuser", f"{self._bridge_port}/tcp"],
capture_output=True, timeout=5,
)
if result.returncode == 0:
# Port is in use — kill the process
subprocess.run(
["fuser", "-k", f"{self._bridge_port}/tcp"],
capture_output=True, timeout=5,
)
import time
time.sleep(2)
except Exception:
pass
_kill_port_process(self._bridge_port)
import time
time.sleep(1)
# Start the bridge process in its own process group.
# Route output to a log file so QR codes, errors, and reconnection
@@ -293,13 +316,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
print(f"[{self.name}] Error stopping bridge: {e}")
# Also kill any orphaned bridge processes on our port
try:
subprocess.run(
["fuser", "-k", f"{self._bridge_port}/tcp"],
capture_output=True, timeout=5,
)
except Exception:
pass
_kill_port_process(self._bridge_port)
self._running = False
self._bridge_process = None

View File

@@ -66,6 +66,7 @@ if _config_path.exists():
"docker_image": "TERMINAL_DOCKER_IMAGE",
"singularity_image": "TERMINAL_SINGULARITY_IMAGE",
"modal_image": "TERMINAL_MODAL_IMAGE",
"daytona_image": "TERMINAL_DAYTONA_IMAGE",
"ssh_host": "TERMINAL_SSH_HOST",
"ssh_user": "TERMINAL_SSH_USER",
"ssh_port": "TERMINAL_SSH_PORT",
@@ -92,6 +93,11 @@ if _config_path.exists():
if _agent_cfg and isinstance(_agent_cfg, dict):
if "max_turns" in _agent_cfg:
os.environ["HERMES_MAX_ITERATIONS"] = str(_agent_cfg["max_turns"])
# Timezone: bridge config.yaml → HERMES_TIMEZONE env var.
# HERMES_TIMEZONE from .env takes precedence (already in os.environ).
_tz_cfg = _cfg.get("timezone", "")
if _tz_cfg and isinstance(_tz_cfg, str) and "HERMES_TIMEZONE" not in os.environ:
os.environ["HERMES_TIMEZONE"] = _tz_cfg.strip()
except Exception:
pass # Non-fatal; gateway can still run with .env values
@@ -658,7 +664,7 @@ class GatewayRunner:
# Emit command:* hook for any recognized slash command
_known_commands = {"new", "reset", "help", "status", "stop", "model",
"personality", "retry", "undo", "sethome", "set-home",
"compress", "usage", "reload-mcp", "update"}
"compress", "usage", "insights", "reload-mcp", "update"}
if command and command in _known_commands:
await self.hooks.emit(f"command:{command}", {
"platform": source.platform.value if source.platform else "",
@@ -700,6 +706,9 @@ class GatewayRunner:
if command == "usage":
return await self._handle_usage_command(event)
if command == "insights":
return await self._handle_insights_command(event)
if command == "reload-mcp":
return await self._handle_reload_mcp_command(event)
@@ -1103,6 +1112,7 @@ class GatewayRunner:
"`/sethome` — Set this chat as the home channel",
"`/compress` — Compress conversation context",
"`/usage` — Show token usage for this session",
"`/insights [days]` — Show usage insights and analytics",
"`/reload-mcp` — Reload MCP servers from config",
"`/update` — Update Hermes Agent to the latest version",
"`/help` — Show this message",
@@ -1253,8 +1263,7 @@ class GatewayRunner:
)
# Let the normal message handler process it
await self._handle_message(retry_event)
return None # Response sent through normal flow
return await self._handle_message(retry_event)
async def _handle_undo_command(self, event: MessageEvent) -> str:
"""Handle /undo command - remove the last user/assistant exchange."""
@@ -1397,6 +1406,53 @@ class GatewayRunner:
)
return "No usage data available for this session."
async def _handle_insights_command(self, event: MessageEvent) -> str:
"""Handle /insights command -- show usage insights and analytics."""
import asyncio as _asyncio
args = event.get_command_args().strip()
days = 30
source = None
# Parse simple args: /insights 7 or /insights --days 7
if args:
parts = args.split()
i = 0
while i < len(parts):
if parts[i] == "--days" and i + 1 < len(parts):
try:
days = int(parts[i + 1])
except ValueError:
return f"Invalid --days value: {parts[i + 1]}"
i += 2
elif parts[i] == "--source" and i + 1 < len(parts):
source = parts[i + 1]
i += 2
elif parts[i].isdigit():
days = int(parts[i])
i += 1
else:
i += 1
try:
from hermes_state import SessionDB
from agent.insights import InsightsEngine
loop = _asyncio.get_event_loop()
def _run_insights():
db = SessionDB()
engine = InsightsEngine(db)
report = engine.generate(days=days, source=source)
result = engine.format_gateway(report)
db.close()
return result
return await loop.run_in_executor(None, _run_insights)
except Exception as e:
logger.error("Insights command error: %s", e, exc_info=True)
return f"Error generating insights: {e}"
async def _handle_reload_mcp_command(self, event: MessageEvent) -> str:
"""Handle /reload-mcp command -- disconnect and reconnect all MCP servers."""
loop = asyncio.get_event_loop()
@@ -2041,7 +2097,7 @@ class GatewayRunner:
os.environ["HERMES_SESSION_KEY"] = session_key or ""
# Read from env var or use default (same as CLI)
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "60"))
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
# Map platform enum to the platform hint key the agent understands.
# Platform.LOCAL ("local") maps to "cli"; others pass through as-is.
@@ -2389,6 +2445,34 @@ async def start_gateway(config: Optional[GatewayConfig] = None) -> bool:
Returns True if the gateway ran successfully, False if it failed to start.
A False return causes a non-zero exit code so systemd can auto-restart.
"""
# ── Duplicate-instance guard ──────────────────────────────────────
# Prevent two gateways from running under the same HERMES_HOME.
# The PID file is scoped to HERMES_HOME, so future multi-profile
# setups (each profile using a distinct HERMES_HOME) will naturally
# allow concurrent instances without tripping this guard.
from gateway.status import get_running_pid
existing_pid = get_running_pid()
if existing_pid is not None and existing_pid != os.getpid():
hermes_home = os.getenv("HERMES_HOME", "~/.hermes")
logger.error(
"Another gateway instance is already running (PID %d, HERMES_HOME=%s). "
"Use 'hermes gateway restart' to replace it, or 'hermes gateway stop' first.",
existing_pid, hermes_home,
)
print(
f"\n❌ Gateway already running (PID {existing_pid}).\n"
f" Use 'hermes gateway restart' to replace it,\n"
f" or 'hermes gateway stop' to kill it first.\n"
)
return False
# Sync bundled skills on gateway start (fast -- skips unchanged)
try:
from tools.skills_sync import sync_skills
sync_skills(quiet=True)
except Exception:
pass
# Configure rotating file log so gateway output is persisted for debugging
log_dir = _hermes_home / 'logs'
log_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -3,37 +3,59 @@ Gateway runtime status helpers.
Provides PID-file based detection of whether the gateway daemon is running,
used by send_message's check_fn to gate availability in the CLI.
The PID file lives at ``{HERMES_HOME}/gateway.pid``. HERMES_HOME defaults to
``~/.hermes`` but can be overridden via the environment variable. This means
separate HERMES_HOME directories naturally get separate PID files — a property
that will be useful when we add named profiles (multiple agents running
concurrently under distinct configurations).
"""
import os
from pathlib import Path
from typing import Optional
_PID_FILE = Path.home() / ".hermes" / "gateway.pid"
def _get_pid_path() -> Path:
"""Return the path to the gateway PID file, respecting HERMES_HOME."""
home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
return home / "gateway.pid"
def write_pid_file() -> None:
"""Write the current process PID to the gateway PID file."""
_PID_FILE.parent.mkdir(parents=True, exist_ok=True)
_PID_FILE.write_text(str(os.getpid()))
pid_path = _get_pid_path()
pid_path.parent.mkdir(parents=True, exist_ok=True)
pid_path.write_text(str(os.getpid()))
def remove_pid_file() -> None:
"""Remove the gateway PID file if it exists."""
try:
_PID_FILE.unlink(missing_ok=True)
_get_pid_path().unlink(missing_ok=True)
except Exception:
pass
def get_running_pid() -> Optional[int]:
"""Return the PID of a running gateway instance, or ``None``.
Checks the PID file and verifies the process is actually alive.
Cleans up stale PID files automatically.
"""
pid_path = _get_pid_path()
if not pid_path.exists():
return None
try:
pid = int(pid_path.read_text().strip())
os.kill(pid, 0) # signal 0 = existence check, no actual signal sent
return pid
except (ValueError, ProcessLookupError, PermissionError):
# Stale PID file — process is gone
remove_pid_file()
return None
def is_gateway_running() -> bool:
"""Check if the gateway daemon is currently running."""
if not _PID_FILE.exists():
return False
try:
pid = int(_PID_FILE.read_text().strip())
os.kill(pid, 0) # signal 0 = existence check, no actual signal sent
return True
except (ValueError, ProcessLookupError, PermissionError):
# Stale PID file -- process is gone
remove_pid_file()
return False
return get_running_pid() is not None

View File

@@ -72,15 +72,19 @@ CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
@dataclass
class ProviderConfig:
"""Describes a known OAuth provider."""
"""Describes a known inference provider."""
id: str
name: str
auth_type: str # "oauth_device_code" or "api_key"
auth_type: str # "oauth_device_code", "oauth_external", or "api_key"
portal_base_url: str = ""
inference_base_url: str = ""
client_id: str = ""
scope: str = ""
extra: Dict[str, Any] = field(default_factory=dict)
# For API-key providers: env vars to check (in priority order)
api_key_env_vars: tuple = ()
# Optional env var for base URL override
base_url_env_var: str = ""
PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
@@ -99,6 +103,38 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
auth_type="oauth_external",
inference_base_url=DEFAULT_CODEX_BASE_URL,
),
"zai": ProviderConfig(
id="zai",
name="Z.AI / GLM",
auth_type="api_key",
inference_base_url="https://api.z.ai/api/paas/v4",
api_key_env_vars=("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
base_url_env_var="GLM_BASE_URL",
),
"kimi-coding": ProviderConfig(
id="kimi-coding",
name="Kimi / Moonshot",
auth_type="api_key",
inference_base_url="https://api.moonshot.ai/v1",
api_key_env_vars=("KIMI_API_KEY",),
base_url_env_var="KIMI_BASE_URL",
),
"minimax": ProviderConfig(
id="minimax",
name="MiniMax",
auth_type="api_key",
inference_base_url="https://api.minimax.io/v1",
api_key_env_vars=("MINIMAX_API_KEY",),
base_url_env_var="MINIMAX_BASE_URL",
),
"minimax-cn": ProviderConfig(
id="minimax-cn",
name="MiniMax (China)",
auth_type="api_key",
inference_base_url="https://api.minimaxi.com/v1",
api_key_env_vars=("MINIMAX_CN_API_KEY",),
base_url_env_var="MINIMAX_CN_BASE_URL",
),
}
@@ -355,10 +391,19 @@ def resolve_provider(
1. active_provider in auth.json with valid credentials
2. Explicit CLI api_key/base_url -> "openrouter"
3. OPENAI_API_KEY or OPENROUTER_API_KEY env vars -> "openrouter"
4. Fallback: "openrouter"
4. Provider-specific API keys (GLM, Kimi, MiniMax) -> that provider
5. Fallback: "openrouter"
"""
normalized = (requested or "auto").strip().lower()
# Normalize provider aliases
_PROVIDER_ALIASES = {
"glm": "zai", "z-ai": "zai", "z.ai": "zai", "zhipu": "zai",
"kimi": "kimi-coding", "moonshot": "kimi-coding",
"minimax-china": "minimax-cn", "minimax_cn": "minimax-cn",
}
normalized = _PROVIDER_ALIASES.get(normalized, normalized)
if normalized in {"openrouter", "custom"}:
return "openrouter"
if normalized in PROVIDER_REGISTRY:
@@ -387,6 +432,14 @@ def resolve_provider(
if os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"):
return "openrouter"
# Auto-detect API-key providers by checking their env vars
for pid, pconfig in PROVIDER_REGISTRY.items():
if pconfig.auth_type != "api_key":
continue
for env_var in pconfig.api_key_env_vars:
if os.getenv(env_var, "").strip():
return pid
return "openrouter"
@@ -1230,6 +1283,37 @@ def get_codex_auth_status() -> Dict[str, Any]:
}
def get_api_key_provider_status(provider_id: str) -> Dict[str, Any]:
"""Status snapshot for API-key providers (z.ai, Kimi, MiniMax)."""
pconfig = PROVIDER_REGISTRY.get(provider_id)
if not pconfig or pconfig.auth_type != "api_key":
return {"configured": False}
api_key = ""
key_source = ""
for env_var in pconfig.api_key_env_vars:
val = os.getenv(env_var, "").strip()
if val:
api_key = val
key_source = env_var
break
base_url = pconfig.inference_base_url
if pconfig.base_url_env_var:
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
if env_url:
base_url = env_url
return {
"configured": bool(api_key),
"provider": provider_id,
"name": pconfig.name,
"key_source": key_source,
"base_url": base_url,
"logged_in": bool(api_key), # compat with OAuth status shape
}
def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
"""Generic auth status dispatcher."""
target = provider_id or get_active_provider()
@@ -1237,9 +1321,49 @@ def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
return get_nous_auth_status()
if target == "openai-codex":
return get_codex_auth_status()
# API-key providers
pconfig = PROVIDER_REGISTRY.get(target)
if pconfig and pconfig.auth_type == "api_key":
return get_api_key_provider_status(target)
return {"logged_in": False}
def resolve_api_key_provider_credentials(provider_id: str) -> Dict[str, Any]:
"""Resolve API key and base URL for an API-key provider.
Returns dict with: provider, api_key, base_url, source.
"""
pconfig = PROVIDER_REGISTRY.get(provider_id)
if not pconfig or pconfig.auth_type != "api_key":
raise AuthError(
f"Provider '{provider_id}' is not an API-key provider.",
provider=provider_id,
code="invalid_provider",
)
api_key = ""
key_source = ""
for env_var in pconfig.api_key_env_vars:
val = os.getenv(env_var, "").strip()
if val:
api_key = val
key_source = env_var
break
base_url = pconfig.inference_base_url
if pconfig.base_url_env_var:
env_url = os.getenv(pconfig.base_url_env_var, "").strip()
if env_url:
base_url = env_url.rstrip("/")
return {
"provider": provider_id,
"api_key": api_key,
"base_url": base_url.rstrip("/"),
"source": key_source or "default",
}
# =============================================================================
# External credential detection
# =============================================================================

View File

@@ -1,10 +1,15 @@
"""Welcome banner, ASCII art, and skills summary for the CLI.
"""Welcome banner, ASCII art, skills summary, and update check for the CLI.
Pure display functions with no HermesCLI state dependency.
"""
import json
import logging
import os
import subprocess
import time
from pathlib import Path
from typing import Dict, List, Any
from typing import Dict, List, Any, Optional
from rich.console import Console
from rich.panel import Panel
@@ -13,6 +18,8 @@ from rich.table import Table
from prompt_toolkit import print_formatted_text as _pt_print
from prompt_toolkit.formatted_text import ANSI as _PT_ANSI
logger = logging.getLogger(__name__)
# =========================================================================
# ANSI building blocks for conversation display
@@ -95,15 +102,93 @@ def get_available_skills() -> Dict[str, List[str]]:
return skills_by_category
# =========================================================================
# Update check
# =========================================================================
# Cache update check results for 6 hours to avoid repeated git fetches
_UPDATE_CHECK_CACHE_SECONDS = 6 * 3600
def check_for_updates() -> Optional[int]:
"""Check how many commits behind origin/main the local repo is.
Does a ``git fetch`` at most once every 6 hours (cached to
``~/.hermes/.update_check``). Returns the number of commits behind,
or ``None`` if the check fails or isn't applicable.
"""
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
repo_dir = hermes_home / "hermes-agent"
cache_file = hermes_home / ".update_check"
# Must be a git repo
if not (repo_dir / ".git").exists():
return None
# Read cache
now = time.time()
try:
if cache_file.exists():
cached = json.loads(cache_file.read_text())
if now - cached.get("ts", 0) < _UPDATE_CHECK_CACHE_SECONDS:
return cached.get("behind")
except Exception:
pass
# Fetch latest refs (fast — only downloads ref metadata, no files)
try:
subprocess.run(
["git", "fetch", "origin", "--quiet"],
capture_output=True, timeout=10,
cwd=str(repo_dir),
)
except Exception:
pass # Offline or timeout — use stale refs, that's fine
# Count commits behind
try:
result = subprocess.run(
["git", "rev-list", "--count", "HEAD..origin/main"],
capture_output=True, text=True, timeout=5,
cwd=str(repo_dir),
)
if result.returncode == 0:
behind = int(result.stdout.strip())
else:
behind = None
except Exception:
behind = None
# Write cache
try:
cache_file.write_text(json.dumps({"ts": now, "behind": behind}))
except Exception:
pass
return behind
# =========================================================================
# Welcome banner
# =========================================================================
def _format_context_length(tokens: int) -> str:
"""Format a token count for display (e.g. 128000 → '128K', 1048576 → '1M')."""
if tokens >= 1_000_000:
val = tokens / 1_000_000
return f"{val:g}M"
elif tokens >= 1_000:
val = tokens / 1_000
return f"{val:g}K"
return str(tokens)
def build_welcome_banner(console: Console, model: str, cwd: str,
tools: List[dict] = None,
enabled_toolsets: List[str] = None,
session_id: str = None,
get_toolset_for_tool=None):
get_toolset_for_tool=None,
context_length: int = None):
"""Build and print a welcome banner with caduceus on left and info on right.
Args:
@@ -114,6 +199,7 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
enabled_toolsets: List of enabled toolset names.
session_id: Session identifier.
get_toolset_for_tool: Callable to map tool name -> toolset name.
context_length: Model's context window size in tokens.
"""
from model_tools import check_tool_availability, TOOLSET_REQUIREMENTS
if get_toolset_for_tool is None:
@@ -135,7 +221,8 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
model_short = model.split("/")[-1] if "/" in model else model
if len(model_short) > 28:
model_short = model_short[:25] + "..."
left_lines.append(f"[#FFBF00]{model_short}[/] [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
ctx_str = f" [dim #B8860B]·[/] [dim #B8860B]{_format_context_length(context_length)} context[/]" if context_length else ""
left_lines.append(f"[#FFBF00]{model_short}[/]{ctx_str} [dim #B8860B]·[/] [dim #B8860B]Nous Research[/]")
left_lines.append(f"[dim #B8860B]{cwd}[/]")
if session_id:
left_lines.append(f"[dim #8B8682]Session: {session_id}[/]")
@@ -245,6 +332,18 @@ def build_welcome_banner(console: Console, model: str, cwd: str,
summary_parts.append("/help for commands")
right_lines.append(f"[dim #B8860B]{' · '.join(summary_parts)}[/]")
# Update check — show if behind origin/main
try:
behind = check_for_updates()
if behind and behind > 0:
commits_word = "commit" if behind == 1 else "commits"
right_lines.append(
f"[bold yellow]⚠ {behind} {commits_word} behind[/]"
f"[dim yellow] — run [bold]hermes update[/bold] to update[/]"
)
except Exception:
pass # Never break the banner over an update check
right_content = "\n".join(right_lines)
layout_table.add_row(left_content, right_content)

352
hermes_cli/clipboard.py Normal file
View File

@@ -0,0 +1,352 @@
"""Clipboard image extraction for macOS, Linux, and WSL2.
Provides a single function `save_clipboard_image(dest)` that checks the
system clipboard for image data, saves it to *dest* as PNG, and returns
True on success. No external Python dependencies — uses only OS-level
CLI tools that ship with the platform (or are commonly installed).
Platform support:
macOS — osascript (always available), pngpaste (if installed)
WSL2 — powershell.exe via .NET System.Windows.Forms.Clipboard
Linux — wl-paste (Wayland), xclip (X11)
"""
import base64
import logging
import os
import subprocess
import sys
from pathlib import Path
logger = logging.getLogger(__name__)
# Cache WSL detection (checked once per process)
_wsl_detected: bool | None = None
def save_clipboard_image(dest: Path) -> bool:
"""Extract an image from the system clipboard and save it as PNG.
Returns True if an image was found and saved, False otherwise.
"""
dest.parent.mkdir(parents=True, exist_ok=True)
if sys.platform == "darwin":
return _macos_save(dest)
return _linux_save(dest)
def has_clipboard_image() -> bool:
"""Quick check: does the clipboard currently contain an image?
Lighter than save_clipboard_image — doesn't extract or write anything.
"""
if sys.platform == "darwin":
return _macos_has_image()
if _is_wsl():
return _wsl_has_image()
if os.environ.get("WAYLAND_DISPLAY"):
return _wayland_has_image()
return _xclip_has_image()
# ── macOS ────────────────────────────────────────────────────────────────
def _macos_save(dest: Path) -> bool:
"""Try pngpaste first (fast, handles more formats), fall back to osascript."""
return _macos_pngpaste(dest) or _macos_osascript(dest)
def _macos_has_image() -> bool:
"""Check if macOS clipboard contains image data."""
try:
info = subprocess.run(
["osascript", "-e", "clipboard info"],
capture_output=True, text=True, timeout=3,
)
return "«class PNGf»" in info.stdout or "«class TIFF»" in info.stdout
except Exception:
return False
def _macos_pngpaste(dest: Path) -> bool:
"""Use pngpaste (brew install pngpaste) — fastest, cleanest."""
try:
r = subprocess.run(
["pngpaste", str(dest)],
capture_output=True, timeout=3,
)
if r.returncode == 0 and dest.exists() and dest.stat().st_size > 0:
return True
except FileNotFoundError:
pass # pngpaste not installed
except Exception as e:
logger.debug("pngpaste failed: %s", e)
return False
def _macos_osascript(dest: Path) -> bool:
"""Use osascript to extract PNG data from clipboard (always available)."""
if not _macos_has_image():
return False
# Extract as PNG
script = (
'try\n'
' set imgData to the clipboard as «class PNGf»\n'
f' set f to open for access POSIX file "{dest}" with write permission\n'
' write imgData to f\n'
' close access f\n'
'on error\n'
' return "fail"\n'
'end try\n'
)
try:
r = subprocess.run(
["osascript", "-e", script],
capture_output=True, text=True, timeout=5,
)
if r.returncode == 0 and "fail" not in r.stdout and dest.exists() and dest.stat().st_size > 0:
return True
except Exception as e:
logger.debug("osascript clipboard extract failed: %s", e)
return False
# ── Linux ────────────────────────────────────────────────────────────────
def _is_wsl() -> bool:
"""Detect if running inside WSL (1 or 2)."""
global _wsl_detected
if _wsl_detected is not None:
return _wsl_detected
try:
with open("/proc/version", "r") as f:
_wsl_detected = "microsoft" in f.read().lower()
except Exception:
_wsl_detected = False
return _wsl_detected
def _linux_save(dest: Path) -> bool:
"""Try clipboard backends in priority order: WSL → Wayland → X11."""
if _is_wsl():
if _wsl_save(dest):
return True
# Fall through — WSLg might have wl-paste or xclip working
if os.environ.get("WAYLAND_DISPLAY"):
if _wayland_save(dest):
return True
return _xclip_save(dest)
# ── WSL2 (powershell.exe) ────────────────────────────────────────────────
# PowerShell script: get clipboard image as base64-encoded PNG on stdout.
# Using .NET System.Windows.Forms.Clipboard — always available on Windows.
_PS_CHECK_IMAGE = (
"Add-Type -AssemblyName System.Windows.Forms;"
"[System.Windows.Forms.Clipboard]::ContainsImage()"
)
_PS_EXTRACT_IMAGE = (
"Add-Type -AssemblyName System.Windows.Forms;"
"Add-Type -AssemblyName System.Drawing;"
"$img = [System.Windows.Forms.Clipboard]::GetImage();"
"if ($null -eq $img) { exit 1 }"
"$ms = New-Object System.IO.MemoryStream;"
"$img.Save($ms, [System.Drawing.Imaging.ImageFormat]::Png);"
"[System.Convert]::ToBase64String($ms.ToArray())"
)
def _wsl_has_image() -> bool:
"""Check if Windows clipboard has an image (via powershell.exe)."""
try:
r = subprocess.run(
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
_PS_CHECK_IMAGE],
capture_output=True, text=True, timeout=8,
)
return r.returncode == 0 and "True" in r.stdout
except FileNotFoundError:
logger.debug("powershell.exe not found — WSL clipboard unavailable")
except Exception as e:
logger.debug("WSL clipboard check failed: %s", e)
return False
def _wsl_save(dest: Path) -> bool:
"""Extract clipboard image via powershell.exe → base64 → decode to PNG."""
try:
r = subprocess.run(
["powershell.exe", "-NoProfile", "-NonInteractive", "-Command",
_PS_EXTRACT_IMAGE],
capture_output=True, text=True, timeout=15,
)
if r.returncode != 0:
return False
b64_data = r.stdout.strip()
if not b64_data:
return False
png_bytes = base64.b64decode(b64_data)
dest.write_bytes(png_bytes)
return dest.exists() and dest.stat().st_size > 0
except FileNotFoundError:
logger.debug("powershell.exe not found — WSL clipboard unavailable")
except Exception as e:
logger.debug("WSL clipboard extraction failed: %s", e)
dest.unlink(missing_ok=True)
return False
# ── Wayland (wl-paste) ──────────────────────────────────────────────────
def _wayland_has_image() -> bool:
"""Check if Wayland clipboard has image content."""
try:
r = subprocess.run(
["wl-paste", "--list-types"],
capture_output=True, text=True, timeout=3,
)
return r.returncode == 0 and any(
t.startswith("image/") for t in r.stdout.splitlines()
)
except FileNotFoundError:
logger.debug("wl-paste not installed — Wayland clipboard unavailable")
except Exception:
pass
return False
def _wayland_save(dest: Path) -> bool:
"""Use wl-paste to extract clipboard image (Wayland sessions)."""
try:
# Check available MIME types
types_r = subprocess.run(
["wl-paste", "--list-types"],
capture_output=True, text=True, timeout=3,
)
if types_r.returncode != 0:
return False
types = types_r.stdout.splitlines()
# Prefer PNG, fall back to other image formats
mime = None
for preferred in ("image/png", "image/jpeg", "image/bmp",
"image/gif", "image/webp"):
if preferred in types:
mime = preferred
break
if not mime:
return False
# Extract the image data
with open(dest, "wb") as f:
subprocess.run(
["wl-paste", "--type", mime],
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
)
if not dest.exists() or dest.stat().st_size == 0:
return False
# BMP needs conversion to PNG (common in WSLg where only BMP
# is bridged from Windows clipboard via RDP).
if mime == "image/bmp":
return _convert_to_png(dest)
return True
except FileNotFoundError:
logger.debug("wl-paste not installed — Wayland clipboard unavailable")
except Exception as e:
logger.debug("wl-paste clipboard extraction failed: %s", e)
dest.unlink(missing_ok=True)
return False
def _convert_to_png(path: Path) -> bool:
"""Convert an image file to PNG in-place (requires Pillow or ImageMagick)."""
# Try Pillow first (likely installed in the venv)
try:
from PIL import Image
img = Image.open(path)
img.save(path, "PNG")
return True
except ImportError:
pass
except Exception as e:
logger.debug("Pillow BMP→PNG conversion failed: %s", e)
# Fall back to ImageMagick convert
try:
tmp = path.with_suffix(".bmp")
path.rename(tmp)
r = subprocess.run(
["convert", str(tmp), "png:" + str(path)],
capture_output=True, timeout=5,
)
tmp.unlink(missing_ok=True)
if r.returncode == 0 and path.exists() and path.stat().st_size > 0:
return True
except FileNotFoundError:
logger.debug("ImageMagick not installed — cannot convert BMP to PNG")
except Exception as e:
logger.debug("ImageMagick BMP→PNG conversion failed: %s", e)
# Can't convert — BMP is still usable as-is for most APIs
return path.exists() and path.stat().st_size > 0
# ── X11 (xclip) ─────────────────────────────────────────────────────────
def _xclip_has_image() -> bool:
"""Check if X11 clipboard has image content."""
try:
r = subprocess.run(
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
capture_output=True, text=True, timeout=3,
)
return r.returncode == 0 and "image/png" in r.stdout
except FileNotFoundError:
pass
except Exception:
pass
return False
def _xclip_save(dest: Path) -> bool:
"""Use xclip to extract clipboard image (X11 sessions)."""
# Check if clipboard has image content
try:
targets = subprocess.run(
["xclip", "-selection", "clipboard", "-t", "TARGETS", "-o"],
capture_output=True, text=True, timeout=3,
)
if "image/png" not in targets.stdout:
return False
except FileNotFoundError:
logger.debug("xclip not installed — X11 clipboard image paste unavailable")
return False
except Exception:
return False
# Extract PNG data
try:
with open(dest, "wb") as f:
subprocess.run(
["xclip", "-selection", "clipboard", "-t", "image/png", "-o"],
stdout=f, stderr=subprocess.DEVNULL, timeout=5, check=True,
)
if dest.exists() and dest.stat().st_size > 0:
return True
except Exception as e:
logger.debug("xclip image extraction failed: %s", e)
dest.unlink(missing_ok=True)
return False

View File

@@ -28,6 +28,7 @@ COMMANDS = {
"/verbose": "Cycle tool progress display: off → new → all → verbose",
"/compress": "Manually compress conversation context (flush memories + summarize)",
"/usage": "Show token usage for the current session",
"/insights": "Show usage insights and analytics (last 30 days)",
"/quit": "Exit the CLI (also: /exit, /q)",
}

View File

@@ -71,7 +71,8 @@ DEFAULT_CONFIG = {
"docker_image": "nikolaik/python-nodejs:python3.11-nodejs20",
"singularity_image": "docker://nikolaik/python-nodejs:python3.11-nodejs20",
"modal_image": "nikolaik/python-nodejs:python3.11-nodejs20",
# Container resource limits (docker, singularity, modal — ignored for local/ssh)
"daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20",
# Container resource limits (docker, singularity, modal, daytona — ignored for local/ssh)
"container_cpu": 1,
"container_memory": 5120, # MB (default 5GB)
"container_disk": 51200, # MB (default 50GB)
@@ -86,6 +87,20 @@ DEFAULT_CONFIG = {
"enabled": True,
"threshold": 0.85,
"summary_model": "google/gemini-3-flash-preview",
"summary_provider": "auto",
},
# Auxiliary model overrides (advanced). By default Hermes auto-selects
# the provider and model for each side task. Set these to override.
"auxiliary": {
"vision": {
"provider": "auto", # auto | openrouter | nous | main
"model": "", # e.g. "google/gemini-2.5-flash", "gpt-4o"
},
"web_extract": {
"provider": "auto",
"model": "",
},
},
"display": {
@@ -140,9 +155,13 @@ DEFAULT_CONFIG = {
# (apiKey, workspace, peerName, sessions, enabled) comes from the global config.
"honcho": {},
# IANA timezone (e.g. "Asia/Kolkata", "America/New_York").
# Empty string means use server-local time.
"timezone": "",
# Permanently allowed dangerous command patterns (added via "always" approval)
"command_allowlist": [],
# Config schema version - bump this when adding new required fields
"_config_version": 5,
}
@@ -169,6 +188,86 @@ OPTIONAL_ENV_VARS = {
"category": "provider",
"advanced": True,
},
"GLM_API_KEY": {
"description": "Z.AI / GLM API key (also recognized as ZAI_API_KEY / Z_AI_API_KEY)",
"prompt": "Z.AI / GLM API key",
"url": "https://z.ai/",
"password": True,
"category": "provider",
"advanced": True,
},
"ZAI_API_KEY": {
"description": "Z.AI API key (alias for GLM_API_KEY)",
"prompt": "Z.AI API key",
"url": "https://z.ai/",
"password": True,
"category": "provider",
"advanced": True,
},
"Z_AI_API_KEY": {
"description": "Z.AI API key (alias for GLM_API_KEY)",
"prompt": "Z.AI API key",
"url": "https://z.ai/",
"password": True,
"category": "provider",
"advanced": True,
},
"GLM_BASE_URL": {
"description": "Z.AI / GLM base URL override",
"prompt": "Z.AI / GLM base URL (leave empty for default)",
"url": None,
"password": False,
"category": "provider",
"advanced": True,
},
"KIMI_API_KEY": {
"description": "Kimi / Moonshot API key",
"prompt": "Kimi API key",
"url": "https://platform.moonshot.cn/",
"password": True,
"category": "provider",
"advanced": True,
},
"KIMI_BASE_URL": {
"description": "Kimi / Moonshot base URL override",
"prompt": "Kimi base URL (leave empty for default)",
"url": None,
"password": False,
"category": "provider",
"advanced": True,
},
"MINIMAX_API_KEY": {
"description": "MiniMax API key (international)",
"prompt": "MiniMax API key",
"url": "https://www.minimax.io/",
"password": True,
"category": "provider",
"advanced": True,
},
"MINIMAX_BASE_URL": {
"description": "MiniMax base URL override",
"prompt": "MiniMax base URL (leave empty for default)",
"url": None,
"password": False,
"category": "provider",
"advanced": True,
},
"MINIMAX_CN_API_KEY": {
"description": "MiniMax API key (China endpoint)",
"prompt": "MiniMax (China) API key",
"url": "https://www.minimaxi.com/",
"password": True,
"category": "provider",
"advanced": True,
},
"MINIMAX_CN_BASE_URL": {
"description": "MiniMax (China) base URL override",
"prompt": "MiniMax (China) base URL (leave empty for default)",
"url": None,
"password": False,
"category": "provider",
"advanced": True,
},
# ── Tool API keys ──
"FIRECRAWL_API_KEY": {
@@ -179,8 +278,16 @@ OPTIONAL_ENV_VARS = {
"password": True,
"category": "tool",
},
"FIRECRAWL_API_URL": {
"description": "Firecrawl API URL for self-hosted instances (optional)",
"prompt": "Firecrawl API URL (leave empty for cloud)",
"url": None,
"password": False,
"category": "tool",
"advanced": True,
},
"BROWSERBASE_API_KEY": {
"description": "Browserbase API key for browser automation",
"description": "Browserbase API key for cloud browser (optional — local browser works without this)",
"prompt": "Browserbase API key",
"url": "https://browserbase.com/",
"tools": ["browser_navigate", "browser_click"],
@@ -188,7 +295,7 @@ OPTIONAL_ENV_VARS = {
"category": "tool",
},
"BROWSERBASE_PROJECT_ID": {
"description": "Browserbase project ID",
"description": "Browserbase project ID (optional — only needed for cloud browser)",
"prompt": "Browserbase project ID",
"url": "https://browserbase.com/",
"tools": ["browser_navigate", "browser_click"],
@@ -476,6 +583,22 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A
if not quiet:
print(f" ✓ Migrated tool progress to config.yaml: {display['tool_progress']}")
# ── Version 4 → 5: add timezone field ──
if current_ver < 5:
config = load_config()
if "timezone" not in config:
old_tz = os.getenv("HERMES_TIMEZONE", "")
if old_tz and old_tz.strip():
config["timezone"] = old_tz.strip()
results["config_added"].append(f"timezone={old_tz.strip()} (from HERMES_TIMEZONE)")
else:
config["timezone"] = ""
results["config_added"].append("timezone= (empty, uses server-local)")
save_config(config)
if not quiet:
tz_display = config["timezone"] or "(server-local)"
print(f" ✓ Added timezone to config.yaml: {tz_display}")
if current_ver < latest_ver and not quiet:
print(f"Config version: {current_ver}{latest_ver}")
@@ -753,12 +876,25 @@ def show_config():
print(f" Modal image: {terminal.get('modal_image', 'python:3.11')}")
modal_token = get_env_value('MODAL_TOKEN_ID')
print(f" Modal token: {'configured' if modal_token else '(not set)'}")
elif terminal.get('backend') == 'daytona':
print(f" Daytona image: {terminal.get('daytona_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}")
daytona_key = get_env_value('DAYTONA_API_KEY')
print(f" API key: {'configured' if daytona_key else '(not set)'}")
elif terminal.get('backend') == 'ssh':
ssh_host = get_env_value('TERMINAL_SSH_HOST')
ssh_user = get_env_value('TERMINAL_SSH_USER')
print(f" SSH host: {ssh_host or '(not set)'}")
print(f" SSH user: {ssh_user or '(not set)'}")
# Timezone
print()
print(color("◆ Timezone", Colors.CYAN, Colors.BOLD))
tz = config.get('timezone', '')
if tz:
print(f" Timezone: {tz}")
else:
print(f" Timezone: {color('(server-local)', Colors.DIM)}")
# Compression
print()
print(color("◆ Context Compression", Colors.CYAN, Colors.BOLD))
@@ -768,6 +904,31 @@ def show_config():
if enabled:
print(f" Threshold: {compression.get('threshold', 0.85) * 100:.0f}%")
print(f" Model: {compression.get('summary_model', 'google/gemini-3-flash-preview')}")
comp_provider = compression.get('summary_provider', 'auto')
if comp_provider != 'auto':
print(f" Provider: {comp_provider}")
# Auxiliary models
auxiliary = config.get('auxiliary', {})
aux_tasks = {
"Vision": auxiliary.get('vision', {}),
"Web extract": auxiliary.get('web_extract', {}),
}
has_overrides = any(
t.get('provider', 'auto') != 'auto' or t.get('model', '')
for t in aux_tasks.values()
)
if has_overrides:
print()
print(color("◆ Auxiliary Models (overrides)", Colors.CYAN, Colors.BOLD))
for label, task_cfg in aux_tasks.items():
prov = task_cfg.get('provider', 'auto')
mdl = task_cfg.get('model', '')
if prov != 'auto' or mdl:
parts = [f"provider={prov}"]
if mdl:
parts.append(f"model={mdl}")
print(f" {label:12s} {', '.join(parts)}")
# Messaging
print()
@@ -820,15 +981,16 @@ def set_config_value(key: str, value: str):
"""Set a configuration value."""
# Check if it's an API key (goes to .env)
api_keys = [
'OPENROUTER_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
'FIRECRAWL_API_KEY', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID',
'OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'ANTHROPIC_API_KEY', 'VOICE_TOOLS_OPENAI_KEY',
'FIRECRAWL_API_KEY', 'FIRECRAWL_API_URL', 'BROWSERBASE_API_KEY', 'BROWSERBASE_PROJECT_ID',
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',
'GITHUB_TOKEN', 'HONCHO_API_KEY',
'GITHUB_TOKEN', 'HONCHO_API_KEY', 'NOUS_API_KEY', 'WANDB_API_KEY',
'TINKER_API_KEY',
]
if key.upper() in api_keys or key.upper().startswith('TERMINAL_SSH'):
if key.upper() in api_keys or key.upper().endswith('_API_KEY') or key.upper().endswith('_TOKEN') or key.upper().startswith('TERMINAL_SSH'):
save_env_value(key.upper(), value)
print(f"✓ Set {key} in {get_env_path()}")
return
@@ -878,6 +1040,7 @@ def set_config_value(key: str, value: str):
"terminal.docker_image": "TERMINAL_DOCKER_IMAGE",
"terminal.singularity_image": "TERMINAL_SINGULARITY_IMAGE",
"terminal.modal_image": "TERMINAL_MODAL_IMAGE",
"terminal.daytona_image": "TERMINAL_DAYTONA_IMAGE",
"terminal.cwd": "TERMINAL_CWD",
"terminal.timeout": "TERMINAL_TIMEOUT",
}

View File

@@ -132,7 +132,11 @@ def run_doctor(args):
# Check for common issues
content = env_path.read_text()
if "OPENROUTER_API_KEY" in content or "ANTHROPIC_API_KEY" in content:
if any(k in content for k in (
"OPENROUTER_API_KEY", "ANTHROPIC_API_KEY",
"GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY",
"KIMI_API_KEY", "MINIMAX_API_KEY", "MINIMAX_CN_API_KEY",
)):
check_ok("API key configured")
else:
check_warn("No API key found in ~/.hermes/.env")
@@ -355,6 +359,21 @@ def run_doctor(args):
check_fail("TERMINAL_SSH_HOST not set", "(required for TERMINAL_ENV=ssh)")
issues.append("Set TERMINAL_SSH_HOST in .env")
# Daytona (if using daytona backend)
if terminal_env == "daytona":
daytona_key = os.getenv("DAYTONA_API_KEY")
if daytona_key:
check_ok("Daytona API key", "(configured)")
else:
check_fail("DAYTONA_API_KEY not set", "(required for TERMINAL_ENV=daytona)")
issues.append("Set DAYTONA_API_KEY environment variable")
try:
from daytona import Daytona
check_ok("daytona SDK", "(installed)")
except ImportError:
check_fail("daytona SDK not installed", "(pip install daytona)")
issues.append("Install daytona SDK: pip install daytona")
# Node.js + agent-browser (for browser automation tools)
if shutil.which("node"):
check_ok("Node.js")
@@ -453,7 +472,42 @@ def run_doctor(args):
print(f"\r {color('', Colors.YELLOW)} Anthropic API {color(msg, Colors.DIM)} ")
except Exception as e:
print(f"\r {color('', Colors.YELLOW)} Anthropic API {color(f'({e})', Colors.DIM)} ")
# -- API-key providers (Z.AI/GLM, Kimi, MiniMax, MiniMax-CN) --
_apikey_providers = [
("Z.AI / GLM", ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), "https://api.z.ai/api/paas/v4/models", "GLM_BASE_URL"),
("Kimi / Moonshot", ("KIMI_API_KEY",), "https://api.moonshot.ai/v1/models", "KIMI_BASE_URL"),
("MiniMax", ("MINIMAX_API_KEY",), "https://api.minimax.io/v1/models", "MINIMAX_BASE_URL"),
("MiniMax (China)", ("MINIMAX_CN_API_KEY",), "https://api.minimaxi.com/v1/models", "MINIMAX_CN_BASE_URL"),
]
for _pname, _env_vars, _default_url, _base_env in _apikey_providers:
_key = ""
for _ev in _env_vars:
_key = os.getenv(_ev, "")
if _key:
break
if _key:
_label = _pname.ljust(20)
print(f" Checking {_pname} API...", end="", flush=True)
try:
import httpx
_base = os.getenv(_base_env, "")
_url = (_base.rstrip("/") + "/models") if _base else _default_url
_resp = httpx.get(
_url,
headers={"Authorization": f"Bearer {_key}"},
timeout=10,
)
if _resp.status_code == 200:
print(f"\r {color('', Colors.GREEN)} {_label} ")
elif _resp.status_code == 401:
print(f"\r {color('', Colors.RED)} {_label} {color('(invalid API key)', Colors.DIM)} ")
issues.append(f"Check {_env_vars[0]} in .env")
else:
print(f"\r {color('', Colors.YELLOW)} {_label} {color(f'(HTTP {_resp.status_code})', Colors.DIM)} ")
except Exception as _e:
print(f"\r {color('', Colors.YELLOW)} {_label} {color(f'({_e})', Colors.DIM)} ")
# =========================================================================
# Check: Submodules
# =========================================================================

View File

@@ -64,7 +64,13 @@ def _has_any_provider_configured() -> bool:
# Check env vars (may be set by .env or shell).
# OPENAI_BASE_URL alone counts — local models (vLLM, llama.cpp, etc.)
# often don't require an API key.
provider_env_vars = ("OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL")
from hermes_cli.auth import PROVIDER_REGISTRY
# Collect all provider env vars
provider_env_vars = {"OPENROUTER_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", "OPENAI_BASE_URL"}
for pconfig in PROVIDER_REGISTRY.values():
if pconfig.auth_type == "api_key":
provider_env_vars.update(pconfig.api_key_env_vars)
if any(os.getenv(v) for v in provider_env_vars):
return True
@@ -143,6 +149,13 @@ def cmd_chat(args):
print("You can run 'hermes setup' at any time to configure.")
sys.exit(1)
# Sync bundled skills on every CLI launch (fast -- skips unchanged skills)
try:
from tools.skills_sync import sync_skills
sync_skills(quiet=True)
except Exception:
pass
# Import and run the CLI
from cli import main as cli_main
@@ -404,6 +417,10 @@ def cmd_model(args):
"openrouter": "OpenRouter",
"nous": "Nous Portal",
"openai-codex": "OpenAI Codex",
"zai": "Z.AI / GLM",
"kimi-coding": "Kimi / Moonshot",
"minimax": "MiniMax",
"minimax-cn": "MiniMax (China)",
"custom": "Custom endpoint",
}
active_label = provider_labels.get(active, active)
@@ -418,11 +435,16 @@ def cmd_model(args):
("openrouter", "OpenRouter (100+ models, pay-per-use)"),
("nous", "Nous Portal (Nous Research subscription)"),
("openai-codex", "OpenAI Codex"),
("zai", "Z.AI / GLM (Zhipu AI direct API)"),
("kimi-coding", "Kimi / Moonshot (Moonshot AI direct API)"),
("minimax", "MiniMax (global direct API)"),
("minimax-cn", "MiniMax China (domestic direct API)"),
("custom", "Custom endpoint (self-hosted / VLLM / etc.)"),
]
# Reorder so the active provider is at the top
active_key = active if active in ("openrouter", "nous", "openai-codex") else "custom"
known_keys = {k for k, _ in providers}
active_key = active if active in known_keys else "custom"
ordered = []
for key, label in providers:
if key == active_key:
@@ -447,6 +469,8 @@ def cmd_model(args):
_model_flow_openai_codex(config, current_model)
elif selected_provider == "custom":
_model_flow_custom(config)
elif selected_provider in ("zai", "kimi-coding", "minimax", "minimax-cn"):
_model_flow_api_key_provider(config, selected_provider, current_model)
def _prompt_provider_choice(choices):
@@ -716,6 +740,117 @@ def _model_flow_custom(config):
print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.")
# Curated model lists for direct API-key providers
_PROVIDER_MODELS = {
"zai": [
"glm-5",
"glm-4.7",
"glm-4.5",
"glm-4.5-flash",
],
"kimi-coding": [
"kimi-k2.5",
"kimi-k2-thinking",
"kimi-k2-turbo-preview",
"kimi-k2-0905-preview",
],
"minimax": [
"MiniMax-M2.5",
"MiniMax-M2.5-highspeed",
"MiniMax-M2.1",
],
"minimax-cn": [
"MiniMax-M2.5",
"MiniMax-M2.5-highspeed",
"MiniMax-M2.1",
],
}
def _model_flow_api_key_provider(config, provider_id, current_model=""):
"""Generic flow for API-key providers (z.ai, Kimi, MiniMax)."""
from hermes_cli.auth import (
PROVIDER_REGISTRY, _prompt_model_selection, _save_model_choice,
_update_config_for_provider, deactivate_provider,
)
from hermes_cli.config import get_env_value, save_env_value, load_config, save_config
pconfig = PROVIDER_REGISTRY[provider_id]
key_env = pconfig.api_key_env_vars[0] if pconfig.api_key_env_vars else ""
base_url_env = pconfig.base_url_env_var or ""
# Check / prompt for API key
existing_key = ""
for ev in pconfig.api_key_env_vars:
existing_key = get_env_value(ev) or os.getenv(ev, "")
if existing_key:
break
if not existing_key:
print(f"No {pconfig.name} API key configured.")
if key_env:
try:
new_key = input(f"{key_env} (or Enter to cancel): ").strip()
except (KeyboardInterrupt, EOFError):
print()
return
if not new_key:
print("Cancelled.")
return
save_env_value(key_env, new_key)
print("API key saved.")
print()
else:
print(f" {pconfig.name} API key: {existing_key[:8]}... ✓")
print()
# Optional base URL override
current_base = ""
if base_url_env:
current_base = get_env_value(base_url_env) or os.getenv(base_url_env, "")
effective_base = current_base or pconfig.inference_base_url
try:
override = input(f"Base URL [{effective_base}]: ").strip()
except (KeyboardInterrupt, EOFError):
print()
override = ""
if override and base_url_env:
save_env_value(base_url_env, override)
effective_base = override
# Model selection
model_list = _PROVIDER_MODELS.get(provider_id, [])
if model_list:
selected = _prompt_model_selection(model_list, current_model=current_model)
else:
try:
selected = input("Model name: ").strip()
except (KeyboardInterrupt, EOFError):
selected = None
if selected:
# Clear custom endpoint if set (avoid confusion)
if get_env_value("OPENAI_BASE_URL"):
save_env_value("OPENAI_BASE_URL", "")
save_env_value("OPENAI_API_KEY", "")
_save_model_choice(selected)
# Update config with provider and base URL
cfg = load_config()
model = cfg.get("model")
if isinstance(model, dict):
model["provider"] = provider_id
model["base_url"] = effective_base
save_config(cfg)
deactivate_provider()
print(f"Default model set to: {selected} (via {pconfig.name})")
else:
print("No change.")
def cmd_login(args):
"""Authenticate Hermes CLI with a provider."""
from hermes_cli.auth import login_command
@@ -851,11 +986,17 @@ def _update_via_zip(args):
# Sync skills
try:
from tools.skills_sync import sync_skills
print("Checking for new bundled skills...")
print("Syncing bundled skills...")
result = sync_skills(quiet=True)
if result["copied"]:
print(f" + {len(result['copied'])} new skill(s): {', '.join(result['copied'])}")
else:
print(f" + {len(result['copied'])} new: {', '.join(result['copied'])}")
if result.get("updated"):
print(f"{len(result['updated'])} updated: {', '.join(result['updated'])}")
if result.get("user_modified"):
print(f" ~ {len(result['user_modified'])} user-modified (kept)")
if result.get("cleaned"):
print(f" {len(result['cleaned'])} removed from manifest")
if not result["copied"] and not result.get("updated"):
print(" ✓ Skills are up to date")
except Exception:
pass
@@ -961,15 +1102,21 @@ def cmd_update(args):
print()
print("✓ Code updated!")
# Sync any new bundled skills (manifest-based -- won't overwrite or re-add deleted skills)
# Sync bundled skills (copies new, updates changed, respects user deletions)
try:
from tools.skills_sync import sync_skills
print()
print("Checking for new bundled skills...")
print("Syncing bundled skills...")
result = sync_skills(quiet=True)
if result["copied"]:
print(f" + {len(result['copied'])} new skill(s): {', '.join(result['copied'])}")
else:
print(f" + {len(result['copied'])} new: {', '.join(result['copied'])}")
if result.get("updated"):
print(f"{len(result['updated'])} updated: {', '.join(result['updated'])}")
if result.get("user_modified"):
print(f" ~ {len(result['user_modified'])} user-modified (kept)")
if result.get("cleaned"):
print(f" {len(result['cleaned'])} removed from manifest")
if not result["copied"] and not result.get("updated"):
print(" ✓ Skills are up to date")
except Exception as e:
logger.debug("Skills sync during update failed: %s", e)
@@ -1122,7 +1269,7 @@ For more help on a command:
)
chat_parser.add_argument(
"--provider",
choices=["auto", "openrouter", "nous", "openai-codex"],
choices=["auto", "openrouter", "nous", "openai-codex", "zai", "kimi-coding", "minimax", "minimax-cn"],
default=None,
help="Inference provider (default: auto)"
)
@@ -1200,7 +1347,15 @@ For more help on a command:
setup_parser = subparsers.add_parser(
"setup",
help="Interactive setup wizard",
description="Configure Hermes Agent with an interactive wizard"
description="Configure Hermes Agent with an interactive wizard. "
"Run a specific section: hermes setup model|terminal|gateway|tools|agent"
)
setup_parser.add_argument(
"section",
nargs="?",
choices=["model", "terminal", "gateway", "tools", "agent"],
default=None,
help="Run a specific setup section instead of the full wizard"
)
setup_parser.add_argument(
"--non-interactive",
@@ -1424,9 +1579,16 @@ For more help on a command:
)
skills_subparsers = skills_parser.add_subparsers(dest="skills_action")
skills_browse = skills_subparsers.add_parser("browse", help="Browse all available skills (paginated)")
skills_browse.add_argument("--page", type=int, default=1, help="Page number (default: 1)")
skills_browse.add_argument("--size", type=int, default=20, help="Results per page (default: 20)")
skills_browse.add_argument("--source", default="all",
choices=["all", "official", "github", "clawhub", "lobehub"],
help="Filter by source (default: all)")
skills_search = skills_subparsers.add_parser("search", help="Search skill registries")
skills_search.add_argument("query", help="Search query")
skills_search.add_argument("--source", default="all", choices=["all", "github", "clawhub", "lobehub"])
skills_search.add_argument("--source", default="all", choices=["all", "official", "github", "clawhub", "lobehub"])
skills_search.add_argument("--limit", type=int, default=10, help="Max results")
skills_install = skills_subparsers.add_parser("install", help="Install a skill")
@@ -1603,6 +1765,32 @@ For more help on a command:
sessions_parser.set_defaults(func=cmd_sessions)
# =========================================================================
# insights command
# =========================================================================
insights_parser = subparsers.add_parser(
"insights",
help="Show usage insights and analytics",
description="Analyze session history to show token usage, costs, tool patterns, and activity trends"
)
insights_parser.add_argument("--days", type=int, default=30, help="Number of days to analyze (default: 30)")
insights_parser.add_argument("--source", help="Filter by platform (cli, telegram, discord, etc.)")
def cmd_insights(args):
try:
from hermes_state import SessionDB
from agent.insights import InsightsEngine
db = SessionDB()
engine = InsightsEngine(db)
report = engine.generate(days=args.days, source=args.source)
print(engine.format_terminal(report))
db.close()
except Exception as e:
print(f"Error generating insights: {e}")
insights_parser.set_defaults(func=cmd_insights)
# =========================================================================
# version command
# =========================================================================

View File

@@ -9,14 +9,17 @@ Add, remove, or reorder entries here — both `hermes setup` and
OPENROUTER_MODELS: list[tuple[str, str]] = [
("anthropic/claude-opus-4.6", "recommended"),
("anthropic/claude-sonnet-4.5", ""),
("anthropic/claude-opus-4.5", ""),
("openai/gpt-5.2", ""),
("openai/gpt-5.4-pro", ""),
("openai/gpt-5.4", ""),
("openai/gpt-5.3-codex", ""),
("google/gemini-3-pro-preview", ""),
("google/gemini-3-flash-preview", ""),
("z-ai/glm-4.7", ""),
("qwen/qwen3.5-plus-02-15", ""),
("qwen/qwen3.5-35b-a3b", ""),
("stepfun/step-3.5-flash", ""),
("z-ai/glm-5", ""),
("moonshotai/kimi-k2.5", ""),
("minimax/minimax-m2.1", ""),
("minimax/minimax-m2.5", ""),
]

View File

@@ -7,10 +7,12 @@ from typing import Any, Dict, Optional
from hermes_cli.auth import (
AuthError,
PROVIDER_REGISTRY,
format_auth_error,
resolve_provider,
resolve_nous_runtime_credentials,
resolve_codex_runtime_credentials,
resolve_api_key_provider_credentials,
)
from hermes_cli.config import load_config
from hermes_constants import OPENROUTER_BASE_URL
@@ -72,12 +74,26 @@ def _resolve_openrouter_runtime(
or OPENROUTER_BASE_URL
).rstrip("/")
api_key = (
explicit_api_key
or os.getenv("OPENROUTER_API_KEY")
or os.getenv("OPENAI_API_KEY")
or ""
)
# Choose API key based on whether the resolved base_url targets OpenRouter.
# When hitting OpenRouter, prefer OPENROUTER_API_KEY (issue #289).
# When hitting a custom endpoint (e.g. Z.ai, local LLM), prefer
# OPENAI_API_KEY so the OpenRouter key doesn't leak to an unrelated
# provider (issues #420, #560).
_is_openrouter_url = "openrouter.ai" in base_url
if _is_openrouter_url:
api_key = (
explicit_api_key
or os.getenv("OPENROUTER_API_KEY")
or os.getenv("OPENAI_API_KEY")
or ""
)
else:
api_key = (
explicit_api_key
or os.getenv("OPENAI_API_KEY")
or os.getenv("OPENROUTER_API_KEY")
or ""
)
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
@@ -132,6 +148,19 @@ def resolve_runtime_provider(
"requested_provider": requested_provider,
}
# API-key providers (z.ai/GLM, Kimi, MiniMax, MiniMax-CN)
pconfig = PROVIDER_REGISTRY.get(provider)
if pconfig and pconfig.auth_type == "api_key":
creds = resolve_api_key_provider_credentials(provider)
return {
"provider": provider,
"api_mode": "chat_completions",
"base_url": creds.get("base_url", "").rstrip("/"),
"api_key": creds.get("api_key", ""),
"source": creds.get("source", "env"),
"requested_provider": requested_provider,
}
runtime = _resolve_openrouter_runtime(
requested_provider=requested_provider,
explicit_api_key=explicit_api_key,

File diff suppressed because it is too large Load Diff

View File

@@ -57,8 +57,9 @@ def _resolve_short_name(name: str, sources, console: Console) -> str:
table.add_column("Trust", style="dim")
table.add_column("Identifier", style="bold cyan")
for r in exact:
trust_style = {"trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
table.add_row(r.source, f"[{trust_style}]{r.trust_level}[/]", r.identifier)
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
trust_label = "official" if r.source == "official" else r.trust_level
table.add_row(r.source, f"[{trust_style}]{trust_label}[/]", r.identifier)
c.print(table)
c.print("[bold]Use the full identifier to install a specific one.[/]\n")
return ""
@@ -99,12 +100,13 @@ def do_search(query: str, source: str = "all", limit: int = 10,
table.add_column("Identifier", style="dim")
for r in results:
trust_style = {"trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(r.trust_level, "dim")
trust_label = "official" if r.source == "official" else r.trust_level
table.add_row(
r.name,
r.description[:60] + ("..." if len(r.description) > 60 else ""),
r.source,
f"[{trust_style}]{r.trust_level}[/]",
f"[{trust_style}]{trust_label}[/]",
r.identifier,
)
@@ -113,6 +115,130 @@ def do_search(query: str, source: str = "all", limit: int = 10,
"hermes skills install <identifier> to install[/]\n")
def do_browse(page: int = 1, page_size: int = 20, source: str = "all",
console: Optional[Console] = None) -> None:
"""Browse all available skills across registries, paginated.
Official skills are always shown first, regardless of source filter.
"""
from tools.skills_hub import (
GitHubAuth, create_source_router, OptionalSkillSource, SkillMeta,
)
# Clamp page_size to safe range
page_size = max(1, min(page_size, 100))
c = console or _console
auth = GitHubAuth()
sources = create_source_router(auth)
# Collect results from all (or filtered) sources
# Use empty query to get everything; per-source limits prevent overload
_TRUST_RANK = {"builtin": 3, "trusted": 2, "community": 1}
_PER_SOURCE_LIMIT = {"official": 100, "github": 100, "clawhub": 50,
"claude-marketplace": 50, "lobehub": 50}
all_results: list = []
source_counts: dict = {}
for src in sources:
sid = src.source_id()
if source != "all" and sid != source and sid != "official":
# Always include official source for the "first" placement
continue
try:
limit = _PER_SOURCE_LIMIT.get(sid, 50)
results = src.search("", limit=limit)
source_counts[sid] = len(results)
all_results.extend(results)
except Exception:
continue
if not all_results:
c.print("[dim]No skills found in the Skills Hub.[/]\n")
return
# Deduplicate by name, preferring higher trust
seen: dict = {}
for r in all_results:
rank = _TRUST_RANK.get(r.trust_level, 0)
if r.name not in seen or rank > _TRUST_RANK.get(seen[r.name].trust_level, 0):
seen[r.name] = r
deduped = list(seen.values())
# Sort: official first, then by trust level (desc), then alphabetically
deduped.sort(key=lambda r: (
-_TRUST_RANK.get(r.trust_level, 0),
r.source != "official",
r.name.lower(),
))
# Paginate
total = len(deduped)
total_pages = max(1, (total + page_size - 1) // page_size)
page = max(1, min(page, total_pages))
start = (page - 1) * page_size
end = min(start + page_size, total)
page_items = deduped[start:end]
# Count official vs other
official_count = sum(1 for r in deduped if r.source == "official")
# Build header
source_label = f"{source}" if source != "all" else "— all sources"
c.print(f"\n[bold]Skills Hub — Browse {source_label}[/]"
f" [dim]({total} skills, page {page}/{total_pages})[/]")
if official_count > 0 and page == 1:
c.print(f"[bright_cyan]★ {official_count} official optional skill(s) from Nous Research[/]")
c.print()
# Build table
table = Table(show_header=True, header_style="bold")
table.add_column("#", style="dim", width=4, justify="right")
table.add_column("Name", style="bold cyan", max_width=25)
table.add_column("Description", max_width=50)
table.add_column("Source", style="dim", width=12)
table.add_column("Trust", width=10)
for i, r in enumerate(page_items, start=start + 1):
trust_style = {"builtin": "bright_cyan", "trusted": "green",
"community": "yellow"}.get(r.trust_level, "dim")
trust_label = "★ official" if r.source == "official" else r.trust_level
desc = r.description[:50]
if len(r.description) > 50:
desc += "..."
table.add_row(
str(i),
r.name,
desc,
r.source,
f"[{trust_style}]{trust_label}[/]",
)
c.print(table)
# Navigation hints
nav_parts = []
if page > 1:
nav_parts.append(f"[cyan]--page {page - 1}[/] ← prev")
if page < total_pages:
nav_parts.append(f"[cyan]--page {page + 1}[/] → next")
if nav_parts:
c.print(f" {' | '.join(nav_parts)}")
# Source summary
if source == "all" and source_counts:
parts = [f"{sid}: {ct}" for sid, ct in sorted(source_counts.items())]
c.print(f" [dim]Sources: {', '.join(parts)}[/]")
c.print("[dim]Use: hermes skills inspect <identifier> to preview, "
"hermes skills install <identifier> to install[/]\n")
def do_install(identifier: str, category: str = "", force: bool = False,
console: Optional[Console] = None) -> None:
"""Fetch, quarantine, scan, confirm, and install a skill."""
@@ -147,6 +273,12 @@ def do_install(identifier: str, category: str = "", force: bool = False,
c.print(f"[bold red]Error:[/] Could not fetch '{identifier}' from any source.\n")
return
# Auto-detect category for official skills (e.g. "official/autonomous-ai-agents/blackbox")
if bundle.source == "official" and not category:
id_parts = bundle.identifier.split("/") # ["official", "category", "skill"]
if len(id_parts) >= 3:
category = id_parts[1]
# Check if already installed
lock = HubLockFile()
existing = lock.get_installed(bundle.name)
@@ -177,18 +309,28 @@ def do_install(identifier: str, category: str = "", force: bool = False,
f"{len(result.findings)}_findings")
return
# Confirm with user — always show risk warning regardless of source
# Confirm with user — show appropriate warning based on source
if not force:
c.print()
c.print(Panel(
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
"External skills can contain instructions that influence agent behavior,\n"
"shell commands, and scripts. Even after automated scanning, you should\n"
"review the installed files before use.\n\n"
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
title="Disclaimer",
border_style="yellow",
))
if bundle.source == "official":
c.print(Panel(
"[bold bright_cyan]This is an official optional skill maintained by Nous Research.[/]\n\n"
"It ships with hermes-agent but is not activated by default.\n"
"Installing will copy it to your skills directory where the agent can use it.\n\n"
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
title="Official Skill",
border_style="bright_cyan",
))
else:
c.print(Panel(
"[bold yellow]You are installing a third-party skill at your own risk.[/]\n\n"
"External skills can contain instructions that influence agent behavior,\n"
"shell commands, and scripts. Even after automated scanning, you should\n"
"review the installed files before use.\n\n"
f"Files will be at: [cyan]~/.hermes/skills/{category + '/' if category else ''}{bundle.name}/[/]",
title="Disclaimer",
border_style="yellow",
))
c.print(f"[bold]Install '{bundle.name}'?[/]")
try:
answer = input("Confirm [y/N]: ").strip().lower()
@@ -237,13 +379,14 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
break
c.print()
trust_style = {"trusted": "green", "community": "yellow"}.get(meta.trust_level, "dim")
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(meta.trust_level, "dim")
trust_label = "official" if meta.source == "official" else meta.trust_level
info_lines = [
f"[bold]Name:[/] {meta.name}",
f"[bold]Description:[/] {meta.description}",
f"[bold]Source:[/] {meta.source}",
f"[bold]Trust:[/] [{trust_style}]{meta.trust_level}[/]",
f"[bold]Trust:[/] [{trust_style}]{trust_label}[/]",
f"[bold]Identifier:[/] {meta.identifier}",
]
if meta.tags:
@@ -297,8 +440,9 @@ def do_list(source_filter: str = "all", console: Optional[Console] = None) -> No
if source_filter == "builtin" and hub_entry:
continue
trust_style = {"builtin": "blue", "trusted": "green", "community": "yellow"}.get(trust, "dim")
table.add_row(name, category, source_display, f"[{trust_style}]{trust}[/]")
trust_style = {"builtin": "bright_cyan", "trusted": "green", "community": "yellow"}.get(trust, "dim")
trust_label = "official" if source_display == "official" else trust
table.add_row(name, category, source_display, f"[{trust_style}]{trust_label}[/]")
c.print(table)
c.print(f"[dim]{len(hub_installed)} hub-installed, "
@@ -658,7 +802,9 @@ def skills_command(args) -> None:
"""Router for `hermes skills <subcommand>` — called from hermes_cli/main.py."""
action = getattr(args, "skills_action", None)
if action == "search":
if action == "browse":
do_browse(page=args.page, page_size=args.size, source=args.source)
elif action == "search":
do_search(args.query, source=args.source, limit=args.limit)
elif action == "install":
do_install(args.identifier, category=args.category, force=args.force)
@@ -692,7 +838,7 @@ def skills_command(args) -> None:
return
do_tap(tap_action, repo=repo)
else:
_console.print("Usage: hermes skills [search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n")
_console.print("Usage: hermes skills [browse|search|install|inspect|list|audit|uninstall|publish|snapshot|tap]\n")
_console.print("Run 'hermes skills <command> --help' for details.\n")
@@ -732,7 +878,32 @@ def handle_skills_slash(cmd: str, console: Optional[Console] = None) -> None:
action = parts[0].lower()
args = parts[1:]
if action == "search":
if action == "browse":
page = 1
page_size = 20
source = "all"
i = 0
while i < len(args):
if args[i] == "--page" and i + 1 < len(args):
try:
page = int(args[i + 1])
except ValueError:
pass
i += 2
elif args[i] == "--size" and i + 1 < len(args):
try:
page_size = int(args[i + 1])
except ValueError:
pass
i += 2
elif args[i] == "--source" and i + 1 < len(args):
source = args[i + 1]
i += 2
else:
i += 1
do_browse(page=page, page_size=page_size, source=source, console=c)
elif action == "search":
if not args:
c.print("[bold red]Usage:[/] /skills search <query> [--source github] [--limit N]\n")
return
@@ -838,6 +1009,7 @@ def _print_skills_help(console: Console) -> None:
"""Print help for the /skills slash command."""
console.print(Panel(
"[bold]Skills Hub Commands:[/]\n\n"
" [cyan]browse[/] [--source official] Browse all available skills (paginated)\n"
" [cyan]search[/] <query> Search registries for skills\n"
" [cyan]install[/] <identifier> Install a skill (with security scan)\n"
" [cyan]inspect[/] <identifier> Preview a skill without installing\n"

View File

@@ -79,8 +79,12 @@ def show_status(args):
"OpenRouter": "OPENROUTER_API_KEY",
"Anthropic": "ANTHROPIC_API_KEY",
"OpenAI": "OPENAI_API_KEY",
"Z.AI/GLM": "GLM_API_KEY",
"Kimi": "KIMI_API_KEY",
"MiniMax": "MINIMAX_API_KEY",
"MiniMax-CN": "MINIMAX_CN_API_KEY",
"Firecrawl": "FIRECRAWL_API_KEY",
"Browserbase": "BROWSERBASE_API_KEY",
"Browserbase": "BROWSERBASE_API_KEY", # Optional — local browser works without this
"FAL": "FAL_KEY",
"Tinker": "TINKER_API_KEY",
"WandB": "WANDB_API_KEY",
@@ -128,7 +132,7 @@ def show_status(args):
f" {'OpenAI Codex':<12} {check_mark(codex_logged_in)} "
f"{'logged in' if codex_logged_in else 'not logged in (run: hermes model)'}"
)
codex_auth_file = codex_status.get("auth_file")
codex_auth_file = codex_status.get("auth_store")
if codex_auth_file:
print(f" Auth file: {codex_auth_file}")
codex_last_refresh = _format_iso_timestamp(codex_status.get("last_refresh"))
@@ -137,6 +141,28 @@ def show_status(args):
if codex_status.get("error") and not codex_logged_in:
print(f" Error: {codex_status.get('error')}")
# =========================================================================
# API-Key Providers
# =========================================================================
print()
print(color("◆ API-Key Providers", Colors.CYAN, Colors.BOLD))
apikey_providers = {
"Z.AI / GLM": ("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"),
"Kimi / Moonshot": ("KIMI_API_KEY",),
"MiniMax": ("MINIMAX_API_KEY",),
"MiniMax (China)": ("MINIMAX_CN_API_KEY",),
}
for pname, env_vars in apikey_providers.items():
key_val = ""
for ev in env_vars:
key_val = get_env_value(ev) or ""
if key_val:
break
configured = bool(key_val)
label = "configured" if configured else "not configured (run: hermes model)"
print(f" {pname:<16} {check_mark(configured)} {label}")
# =========================================================================
# Terminal Configuration
# =========================================================================
@@ -163,6 +189,9 @@ def show_status(args):
elif terminal_env == "docker":
docker_image = os.getenv("TERMINAL_DOCKER_IMAGE", "python:3.11-slim")
print(f" Docker Image: {docker_image}")
elif terminal_env == "daytona":
daytona_image = os.getenv("TERMINAL_DAYTONA_IMAGE", "nikolaik/python-nodejs:python3.11-nodejs20")
print(f" Daytona Image: {daytona_image}")
sudo_password = os.getenv("SUDO_PASSWORD", "")
print(f" Sudo: {check_mark(bool(sudo_password))} {'enabled' if sudo_password else 'disabled'}")

View File

@@ -1,7 +1,10 @@
"""
Interactive tool configuration for Hermes Agent.
Unified tool configuration for Hermes Agent.
`hermes tools` and `hermes setup tools` both enter this module.
Select a platform → toggle toolsets on/off → for newly enabled tools
that need API keys, run through provider-aware configuration.
`hermes tools` — select a platform, then toggle toolsets on/off via checklist.
Saves per-platform tool configuration to ~/.hermes/config.yaml under
the `platform_toolsets` key.
"""
@@ -12,9 +15,63 @@ from typing import Dict, List, Set
import os
from hermes_cli.config import load_config, save_config, get_env_value, save_env_value
from hermes_cli.config import (
load_config, save_config, get_env_value, save_env_value,
get_hermes_home,
)
from hermes_cli.colors import Colors, color
PROJECT_ROOT = Path(__file__).parent.parent.resolve()
# ─── UI Helpers (shared with setup.py) ────────────────────────────────────────
def _print_info(text: str):
print(color(f" {text}", Colors.DIM))
def _print_success(text: str):
print(color(f"{text}", Colors.GREEN))
def _print_warning(text: str):
print(color(f"{text}", Colors.YELLOW))
def _print_error(text: str):
print(color(f"{text}", Colors.RED))
def _prompt(question: str, default: str = None, password: bool = False) -> str:
if default:
display = f"{question} [{default}]: "
else:
display = f"{question}: "
try:
if password:
import getpass
value = getpass.getpass(color(display, Colors.YELLOW))
else:
value = input(color(display, Colors.YELLOW))
return value.strip() or default or ""
except (KeyboardInterrupt, EOFError):
print()
return default or ""
def _prompt_yes_no(question: str, default: bool = True) -> bool:
default_str = "Y/n" if default else "y/N"
while True:
try:
value = input(color(f"{question} [{default_str}]: ", Colors.YELLOW)).strip().lower()
except (KeyboardInterrupt, EOFError):
print()
return default
if not value:
return default
if value in ('y', 'yes'):
return True
if value in ('n', 'no'):
return False
# ─── Toolset Registry ─────────────────────────────────────────────────────────
# Toolsets shown in the configurator, grouped for display.
# Each entry: (toolset_name, label, description)
# These map to keys in toolsets.py TOOLSETS dict.
@@ -49,6 +106,187 @@ PLATFORMS = {
}
# ─── Tool Categories (provider-aware configuration) ──────────────────────────
# Maps toolset keys to their provider options. When a toolset is newly enabled,
# we use this to show provider selection and prompt for the right API keys.
# Toolsets not in this map either need no config or use the simple fallback.
TOOL_CATEGORIES = {
"tts": {
"name": "Text-to-Speech",
"icon": "🔊",
"providers": [
{
"name": "Microsoft Edge TTS",
"tag": "Free - no API key needed",
"env_vars": [],
"tts_provider": "edge",
},
{
"name": "OpenAI TTS",
"tag": "Premium - high quality voices",
"env_vars": [
{"key": "VOICE_TOOLS_OPENAI_KEY", "prompt": "OpenAI API key", "url": "https://platform.openai.com/api-keys"},
],
"tts_provider": "openai",
},
{
"name": "ElevenLabs",
"tag": "Premium - most natural voices",
"env_vars": [
{"key": "ELEVENLABS_API_KEY", "prompt": "ElevenLabs API key", "url": "https://elevenlabs.io/app/settings/api-keys"},
],
"tts_provider": "elevenlabs",
},
],
},
"web": {
"name": "Web Search & Extract",
"icon": "🔍",
"providers": [
{
"name": "Firecrawl Cloud",
"tag": "Recommended - hosted service",
"env_vars": [
{"key": "FIRECRAWL_API_KEY", "prompt": "Firecrawl API key", "url": "https://firecrawl.dev"},
],
},
{
"name": "Firecrawl Self-Hosted",
"tag": "Free - run your own instance",
"env_vars": [
{"key": "FIRECRAWL_API_URL", "prompt": "Your Firecrawl instance URL (e.g., http://localhost:3002)"},
],
},
],
},
"image_gen": {
"name": "Image Generation",
"icon": "🎨",
"providers": [
{
"name": "FAL.ai",
"tag": "FLUX 2 Pro with auto-upscaling",
"env_vars": [
{"key": "FAL_KEY", "prompt": "FAL API key", "url": "https://fal.ai/dashboard/keys"},
],
},
],
},
"browser": {
"name": "Browser Automation",
"icon": "🌐",
"providers": [
{
"name": "Local Browser",
"tag": "Free headless Chromium (no API key needed)",
"env_vars": [],
"post_setup": "browserbase", # Same npm install for agent-browser
},
{
"name": "Browserbase",
"tag": "Cloud browser with stealth & proxies",
"env_vars": [
{"key": "BROWSERBASE_API_KEY", "prompt": "Browserbase API key", "url": "https://browserbase.com"},
{"key": "BROWSERBASE_PROJECT_ID", "prompt": "Browserbase project ID"},
],
"post_setup": "browserbase",
},
],
},
"homeassistant": {
"name": "Smart Home",
"icon": "🏠",
"providers": [
{
"name": "Home Assistant",
"tag": "REST API integration",
"env_vars": [
{"key": "HASS_TOKEN", "prompt": "Home Assistant Long-Lived Access Token"},
{"key": "HASS_URL", "prompt": "Home Assistant URL", "default": "http://homeassistant.local:8123"},
],
},
],
},
"rl": {
"name": "RL Training",
"icon": "🧪",
"requires_python": (3, 11),
"providers": [
{
"name": "Tinker / Atropos",
"tag": "RL training platform",
"env_vars": [
{"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"},
{"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"},
],
"post_setup": "rl_training",
},
],
},
}
# Simple env-var requirements for toolsets NOT in TOOL_CATEGORIES.
# Used as a fallback for tools like vision/moa that just need an API key.
TOOLSET_ENV_REQUIREMENTS = {
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
}
# ─── Post-Setup Hooks ─────────────────────────────────────────────────────────
def _run_post_setup(post_setup_key: str):
"""Run post-setup hooks for tools that need extra installation steps."""
import shutil
if post_setup_key == "browserbase":
node_modules = PROJECT_ROOT / "node_modules" / "agent-browser"
if not node_modules.exists() and shutil.which("npm"):
_print_info(" Installing Node.js dependencies for browser tools...")
import subprocess
result = subprocess.run(
["npm", "install", "--silent"],
capture_output=True, text=True, cwd=str(PROJECT_ROOT)
)
if result.returncode == 0:
_print_success(" Node.js dependencies installed")
else:
_print_warning(" npm install failed - run manually: cd ~/.hermes/hermes-agent && npm install")
elif not node_modules.exists():
_print_warning(" Node.js not found - browser tools require: npm install (in hermes-agent directory)")
elif post_setup_key == "rl_training":
try:
__import__("tinker_atropos")
except ImportError:
tinker_dir = PROJECT_ROOT / "tinker-atropos"
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
_print_info(" Installing tinker-atropos submodule...")
import subprocess
uv_bin = shutil.which("uv")
if uv_bin:
result = subprocess.run(
[uv_bin, "pip", "install", "--python", sys.executable, "-e", str(tinker_dir)],
capture_output=True, text=True
)
else:
result = subprocess.run(
[sys.executable, "-m", "pip", "install", "-e", str(tinker_dir)],
capture_output=True, text=True
)
if result.returncode == 0:
_print_success(" tinker-atropos installed")
else:
_print_warning(" tinker-atropos install failed - run manually:")
_print_info(' uv pip install -e "./tinker-atropos"')
else:
_print_warning(" tinker-atropos submodule not found - run:")
_print_info(" git submodule update --init --recursive")
_print_info(' uv pip install -e "./tinker-atropos"')
# ─── Platform / Toolset Helpers ───────────────────────────────────────────────
def _get_enabled_platforms() -> List[str]:
"""Return platform keys that are configured (have tokens or are CLI)."""
enabled = ["cli"]
@@ -97,6 +335,28 @@ def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[
save_config(config)
def _toolset_has_keys(ts_key: str) -> bool:
"""Check if a toolset's required API keys are configured."""
# Check TOOL_CATEGORIES first (provider-aware)
cat = TOOL_CATEGORIES.get(ts_key)
if cat:
for provider in cat["providers"]:
env_vars = provider.get("env_vars", [])
if not env_vars:
return True # Free provider (e.g., Edge TTS)
if all(get_env_value(v["key"]) for v in env_vars):
return True
return False
# Fallback to simple requirements
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
if not requirements:
return True
return all(get_env_value(var) for var, _ in requirements)
# ─── Menu Helpers ─────────────────────────────────────────────────────────────
def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
"""Single-select menu (arrow keys)."""
print(color(question, Colors.YELLOW))
@@ -114,7 +374,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
)
idx = menu.show()
if idx is None:
sys.exit(0)
return default
print()
return idx
except (ImportError, NotImplementedError):
@@ -132,15 +392,7 @@ def _prompt_choice(question: str, choices: list, default: int = 0) -> int:
return idx
except (ValueError, KeyboardInterrupt, EOFError):
print()
sys.exit(0)
def _toolset_has_keys(ts_key: str) -> bool:
"""Check if a toolset's required API keys are configured."""
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
if not requirements:
return True
return all(get_env_value(var) for var, _ in requirements)
return default
def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str]:
@@ -150,8 +402,8 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
labels = []
for ts_key, ts_label, ts_desc in CONFIGURABLE_TOOLSETS:
suffix = ""
if not _toolset_has_keys(ts_key) and TOOLSET_ENV_REQUIREMENTS.get(ts_key):
suffix = " no API key"
if not _toolset_has_keys(ts_key) and (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
suffix = " [no API key]"
labels.append(f"{ts_label} ({ts_desc}){suffix}")
pre_selected_indices = [
@@ -302,77 +554,294 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
return {CONFIGURABLE_TOOLSETS[i][0] for i in selected}
# Map toolset keys to the env vars they require and where to get them
TOOLSET_ENV_REQUIREMENTS = {
"web": [("FIRECRAWL_API_KEY", "https://firecrawl.dev/")],
"browser": [("BROWSERBASE_API_KEY", "https://browserbase.com/"),
("BROWSERBASE_PROJECT_ID", None)],
"vision": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
"image_gen": [("FAL_KEY", "https://fal.ai/")],
"moa": [("OPENROUTER_API_KEY", "https://openrouter.ai/keys")],
"tts": [], # Edge TTS is free, no key needed
"rl": [("TINKER_API_KEY", "https://tinker-console.thinkingmachines.ai/keys"),
("WANDB_API_KEY", "https://wandb.ai/authorize")],
"homeassistant": [("HASS_TOKEN", "Home Assistant > Profile > Long-Lived Access Tokens"),
("HASS_URL", None)],
}
# ─── Provider-Aware Configuration ────────────────────────────────────────────
def _configure_toolset(ts_key: str, config: dict):
"""Configure a toolset - provider selection + API keys.
Uses TOOL_CATEGORIES for provider-aware config, falls back to simple
env var prompts for toolsets not in TOOL_CATEGORIES.
"""
cat = TOOL_CATEGORIES.get(ts_key)
if cat:
_configure_tool_category(ts_key, cat, config)
else:
# Simple fallback for vision, moa, etc.
_configure_simple_requirements(ts_key)
def _check_and_prompt_requirements(newly_enabled: Set[str]):
"""Check if newly enabled toolsets have missing API keys and offer to set them up."""
for ts_key in sorted(newly_enabled):
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
if not requirements:
continue
def _configure_tool_category(ts_key: str, cat: dict, config: dict):
"""Configure a tool category with provider selection."""
icon = cat.get("icon", "")
name = cat["name"]
providers = cat["providers"]
missing = [(var, url) for var, url in requirements if not get_env_value(var)]
if not missing:
continue
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
print()
print(color(f"{ts_label} requires configuration:", Colors.YELLOW))
for var, url in missing:
if url:
print(color(f" {var}", Colors.CYAN) + color(f" ({url})", Colors.DIM))
else:
print(color(f" {var}", Colors.CYAN))
print()
try:
response = input(color(" Set up now? [Y/n] ", Colors.YELLOW)).strip().lower()
except (KeyboardInterrupt, EOFError):
# Check Python version requirement
if cat.get("requires_python"):
req = cat["requires_python"]
if sys.version_info < req:
print()
continue
_print_error(f" {name} requires Python {req[0]}.{req[1]}+ (current: {sys.version_info.major}.{sys.version_info.minor})")
_print_info(" Upgrade Python and reinstall to enable this tool.")
return
if response in ("", "y", "yes"):
for var, url in missing:
if url:
print(color(f" Get key at: {url}", Colors.DIM))
try:
import getpass
value = getpass.getpass(color(f" {var}: ", Colors.YELLOW))
except (KeyboardInterrupt, EOFError):
print()
break
if value.strip():
save_env_value(var, value.strip())
print(color(f" ✓ Saved", Colors.GREEN))
if len(providers) == 1:
# Single provider - configure directly
provider = providers[0]
print()
print(color(f" --- {icon} {name} ({provider['name']}) ---", Colors.CYAN))
if provider.get("tag"):
_print_info(f" {provider['tag']}")
_configure_provider(provider, config)
else:
# Multiple providers - let user choose
print()
print(color(f" --- {icon} {name} - Choose a provider ---", Colors.CYAN))
print()
# Plain text labels only (no ANSI codes in menu items)
provider_choices = []
for p in providers:
tag = f" ({p['tag']})" if p.get("tag") else ""
configured = ""
env_vars = p.get("env_vars", [])
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
configured = " [active]"
elif not env_vars:
configured = " [active]" if config.get("tts", {}).get("provider", "edge") == p.get("tts_provider", "") else ""
else:
print(color(f" Skipped", Colors.DIM))
configured = " [configured]"
provider_choices.append(f"{p['name']}{tag}{configured}")
# Detect current provider as default
default_idx = 0
for i, p in enumerate(providers):
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
default_idx = i
break
env_vars = p.get("env_vars", [])
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
default_idx = i
break
provider_idx = _prompt_choice(" Select provider:", provider_choices, default_idx)
_configure_provider(providers[provider_idx], config)
def _configure_provider(provider: dict, config: dict):
"""Configure a single provider - prompt for API keys and set config."""
env_vars = provider.get("env_vars", [])
# Set TTS provider in config if applicable
if provider.get("tts_provider"):
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
if not env_vars:
_print_success(f" {provider['name']} - no configuration needed!")
return
# Prompt for each required env var
all_configured = True
for var in env_vars:
existing = get_env_value(var["key"])
if existing:
_print_success(f" {var['key']}: already configured")
# Don't ask to update - this is a new enable flow.
# Reconfigure is handled separately.
else:
print(color(" Skipped — configure later with 'hermes setup'", Colors.DIM))
url = var.get("url", "")
if url:
_print_info(f" Get yours at: {url}")
default_val = var.get("default", "")
if default_val:
value = _prompt(f" {var.get('prompt', var['key'])}", default_val)
else:
value = _prompt(f" {var.get('prompt', var['key'])}", password=True)
if value:
save_env_value(var["key"], value)
_print_success(f" Saved")
else:
_print_warning(f" Skipped")
all_configured = False
# Run post-setup hooks if needed
if provider.get("post_setup") and all_configured:
_run_post_setup(provider["post_setup"])
if all_configured:
_print_success(f" {provider['name']} configured!")
def tools_command(args):
"""Entry point for `hermes tools`."""
def _configure_simple_requirements(ts_key: str):
"""Simple fallback for toolsets that just need env vars (no provider selection)."""
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
if not requirements:
return
missing = [(var, url) for var, url in requirements if not get_env_value(var)]
if not missing:
return
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
print()
print(color(f" {ts_label} requires configuration:", Colors.YELLOW))
for var, url in missing:
if url:
_print_info(f" Get key at: {url}")
value = _prompt(f" {var}", password=True)
if value and value.strip():
save_env_value(var, value.strip())
_print_success(f" Saved")
else:
_print_warning(f" Skipped")
def _reconfigure_tool(config: dict):
"""Let user reconfigure an existing tool's provider or API key."""
# Build list of configurable tools that are currently set up
configurable = []
for ts_key, ts_label, _ in CONFIGURABLE_TOOLSETS:
cat = TOOL_CATEGORIES.get(ts_key)
reqs = TOOLSET_ENV_REQUIREMENTS.get(ts_key)
if cat or reqs:
if _toolset_has_keys(ts_key):
configurable.append((ts_key, ts_label))
if not configurable:
_print_info("No configured tools to reconfigure.")
return
choices = [label for _, label in configurable]
choices.append("Cancel")
idx = _prompt_choice(" Which tool would you like to reconfigure?", choices, len(choices) - 1)
if idx >= len(configurable):
return # Cancel
ts_key, ts_label = configurable[idx]
cat = TOOL_CATEGORIES.get(ts_key)
if cat:
_configure_tool_category_for_reconfig(ts_key, cat, config)
else:
_reconfigure_simple_requirements(ts_key)
save_config(config)
def _configure_tool_category_for_reconfig(ts_key: str, cat: dict, config: dict):
"""Reconfigure a tool category - provider selection + API key update."""
icon = cat.get("icon", "")
name = cat["name"]
providers = cat["providers"]
if len(providers) == 1:
provider = providers[0]
print()
print(color(f" --- {icon} {name} ({provider['name']}) ---", Colors.CYAN))
_reconfigure_provider(provider, config)
else:
print()
print(color(f" --- {icon} {name} - Choose a provider ---", Colors.CYAN))
print()
provider_choices = []
for p in providers:
tag = f" ({p['tag']})" if p.get("tag") else ""
configured = ""
env_vars = p.get("env_vars", [])
if not env_vars or all(get_env_value(v["key"]) for v in env_vars):
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
configured = " [active]"
elif not env_vars:
configured = ""
else:
configured = " [configured]"
provider_choices.append(f"{p['name']}{tag}{configured}")
default_idx = 0
for i, p in enumerate(providers):
if p.get("tts_provider") and config.get("tts", {}).get("provider") == p["tts_provider"]:
default_idx = i
break
env_vars = p.get("env_vars", [])
if env_vars and all(get_env_value(v["key"]) for v in env_vars):
default_idx = i
break
provider_idx = _prompt_choice(" Select provider:", provider_choices, default_idx)
_reconfigure_provider(providers[provider_idx], config)
def _reconfigure_provider(provider: dict, config: dict):
"""Reconfigure a provider - update API keys."""
env_vars = provider.get("env_vars", [])
if provider.get("tts_provider"):
config.setdefault("tts", {})["provider"] = provider["tts_provider"]
_print_success(f" TTS provider set to: {provider['tts_provider']}")
if not env_vars:
_print_success(f" {provider['name']} - no configuration needed!")
return
for var in env_vars:
existing = get_env_value(var["key"])
if existing:
_print_info(f" {var['key']}: configured ({existing[:8]}...)")
url = var.get("url", "")
if url:
_print_info(f" Get yours at: {url}")
default_val = var.get("default", "")
value = _prompt(f" {var.get('prompt', var['key'])} (Enter to keep current)", password=not default_val)
if value and value.strip():
save_env_value(var["key"], value.strip())
_print_success(f" Updated")
else:
_print_info(f" Kept current")
def _reconfigure_simple_requirements(ts_key: str):
"""Reconfigure simple env var requirements."""
requirements = TOOLSET_ENV_REQUIREMENTS.get(ts_key, [])
if not requirements:
return
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
print()
print(color(f" {ts_label}:", Colors.CYAN))
for var, url in requirements:
existing = get_env_value(var)
if existing:
_print_info(f" {var}: configured ({existing[:8]}...)")
if url:
_print_info(f" Get key at: {url}")
value = _prompt(f" {var} (Enter to keep current)", password=True)
if value and value.strip():
save_env_value(var, value.strip())
_print_success(f" Updated")
else:
_print_info(f" Kept current")
# ─── Main Entry Point ─────────────────────────────────────────────────────────
def tools_command(args=None):
"""Entry point for `hermes tools` and `hermes setup tools`."""
config = load_config()
enabled_platforms = _get_enabled_platforms()
print()
print(color("⚕ Hermes Tool Configuration", Colors.CYAN, Colors.BOLD))
print(color(" Enable or disable tools per platform.", Colors.DIM))
print(color(" Tools that need API keys will be configured when enabled.", Colors.DIM))
print()
# Build platform choices
@@ -380,22 +849,28 @@ def tools_command(args):
platform_keys = []
for pkey in enabled_platforms:
pinfo = PLATFORMS[pkey]
# Count currently enabled toolsets
current = _get_platform_tools(config, pkey)
count = len(current)
total = len(CONFIGURABLE_TOOLSETS)
platform_choices.append(f"Configure {pinfo['label']} ({count}/{total} enabled)")
platform_keys.append(pkey)
platform_choices.append("Done — save and exit")
platform_choices.append("Reconfigure an existing tool's provider or API key")
platform_choices.append("Done")
while True:
idx = _prompt_choice("Select a platform to configure:", platform_choices, default=0)
idx = _prompt_choice("Select an option:", platform_choices, default=0)
# "Done" selected
if idx == len(platform_keys):
if idx == len(platform_keys) + 1:
break
# "Reconfigure" selected
if idx == len(platform_keys):
_reconfigure_tool(config)
print()
continue
pkey = platform_keys[idx]
pinfo = PLATFORMS[pkey]
@@ -418,11 +893,15 @@ def tools_command(args):
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
print(color(f" - {label}", Colors.RED))
# Prompt for missing API keys on newly enabled toolsets
# Configure newly enabled toolsets that need API keys
if added:
_check_and_prompt_requirements(added)
for ts_key in sorted(added):
if TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key):
if not _toolset_has_keys(ts_key):
_configure_toolset(ts_key, config)
_save_platform_tools(config, pkey, new_enabled)
save_config(config)
print(color(f" ✓ Saved {pinfo['label']} configuration", Colors.GREEN))
else:
print(color(f" No changes to {pinfo['label']}", Colors.DIM))

119
hermes_time.py Normal file
View File

@@ -0,0 +1,119 @@
"""
Timezone-aware clock for Hermes.
Provides a single ``now()`` helper that returns a timezone-aware datetime
based on the user's configured IANA timezone (e.g. ``Asia/Kolkata``).
Resolution order:
1. ``HERMES_TIMEZONE`` environment variable
2. ``timezone`` key in ``~/.hermes/config.yaml``
3. Falls back to the server's local time (``datetime.now().astimezone()``)
Invalid timezone values log a warning and fall back safely — Hermes never
crashes due to a bad timezone string.
"""
import logging
import os
from datetime import datetime, timezone as _tz
from pathlib import Path
from typing import Optional
logger = logging.getLogger(__name__)
try:
from zoneinfo import ZoneInfo
except ImportError:
# Python 3.8 fallback (shouldn't be needed — Hermes requires 3.9+)
from backports.zoneinfo import ZoneInfo # type: ignore[no-redef]
# Cached state — resolved once, reused on every call.
# Call reset_cache() to force re-resolution (e.g. after config changes).
_cached_tz: Optional[ZoneInfo] = None
_cached_tz_name: Optional[str] = None
_cache_resolved: bool = False
def _resolve_timezone_name() -> str:
"""Read the configured IANA timezone string (or empty string).
This does file I/O when falling through to config.yaml, so callers
should cache the result rather than calling on every ``now()``.
"""
# 1. Environment variable (highest priority — set by Supervisor, etc.)
tz_env = os.getenv("HERMES_TIMEZONE", "").strip()
if tz_env:
return tz_env
# 2. config.yaml ``timezone`` key
try:
import yaml
hermes_home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
config_path = hermes_home / "config.yaml"
if config_path.exists():
with open(config_path) as f:
cfg = yaml.safe_load(f) or {}
tz_cfg = cfg.get("timezone", "")
if isinstance(tz_cfg, str) and tz_cfg.strip():
return tz_cfg.strip()
except Exception:
pass
return ""
def _get_zoneinfo(name: str) -> Optional[ZoneInfo]:
"""Validate and return a ZoneInfo, or None if invalid."""
if not name:
return None
try:
return ZoneInfo(name)
except (KeyError, Exception) as exc:
logger.warning(
"Invalid timezone '%s': %s. Falling back to server local time.",
name, exc,
)
return None
def get_timezone() -> Optional[ZoneInfo]:
"""Return the user's configured ZoneInfo, or None (meaning server-local).
Resolved once and cached. Call ``reset_cache()`` after config changes.
"""
global _cached_tz, _cached_tz_name, _cache_resolved
if not _cache_resolved:
_cached_tz_name = _resolve_timezone_name()
_cached_tz = _get_zoneinfo(_cached_tz_name)
_cache_resolved = True
return _cached_tz
def get_timezone_name() -> str:
"""Return the IANA name of the configured timezone, or empty string."""
global _cached_tz_name, _cache_resolved
if not _cache_resolved:
get_timezone() # populates cache
return _cached_tz_name or ""
def now() -> datetime:
"""
Return the current time as a timezone-aware datetime.
If a valid timezone is configured, returns wall-clock time in that zone.
Otherwise returns the server's local time (via ``astimezone()``).
"""
tz = get_timezone()
if tz is not None:
return datetime.now(tz)
# No timezone configured — use server-local (still tz-aware)
return datetime.now().astimezone()
def reset_cache() -> None:
"""Clear the cached timezone. Used by tests and after config changes."""
global _cached_tz, _cached_tz_name, _cache_resolved
_cached_tz = None
_cached_tz_name = None
_cache_resolved = False

View File

@@ -1,64 +0,0 @@
"""Modal deployment configuration for hermes-agent.
Deploys the FastAPI streaming wrapper as a serverless ASGI app on Modal.
Usage:
modal deploy modal_app.py # Deploy to Modal
modal serve modal_app.py # Local dev with hot-reload
"""
import modal
image = (
modal.Image.debian_slim(python_version="3.11")
.apt_install("git")
.pip_install(
"fastapi[standard]",
"uvicorn",
"openai",
"python-dotenv",
"fire",
"httpx",
"rich",
"tenacity",
"pyyaml",
"requests",
"jinja2",
"pydantic>=2.0",
"prompt_toolkit",
"firecrawl-py",
"fal-client",
"edge-tts",
"litellm>=1.75.5",
"typer",
"platformdirs",
"PyJWT[crypto]",
)
.add_local_dir(".", "/app", copy=True, ignore=[".git", "__pycache__", "venv", ".venv", "*.pyc"])
)
app = modal.App("hermes-agent", image=image)
@app.function(
min_containers=0,
scaledown_window=300,
timeout=600,
secrets=[modal.Secret.from_name("hermes-secrets")],
)
@modal.concurrent(max_inputs=10)
@modal.asgi_app()
def web():
import os
import sys
from pathlib import Path
# Force HERMES_HOME to a known writable path inside the container
hermes_home = "/tmp/hermes"
os.environ["HERMES_HOME"] = hermes_home
Path(hermes_home).mkdir(parents=True, exist_ok=True)
(Path(hermes_home) / "logs").mkdir(parents=True, exist_ok=True)
sys.path.insert(0, "/app")
from serve import app as fastapi_app
return fastapi_app

View File

@@ -225,6 +225,18 @@ def get_tool_definitions(
# Ask the registry for schemas (only returns tools whose check_fn passes)
filtered_tools = registry.get_definitions(tools_to_include, quiet=quiet_mode)
# Rebuild execute_code schema to only list sandbox tools that are actually
# enabled. Without this, the model sees "web_search is available in
# execute_code" even when the user disabled the web toolset (#560-discord).
if "execute_code" in tools_to_include:
from tools.code_execution_tool import SANDBOX_ALLOWED_TOOLS, build_execute_code_schema
sandbox_enabled = SANDBOX_ALLOWED_TOOLS & tools_to_include
dynamic_schema = build_execute_code_schema(sandbox_enabled)
for i, td in enumerate(filtered_tools):
if td.get("function", {}).get("name") == "execute_code":
filtered_tools[i] = {"type": "function", "function": dynamic_schema}
break
if not quiet_mode:
if filtered_tools:
tool_names = [t["function"]["name"] for t in filtered_tools]

View File

@@ -0,0 +1,24 @@
# Optional Skills
Official skills maintained by Nous Research that are **not activated by default**.
These skills ship with the hermes-agent repository but are not copied to
`~/.hermes/skills/` during setup. They are discoverable via the Skills Hub:
```bash
hermes skills browse # browse all skills, official shown first
hermes skills browse --source official # browse only official optional skills
hermes skills search <query> # finds optional skills labeled "official"
hermes skills install <identifier> # copies to ~/.hermes/skills/ and activates
```
## Why optional?
Some skills are useful but not broadly needed by every user:
- **Niche integrations** — specific paid services, specialized tools
- **Experimental features** — promising but not yet proven
- **Heavyweight dependencies** — require significant setup (API keys, installs)
By keeping them optional, we keep the default skill set lean while still
providing curated, tested, official skills for users who want them.

View File

@@ -0,0 +1,2 @@
Optional autonomous AI agent integrations — external coding agent CLIs
that can be delegated to for independent coding tasks.

View File

@@ -0,0 +1,143 @@
---
name: blackbox
description: Delegate coding tasks to Blackbox AI CLI agent. Multi-model agent with built-in judge that runs tasks through multiple LLMs and picks the best result. Requires the blackbox CLI and a Blackbox AI API key.
version: 1.0.0
author: Hermes Agent (Nous Research)
license: MIT
metadata:
hermes:
tags: [Coding-Agent, Blackbox, Multi-Agent, Judge, Multi-Model]
related_skills: [claude-code, codex, hermes-agent]
---
# Blackbox CLI
Delegate coding tasks to [Blackbox AI](https://www.blackbox.ai/) via the Hermes terminal. Blackbox is a multi-model coding agent CLI that dispatches tasks to multiple LLMs (Claude, Codex, Gemini, Blackbox Pro) and uses a judge to select the best implementation.
The CLI is [open-source](https://github.com/blackboxaicode/cli) (GPL-3.0, TypeScript, forked from Gemini CLI) and supports interactive sessions, non-interactive one-shots, checkpointing, MCP, and vision model switching.
## Prerequisites
- Node.js 20+ installed
- Blackbox CLI installed: `npm install -g @blackboxai/cli`
- Or install from source:
```
git clone https://github.com/blackboxaicode/cli.git
cd cli && npm install && npm install -g .
```
- API key from [app.blackbox.ai/dashboard](https://app.blackbox.ai/dashboard)
- Configured: run `blackbox configure` and enter your API key
- Use `pty=true` in terminal calls — Blackbox CLI is an interactive terminal app
## One-Shot Tasks
```
terminal(command="blackbox --prompt 'Add JWT authentication with refresh tokens to the Express API'", workdir="/path/to/project", pty=true)
```
For quick scratch work:
```
terminal(command="cd $(mktemp -d) && git init && blackbox --prompt 'Build a REST API for todos with SQLite'", pty=true)
```
## Background Mode (Long Tasks)
For tasks that take minutes, use background mode so you can monitor progress:
```
# Start in background with PTY
terminal(command="blackbox --prompt 'Refactor the auth module to use OAuth 2.0'", workdir="~/project", background=true, pty=true)
# Returns session_id
# Monitor progress
process(action="poll", session_id="<id>")
process(action="log", session_id="<id>")
# Send input if Blackbox asks a question
process(action="submit", session_id="<id>", data="yes")
# Kill if needed
process(action="kill", session_id="<id>")
```
## Checkpoints & Resume
Blackbox CLI has built-in checkpoint support for pausing and resuming tasks:
```
# After a task completes, Blackbox shows a checkpoint tag
# Resume with a follow-up task:
terminal(command="blackbox --resume-checkpoint 'task-abc123-2026-03-06' --prompt 'Now add rate limiting to the endpoints'", workdir="~/project", pty=true)
```
## Session Commands
During an interactive session, use these commands:
| Command | Effect |
|---------|--------|
| `/compress` | Shrink conversation history to save tokens |
| `/clear` | Wipe history and start fresh |
| `/stats` | View current token usage |
| `Ctrl+C` | Cancel current operation |
## PR Reviews
Clone to a temp directory to avoid modifying the working tree:
```
terminal(command="REVIEW=$(mktemp -d) && git clone https://github.com/user/repo.git $REVIEW && cd $REVIEW && gh pr checkout 42 && blackbox --prompt 'Review this PR against main. Check for bugs, security issues, and code quality.'", pty=true)
```
## Parallel Work
Spawn multiple Blackbox instances for independent tasks:
```
terminal(command="blackbox --prompt 'Fix the login bug'", workdir="/tmp/issue-1", background=true, pty=true)
terminal(command="blackbox --prompt 'Add unit tests for auth'", workdir="/tmp/issue-2", background=true, pty=true)
# Monitor all
process(action="list")
```
## Multi-Model Mode
Blackbox's unique feature is running the same task through multiple models and judging the results. Configure which models to use via `blackbox configure` — select multiple providers to enable the Chairman/judge workflow where the CLI evaluates outputs from different models and picks the best one.
## Key Flags
| Flag | Effect |
|------|--------|
| `--prompt "task"` | Non-interactive one-shot execution |
| `--resume-checkpoint "tag"` | Resume from a saved checkpoint |
| `--yolo` | Auto-approve all actions and model switches |
| `blackbox session` | Start interactive chat session |
| `blackbox configure` | Change settings, providers, models |
| `blackbox info` | Display system information |
## Vision Support
Blackbox automatically detects images in input and can switch to multimodal analysis. VLM modes:
- `"once"` — Switch model for current query only
- `"session"` — Switch for entire session
- `"persist"` — Stay on current model (no switch)
## Token Limits
Control token usage via `.blackboxcli/settings.json`:
```json
{
"sessionTokenLimit": 32000
}
```
## Rules
1. **Always use `pty=true`** — Blackbox CLI is an interactive terminal app and will hang without a PTY
2. **Use `workdir`** — keep the agent focused on the right directory
3. **Background for long tasks** — use `background=true` and monitor with `process` tool
4. **Don't interfere** — monitor with `poll`/`log`, don't kill sessions because they're slow
5. **Report results** — after completion, check what changed and summarize for the user
6. **Credits cost money** — Blackbox uses a credit-based system; multi-model mode consumes credits faster
7. **Check prerequisites** — verify `blackbox` CLI is installed before attempting delegation

View File

@@ -5,9 +5,9 @@ build-backend = "setuptools.build_meta"
[project]
name = "hermes-agent"
version = "0.1.0"
description = "AI agent with advanced tool-calling and toolsets"
description = "The self-improving AI agent — creates skills from experience, improves them during use, and runs anywhere"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.11"
authors = [{ name = "Nous Research" }]
license = { text = "MIT" }
dependencies = [
@@ -39,7 +39,7 @@ dependencies = [
[project.optional-dependencies]
modal = ["swe-rex[modal]>=1.4.0"]
serve = ["fastapi[standard]", "uvicorn"]
daytona = ["daytona>=0.148.0"]
dev = ["pytest", "pytest-asyncio"]
messaging = ["python-telegram-bot>=20.0", "discord.py>=2.0", "aiohttp>=3.9.0", "slack-bolt>=1.18.0", "slack-sdk>=3.27.0"]
cron = ["croniter"]
@@ -50,9 +50,10 @@ pty = ["ptyprocess>=0.7.0"]
honcho = ["honcho-ai>=2.0.1"]
mcp = ["mcp>=1.2.0"]
homeassistant = ["aiohttp>=3.9.0"]
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git"]
all = [
"hermes-agent[modal]",
"hermes-agent[serve]",
"hermes-agent[daytona]",
"hermes-agent[messaging]",
"hermes-agent[cron]",
"hermes-agent[cli]",

View File

@@ -26,7 +26,6 @@ import json
import logging
logger = logging.getLogger(__name__)
import os
import queue
import random
import re
import sys
@@ -83,6 +82,8 @@ from agent.prompt_builder import (
from agent.model_metadata import (
fetch_model_metadata, get_model_context_length,
estimate_tokens_rough, estimate_messages_tokens_rough,
get_next_probe_tier, parse_context_limit_from_error,
save_context_length,
)
from agent.context_compressor import ContextCompressor
from agent.prompt_caching import apply_anthropic_cache_control
@@ -98,6 +99,46 @@ from agent.trajectory import (
)
class IterationBudget:
"""Thread-safe shared iteration counter for parent and child agents.
Tracks total LLM-call iterations consumed across a parent agent and all
its subagents. A single ``IterationBudget`` is created by the parent
and passed to every child so they share the same cap.
``execute_code`` (programmatic tool calling) iterations are refunded via
:meth:`refund` so they don't eat into the budget.
"""
def __init__(self, max_total: int):
self.max_total = max_total
self._used = 0
self._lock = threading.Lock()
def consume(self) -> bool:
"""Try to consume one iteration. Returns True if allowed."""
with self._lock:
if self._used >= self.max_total:
return False
self._used += 1
return True
def refund(self) -> None:
"""Give back one iteration (e.g. for execute_code turns)."""
with self._lock:
if self._used > 0:
self._used -= 1
@property
def used(self) -> int:
return self._used
@property
def remaining(self) -> int:
with self._lock:
return max(0, self.max_total - self._used)
class AIAgent:
"""
AI Agent with tool calling capabilities.
@@ -113,7 +154,7 @@ class AIAgent:
provider: str = None,
api_mode: str = None,
model: str = "anthropic/claude-opus-4.6", # OpenRouter format
max_iterations: int = 60, # Default tool-calling iterations
max_iterations: int = 90, # Default tool-calling iterations (shared with subagents)
tool_delay: float = 1.0,
enabled_toolsets: List[str] = None,
disabled_toolsets: List[str] = None,
@@ -141,8 +182,7 @@ class AIAgent:
skip_memory: bool = False,
session_db=None,
honcho_session_key: str = None,
event_queue: "queue.Queue | None" = None,
extra_tags: List[str] = None,
iteration_budget: "IterationBudget" = None,
):
"""
Initialize the AI Agent.
@@ -153,7 +193,7 @@ class AIAgent:
provider (str): Provider identifier (optional; used for telemetry/routing hints)
api_mode (str): API mode override: "chat_completions" or "codex_responses"
model (str): Model name to use (default: "anthropic/claude-opus-4.6")
max_iterations (int): Maximum number of tool calling iterations (default: 60)
max_iterations (int): Maximum number of tool calling iterations (default: 90)
tool_delay (float): Delay between tool calls in seconds (default: 1.0)
enabled_toolsets (List[str]): Only enable tools from these toolsets (optional)
disabled_toolsets (List[str]): Disable tools from these toolsets (optional)
@@ -187,6 +227,9 @@ class AIAgent:
"""
self.model = model
self.max_iterations = max_iterations
# Shared iteration budget — parent creates, children inherit.
# Consumed by every LLM turn across parent + all subagents.
self.iteration_budget = iteration_budget or IterationBudget(max_iterations)
self.tool_delay = tool_delay
self.save_trajectories = save_trajectories
self.verbose_logging = verbose_logging
@@ -220,8 +263,6 @@ class AIAgent:
self.tool_progress_callback = tool_progress_callback
self.clarify_callback = clarify_callback
self.step_callback = step_callback
self.event_queue: queue.Queue | None = event_queue
self._extra_tags: List[str] = extra_tags or []
self._last_reported_tool = None # Track for "new tool" mode
# Interrupt mechanism for breaking out of tool loops
@@ -260,7 +301,7 @@ class AIAgent:
# Persistent error log -- always writes WARNING+ to ~/.hermes/logs/errors.log
# so tool failures, API errors, etc. are inspectable after the fact.
from agent.redact import RedactingFormatter
_error_log_dir = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")) / "logs"
_error_log_dir = Path.home() / ".hermes" / "logs"
_error_log_dir.mkdir(parents=True, exist_ok=True)
_error_log_path = _error_log_dir / "errors.log"
from logging.handlers import RotatingFileHandler
@@ -541,6 +582,7 @@ class AIAgent:
summary_target_tokens=500,
summary_model_override=compression_summary_model,
quiet_mode=self.quiet_mode,
base_url=self.base_url,
)
self.compression_enabled = compression_enabled
self._user_turn_count = 0
@@ -1310,19 +1352,6 @@ class AIAgent:
except Exception as e:
logger.debug("Honcho sync failed (non-fatal): %s", e)
def _emit_event(self, event: Dict[str, Any]) -> None:
"""Push a structured event onto the event queue (if one is attached).
Used by the serve layer to stream intermediate agent progress
(text tokens, tool calls, tool results) back to callers over SSE.
No-op when ``event_queue`` is ``None`` (CLI / gateway usage).
"""
if self.event_queue is not None:
try:
self.event_queue.put_nowait(event)
except Exception:
pass
def _build_system_prompt(self, system_message: str = None) -> str:
"""
Assemble the full system prompt from all layers.
@@ -1378,7 +1407,8 @@ class AIAgent:
if context_files_prompt:
prompt_parts.append(context_files_prompt)
now = datetime.now()
from hermes_time import now as _hermes_now
now = _hermes_now()
prompt_parts.append(
f"Conversation started: {now.strftime('%A, %B %d, %Y %I:%M %p')}"
)
@@ -2033,6 +2063,49 @@ class AIAgent:
return True
def _try_refresh_nous_client_credentials(self, *, force: bool = True) -> bool:
if self.api_mode != "chat_completions" or self.provider != "nous":
return False
try:
from hermes_cli.auth import resolve_nous_runtime_credentials
creds = resolve_nous_runtime_credentials(
min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))),
timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")),
force_mint=force,
)
except Exception as exc:
logger.debug("Nous credential refresh failed: %s", exc)
return False
api_key = creds.get("api_key")
base_url = creds.get("base_url")
if not isinstance(api_key, str) or not api_key.strip():
return False
if not isinstance(base_url, str) or not base_url.strip():
return False
self.api_key = api_key.strip()
self.base_url = base_url.strip().rstrip("/")
self._client_kwargs["api_key"] = self.api_key
self._client_kwargs["base_url"] = self.base_url
# Nous requests should not inherit OpenRouter-only attribution headers.
self._client_kwargs.pop("default_headers", None)
try:
self.client.close()
except Exception:
pass
try:
self.client = OpenAI(**self._client_kwargs)
except Exception as exc:
logger.warning("Failed to rebuild OpenAI client after Nous refresh: %s", exc)
return False
return True
def _interruptible_api_call(self, api_kwargs: dict):
"""
Run the API call in a background thread so the main conversation loop
@@ -2154,11 +2227,9 @@ class AIAgent:
"effort": "xhigh"
}
# Nous Portal product attribution + caller-supplied tags
# Nous Portal product attribution
if _is_nous:
tags = list(self._extra_tags)
tags.append("product=hermes-agent")
extra_body["tags"] = tags
extra_body["tags"] = ["product=hermes-agent"]
if extra_body:
api_kwargs["extra_body"] = extra_body
@@ -2474,13 +2545,6 @@ class AIAgent:
except Exception as cb_err:
logging.debug(f"Tool progress callback error: {cb_err}")
self._emit_event({
"type": "tool-call",
"name": function_name,
"args": function_args,
"status": "calling",
})
tool_start_time = time.time()
if function_name == "todo":
@@ -2644,14 +2708,6 @@ class AIAgent:
messages.append(tool_msg)
self._log_msg_to_db(tool_msg)
self._emit_event({
"type": "tool-result",
"name": function_name,
"output": function_result[:4000],
"status": "complete",
"duration": round(tool_duration, 2),
})
if not self.quiet_mode:
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
@@ -2775,7 +2831,7 @@ class AIAgent:
"messages": api_messages,
}
if self.max_tokens is not None:
summary_kwargs["max_tokens"] = self.max_tokens
summary_kwargs.update(self._max_tokens_param(self.max_tokens))
if summary_extra_body:
summary_kwargs["extra_body"] = summary_extra_body
@@ -2962,7 +3018,7 @@ class AIAgent:
# Clear any stale interrupt state at start
self.clear_interrupt()
while api_call_count < self.max_iterations:
while api_call_count < self.max_iterations and self.iteration_budget.remaining > 0:
# Check for interrupt request (e.g., user sent new message)
if self._interrupt_requested:
interrupted = True
@@ -2971,6 +3027,10 @@ class AIAgent:
break
api_call_count += 1
if not self.iteration_budget.consume():
if not self.quiet_mode:
print(f"\n⚠️ Session iteration budget exhausted ({self.iteration_budget.max_total} total across agent + subagents)")
break
# Fire step_callback for gateway hooks (agent:step event)
if self.step_callback is not None:
@@ -3047,6 +3107,13 @@ class AIAgent:
if self._use_prompt_caching:
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
# Safety net: strip orphaned tool results / add stubs for missing
# results before sending to the API. The compressor handles this
# during compression, but orphans can also sneak in from session
# loading or manual message manipulation.
if hasattr(self, 'context_compressor') and self.context_compressor:
api_messages = self.context_compressor._sanitize_tool_pairs(api_messages)
# Calculate approximate request size for logging
total_chars = sum(len(str(msg)) for msg in api_messages)
approx_tokens = total_chars // 4 # Rough estimate: 4 chars per token
@@ -3076,6 +3143,7 @@ class AIAgent:
retry_count = 0
max_retries = 6 # Increased to allow longer backoff periods
codex_auth_retry_attempted = False
nous_auth_retry_attempted = False
finish_reason = "stop"
@@ -3271,6 +3339,13 @@ class AIAgent:
}
self.context_compressor.update_from_response(usage_dict)
# Cache discovered context length after successful call
if self.context_compressor._context_probed:
ctx = self.context_compressor.context_length
save_context_length(self.model, self.base_url, ctx)
print(f"{self.log_prefix}💾 Cached context length: {ctx:,} tokens for {self.model}")
self.context_compressor._context_probed = False
self.session_prompt_tokens += prompt_tokens
self.session_completion_tokens += completion_tokens
self.session_total_tokens += total_tokens
@@ -3318,6 +3393,16 @@ class AIAgent:
if self._try_refresh_codex_client_credentials(force=True):
print(f"{self.log_prefix}🔐 Codex auth refreshed after 401. Retrying request...")
continue
if (
self.api_mode == "chat_completions"
and self.provider == "nous"
and status_code == 401
and not nous_auth_retry_attempted
):
nous_auth_retry_attempted = True
if self._try_refresh_nous_client_credentials(force=True):
print(f"{self.log_prefix}🔐 Nous agent key refreshed after 401. Retrying request...")
continue
retry_count += 1
elapsed_time = time.time() - api_start_time
@@ -3390,18 +3475,37 @@ class AIAgent:
])
if is_context_length_error:
print(f"{self.log_prefix}⚠️ Context length exceeded - attempting compression...")
compressor = self.context_compressor
old_ctx = compressor.context_length
# Try to parse the actual limit from the error message
parsed_limit = parse_context_limit_from_error(error_msg)
if parsed_limit and parsed_limit < old_ctx:
new_ctx = parsed_limit
print(f"{self.log_prefix}⚠️ Context limit detected from API: {new_ctx:,} tokens (was {old_ctx:,})")
else:
# Step down to the next probe tier
new_ctx = get_next_probe_tier(old_ctx)
if new_ctx and new_ctx < old_ctx:
compressor.context_length = new_ctx
compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent)
compressor._context_probed = True
print(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,}{new_ctx:,} tokens")
else:
print(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...")
original_len = len(messages)
messages, active_system_prompt = self._compress_context(
messages, system_message, approx_tokens=approx_tokens
)
if len(messages) < original_len:
print(f"{self.log_prefix} 🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
continue # Retry with compressed messages
if len(messages) < original_len or new_ctx and new_ctx < old_ctx:
if len(messages) < original_len:
print(f"{self.log_prefix} 🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
continue # Retry with compressed messages or new tier
else:
# Can't compress further
# Can't compress further and already at minimum tier
print(f"{self.log_prefix}❌ Context length exceeded and cannot compress further.")
print(f"{self.log_prefix} 💡 The conversation has accumulated too much content.")
logging.error(f"{self.log_prefix}Context length exceeded: {approx_tokens:,} tokens. Cannot compress further.")
@@ -3693,6 +3797,13 @@ class AIAgent:
self._log_msg_to_db(assistant_msg)
self._execute_tool_calls(assistant_message, messages, effective_task_id)
# Refund the iteration if the ONLY tool(s) called were
# execute_code (programmatic tool calling). These are
# cheap RPC-style calls that shouldn't eat the budget.
_tc_names = {tc.function.name for tc in assistant_message.tool_calls}
if _tc_names == {"execute_code"}:
self.iteration_budget.refund()
if self.compression_enabled and self.context_compressor.should_compress():
messages, active_system_prompt = self._compress_context(
@@ -3713,13 +3824,33 @@ class AIAgent:
# Check if response only has think block with no actual content after it
if not self._has_content_after_think_block(final_response):
# Track retries for empty-after-think responses
# If the previous turn already delivered real content alongside
# tool calls (e.g. "You're welcome!" + memory save), the model
# has nothing more to say. Use the earlier content immediately
# instead of wasting API calls on retries that won't help.
fallback = getattr(self, '_last_content_with_tools', None)
if fallback:
logger.debug("Empty follow-up after tool calls — using prior turn content as final response")
self._last_content_with_tools = None
self._empty_content_retries = 0
for i in range(len(messages) - 1, -1, -1):
msg = messages[i]
if msg.get("role") == "assistant" and msg.get("tool_calls"):
tool_names = []
for tc in msg["tool_calls"]:
fn = tc.get("function", {})
tool_names.append(fn.get("name", "unknown"))
msg["content"] = f"Calling the {', '.join(tool_names)} tool{'s' if len(tool_names) > 1 else ''}..."
break
final_response = self._strip_think_blocks(fallback).strip()
break
# No fallback available — this is a genuine empty response.
# Retry in case the model just had a bad generation.
if not hasattr(self, '_empty_content_retries'):
self._empty_content_retries = 0
self._empty_content_retries += 1
# Show the reasoning/thinking content so the user can see
# what the model was thinking even though content is empty
reasoning_text = self._extract_reasoning(assistant_message)
print(f"{self.log_prefix}⚠️ Response only contains think block with no content after it")
if reasoning_text:
@@ -3814,9 +3945,6 @@ class AIAgent:
# Strip <think> blocks from user-facing response (keep raw in messages for trajectory)
final_response = self._strip_think_blocks(final_response).strip()
if final_response:
self._emit_event({"type": "text", "text": final_response})
final_msg = self._build_assistant_message(assistant_message, finish_reason)
@@ -3913,8 +4041,6 @@ class AIAgent:
# Clear interrupt state after handling
self.clear_interrupt()
self._emit_event({"type": "done"})
return result

124
serve.py
View File

@@ -1,124 +0,0 @@
"""FastAPI streaming wrapper for AIAgent.
Exposes hermes-agent as an HTTP service with SSE streaming.
Run locally with: uvicorn serve:app --host 0.0.0.0 --port 8000
Deploy on Modal via modal_app.py.
"""
import asyncio
import json
import logging
import os
import queue
import threading
from pathlib import Path
from typing import Any
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
logger = logging.getLogger(__name__)
# Force HERMES_HOME to a writable path. Modal secrets may set HERMES_HOME to
# a non-existent path (e.g. /app/tinker-atropos) — override unconditionally.
_hermes_home = Path("/tmp/hermes")
_hermes_home.mkdir(parents=True, exist_ok=True)
(_hermes_home / "logs").mkdir(parents=True, exist_ok=True)
os.environ["HERMES_HOME"] = str(_hermes_home)
# Pre-import modules that register signal handlers so they run in the
# main thread (signal.signal() fails if called from a worker thread).
try:
import tools.browser_tool # noqa: F401
except Exception:
pass
try:
from run_agent import AIAgent # noqa: F401
except Exception as e:
logger.warning("Failed to pre-import AIAgent: %s", e)
app = FastAPI(title="hermes-agent", version="0.1.0")
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/v1/agent/stream")
async def agent_stream(request: Request):
body = await request.json()
messages = body.get("messages", [])
model = body.get("model", "anthropic/claude-opus-4.6")
system_prompt = body.get("system_prompt")
toolsets = body.get("toolsets")
max_iterations = body.get("max_iterations", 30)
base_url = body.get("base_url") or os.getenv("AGENT_LLM_BASE_URL")
api_key = body.get("api_key") or os.getenv("AGENT_LLM_API_KEY")
tags = body.get("tags")
user_message = ""
conversation_history = []
for msg in messages:
if msg.get("role") == "user":
user_message = msg.get("content", "")
conversation_history.append(msg)
if conversation_history and conversation_history[-1].get("role") == "user":
user_message = conversation_history.pop().get("content", "")
eq: queue.Queue[dict[str, Any]] = queue.Queue(maxsize=512)
def run_agent():
try:
agent = AIAgent(
model=model,
base_url=base_url,
api_key=api_key,
max_iterations=max_iterations,
quiet_mode=True,
enabled_toolsets=toolsets,
event_queue=eq,
ephemeral_system_prompt=system_prompt,
extra_tags=tags,
)
result = agent.run_conversation(
user_message=user_message,
conversation_history=conversation_history or None,
)
if result and result.get("failed"):
eq.put({"type": "error", "error": result.get("error", "Agent failed")})
eq.put({"type": "done"})
except Exception as e:
logger.exception("Agent error")
eq.put({"type": "error", "error": str(e)})
eq.put({"type": "done"})
thread = threading.Thread(target=run_agent, daemon=True)
thread.start()
loop = asyncio.get_event_loop()
async def event_generator():
while True:
try:
event = await loop.run_in_executor(None, lambda: eq.get(timeout=120))
except queue.Empty:
yield "data: {\"type\": \"done\"}\n\n"
break
yield f"data: {json.dumps(event)}\n\n"
if event.get("type") == "done":
break
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
},
)

View File

@@ -0,0 +1,3 @@
---
description: Apple/macOS-specific skills — iMessage, Reminders, Notes, FindMy, and macOS automation. These skills only load on macOS systems.
---

View File

@@ -0,0 +1,88 @@
---
name: apple-notes
description: Manage Apple Notes via the memo CLI on macOS (create, view, search, edit).
version: 1.0.0
author: Hermes Agent
license: MIT
platforms: [macos]
metadata:
hermes:
tags: [Notes, Apple, macOS, note-taking]
related_skills: [obsidian]
---
# Apple Notes
Use `memo` to manage Apple Notes directly from the terminal. Notes sync across all Apple devices via iCloud.
## Prerequisites
- **macOS** with Notes.app
- Install: `brew tap antoniorodr/memo && brew install antoniorodr/memo/memo`
- Grant Automation access to Notes.app when prompted (System Settings → Privacy → Automation)
## When to Use
- User asks to create, view, or search Apple Notes
- Saving information to Notes.app for cross-device access
- Organizing notes into folders
- Exporting notes to Markdown/HTML
## When NOT to Use
- Obsidian vault management → use the `obsidian` skill
- Bear Notes → separate app (not supported here)
- Quick agent-only notes → use the `memory` tool instead
## Quick Reference
### View Notes
```bash
memo notes # List all notes
memo notes -f "Folder Name" # Filter by folder
memo notes -s "query" # Search notes (fuzzy)
```
### Create Notes
```bash
memo notes -a # Interactive editor
memo notes -a "Note Title" # Quick add with title
```
### Edit Notes
```bash
memo notes -e # Interactive selection to edit
```
### Delete Notes
```bash
memo notes -d # Interactive selection to delete
```
### Move Notes
```bash
memo notes -m # Move note to folder (interactive)
```
### Export Notes
```bash
memo notes -ex # Export to HTML/Markdown
```
## Limitations
- Cannot edit notes containing images or attachments
- Interactive prompts require terminal access (use pty=true if needed)
- macOS only — requires Apple Notes.app
## Rules
1. Prefer Apple Notes when user wants cross-device sync (iPhone/iPad/Mac)
2. Use the `memory` tool for agent-internal notes that don't need to sync
3. Use the `obsidian` skill for Markdown-native knowledge management

View File

@@ -0,0 +1,96 @@
---
name: apple-reminders
description: Manage Apple Reminders via remindctl CLI (list, add, complete, delete).
version: 1.0.0
author: Hermes Agent
license: MIT
platforms: [macos]
metadata:
hermes:
tags: [Reminders, tasks, todo, macOS, Apple]
---
# Apple Reminders
Use `remindctl` to manage Apple Reminders directly from the terminal. Tasks sync across all Apple devices via iCloud.
## Prerequisites
- **macOS** with Reminders.app
- Install: `brew install steipete/tap/remindctl`
- Grant Reminders permission when prompted
- Check: `remindctl status` / Request: `remindctl authorize`
## When to Use
- User mentions "reminder" or "Reminders app"
- Creating personal to-dos with due dates that sync to iOS
- Managing Apple Reminders lists
- User wants tasks to appear on their iPhone/iPad
## When NOT to Use
- Scheduling agent alerts → use the cronjob tool instead
- Calendar events → use Apple Calendar or Google Calendar
- Project task management → use GitHub Issues, Notion, etc.
- If user says "remind me" but means an agent alert → clarify first
## Quick Reference
### View Reminders
```bash
remindctl # Today's reminders
remindctl today # Today
remindctl tomorrow # Tomorrow
remindctl week # This week
remindctl overdue # Past due
remindctl all # Everything
remindctl 2026-01-04 # Specific date
```
### Manage Lists
```bash
remindctl list # List all lists
remindctl list Work # Show specific list
remindctl list Projects --create # Create list
remindctl list Work --delete # Delete list
```
### Create Reminders
```bash
remindctl add "Buy milk"
remindctl add --title "Call mom" --list Personal --due tomorrow
remindctl add --title "Meeting prep" --due "2026-02-15 09:00"
```
### Complete / Delete
```bash
remindctl complete 1 2 3 # Complete by ID
remindctl delete 4A83 --force # Delete by ID
```
### Output Formats
```bash
remindctl today --json # JSON for scripting
remindctl today --plain # TSV format
remindctl today --quiet # Counts only
```
## Date Formats
Accepted by `--due` and date filters:
- `today`, `tomorrow`, `yesterday`
- `YYYY-MM-DD`
- `YYYY-MM-DD HH:mm`
- ISO 8601 (`2026-01-04T12:34:56Z`)
## Rules
1. When user says "remind me", clarify: Apple Reminders (syncs to phone) vs agent cronjob alert
2. Always confirm reminder content and due date before creating
3. Use `--json` for programmatic parsing

View File

@@ -0,0 +1,131 @@
---
name: findmy
description: Track Apple devices and AirTags via FindMy.app on macOS using AppleScript and screen capture.
version: 1.0.0
author: Hermes Agent
license: MIT
platforms: [macos]
metadata:
hermes:
tags: [FindMy, AirTag, location, tracking, macOS, Apple]
---
# Find My (Apple)
Track Apple devices and AirTags via the FindMy.app on macOS. Since Apple doesn't
provide a CLI for FindMy, this skill uses AppleScript to open the app and
screen capture to read device locations.
## Prerequisites
- **macOS** with Find My app and iCloud signed in
- Devices/AirTags already registered in Find My
- Screen Recording permission for terminal (System Settings → Privacy → Screen Recording)
- **Optional but recommended**: Install `peekaboo` for better UI automation:
`brew install steipete/tap/peekaboo`
## When to Use
- User asks "where is my [device/cat/keys/bag]?"
- Tracking AirTag locations
- Checking device locations (iPhone, iPad, Mac, AirPods)
- Monitoring pet or item movement over time (AirTag patrol routes)
## Method 1: AppleScript + Screenshot (Basic)
### Open FindMy and Navigate
```bash
# Open Find My app
osascript -e 'tell application "FindMy" to activate'
# Wait for it to load
sleep 3
# Take a screenshot of the Find My window
screencapture -w -o /tmp/findmy.png
```
Then use `vision_analyze` to read the screenshot:
```
vision_analyze(image_url="/tmp/findmy.png", question="What devices/items are shown and what are their locations?")
```
### Switch Between Tabs
```bash
# Switch to Devices tab
osascript -e '
tell application "System Events"
tell process "FindMy"
click button "Devices" of toolbar 1 of window 1
end tell
end tell'
# Switch to Items tab (AirTags)
osascript -e '
tell application "System Events"
tell process "FindMy"
click button "Items" of toolbar 1 of window 1
end tell
end tell'
```
## Method 2: Peekaboo UI Automation (Recommended)
If `peekaboo` is installed, use it for more reliable UI interaction:
```bash
# Open Find My
osascript -e 'tell application "FindMy" to activate'
sleep 3
# Capture and annotate the UI
peekaboo see --app "FindMy" --annotate --path /tmp/findmy-ui.png
# Click on a specific device/item by element ID
peekaboo click --on B3 --app "FindMy"
# Capture the detail view
peekaboo image --app "FindMy" --path /tmp/findmy-detail.png
```
Then analyze with vision:
```
vision_analyze(image_url="/tmp/findmy-detail.png", question="What is the location shown for this device/item? Include address and coordinates if visible.")
```
## Workflow: Track AirTag Location Over Time
For monitoring an AirTag (e.g., tracking a cat's patrol route):
```bash
# 1. Open FindMy to Items tab
osascript -e 'tell application "FindMy" to activate'
sleep 3
# 2. Click on the AirTag item (stay on page — AirTag only updates when page is open)
# 3. Periodically capture location
while true; do
screencapture -w -o /tmp/findmy-$(date +%H%M%S).png
sleep 300 # Every 5 minutes
done
```
Analyze each screenshot with vision to extract coordinates, then compile a route.
## Limitations
- FindMy has **no CLI or API** — must use UI automation
- AirTags only update location while the FindMy page is actively displayed
- Location accuracy depends on nearby Apple devices in the FindMy network
- Screen Recording permission required for screenshots
- AppleScript UI automation may break across macOS versions
## Rules
1. Keep FindMy app in the foreground when tracking AirTags (updates stop when minimized)
2. Use `vision_analyze` to read screenshot content — don't try to parse pixels
3. For ongoing tracking, use a cronjob to periodically capture and log locations
4. Respect privacy — only track devices/items the user owns

View File

@@ -0,0 +1,100 @@
---
name: imessage
description: Send and receive iMessages/SMS via the imsg CLI on macOS.
version: 1.0.0
author: Hermes Agent
license: MIT
platforms: [macos]
metadata:
hermes:
tags: [iMessage, SMS, messaging, macOS, Apple]
---
# iMessage
Use `imsg` to read and send iMessage/SMS via macOS Messages.app.
## Prerequisites
- **macOS** with Messages.app signed in
- Install: `brew install steipete/tap/imsg`
- Grant Full Disk Access for terminal (System Settings → Privacy → Full Disk Access)
- Grant Automation permission for Messages.app when prompted
## When to Use
- User asks to send an iMessage or text message
- Reading iMessage conversation history
- Checking recent Messages.app chats
- Sending to phone numbers or Apple IDs
## When NOT to Use
- Telegram/Discord/Slack/WhatsApp messages → use the appropriate gateway channel
- Group chat management (adding/removing members) → not supported
- Bulk/mass messaging → always confirm with user first
## Quick Reference
### List Chats
```bash
imsg chats --limit 10 --json
```
### View History
```bash
# By chat ID
imsg history --chat-id 1 --limit 20 --json
# With attachments info
imsg history --chat-id 1 --limit 20 --attachments --json
```
### Send Messages
```bash
# Text only
imsg send --to "+14155551212" --text "Hello!"
# With attachment
imsg send --to "+14155551212" --text "Check this out" --file /path/to/image.jpg
# Force iMessage or SMS
imsg send --to "+14155551212" --text "Hi" --service imessage
imsg send --to "+14155551212" --text "Hi" --service sms
```
### Watch for New Messages
```bash
imsg watch --chat-id 1 --attachments
```
## Service Options
- `--service imessage` — Force iMessage (requires recipient has iMessage)
- `--service sms` — Force SMS (green bubble)
- `--service auto` — Let Messages.app decide (default)
## Rules
1. **Always confirm recipient and message content** before sending
2. **Never send to unknown numbers** without explicit user approval
3. **Verify file paths** exist before attaching
4. **Don't spam** — rate-limit yourself
## Example Workflow
User: "Text mom that I'll be late"
```bash
# 1. Find mom's chat
imsg chats --limit 20 --json | jq '.[] | select(.displayName | contains("Mom"))'
# 2. Confirm with user: "Found Mom at +1555123456. Send 'I'll be late' via iMessage?"
# 3. Send after confirmation
imsg send --to "+1555123456" --text "I'll be late"
```

View File

@@ -0,0 +1,335 @@
---
name: huggingface-accelerate
description: Simplest distributed training API. 4 lines to add distributed support to any PyTorch script. Unified API for DeepSpeed/FSDP/Megatron/DDP. Automatic device placement, mixed precision (FP16/BF16/FP8). Interactive config, single launch command. HuggingFace ecosystem standard.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [accelerate, torch, transformers]
metadata:
hermes:
tags: [Distributed Training, HuggingFace, Accelerate, DeepSpeed, FSDP, Mixed Precision, PyTorch, DDP, Unified API, Simple]
---
# HuggingFace Accelerate - Unified Distributed Training
## Quick start
Accelerate simplifies distributed training to 4 lines of code.
**Installation**:
```bash
pip install accelerate
```
**Convert PyTorch script** (4 lines):
```python
import torch
+ from accelerate import Accelerator
+ accelerator = Accelerator()
model = torch.nn.Transformer()
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset)
+ model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
- loss.backward()
+ accelerator.backward(loss)
optimizer.step()
```
**Run** (single command):
```bash
accelerate launch train.py
```
## Common workflows
### Workflow 1: From single GPU to multi-GPU
**Original script**:
```python
# train.py
import torch
model = torch.nn.Linear(10, 2).to('cuda')
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
for epoch in range(10):
for batch in dataloader:
batch = batch.to('cuda')
optimizer.zero_grad()
loss = model(batch).mean()
loss.backward()
optimizer.step()
```
**With Accelerate** (4 lines added):
```python
# train.py
import torch
from accelerate import Accelerator # +1
accelerator = Accelerator() # +2
model = torch.nn.Linear(10, 2)
optimizer = torch.optim.Adam(model.parameters())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) # +3
for epoch in range(10):
for batch in dataloader:
# No .to('cuda') needed - automatic!
optimizer.zero_grad()
loss = model(batch).mean()
accelerator.backward(loss) # +4
optimizer.step()
```
**Configure** (interactive):
```bash
accelerate config
```
**Questions**:
- Which machine? (single/multi GPU/TPU/CPU)
- How many machines? (1)
- Mixed precision? (no/fp16/bf16/fp8)
- DeepSpeed? (no/yes)
**Launch** (works on any setup):
```bash
# Single GPU
accelerate launch train.py
# Multi-GPU (8 GPUs)
accelerate launch --multi_gpu --num_processes 8 train.py
# Multi-node
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 0 \
--main_process_ip $MASTER_ADDR \
train.py
```
### Workflow 2: Mixed precision training
**Enable FP16/BF16**:
```python
from accelerate import Accelerator
# FP16 (with gradient scaling)
accelerator = Accelerator(mixed_precision='fp16')
# BF16 (no scaling, more stable)
accelerator = Accelerator(mixed_precision='bf16')
# FP8 (H100+)
accelerator = Accelerator(mixed_precision='fp8')
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
# Everything else is automatic!
for batch in dataloader:
with accelerator.autocast(): # Optional, done automatically
loss = model(batch)
accelerator.backward(loss)
```
### Workflow 3: DeepSpeed ZeRO integration
**Enable DeepSpeed ZeRO-2**:
```python
from accelerate import Accelerator
accelerator = Accelerator(
mixed_precision='bf16',
deepspeed_plugin={
"zero_stage": 2, # ZeRO-2
"offload_optimizer": False,
"gradient_accumulation_steps": 4
}
)
# Same code as before!
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
```
**Or via config**:
```bash
accelerate config
# Select: DeepSpeed → ZeRO-2
```
**deepspeed_config.json**:
```json
{
"fp16": {"enabled": false},
"bf16": {"enabled": true},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {"device": "cpu"},
"allgather_bucket_size": 5e8,
"reduce_bucket_size": 5e8
}
}
```
**Launch**:
```bash
accelerate launch --config_file deepspeed_config.json train.py
```
### Workflow 4: FSDP (Fully Sharded Data Parallel)
**Enable FSDP**:
```python
from accelerate import Accelerator, FullyShardedDataParallelPlugin
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy="FULL_SHARD", # ZeRO-3 equivalent
auto_wrap_policy="TRANSFORMER_AUTO_WRAP",
cpu_offload=False
)
accelerator = Accelerator(
mixed_precision='bf16',
fsdp_plugin=fsdp_plugin
)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
```
**Or via config**:
```bash
accelerate config
# Select: FSDP → Full Shard → No CPU Offload
```
### Workflow 5: Gradient accumulation
**Accumulate gradients**:
```python
from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=4)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
with accelerator.accumulate(model): # Handles accumulation
optimizer.zero_grad()
loss = model(batch)
accelerator.backward(loss)
optimizer.step()
```
**Effective batch size**: `batch_size * num_gpus * gradient_accumulation_steps`
## When to use vs alternatives
**Use Accelerate when**:
- Want simplest distributed training
- Need single script for any hardware
- Use HuggingFace ecosystem
- Want flexibility (DDP/DeepSpeed/FSDP/Megatron)
- Need quick prototyping
**Key advantages**:
- **4 lines**: Minimal code changes
- **Unified API**: Same code for DDP, DeepSpeed, FSDP, Megatron
- **Automatic**: Device placement, mixed precision, sharding
- **Interactive config**: No manual launcher setup
- **Single launch**: Works everywhere
**Use alternatives instead**:
- **PyTorch Lightning**: Need callbacks, high-level abstractions
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
- **DeepSpeed**: Direct API control, advanced features
- **Raw DDP**: Maximum control, minimal abstraction
## Common issues
**Issue: Wrong device placement**
Don't manually move to device:
```python
# WRONG
batch = batch.to('cuda')
# CORRECT
# Accelerate handles it automatically after prepare()
```
**Issue: Gradient accumulation not working**
Use context manager:
```python
# CORRECT
with accelerator.accumulate(model):
optimizer.zero_grad()
accelerator.backward(loss)
optimizer.step()
```
**Issue: Checkpointing in distributed**
Use accelerator methods:
```python
# Save only on main process
if accelerator.is_main_process:
accelerator.save_state('checkpoint/')
# Load on all processes
accelerator.load_state('checkpoint/')
```
**Issue: Different results with FSDP**
Ensure same random seed:
```python
from accelerate.utils import set_seed
set_seed(42)
```
## Advanced topics
**Megatron integration**: See [references/megatron-integration.md](references/megatron-integration.md) for tensor parallelism, pipeline parallelism, and sequence parallelism setup.
**Custom plugins**: See [references/custom-plugins.md](references/custom-plugins.md) for creating custom distributed plugins and advanced configuration.
**Performance tuning**: See [references/performance.md](references/performance.md) for profiling, memory optimization, and best practices.
## Hardware requirements
- **CPU**: Works (slow)
- **Single GPU**: Works
- **Multi-GPU**: DDP (default), DeepSpeed, or FSDP
- **Multi-node**: DDP, DeepSpeed, FSDP, Megatron
- **TPU**: Supported
- **Apple MPS**: Supported
**Launcher requirements**:
- **DDP**: `torch.distributed.run` (built-in)
- **DeepSpeed**: `deepspeed` (pip install deepspeed)
- **FSDP**: PyTorch 1.12+ (built-in)
- **Megatron**: Custom setup
## Resources
- Docs: https://huggingface.co/docs/accelerate
- GitHub: https://github.com/huggingface/accelerate
- Version: 1.11.0+
- Tutorial: "Accelerate your scripts"
- Examples: https://github.com/huggingface/accelerate/tree/main/examples
- Used by: HuggingFace Transformers, TRL, PEFT, all HF libraries

View File

@@ -0,0 +1,453 @@
# Custom Plugins for Accelerate
## Overview
Accelerate allows creating **custom plugins** to extend distributed training strategies beyond built-in options (DDP, FSDP, DeepSpeed).
## Plugin Architecture
### Base Plugin Structure
```python
from accelerate.utils import DistributedDataParallelKwargs
from dataclasses import dataclass
@dataclass
class CustomPlugin:
"""Custom training plugin."""
# Plugin configuration
param1: int = 1
param2: str = "default"
def __post_init__(self):
# Validation logic
if self.param1 < 1:
raise ValueError("param1 must be >= 1")
```
### Using Custom Plugin
```python
from accelerate import Accelerator
# Create plugin
custom_plugin = CustomPlugin(param1=4, param2="value")
# Pass to Accelerator
accelerator = Accelerator(
custom_plugin=custom_plugin # Not a real parameter, example only
)
```
## Built-In Plugin Examples
### 1. GradScalerKwargs (FP16 Configuration)
```python
from accelerate.utils import GradScalerKwargs
# Configure gradient scaler for FP16
scaler_kwargs = GradScalerKwargs(
init_scale=2.**16, # Initial loss scale
growth_factor=2.0, # Scale growth rate
backoff_factor=0.5, # Scale backoff rate
growth_interval=2000, # Steps between scale increases
enabled=True # Enable scaler
)
accelerator = Accelerator(
mixed_precision='fp16',
kwargs_handlers=[scaler_kwargs] # Pass as kwargs handler
)
```
**Use case**: Fine-tune FP16 gradient scaling behavior
### 2. DistributedDataParallelKwargs
```python
from accelerate.utils import DistributedDataParallelKwargs
# Configure DDP behavior
ddp_kwargs = DistributedDataParallelKwargs(
bucket_cap_mb=25, # Gradient bucketing size
find_unused_parameters=False, # Find unused params (slower)
check_reduction=False, # Check gradient reduction
gradient_as_bucket_view=True, # Memory optimization
static_graph=False # Static computation graph
)
accelerator = Accelerator(
kwargs_handlers=[ddp_kwargs]
)
```
**Use case**: Optimize DDP performance for specific models
### 3. FP8RecipeKwargs (H100 FP8)
```python
from accelerate.utils import FP8RecipeKwargs
# Configure FP8 training (H100)
fp8_recipe = FP8RecipeKwargs(
backend="te", # TransformerEngine backend
margin=0, # Scaling margin
interval=1, # Scaling interval
fp8_format="HYBRID", # E4M3 + E5M2 hybrid
amax_history_len=1024, # AMAX history length
amax_compute_algo="max" # AMAX computation algorithm
)
accelerator = Accelerator(
mixed_precision='fp8',
kwargs_handlers=[fp8_recipe]
)
```
**Use case**: Ultra-fast training on H100 GPUs
## Custom DeepSpeed Configuration
### ZeRO-3 with CPU Offload
```python
from accelerate import Accelerator
from accelerate.utils import DeepSpeedPlugin
# Custom DeepSpeed config
ds_plugin = DeepSpeedPlugin(
zero_stage=3, # ZeRO-3
offload_optimizer_device="cpu", # CPU offload optimizer
offload_param_device="cpu", # CPU offload parameters
zero3_init_flag=True, # ZeRO-3 initialization
zero3_save_16bit_model=True, # Save FP16 weights
)
accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
mixed_precision='bf16'
)
```
### ZeRO-2 with NVMe Offload
```python
ds_plugin = DeepSpeedPlugin(
zero_stage=2,
offload_optimizer_device="nvme", # NVMe offload
offload_param_device="nvme",
nvme_path="/local_nvme", # NVMe mount path
)
```
### Custom JSON Config
```python
import json
# Load custom DeepSpeed config
with open('deepspeed_config.json', 'r') as f:
ds_config = json.load(f)
ds_plugin = DeepSpeedPlugin(hf_ds_config=ds_config)
accelerator = Accelerator(deepspeed_plugin=ds_plugin)
```
**Example config** (`deepspeed_config.json`):
```json
{
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": 1.0,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true
},
"steps_per_print": 100,
"wall_clock_breakdown": false
}
```
## Custom FSDP Configuration
### FSDP with Custom Auto-Wrap Policy
```python
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import functools
# Custom wrap policy (size-based)
wrap_policy = functools.partial(
size_based_auto_wrap_policy,
min_num_params=1e6 # Wrap layers with 1M+ params
)
fsdp_plugin = FullyShardedDataParallelPlugin(
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # Prefetch strategy
mixed_precision_policy=None, # Use Accelerator's mixed precision
auto_wrap_policy=wrap_policy, # Custom wrapping
cpu_offload=False,
ignored_modules=None, # Modules to not wrap
state_dict_type="FULL_STATE_DICT", # Save format
optim_state_dict_config=None,
limit_all_gathers=False,
use_orig_params=True, # Use original param shapes
)
accelerator = Accelerator(
fsdp_plugin=fsdp_plugin,
mixed_precision='bf16'
)
```
### FSDP with Transformer Auto-Wrap
```python
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
# Wrap at transformer block level
wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={GPT2Block} # Wrap GPT2Block layers
)
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=wrap_policy
)
```
## Creating Custom Training Strategy
### Example: Custom Gradient Accumulation
```python
from accelerate import Accelerator
class CustomGradientAccumulation:
def __init__(self, steps=4, adaptive=False):
self.steps = steps
self.adaptive = adaptive
self.current_step = 0
def should_sync(self, loss):
"""Decide whether to sync gradients."""
self.current_step += 1
# Adaptive: sync on high loss
if self.adaptive and loss > threshold:
self.current_step = 0
return True
# Regular: sync every N steps
if self.current_step >= self.steps:
self.current_step = 0
return True
return False
# Usage
custom_accum = CustomGradientAccumulation(steps=8, adaptive=True)
accelerator = Accelerator()
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
# Scale loss
loss = loss / custom_accum.steps
accelerator.backward(loss)
# Conditional sync
if custom_accum.should_sync(loss.item()):
optimizer.step()
optimizer.zero_grad()
```
### Example: Custom Mixed Precision
```python
import torch
class CustomMixedPrecision:
"""Custom mixed precision with dynamic loss scaling."""
def __init__(self, init_scale=2**16, scale_window=2000):
self.scaler = torch.cuda.amp.GradScaler(
init_scale=init_scale,
growth_interval=scale_window
)
self.scale_history = []
def scale_loss(self, loss):
"""Scale loss for backward."""
return self.scaler.scale(loss)
def unscale_and_clip(self, optimizer, max_norm=1.0):
"""Unscale gradients and clip."""
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
optimizer.param_groups[0]['params'],
max_norm
)
def step(self, optimizer):
"""Optimizer step with scaler update."""
scale_before = self.scaler.get_scale()
self.scaler.step(optimizer)
self.scaler.update()
scale_after = self.scaler.get_scale()
# Track scale changes
if scale_before != scale_after:
self.scale_history.append(scale_after)
# Usage
custom_mp = CustomMixedPrecision()
for batch in dataloader:
with torch.cuda.amp.autocast(dtype=torch.float16):
loss = model(**batch).loss
scaled_loss = custom_mp.scale_loss(loss)
scaled_loss.backward()
custom_mp.unscale_and_clip(optimizer, max_norm=1.0)
custom_mp.step(optimizer)
optimizer.zero_grad()
```
## Advanced: Custom Distributed Backend
### Custom AllReduce Strategy
```python
import torch.distributed as dist
class CustomAllReduce:
"""Custom all-reduce with compression."""
def __init__(self, compression_ratio=0.1):
self.compression_ratio = compression_ratio
def compress_gradients(self, tensor):
"""Top-k gradient compression."""
k = int(tensor.numel() * self.compression_ratio)
values, indices = torch.topk(tensor.abs().view(-1), k)
return values, indices
def all_reduce_compressed(self, tensor):
"""All-reduce with gradient compression."""
# Compress
values, indices = self.compress_gradients(tensor)
# All-reduce compressed gradients
dist.all_reduce(values, op=dist.ReduceOp.SUM)
# Decompress
tensor_compressed = torch.zeros_like(tensor).view(-1)
tensor_compressed[indices] = values / dist.get_world_size()
return tensor_compressed.view_as(tensor)
# Usage in training loop
custom_ar = CustomAllReduce(compression_ratio=0.1)
for batch in dataloader:
loss = model(**batch).loss
loss.backward()
# Custom all-reduce
for param in model.parameters():
if param.grad is not None:
param.grad.data = custom_ar.all_reduce_compressed(param.grad.data)
optimizer.step()
optimizer.zero_grad()
```
## Plugin Best Practices
### 1. Validation in `__post_init__`
```python
@dataclass
class CustomPlugin:
learning_rate: float = 1e-3
warmup_steps: int = 1000
def __post_init__(self):
# Validate parameters
if self.learning_rate <= 0:
raise ValueError("learning_rate must be positive")
if self.warmup_steps < 0:
raise ValueError("warmup_steps must be non-negative")
# Compute derived values
self.min_lr = self.learning_rate * 0.1
```
### 2. Compatibility Checks
```python
@dataclass
class CustomPlugin:
feature_enabled: bool = True
def is_compatible(self, accelerator):
"""Check if plugin is compatible with accelerator config."""
if self.feature_enabled and accelerator.mixed_precision == 'fp8':
raise ValueError("Custom plugin not compatible with FP8")
return True
```
### 3. State Management
```python
@dataclass
class CustomPlugin:
counter: int = 0
history: list = None
def __post_init__(self):
if self.history is None:
self.history = []
def update_state(self, value):
"""Update plugin state during training."""
self.counter += 1
self.history.append(value)
```
## Resources
- Accelerate Plugins: https://huggingface.co/docs/accelerate/package_reference/kwargs
- DeepSpeed Config: https://www.deepspeed.ai/docs/config-json/
- FSDP Guide: https://pytorch.org/docs/stable/fsdp.html
- Custom Training Loops: https://huggingface.co/docs/accelerate/usage_guides/training_tpu

View File

@@ -0,0 +1,489 @@
# Megatron Integration with Accelerate
## Overview
Accelerate supports Megatron-LM for massive model training with tensor parallelism and pipeline parallelism.
**Megatron capabilities**:
- **Tensor Parallelism (TP)**: Split layers across GPUs
- **Pipeline Parallelism (PP)**: Split model depth across GPUs
- **Data Parallelism (DP)**: Replicate model across GPU groups
- **Sequence Parallelism**: Split sequences for long contexts
## Setup
### Install Megatron-LM
```bash
# Clone Megatron-LM repository
git clone https://github.com/NVIDIA/Megatron-LM.git
cd Megatron-LM
pip install -e .
# Install Apex (NVIDIA optimizations)
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation \
--config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
```
### Accelerate Configuration
```bash
accelerate config
```
**Questions**:
```
In which compute environment are you running?
> This machine
Which type of machine are you using?
> Multi-GPU
How many different machines will you use?
> 1
Do you want to use DeepSpeed/FSDP?
> No
Do you want to use Megatron-LM?
> Yes
What is the Tensor Parallelism degree? [1-8]
> 2
Do you want to enable Sequence Parallelism?
> No
What is the Pipeline Parallelism degree? [1-8]
> 2
What is the Data Parallelism degree? [1-8]
> 2
Where to perform activation checkpointing? ['SELECTIVE', 'FULL', 'NONE']
> SELECTIVE
Where to perform activation partitioning? ['SEQUENTIAL', 'UNIFORM']
> SEQUENTIAL
```
**Generated config** (`~/.cache/huggingface/accelerate/default_config.yaml`):
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: MEGATRON_LM
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
megatron_lm_config:
megatron_lm_gradient_clipping: 1.0
megatron_lm_learning_rate_decay_iters: 320000
megatron_lm_num_micro_batches: 1
megatron_lm_pp_degree: 2
megatron_lm_recompute_activations: true
megatron_lm_sequence_parallelism: false
megatron_lm_tp_degree: 2
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
## Parallelism Strategies
### Tensor Parallelism (TP)
**Splits each transformer layer across GPUs**:
```python
# Layer split across 2 GPUs
# GPU 0: First half of attention heads
# GPU 1: Second half of attention heads
# Each GPU computes partial outputs
# All-reduce combines results
```
**TP degree recommendations**:
- **TP=1**: No tensor parallelism (single GPU per layer)
- **TP=2**: 2 GPUs per layer (good for 7-13B models)
- **TP=4**: 4 GPUs per layer (good for 20-40B models)
- **TP=8**: 8 GPUs per layer (good for 70B+ models)
**Benefits**:
- Reduces memory per GPU
- All-reduce communication (fast)
**Drawbacks**:
- Requires fast inter-GPU bandwidth (NVLink)
- Communication overhead per layer
### Pipeline Parallelism (PP)
**Splits model depth across GPUs**:
```python
# 12-layer model, PP=4
# GPU 0: Layers 0-2
# GPU 1: Layers 3-5
# GPU 2: Layers 6-8
# GPU 3: Layers 9-11
```
**PP degree recommendations**:
- **PP=1**: No pipeline parallelism
- **PP=2**: 2 pipeline stages (good for 20-40B models)
- **PP=4**: 4 pipeline stages (good for 70B+ models)
- **PP=8**: 8 pipeline stages (good for 175B+ models)
**Benefits**:
- Linear memory reduction (4× PP = 4× less memory)
- Works across nodes (slower interconnect OK)
**Drawbacks**:
- Pipeline bubbles (idle time)
- Requires micro-batching
### Data Parallelism (DP)
**Replicates model across GPU groups**:
```python
# 8 GPUs, TP=2, PP=2, DP=2
# Group 0 (GPUs 0-3): Full model replica
# Group 1 (GPUs 4-7): Full model replica
```
**DP degree**:
- `DP = total_gpus / (TP × PP)`
- Example: 8 GPUs, TP=2, PP=2 → DP=2
**Benefits**:
- Increases throughput
- Scales batch size
### Sequence Parallelism
**Splits long sequences across GPUs** (extends TP):
```python
# 8K sequence, TP=2, Sequence Parallel=True
# GPU 0: Tokens 0-4095
# GPU 1: Tokens 4096-8191
```
**Benefits**:
- Enables very long sequences (100K+ tokens)
- Reduces activation memory
**Requirements**:
- Must use with TP > 1
- RoPE/ALiBi position encodings work best
## Accelerate Code Example
### Basic Setup
```python
from accelerate import Accelerator
from accelerate.utils import MegatronLMPlugin
# Configure Megatron
megatron_plugin = MegatronLMPlugin(
tp_degree=2, # Tensor parallelism degree
pp_degree=2, # Pipeline parallelism degree
num_micro_batches=4, # Micro-batches for pipeline
gradient_clipping=1.0, # Gradient clipping value
sequence_parallelism=False, # Enable sequence parallelism
recompute_activations=True, # Activation checkpointing
use_distributed_optimizer=True, # Distributed optimizer
custom_prepare_model_function=None, # Custom model prep
)
# Initialize accelerator
accelerator = Accelerator(
mixed_precision='bf16',
megatron_lm_plugin=megatron_plugin
)
# Prepare model and optimizer
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
)
# Training loop (same as DDP!)
for batch in train_dataloader:
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
```
### Full Training Script
```python
import torch
from accelerate import Accelerator
from accelerate.utils import MegatronLMPlugin
from transformers import GPT2Config, GPT2LMHeadModel
def main():
# Megatron configuration
megatron_plugin = MegatronLMPlugin(
tp_degree=2,
pp_degree=2,
num_micro_batches=4,
gradient_clipping=1.0,
)
accelerator = Accelerator(
mixed_precision='bf16',
gradient_accumulation_steps=8,
megatron_lm_plugin=megatron_plugin
)
# Model
config = GPT2Config(
n_layer=24,
n_head=16,
n_embd=1024,
)
model = GPT2LMHeadModel(config)
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)
# Prepare
model, optimizer, train_loader = accelerator.prepare(
model, optimizer, train_loader
)
# Training loop
for epoch in range(num_epochs):
for batch in train_loader:
with accelerator.accumulate(model):
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Save checkpoint
accelerator.wait_for_everyone()
accelerator.save_state(f'checkpoint-epoch-{epoch}')
if __name__ == '__main__':
main()
```
### Launch Command
```bash
# 8 GPUs, TP=2, PP=2, DP=2
accelerate launch --multi_gpu --num_processes 8 train.py
# Multi-node (2 nodes, 8 GPUs each)
# Node 0
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 0 \
--main_process_ip $MASTER_ADDR \
--main_process_port 29500 \
train.py
# Node 1
accelerate launch --multi_gpu --num_processes 16 \
--num_machines 2 --machine_rank 1 \
--main_process_ip $MASTER_ADDR \
--main_process_port 29500 \
train.py
```
## Activation Checkpointing
**Reduces memory by recomputing activations**:
```python
megatron_plugin = MegatronLMPlugin(
recompute_activations=True, # Enable checkpointing
checkpoint_num_layers=1, # Checkpoint every N layers
distribute_checkpointed_activations=True, # Distribute across TP
partition_activations=True, # Partition in PP
check_for_nan_in_loss_and_grad=True, # Stability check
)
```
**Strategies**:
- `SELECTIVE`: Checkpoint transformer blocks only
- `FULL`: Checkpoint all layers
- `NONE`: No checkpointing
**Memory savings**: 30-50% with 10-15% slowdown
## Distributed Optimizer
**Shards optimizer state across DP ranks**:
```python
megatron_plugin = MegatronLMPlugin(
use_distributed_optimizer=True, # Enable sharded optimizer
)
```
**Benefits**:
- Reduces optimizer memory by DP degree
- Example: DP=4 → 4× less optimizer memory per GPU
**Compatible with**:
- AdamW, Adam, SGD
- Mixed precision training
## Performance Tuning
### Micro-Batch Size
```python
# Pipeline parallelism requires micro-batching
megatron_plugin = MegatronLMPlugin(
pp_degree=4,
num_micro_batches=16, # 16 micro-batches per pipeline
)
# Effective batch = num_micro_batches × micro_batch_size × DP
# Example: 16 × 2 × 4 = 128
```
**Recommendations**:
- More micro-batches → less pipeline bubble
- Typical: 4-16 micro-batches
### Sequence Length
```python
# For long sequences, enable sequence parallelism
megatron_plugin = MegatronLMPlugin(
tp_degree=4,
sequence_parallelism=True, # Required: TP > 1
)
# Enables sequences up to TP × normal limit
# Example: TP=4, 8K normal → 32K with sequence parallel
```
### GPU Topology
**NVLink required for TP**:
```bash
# Check NVLink topology
nvidia-smi topo -m
# Good topology (NVLink between all GPUs)
# GPU0 - GPU1: NV12 (fast)
# GPU0 - GPU2: NV12 (fast)
# Bad topology (PCIe only)
# GPU0 - GPU4: PHB (slow, avoid TP across these)
```
**Recommendations**:
- **TP**: Within same node (NVLink)
- **PP**: Across nodes (slower interconnect OK)
- **DP**: Any topology
## Model Size Guidelines
| Model Size | GPUs | TP | PP | DP | Micro-Batches |
|------------|------|----|----|----|--------------|
| 7B | 8 | 1 | 1 | 8 | 1 |
| 13B | 8 | 2 | 1 | 4 | 1 |
| 20B | 16 | 4 | 1 | 4 | 1 |
| 40B | 32 | 4 | 2 | 4 | 4 |
| 70B | 64 | 8 | 2 | 4 | 8 |
| 175B | 128 | 8 | 4 | 4 | 16 |
**Assumptions**: BF16, 2K sequence length, A100 80GB
## Checkpointing
### Save Checkpoint
```python
# Save full model state
accelerator.save_state('checkpoint-1000')
# Megatron saves separate files per rank
# checkpoint-1000/
# pytorch_model_tp_0_pp_0.bin
# pytorch_model_tp_0_pp_1.bin
# pytorch_model_tp_1_pp_0.bin
# pytorch_model_tp_1_pp_1.bin
# optimizer_tp_0_pp_0.bin
# ...
```
### Load Checkpoint
```python
# Resume training
accelerator.load_state('checkpoint-1000')
# Automatically loads correct shard per rank
```
### Convert to Standard PyTorch
```bash
# Merge Megatron checkpoint to single file
python merge_megatron_checkpoint.py \
--checkpoint-dir checkpoint-1000 \
--output pytorch_model.bin
```
## Common Issues
### Issue: OOM with Pipeline Parallelism
**Solution**: Increase micro-batches
```python
megatron_plugin = MegatronLMPlugin(
pp_degree=4,
num_micro_batches=16, # Increase from 4
)
```
### Issue: Slow Training
**Check 1**: Pipeline bubbles (PP too high)
```python
# Reduce PP, increase TP
tp_degree=4 # Increase
pp_degree=2 # Decrease
```
**Check 2**: Micro-batch size too small
```python
num_micro_batches=8 # Increase
```
### Issue: NVLink Not Detected
```bash
# Verify NVLink
nvidia-smi nvlink -s
# If no NVLink, avoid TP > 1
# Use PP or DP instead
```
## Resources
- Megatron-LM: https://github.com/NVIDIA/Megatron-LM
- Accelerate Megatron docs: https://huggingface.co/docs/accelerate/usage_guides/megatron_lm
- Paper: "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism"
- NVIDIA Apex: https://github.com/NVIDIA/apex

View File

@@ -0,0 +1,525 @@
# Accelerate Performance Tuning
## Profiling
### Basic Profiling
```python
from accelerate import Accelerator
import time
accelerator = Accelerator()
# Warmup
for _ in range(10):
batch = next(iter(dataloader))
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Profile training loop
start = time.time()
total_batches = 100
for i, batch in enumerate(dataloader):
if i >= total_batches:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
accelerator.wait_for_everyone() # Sync all processes
elapsed = time.time() - start
# Metrics
batches_per_sec = total_batches / elapsed
samples_per_sec = (total_batches * batch_size * accelerator.num_processes) / elapsed
print(f"Throughput: {samples_per_sec:.2f} samples/sec")
print(f"Batches/sec: {batches_per_sec:.2f}")
```
### PyTorch Profiler Integration
```python
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True
) as prof:
for i, batch in enumerate(dataloader):
if i >= 10: # Profile first 10 batches
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Print profiling results
print(prof.key_averages().table(
sort_by="cuda_time_total", row_limit=20
))
# Export to Chrome tracing
prof.export_chrome_trace("trace.json")
# View at chrome://tracing
```
## Memory Optimization
### 1. Gradient Accumulation
**Problem**: Large batch size causes OOM
**Solution**: Accumulate gradients across micro-batches
```python
accelerator = Accelerator(gradient_accumulation_steps=8)
# Effective batch = batch_size × accumulation_steps × num_gpus
# Example: 4 × 8 × 8 = 256
for batch in dataloader:
with accelerator.accumulate(model): # Handles accumulation logic
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
```
**Memory savings**: 8× less activation memory (with 8 accumulation steps)
### 2. Gradient Checkpointing
**Enable in model**:
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
use_cache=False # Required for gradient checkpointing
)
# Enable checkpointing
model.gradient_checkpointing_enable()
# Prepare with Accelerate
model = accelerator.prepare(model)
```
**Memory savings**: 30-50% with 10-15% slowdown
### 3. Mixed Precision
**BF16 (A100/H100)**:
```python
accelerator = Accelerator(mixed_precision='bf16')
# Automatic mixed precision
for batch in dataloader:
outputs = model(**batch) # Forward in BF16
loss = outputs.loss
accelerator.backward(loss) # Backward in FP32
optimizer.step()
```
**FP16 (V100, older GPUs)**:
```python
from accelerate.utils import GradScalerKwargs
scaler_kwargs = GradScalerKwargs(
init_scale=2.**16,
growth_interval=2000
)
accelerator = Accelerator(
mixed_precision='fp16',
kwargs_handlers=[scaler_kwargs]
)
```
**Memory savings**: 50% compared to FP32
### 4. CPU Offloading (DeepSpeed)
```python
from accelerate.utils import DeepSpeedPlugin
ds_plugin = DeepSpeedPlugin(
zero_stage=3,
offload_optimizer_device="cpu", # Offload optimizer to CPU
offload_param_device="cpu", # Offload parameters to CPU
)
accelerator = Accelerator(
deepspeed_plugin=ds_plugin,
mixed_precision='bf16'
)
```
**Memory savings**: 10-20× for optimizer state, 5-10× for parameters
**Trade-off**: 20-30% slower due to CPU-GPU transfers
### 5. Flash Attention
```python
# Install flash-attn
# pip install flash-attn
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
attn_implementation="flash_attention_2" # Enable Flash Attention 2
)
model = accelerator.prepare(model)
```
**Memory savings**: 50% for attention, 2× faster
**Requirements**: A100/H100, sequence length must be multiple of 128
## Communication Optimization
### 1. Gradient Bucketing (DDP)
```python
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(
bucket_cap_mb=25, # Bucket size for gradient reduction
gradient_as_bucket_view=True, # Reduce memory copies
static_graph=False # Set True if model doesn't change
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
```
**Recommended bucket sizes**:
- Small models (<1B): 25 MB
- Medium models (1-10B): 50-100 MB
- Large models (>10B): 100-200 MB
### 2. Find Unused Parameters
```python
# Only enable if model has unused parameters (slower!)
ddp_kwargs = DistributedDataParallelKwargs(
find_unused_parameters=True
)
```
**Use case**: Models with conditional branches (e.g., mixture of experts)
**Cost**: 10-20% slower
### 3. NCCL Tuning
```bash
# Set environment variables before launch
export NCCL_DEBUG=INFO # Debug info
export NCCL_IB_DISABLE=0 # Enable InfiniBand
export NCCL_SOCKET_IFNAME=eth0 # Network interface
export NCCL_P2P_LEVEL=NVL # Use NVLink
accelerate launch train.py
```
**NCCL_P2P_LEVEL options**:
- `NVL`: NVLink (fastest, within node)
- `PIX`: PCIe (fast, within node)
- `PHB`: PCIe host bridge (slow, cross-node)
## Data Loading Optimization
### 1. DataLoader Workers
```python
from torch.utils.data import DataLoader
train_loader = DataLoader(
dataset,
batch_size=32,
num_workers=4, # Parallel data loading
pin_memory=True, # Pin memory for faster GPU transfer
prefetch_factor=2, # Prefetch batches per worker
persistent_workers=True # Keep workers alive between epochs
)
train_loader = accelerator.prepare(train_loader)
```
**Recommendations**:
- `num_workers`: 2-4 per GPU (8 GPUs → 16-32 workers)
- `pin_memory`: Always True for GPU training
- `prefetch_factor`: 2-4 (higher for slow data loading)
### 2. Data Preprocessing
```python
from datasets import load_dataset
# Bad: Preprocess during training (slow)
dataset = load_dataset("openwebtext")
for batch in dataset:
tokens = tokenizer(batch['text']) # Slow!
...
# Good: Preprocess once, save
dataset = load_dataset("openwebtext")
tokenized = dataset.map(
lambda x: tokenizer(x['text']),
batched=True,
num_proc=8, # Parallel preprocessing
remove_columns=['text']
)
tokenized.save_to_disk("preprocessed_data")
# Load preprocessed
dataset = load_from_disk("preprocessed_data")
```
### 3. Faster Tokenization
```python
import os
# Enable Rust-based tokenizers (10× faster)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"gpt2",
use_fast=True # Use fast Rust tokenizer
)
```
## Compilation (PyTorch 2.0+)
### Compile Model
```python
import torch
# Compile model for faster execution
model = torch.compile(
model,
mode="reduce-overhead", # Options: default, reduce-overhead, max-autotune
fullgraph=False, # Compile entire graph (stricter)
dynamic=True # Support dynamic shapes
)
model = accelerator.prepare(model)
```
**Speedup**: 10-50% depending on model
**Compilation modes**:
- `default`: Balanced (best for most cases)
- `reduce-overhead`: Min overhead (best for small batches)
- `max-autotune`: Max performance (slow compile, best for production)
### Compilation Best Practices
```python
# Bad: Compile after prepare (won't work)
model = accelerator.prepare(model)
model = torch.compile(model) # Error!
# Good: Compile before prepare
model = torch.compile(model)
model = accelerator.prepare(model)
# Training loop
for batch in dataloader:
# First iteration: slow (compilation)
# Subsequent iterations: fast (compiled)
outputs = model(**batch)
...
```
## Benchmarking Different Strategies
### Script Template
```python
import time
import torch
from accelerate import Accelerator
def benchmark_strategy(strategy_name, accelerator_kwargs):
"""Benchmark a specific training strategy."""
accelerator = Accelerator(**accelerator_kwargs)
# Setup
model = create_model()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
dataloader = create_dataloader()
model, optimizer, dataloader = accelerator.prepare(
model, optimizer, dataloader
)
# Warmup
for i, batch in enumerate(dataloader):
if i >= 10:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
# Benchmark
accelerator.wait_for_everyone()
torch.cuda.synchronize()
start = time.time()
num_batches = 100
for i, batch in enumerate(dataloader):
if i >= num_batches:
break
outputs = model(**batch)
loss = outputs.loss
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
accelerator.wait_for_everyone()
torch.cuda.synchronize()
elapsed = time.time() - start
# Metrics
throughput = (num_batches * batch_size * accelerator.num_processes) / elapsed
memory_used = torch.cuda.max_memory_allocated() / 1e9 # GB
if accelerator.is_main_process:
print(f"\n{strategy_name}:")
print(f" Throughput: {throughput:.2f} samples/sec")
print(f" Memory: {memory_used:.2f} GB")
print(f" Time: {elapsed:.2f} sec")
torch.cuda.reset_peak_memory_stats()
# Benchmark different strategies
strategies = [
("DDP + FP32", {}),
("DDP + BF16", {"mixed_precision": "bf16"}),
("DDP + BF16 + GradAccum", {"mixed_precision": "bf16", "gradient_accumulation_steps": 4}),
("FSDP", {"fsdp_plugin": fsdp_plugin}),
("DeepSpeed ZeRO-2", {"deepspeed_plugin": ds_plugin_stage2}),
("DeepSpeed ZeRO-3", {"deepspeed_plugin": ds_plugin_stage3}),
]
for name, kwargs in strategies:
benchmark_strategy(name, kwargs)
```
## Performance Checklist
**Before training**:
- [ ] Use BF16/FP16 mixed precision
- [ ] Enable gradient checkpointing (if OOM)
- [ ] Set appropriate `num_workers` (2-4 per GPU)
- [ ] Enable `pin_memory=True`
- [ ] Preprocess data once, not during training
- [ ] Compile model with `torch.compile` (PyTorch 2.0+)
**For large models**:
- [ ] Use FSDP or DeepSpeed ZeRO-3
- [ ] Enable CPU offloading (if still OOM)
- [ ] Use Flash Attention
- [ ] Increase gradient accumulation
**For multi-node**:
- [ ] Check network topology (InfiniBand > Ethernet)
- [ ] Tune NCCL settings
- [ ] Use larger bucket sizes for DDP
- [ ] Verify NVLink for tensor parallelism
**Profiling**:
- [ ] Profile first 10-100 batches
- [ ] Check GPU utilization (`nvidia-smi dmon`)
- [ ] Check data loading time (should be <5% of iteration)
- [ ] Identify communication bottlenecks
## Common Performance Issues
### Issue: Low GPU Utilization (<80%)
**Cause 1**: Data loading bottleneck
```python
# Solution: Increase workers and prefetch
num_workers=8
prefetch_factor=4
```
**Cause 2**: Small batch size
```python
# Solution: Increase batch size or use gradient accumulation
batch_size=32 # Increase
gradient_accumulation_steps=4 # Or accumulate
```
### Issue: High Memory Usage
**Solution 1**: Gradient checkpointing
```python
model.gradient_checkpointing_enable()
```
**Solution 2**: Reduce batch size, increase accumulation
```python
batch_size=8 # Reduce from 32
gradient_accumulation_steps=16 # Maintain effective batch
```
**Solution 3**: Use FSDP or DeepSpeed ZeRO-3
```python
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
```
### Issue: Slow Multi-GPU Training
**Cause**: Communication bottleneck
**Check 1**: Gradient bucket size
```python
ddp_kwargs = DistributedDataParallelKwargs(bucket_cap_mb=100)
```
**Check 2**: NCCL settings
```bash
export NCCL_DEBUG=INFO
# Check for "Using NVLS" (good) vs "Using PHB" (bad)
```
**Check 3**: Network bandwidth
```bash
# Test inter-GPU bandwidth
nvidia-smi nvlink -s
```
## Resources
- Accelerate Performance: https://huggingface.co/docs/accelerate/usage_guides/performance
- PyTorch Profiler: https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
- NCCL Tuning: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html
- Flash Attention: https://github.com/Dao-AILab/flash-attention

View File

@@ -0,0 +1,567 @@
---
name: audiocraft-audio-generation
description: PyTorch library for audio generation including text-to-music (MusicGen) and text-to-sound (AudioGen). Use when you need to generate music from text descriptions, create sound effects, or perform melody-conditioned music generation.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [audiocraft, torch>=2.0.0, transformers>=4.30.0]
metadata:
hermes:
tags: [Multimodal, Audio Generation, Text-to-Music, Text-to-Audio, MusicGen]
---
# AudioCraft: Audio Generation
Comprehensive guide to using Meta's AudioCraft for text-to-music and text-to-audio generation with MusicGen, AudioGen, and EnCodec.
## When to use AudioCraft
**Use AudioCraft when:**
- Need to generate music from text descriptions
- Creating sound effects and environmental audio
- Building music generation applications
- Need melody-conditioned music generation
- Want stereo audio output
- Require controllable music generation with style transfer
**Key features:**
- **MusicGen**: Text-to-music generation with melody conditioning
- **AudioGen**: Text-to-sound effects generation
- **EnCodec**: High-fidelity neural audio codec
- **Multiple model sizes**: Small (300M) to Large (3.3B)
- **Stereo support**: Full stereo audio generation
- **Style conditioning**: MusicGen-Style for reference-based generation
**Use alternatives instead:**
- **Stable Audio**: For longer commercial music generation
- **Bark**: For text-to-speech with music/sound effects
- **Riffusion**: For spectogram-based music generation
- **OpenAI Jukebox**: For raw audio generation with lyrics
## Quick start
### Installation
```bash
# From PyPI
pip install audiocraft
# From GitHub (latest)
pip install git+https://github.com/facebookresearch/audiocraft.git
# Or use HuggingFace Transformers
pip install transformers torch torchaudio
```
### Basic text-to-music (AudioCraft)
```python
import torchaudio
from audiocraft.models import MusicGen
# Load model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Set generation parameters
model.set_generation_params(
duration=8, # seconds
top_k=250,
temperature=1.0
)
# Generate from text
descriptions = ["happy upbeat electronic dance music with synths"]
wav = model.generate(descriptions)
# Save audio
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=32000)
```
### Using HuggingFace Transformers
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import scipy
# Load model and processor
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
model.to("cuda")
# Generate music
inputs = processor(
text=["80s pop track with bassy drums and synth"],
padding=True,
return_tensors="pt"
).to("cuda")
audio_values = model.generate(
**inputs,
do_sample=True,
guidance_scale=3,
max_new_tokens=256
)
# Save
sampling_rate = model.config.audio_encoder.sampling_rate
scipy.io.wavfile.write("output.wav", rate=sampling_rate, data=audio_values[0, 0].cpu().numpy())
```
### Text-to-sound with AudioGen
```python
from audiocraft.models import AudioGen
# Load AudioGen
model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=5)
# Generate sound effects
descriptions = ["dog barking in a park with birds chirping"]
wav = model.generate(descriptions)
torchaudio.save("sound.wav", wav[0].cpu(), sample_rate=16000)
```
## Core concepts
### Architecture overview
```
AudioCraft Architecture:
┌──────────────────────────────────────────────────────────────┐
│ Text Encoder (T5) │
│ │ │
│ Text Embeddings │
└────────────────────────┬─────────────────────────────────────┘
┌────────────────────────▼─────────────────────────────────────┐
│ Transformer Decoder (LM) │
│ Auto-regressively generates audio tokens │
│ Using efficient token interleaving patterns │
└────────────────────────┬─────────────────────────────────────┘
┌────────────────────────▼─────────────────────────────────────┐
│ EnCodec Audio Decoder │
│ Converts tokens back to audio waveform │
└──────────────────────────────────────────────────────────────┘
```
### Model variants
| Model | Size | Description | Use Case |
|-------|------|-------------|----------|
| `musicgen-small` | 300M | Text-to-music | Quick generation |
| `musicgen-medium` | 1.5B | Text-to-music | Balanced |
| `musicgen-large` | 3.3B | Text-to-music | Best quality |
| `musicgen-melody` | 1.5B | Text + melody | Melody conditioning |
| `musicgen-melody-large` | 3.3B | Text + melody | Best melody |
| `musicgen-stereo-*` | Varies | Stereo output | Stereo generation |
| `musicgen-style` | 1.5B | Style transfer | Reference-based |
| `audiogen-medium` | 1.5B | Text-to-sound | Sound effects |
### Generation parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| `duration` | 8.0 | Length in seconds (1-120) |
| `top_k` | 250 | Top-k sampling |
| `top_p` | 0.0 | Nucleus sampling (0 = disabled) |
| `temperature` | 1.0 | Sampling temperature |
| `cfg_coef` | 3.0 | Classifier-free guidance |
## MusicGen usage
### Text-to-music generation
```python
from audiocraft.models import MusicGen
import torchaudio
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# Configure generation
model.set_generation_params(
duration=30, # Up to 30 seconds
top_k=250, # Sampling diversity
top_p=0.0, # 0 = use top_k only
temperature=1.0, # Creativity (higher = more varied)
cfg_coef=3.0 # Text adherence (higher = stricter)
)
# Generate multiple samples
descriptions = [
"epic orchestral soundtrack with strings and brass",
"chill lo-fi hip hop beat with jazzy piano",
"energetic rock song with electric guitar"
]
# Generate (returns [batch, channels, samples])
wav = model.generate(descriptions)
# Save each
for i, audio in enumerate(wav):
torchaudio.save(f"music_{i}.wav", audio.cpu(), sample_rate=32000)
```
### Melody-conditioned generation
```python
from audiocraft.models import MusicGen
import torchaudio
# Load melody model
model = MusicGen.get_pretrained('facebook/musicgen-melody')
model.set_generation_params(duration=30)
# Load melody audio
melody, sr = torchaudio.load("melody.wav")
# Generate with melody conditioning
descriptions = ["acoustic guitar folk song"]
wav = model.generate_with_chroma(descriptions, melody, sr)
torchaudio.save("melody_conditioned.wav", wav[0].cpu(), sample_rate=32000)
```
### Stereo generation
```python
from audiocraft.models import MusicGen
# Load stereo model
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
model.set_generation_params(duration=15)
descriptions = ["ambient electronic music with wide stereo panning"]
wav = model.generate(descriptions)
# wav shape: [batch, 2, samples] for stereo
print(f"Stereo shape: {wav.shape}") # [1, 2, 480000]
torchaudio.save("stereo.wav", wav[0].cpu(), sample_rate=32000)
```
### Audio continuation
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
processor = AutoProcessor.from_pretrained("facebook/musicgen-medium")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-medium")
# Load audio to continue
import torchaudio
audio, sr = torchaudio.load("intro.wav")
# Process with text and audio
inputs = processor(
audio=audio.squeeze().numpy(),
sampling_rate=sr,
text=["continue with a epic chorus"],
padding=True,
return_tensors="pt"
)
# Generate continuation
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=512)
```
## MusicGen-Style usage
### Style-conditioned generation
```python
from audiocraft.models import MusicGen
# Load style model
model = MusicGen.get_pretrained('facebook/musicgen-style')
# Configure generation with style
model.set_generation_params(
duration=30,
cfg_coef=3.0,
cfg_coef_beta=5.0 # Style influence
)
# Configure style conditioner
model.set_style_conditioner_params(
eval_q=3, # RVQ quantizers (1-6)
excerpt_length=3.0 # Style excerpt length
)
# Load style reference
style_audio, sr = torchaudio.load("reference_style.wav")
# Generate with text + style
descriptions = ["upbeat dance track"]
wav = model.generate_with_style(descriptions, style_audio, sr)
```
### Style-only generation (no text)
```python
# Generate matching style without text prompt
model.set_generation_params(
duration=30,
cfg_coef=3.0,
cfg_coef_beta=None # Disable double CFG for style-only
)
wav = model.generate_with_style([None], style_audio, sr)
```
## AudioGen usage
### Sound effect generation
```python
from audiocraft.models import AudioGen
import torchaudio
model = AudioGen.get_pretrained('facebook/audiogen-medium')
model.set_generation_params(duration=10)
# Generate various sounds
descriptions = [
"thunderstorm with heavy rain and lightning",
"busy city traffic with car horns",
"ocean waves crashing on rocks",
"crackling campfire in forest"
]
wav = model.generate(descriptions)
for i, audio in enumerate(wav):
torchaudio.save(f"sound_{i}.wav", audio.cpu(), sample_rate=16000)
```
## EnCodec usage
### Audio compression
```python
from audiocraft.models import CompressionModel
import torch
import torchaudio
# Load EnCodec
model = CompressionModel.get_pretrained('facebook/encodec_32khz')
# Load audio
wav, sr = torchaudio.load("audio.wav")
# Ensure correct sample rate
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
wav = resampler(wav)
# Encode to tokens
with torch.no_grad():
encoded = model.encode(wav.unsqueeze(0))
codes = encoded[0] # Audio codes
# Decode back to audio
with torch.no_grad():
decoded = model.decode(codes)
torchaudio.save("reconstructed.wav", decoded[0].cpu(), sample_rate=32000)
```
## Common workflows
### Workflow 1: Music generation pipeline
```python
import torch
import torchaudio
from audiocraft.models import MusicGen
class MusicGenerator:
def __init__(self, model_name="facebook/musicgen-medium"):
self.model = MusicGen.get_pretrained(model_name)
self.sample_rate = 32000
def generate(self, prompt, duration=30, temperature=1.0, cfg=3.0):
self.model.set_generation_params(
duration=duration,
top_k=250,
temperature=temperature,
cfg_coef=cfg
)
with torch.no_grad():
wav = self.model.generate([prompt])
return wav[0].cpu()
def generate_batch(self, prompts, duration=30):
self.model.set_generation_params(duration=duration)
with torch.no_grad():
wav = self.model.generate(prompts)
return wav.cpu()
def save(self, audio, path):
torchaudio.save(path, audio, sample_rate=self.sample_rate)
# Usage
generator = MusicGenerator()
audio = generator.generate(
"epic cinematic orchestral music",
duration=30,
temperature=1.0
)
generator.save(audio, "epic_music.wav")
```
### Workflow 2: Sound design batch processing
```python
import json
from pathlib import Path
from audiocraft.models import AudioGen
import torchaudio
def batch_generate_sounds(sound_specs, output_dir):
"""
Generate multiple sounds from specifications.
Args:
sound_specs: list of {"name": str, "description": str, "duration": float}
output_dir: output directory path
"""
model = AudioGen.get_pretrained('facebook/audiogen-medium')
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)
results = []
for spec in sound_specs:
model.set_generation_params(duration=spec.get("duration", 5))
wav = model.generate([spec["description"]])
output_path = output_dir / f"{spec['name']}.wav"
torchaudio.save(str(output_path), wav[0].cpu(), sample_rate=16000)
results.append({
"name": spec["name"],
"path": str(output_path),
"description": spec["description"]
})
return results
# Usage
sounds = [
{"name": "explosion", "description": "massive explosion with debris", "duration": 3},
{"name": "footsteps", "description": "footsteps on wooden floor", "duration": 5},
{"name": "door", "description": "wooden door creaking and closing", "duration": 2}
]
results = batch_generate_sounds(sounds, "sound_effects/")
```
### Workflow 3: Gradio demo
```python
import gradio as gr
import torch
import torchaudio
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
def generate_music(prompt, duration, temperature, cfg_coef):
model.set_generation_params(
duration=duration,
temperature=temperature,
cfg_coef=cfg_coef
)
with torch.no_grad():
wav = model.generate([prompt])
# Save to temp file
path = "temp_output.wav"
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
return path
demo = gr.Interface(
fn=generate_music,
inputs=[
gr.Textbox(label="Music Description", placeholder="upbeat electronic dance music"),
gr.Slider(1, 30, value=8, label="Duration (seconds)"),
gr.Slider(0.5, 2.0, value=1.0, label="Temperature"),
gr.Slider(1.0, 10.0, value=3.0, label="CFG Coefficient")
],
outputs=gr.Audio(label="Generated Music"),
title="MusicGen Demo"
)
demo.launch()
```
## Performance optimization
### Memory optimization
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Clear cache between generations
torch.cuda.empty_cache()
# Generate shorter durations
model.set_generation_params(duration=10) # Instead of 30
# Use half precision
model = model.half()
```
### Batch processing efficiency
```python
# Process multiple prompts at once (more efficient)
descriptions = ["prompt1", "prompt2", "prompt3", "prompt4"]
wav = model.generate(descriptions) # Single batch
# Instead of
for desc in descriptions:
wav = model.generate([desc]) # Multiple batches (slower)
```
### GPU memory requirements
| Model | FP32 VRAM | FP16 VRAM |
|-------|-----------|-----------|
| musicgen-small | ~4GB | ~2GB |
| musicgen-medium | ~8GB | ~4GB |
| musicgen-large | ~16GB | ~8GB |
## Common issues
| Issue | Solution |
|-------|----------|
| CUDA OOM | Use smaller model, reduce duration |
| Poor quality | Increase cfg_coef, better prompts |
| Generation too short | Check max duration setting |
| Audio artifacts | Try different temperature |
| Stereo not working | Use stereo model variant |
## References
- **[Advanced Usage](references/advanced-usage.md)** - Training, fine-tuning, deployment
- **[Troubleshooting](references/troubleshooting.md)** - Common issues and solutions
## Resources
- **GitHub**: https://github.com/facebookresearch/audiocraft
- **Paper (MusicGen)**: https://arxiv.org/abs/2306.05284
- **Paper (AudioGen)**: https://arxiv.org/abs/2209.15352
- **HuggingFace**: https://huggingface.co/facebook/musicgen-small
- **Demo**: https://huggingface.co/spaces/facebook/MusicGen

View File

@@ -0,0 +1,666 @@
# AudioCraft Advanced Usage Guide
## Fine-tuning MusicGen
### Custom dataset preparation
```python
import os
import json
from pathlib import Path
import torchaudio
def prepare_dataset(audio_dir, output_dir, metadata_file):
"""
Prepare dataset for MusicGen fine-tuning.
Directory structure:
output_dir/
├── audio/
│ ├── 0001.wav
│ ├── 0002.wav
│ └── ...
└── metadata.json
"""
output_dir = Path(output_dir)
audio_output = output_dir / "audio"
audio_output.mkdir(parents=True, exist_ok=True)
# Load metadata (format: {"path": "...", "description": "..."})
with open(metadata_file) as f:
metadata = json.load(f)
processed = []
for idx, item in enumerate(metadata):
audio_path = Path(audio_dir) / item["path"]
# Load and resample to 32kHz
wav, sr = torchaudio.load(str(audio_path))
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
wav = resampler(wav)
# Convert to mono if stereo
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
# Save processed audio
output_path = audio_output / f"{idx:04d}.wav"
torchaudio.save(str(output_path), wav, sample_rate=32000)
processed.append({
"path": str(output_path.relative_to(output_dir)),
"description": item["description"],
"duration": wav.shape[1] / 32000
})
# Save processed metadata
with open(output_dir / "metadata.json", "w") as f:
json.dump(processed, f, indent=2)
print(f"Processed {len(processed)} samples")
return processed
```
### Fine-tuning with dora
```bash
# AudioCraft uses dora for experiment management
# Install dora
pip install dora-search
# Clone AudioCraft
git clone https://github.com/facebookresearch/audiocraft.git
cd audiocraft
# Create config for fine-tuning
cat > config/solver/musicgen/finetune.yaml << 'EOF'
defaults:
- musicgen/musicgen_base
- /model: lm/musicgen_lm
- /conditioner: cond_base
solver: musicgen
autocast: true
autocast_dtype: float16
optim:
epochs: 100
batch_size: 4
lr: 1e-4
ema: 0.999
optimizer: adamw
dataset:
batch_size: 4
num_workers: 4
train:
- dset: your_dataset
root: /path/to/dataset
valid:
- dset: your_dataset
root: /path/to/dataset
checkpoint:
save_every: 10
keep_every_states: null
EOF
# Run fine-tuning
dora run solver=musicgen/finetune
```
### LoRA fine-tuning
```python
from peft import LoraConfig, get_peft_model
from audiocraft.models import MusicGen
import torch
# Load base model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Get the language model component
lm = model.lm
# Configure LoRA
lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "v_proj", "k_proj", "out_proj"],
lora_dropout=0.05,
bias="none"
)
# Apply LoRA
lm = get_peft_model(lm, lora_config)
lm.print_trainable_parameters()
```
## Multi-GPU Training
### DataParallel
```python
import torch
import torch.nn as nn
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Wrap LM with DataParallel
if torch.cuda.device_count() > 1:
model.lm = nn.DataParallel(model.lm)
model.to("cuda")
```
### DistributedDataParallel
```python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train(rank, world_size):
setup(rank, world_size)
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.lm = model.lm.to(rank)
model.lm = DDP(model.lm, device_ids=[rank])
# Training loop
# ...
dist.destroy_process_group()
```
## Custom Conditioning
### Adding new conditioners
```python
from audiocraft.modules.conditioners import BaseConditioner
import torch
class CustomConditioner(BaseConditioner):
"""Custom conditioner for additional control signals."""
def __init__(self, dim, output_dim):
super().__init__(dim, output_dim)
self.embed = torch.nn.Linear(dim, output_dim)
def forward(self, x):
return self.embed(x)
def tokenize(self, x):
# Tokenize input for conditioning
return x
# Use with MusicGen
from audiocraft.models.builders import get_lm_model
# Modify model config to include custom conditioner
# This requires editing the model configuration
```
### Melody conditioning internals
```python
from audiocraft.models import MusicGen
from audiocraft.modules.codebooks_patterns import DelayedPatternProvider
import torch
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# Access chroma extractor
chroma_extractor = model.lm.condition_provider.conditioners.get('chroma')
# Manual chroma extraction
def extract_chroma(audio, sr):
"""Extract chroma features from audio."""
import librosa
# Compute chroma
chroma = librosa.feature.chroma_cqt(y=audio.numpy(), sr=sr)
return torch.from_numpy(chroma).float()
# Use extracted chroma for conditioning
chroma = extract_chroma(melody_audio, sample_rate)
```
## EnCodec Deep Dive
### Custom compression settings
```python
from audiocraft.models import CompressionModel
import torch
# Load EnCodec
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
# Access codec parameters
print(f"Sample rate: {encodec.sample_rate}")
print(f"Channels: {encodec.channels}")
print(f"Cardinality: {encodec.cardinality}") # Codebook size
print(f"Num codebooks: {encodec.num_codebooks}")
print(f"Frame rate: {encodec.frame_rate}")
# Encode with specific bandwidth
# Lower bandwidth = more compression, lower quality
encodec.set_target_bandwidth(6.0) # 6 kbps
audio = torch.randn(1, 1, 32000) # 1 second
encoded = encodec.encode(audio)
decoded = encodec.decode(encoded[0])
```
### Streaming encoding
```python
import torch
from audiocraft.models import CompressionModel
encodec = CompressionModel.get_pretrained('facebook/encodec_32khz')
def encode_streaming(audio_stream, chunk_size=32000):
"""Encode audio in streaming fashion."""
all_codes = []
for chunk in audio_stream:
# Ensure chunk is right shape
if chunk.dim() == 1:
chunk = chunk.unsqueeze(0).unsqueeze(0)
with torch.no_grad():
codes = encodec.encode(chunk)[0]
all_codes.append(codes)
return torch.cat(all_codes, dim=-1)
def decode_streaming(codes_stream, output_stream):
"""Decode codes in streaming fashion."""
for codes in codes_stream:
with torch.no_grad():
audio = encodec.decode(codes)
output_stream.write(audio.cpu().numpy())
```
## MultiBand Diffusion
### Using MBD for enhanced quality
```python
from audiocraft.models import MusicGen, MultiBandDiffusion
# Load MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# Load MultiBand Diffusion
mbd = MultiBandDiffusion.get_mbd_musicgen()
model.set_generation_params(duration=10)
# Generate with standard decoder
descriptions = ["epic orchestral music"]
wav_standard = model.generate(descriptions)
# Generate tokens and use MBD decoder
with torch.no_grad():
# Get tokens
gen_tokens = model.generate_tokens(descriptions)
# Decode with MBD
wav_mbd = mbd.tokens_to_wav(gen_tokens)
# Compare quality
print(f"Standard shape: {wav_standard.shape}")
print(f"MBD shape: {wav_mbd.shape}")
```
## API Server Deployment
### FastAPI server
```python
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import torchaudio
from audiocraft.models import MusicGen
import io
import base64
app = FastAPI()
# Load model at startup
model = None
@app.on_event("startup")
async def load_model():
global model
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.set_generation_params(duration=10)
class GenerateRequest(BaseModel):
prompt: str
duration: float = 10.0
temperature: float = 1.0
cfg_coef: float = 3.0
class GenerateResponse(BaseModel):
audio_base64: str
sample_rate: int
duration: float
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
model.set_generation_params(
duration=min(request.duration, 30),
temperature=request.temperature,
cfg_coef=request.cfg_coef
)
with torch.no_grad():
wav = model.generate([request.prompt])
# Convert to bytes
buffer = io.BytesIO()
torchaudio.save(buffer, wav[0].cpu(), sample_rate=32000, format="wav")
buffer.seek(0)
audio_base64 = base64.b64encode(buffer.read()).decode()
return GenerateResponse(
audio_base64=audio_base64,
sample_rate=32000,
duration=wav.shape[-1] / 32000
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
return {"status": "ok", "model_loaded": model is not None}
# Run: uvicorn server:app --host 0.0.0.0 --port 8000
```
### Batch processing service
```python
import asyncio
from concurrent.futures import ThreadPoolExecutor
import torch
from audiocraft.models import MusicGen
class MusicGenService:
def __init__(self, model_name='facebook/musicgen-small', max_workers=2):
self.model = MusicGen.get_pretrained(model_name)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.lock = asyncio.Lock()
async def generate_async(self, prompt, duration=10):
"""Async generation with thread pool."""
loop = asyncio.get_event_loop()
def _generate():
with torch.no_grad():
self.model.set_generation_params(duration=duration)
return self.model.generate([prompt])
# Run in thread pool
wav = await loop.run_in_executor(self.executor, _generate)
return wav[0].cpu()
async def generate_batch_async(self, prompts, duration=10):
"""Process multiple prompts concurrently."""
tasks = [self.generate_async(p, duration) for p in prompts]
return await asyncio.gather(*tasks)
# Usage
service = MusicGenService()
async def main():
prompts = ["jazz piano", "rock guitar", "electronic beats"]
results = await service.generate_batch_async(prompts)
return results
```
## Integration Patterns
### LangChain tool
```python
from langchain.tools import BaseTool
import torch
import torchaudio
from audiocraft.models import MusicGen
import tempfile
class MusicGeneratorTool(BaseTool):
name = "music_generator"
description = "Generate music from a text description. Input should be a detailed description of the music style, mood, and instruments."
def __init__(self):
super().__init__()
self.model = MusicGen.get_pretrained('facebook/musicgen-small')
self.model.set_generation_params(duration=15)
def _run(self, description: str) -> str:
with torch.no_grad():
wav = self.model.generate([description])
# Save to temp file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
torchaudio.save(f.name, wav[0].cpu(), sample_rate=32000)
return f"Generated music saved to: {f.name}"
async def _arun(self, description: str) -> str:
return self._run(description)
```
### Gradio with advanced controls
```python
import gradio as gr
import torch
import torchaudio
from audiocraft.models import MusicGen
models = {}
def load_model(model_size):
if model_size not in models:
model_name = f"facebook/musicgen-{model_size}"
models[model_size] = MusicGen.get_pretrained(model_name)
return models[model_size]
def generate(prompt, duration, temperature, cfg_coef, top_k, model_size):
model = load_model(model_size)
model.set_generation_params(
duration=duration,
temperature=temperature,
cfg_coef=cfg_coef,
top_k=top_k
)
with torch.no_grad():
wav = model.generate([prompt])
# Save
path = "output.wav"
torchaudio.save(path, wav[0].cpu(), sample_rate=32000)
return path
demo = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Prompt", lines=3),
gr.Slider(1, 30, value=10, label="Duration (s)"),
gr.Slider(0.1, 2.0, value=1.0, label="Temperature"),
gr.Slider(0.5, 10.0, value=3.0, label="CFG Coefficient"),
gr.Slider(50, 500, value=250, step=50, label="Top-K"),
gr.Dropdown(["small", "medium", "large"], value="small", label="Model Size")
],
outputs=gr.Audio(label="Generated Music"),
title="MusicGen Advanced",
allow_flagging="never"
)
demo.launch(share=True)
```
## Audio Processing Pipeline
### Post-processing chain
```python
import torch
import torchaudio
import torchaudio.transforms as T
import numpy as np
class AudioPostProcessor:
def __init__(self, sample_rate=32000):
self.sample_rate = sample_rate
def normalize(self, audio, target_db=-14.0):
"""Normalize audio to target loudness."""
rms = torch.sqrt(torch.mean(audio ** 2))
target_rms = 10 ** (target_db / 20)
gain = target_rms / (rms + 1e-8)
return audio * gain
def fade_in_out(self, audio, fade_duration=0.1):
"""Apply fade in/out."""
fade_samples = int(fade_duration * self.sample_rate)
# Create fade curves
fade_in = torch.linspace(0, 1, fade_samples)
fade_out = torch.linspace(1, 0, fade_samples)
# Apply fades
audio[..., :fade_samples] *= fade_in
audio[..., -fade_samples:] *= fade_out
return audio
def apply_reverb(self, audio, decay=0.5):
"""Apply simple reverb effect."""
impulse = torch.zeros(int(self.sample_rate * 0.5))
impulse[0] = 1.0
impulse[int(self.sample_rate * 0.1)] = decay * 0.5
impulse[int(self.sample_rate * 0.2)] = decay * 0.25
# Convolve
audio = torch.nn.functional.conv1d(
audio.unsqueeze(0),
impulse.unsqueeze(0).unsqueeze(0),
padding=len(impulse) // 2
).squeeze(0)
return audio
def process(self, audio):
"""Full processing pipeline."""
audio = self.normalize(audio)
audio = self.fade_in_out(audio)
return audio
# Usage with MusicGen
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.set_generation_params(duration=10)
wav = model.generate(["chill ambient music"])
processor = AudioPostProcessor()
wav_processed = processor.process(wav[0].cpu())
torchaudio.save("processed.wav", wav_processed, sample_rate=32000)
```
## Evaluation
### Audio quality metrics
```python
import torch
from audiocraft.metrics import CLAPTextConsistencyMetric
from audiocraft.data.audio import audio_read
def evaluate_generation(audio_path, text_prompt):
"""Evaluate generated audio quality."""
# Load audio
wav, sr = audio_read(audio_path)
# CLAP consistency (text-audio alignment)
clap_metric = CLAPTextConsistencyMetric()
clap_score = clap_metric.compute(wav, [text_prompt])
return {
"clap_score": clap_score,
"duration": wav.shape[-1] / sr
}
# Batch evaluation
def evaluate_batch(generations):
"""Evaluate multiple generations."""
results = []
for gen in generations:
result = evaluate_generation(gen["path"], gen["prompt"])
result["prompt"] = gen["prompt"]
results.append(result)
# Aggregate
avg_clap = sum(r["clap_score"] for r in results) / len(results)
return {
"individual": results,
"average_clap": avg_clap
}
```
## Model Comparison
### MusicGen variants benchmark
| Model | CLAP Score | Generation Time (10s) | VRAM |
|-------|------------|----------------------|------|
| musicgen-small | 0.35 | ~5s | 2GB |
| musicgen-medium | 0.42 | ~15s | 4GB |
| musicgen-large | 0.48 | ~30s | 8GB |
| musicgen-melody | 0.45 | ~15s | 4GB |
| musicgen-stereo-medium | 0.41 | ~18s | 5GB |
### Prompt engineering tips
```python
# Good prompts - specific and descriptive
good_prompts = [
"upbeat electronic dance music with synthesizer leads and punchy drums at 128 bpm",
"melancholic piano ballad with strings, slow tempo, emotional and cinematic",
"funky disco groove with slap bass, brass section, and rhythmic guitar"
]
# Bad prompts - too vague
bad_prompts = [
"nice music",
"song",
"good beat"
]
# Structure: [mood] [genre] with [instruments] at [tempo/style]
```

View File

@@ -0,0 +1,504 @@
# AudioCraft Troubleshooting Guide
## Installation Issues
### Import errors
**Error**: `ModuleNotFoundError: No module named 'audiocraft'`
**Solutions**:
```bash
# Install from PyPI
pip install audiocraft
# Or from GitHub
pip install git+https://github.com/facebookresearch/audiocraft.git
# Verify installation
python -c "from audiocraft.models import MusicGen; print('OK')"
```
### FFmpeg not found
**Error**: `RuntimeError: ffmpeg not found`
**Solutions**:
```bash
# Ubuntu/Debian
sudo apt-get install ffmpeg
# macOS
brew install ffmpeg
# Windows (using conda)
conda install -c conda-forge ffmpeg
# Verify
ffmpeg -version
```
### PyTorch CUDA mismatch
**Error**: `RuntimeError: CUDA error: no kernel image is available`
**Solutions**:
```bash
# Check CUDA version
nvcc --version
python -c "import torch; print(torch.version.cuda)"
# Install matching PyTorch
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
# For CUDA 11.8
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu118
```
### xformers issues
**Error**: `ImportError: xformers` related errors
**Solutions**:
```bash
# Install xformers for memory efficiency
pip install xformers
# Or disable xformers
export AUDIOCRAFT_USE_XFORMERS=0
# In Python
import os
os.environ["AUDIOCRAFT_USE_XFORMERS"] = "0"
from audiocraft.models import MusicGen
```
## Model Loading Issues
### Out of memory during load
**Error**: `torch.cuda.OutOfMemoryError` during model loading
**Solutions**:
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Force CPU loading first
import torch
device = "cpu"
model = MusicGen.get_pretrained('facebook/musicgen-small', device=device)
model = model.to("cuda")
# Use HuggingFace with device_map
from transformers import MusicgenForConditionalGeneration
model = MusicgenForConditionalGeneration.from_pretrained(
"facebook/musicgen-small",
device_map="auto"
)
```
### Download failures
**Error**: Connection errors or incomplete downloads
**Solutions**:
```python
# Set cache directory
import os
os.environ["AUDIOCRAFT_CACHE_DIR"] = "/path/to/cache"
# Or for HuggingFace
os.environ["HF_HOME"] = "/path/to/hf_cache"
# Resume download
from huggingface_hub import snapshot_download
snapshot_download("facebook/musicgen-small", resume_download=True)
# Use local files
model = MusicGen.get_pretrained('/local/path/to/model')
```
### Wrong model type
**Error**: Loading wrong model for task
**Solutions**:
```python
# For text-to-music: use MusicGen
from audiocraft.models import MusicGen
model = MusicGen.get_pretrained('facebook/musicgen-medium')
# For text-to-sound: use AudioGen
from audiocraft.models import AudioGen
model = AudioGen.get_pretrained('facebook/audiogen-medium')
# For melody conditioning: use melody variant
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# For stereo: use stereo variant
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
```
## Generation Issues
### Empty or silent output
**Problem**: Generated audio is silent or very quiet
**Solutions**:
```python
import torch
# Check output
wav = model.generate(["upbeat music"])
print(f"Shape: {wav.shape}")
print(f"Max amplitude: {wav.abs().max().item()}")
print(f"Mean amplitude: {wav.abs().mean().item()}")
# If too quiet, normalize
def normalize_audio(audio, target_db=-14.0):
rms = torch.sqrt(torch.mean(audio ** 2))
target_rms = 10 ** (target_db / 20)
gain = target_rms / (rms + 1e-8)
return audio * gain
wav_normalized = normalize_audio(wav)
```
### Poor quality output
**Problem**: Generated music sounds bad or noisy
**Solutions**:
```python
# Use larger model
model = MusicGen.get_pretrained('facebook/musicgen-large')
# Adjust generation parameters
model.set_generation_params(
duration=15,
top_k=250, # Increase for more diversity
temperature=0.8, # Lower for more focused output
cfg_coef=4.0 # Increase for better text adherence
)
# Use better prompts
# Bad: "music"
# Good: "upbeat electronic dance music with synthesizers and punchy drums"
# Try MultiBand Diffusion
from audiocraft.models import MultiBandDiffusion
mbd = MultiBandDiffusion.get_mbd_musicgen()
tokens = model.generate_tokens(["prompt"])
wav = mbd.tokens_to_wav(tokens)
```
### Generation too short
**Problem**: Audio shorter than expected
**Solutions**:
```python
# Check duration setting
model.set_generation_params(duration=30) # Set before generate
# Verify in generation
print(f"Duration setting: {model.generation_params}")
# Check output shape
wav = model.generate(["prompt"])
actual_duration = wav.shape[-1] / 32000
print(f"Actual duration: {actual_duration}s")
# Note: max duration is typically 30s
```
### Melody conditioning fails
**Error**: Issues with melody-conditioned generation
**Solutions**:
```python
import torchaudio
from audiocraft.models import MusicGen
# Load melody model (not base model)
model = MusicGen.get_pretrained('facebook/musicgen-melody')
# Load and prepare melody
melody, sr = torchaudio.load("melody.wav")
# Resample to model sample rate if needed
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
melody = resampler(melody)
# Ensure correct shape [batch, channels, samples]
if melody.dim() == 1:
melody = melody.unsqueeze(0).unsqueeze(0)
elif melody.dim() == 2:
melody = melody.unsqueeze(0)
# Convert stereo to mono
if melody.shape[1] > 1:
melody = melody.mean(dim=1, keepdim=True)
# Generate with melody
model.set_generation_params(duration=min(melody.shape[-1] / 32000, 30))
wav = model.generate_with_chroma(["piano cover"], melody, 32000)
```
## Memory Issues
### CUDA out of memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
```python
import torch
# Clear cache before generation
torch.cuda.empty_cache()
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Reduce duration
model.set_generation_params(duration=10) # Instead of 30
# Generate one at a time
for prompt in prompts:
wav = model.generate([prompt])
save_audio(wav)
torch.cuda.empty_cache()
# Use CPU for very large generations
model = MusicGen.get_pretrained('facebook/musicgen-small', device="cpu")
```
### Memory leak during batch processing
**Problem**: Memory grows over time
**Solutions**:
```python
import gc
import torch
def generate_with_cleanup(model, prompts):
results = []
for prompt in prompts:
with torch.no_grad():
wav = model.generate([prompt])
results.append(wav.cpu())
# Cleanup
del wav
gc.collect()
torch.cuda.empty_cache()
return results
# Use context manager
with torch.inference_mode():
wav = model.generate(["prompt"])
```
## Audio Format Issues
### Wrong sample rate
**Problem**: Audio plays at wrong speed
**Solutions**:
```python
import torchaudio
# MusicGen outputs at 32kHz
sample_rate = 32000
# AudioGen outputs at 16kHz
sample_rate = 16000
# Always use correct rate when saving
torchaudio.save("output.wav", wav[0].cpu(), sample_rate=sample_rate)
# Resample if needed
resampler = torchaudio.transforms.Resample(32000, 44100)
wav_resampled = resampler(wav)
```
### Stereo/mono mismatch
**Problem**: Wrong number of channels
**Solutions**:
```python
# Check model type
print(f"Audio channels: {wav.shape}")
# Mono: [batch, 1, samples]
# Stereo: [batch, 2, samples]
# Convert mono to stereo
if wav.shape[1] == 1:
wav_stereo = wav.repeat(1, 2, 1)
# Convert stereo to mono
if wav.shape[1] == 2:
wav_mono = wav.mean(dim=1, keepdim=True)
# Use stereo model for stereo output
model = MusicGen.get_pretrained('facebook/musicgen-stereo-medium')
```
### Clipping and distortion
**Problem**: Audio has clipping or distortion
**Solutions**:
```python
import torch
# Check for clipping
max_val = wav.abs().max().item()
print(f"Max amplitude: {max_val}")
# Normalize to prevent clipping
if max_val > 1.0:
wav = wav / max_val
# Apply soft clipping
def soft_clip(x, threshold=0.9):
return torch.tanh(x / threshold) * threshold
wav_clipped = soft_clip(wav)
# Lower temperature during generation
model.set_generation_params(temperature=0.7) # More controlled
```
## HuggingFace Transformers Issues
### Processor errors
**Error**: Issues with MusicgenProcessor
**Solutions**:
```python
from transformers import AutoProcessor, MusicgenForConditionalGeneration
# Load matching processor and model
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
# Ensure inputs are on same device
inputs = processor(
text=["prompt"],
padding=True,
return_tensors="pt"
).to("cuda")
# Check processor configuration
print(processor.tokenizer)
print(processor.feature_extractor)
```
### Generation parameter errors
**Error**: Invalid generation parameters
**Solutions**:
```python
# HuggingFace uses different parameter names
audio_values = model.generate(
**inputs,
do_sample=True, # Enable sampling
guidance_scale=3.0, # CFG (not cfg_coef)
max_new_tokens=256, # Token limit (not duration)
temperature=1.0
)
# Calculate tokens from duration
# ~50 tokens per second
duration_seconds = 10
max_tokens = duration_seconds * 50
audio_values = model.generate(**inputs, max_new_tokens=max_tokens)
```
## Performance Issues
### Slow generation
**Problem**: Generation takes too long
**Solutions**:
```python
# Use smaller model
model = MusicGen.get_pretrained('facebook/musicgen-small')
# Reduce duration
model.set_generation_params(duration=10)
# Use GPU
model.to("cuda")
# Enable flash attention if available
# (requires compatible hardware)
# Batch multiple prompts
prompts = ["prompt1", "prompt2", "prompt3"]
wav = model.generate(prompts) # Single batch is faster than loop
# Use compile (PyTorch 2.0+)
model.lm = torch.compile(model.lm)
```
### CPU fallback
**Problem**: Generation running on CPU instead of GPU
**Solutions**:
```python
import torch
# Check CUDA availability
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
# Explicitly move to GPU
model = MusicGen.get_pretrained('facebook/musicgen-small')
model.to("cuda")
# Verify model device
print(f"Model device: {next(model.lm.parameters()).device}")
```
## Common Error Messages
| Error | Cause | Solution |
|-------|-------|----------|
| `CUDA out of memory` | Model too large | Use smaller model, reduce duration |
| `ffmpeg not found` | FFmpeg not installed | Install FFmpeg |
| `No module named 'audiocraft'` | Not installed | `pip install audiocraft` |
| `RuntimeError: Expected 3D tensor` | Wrong input shape | Check tensor dimensions |
| `KeyError: 'melody'` | Wrong model for melody | Use musicgen-melody |
| `Sample rate mismatch` | Wrong audio format | Resample to model rate |
## Getting Help
1. **GitHub Issues**: https://github.com/facebookresearch/audiocraft/issues
2. **HuggingFace Forums**: https://discuss.huggingface.co
3. **Paper**: https://arxiv.org/abs/2306.05284
### Reporting Issues
Include:
- Python version
- PyTorch version
- CUDA version
- AudioCraft version: `pip show audiocraft`
- Full error traceback
- Minimal reproducible code
- Hardware (GPU model, VRAM)

View File

@@ -0,0 +1,81 @@
---
name: code-review
description: Guidelines for performing thorough code reviews with security and quality focus
---
# Code Review Skill
Use this skill when reviewing code changes, pull requests, or auditing existing code.
## Review Checklist
### 1. Security First
- [ ] No hardcoded secrets, API keys, or credentials
- [ ] Input validation on all user-provided data
- [ ] SQL queries use parameterized statements (no string concatenation)
- [ ] File operations validate paths (no path traversal)
- [ ] Authentication/authorization checks present where needed
### 2. Error Handling
- [ ] All external calls (API, DB, file) have try/catch
- [ ] Errors are logged with context (but no sensitive data)
- [ ] User-facing errors are helpful but don't leak internals
- [ ] Resources are cleaned up in finally blocks or context managers
### 3. Code Quality
- [ ] Functions do one thing and are reasonably sized (<50 lines ideal)
- [ ] Variable names are descriptive (no single letters except loops)
- [ ] No commented-out code left behind
- [ ] Complex logic has explanatory comments
- [ ] No duplicate code (DRY principle)
### 4. Testing Considerations
- [ ] Edge cases handled (empty inputs, nulls, boundaries)
- [ ] Happy path and error paths both work
- [ ] New code has corresponding tests (if test suite exists)
## Review Response Format
When providing review feedback, structure it as:
```
## Summary
[1-2 sentence overall assessment]
## Critical Issues (Must Fix)
- Issue 1: [description + suggested fix]
- Issue 2: ...
## Suggestions (Nice to Have)
- Suggestion 1: [description]
## Questions
- [Any clarifying questions about intent]
```
## Common Patterns to Flag
### Python
```python
# Bad: SQL injection risk
cursor.execute(f"SELECT * FROM users WHERE id = {user_id}")
# Good: Parameterized query
cursor.execute("SELECT * FROM users WHERE id = ?", (user_id,))
```
### JavaScript
```javascript
// Bad: XSS risk
element.innerHTML = userInput;
// Good: Safe text content
element.textContent = userInput;
```
## Tone Guidelines
- Be constructive, not critical
- Explain *why* something is an issue, not just *what*
- Offer solutions, not just problems
- Acknowledge good patterns you see

224
skills/mlops/faiss/SKILL.md Normal file
View File

@@ -0,0 +1,224 @@
---
name: faiss
description: Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or when you need pure similarity search without metadata. Best for high-performance applications.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [faiss-cpu, faiss-gpu, numpy]
metadata:
hermes:
tags: [RAG, FAISS, Similarity Search, Vector Search, Facebook AI, GPU Acceleration, Billion-Scale, K-NN, HNSW, High Performance, Large Scale]
---
# FAISS - Efficient Similarity Search
Facebook AI's library for billion-scale vector similarity search.
## When to use FAISS
**Use FAISS when:**
- Need fast similarity search on large vector datasets (millions/billions)
- GPU acceleration required
- Pure vector similarity (no metadata filtering needed)
- High throughput, low latency critical
- Offline/batch processing of embeddings
**Metrics**:
- **31,700+ GitHub stars**
- Meta/Facebook AI Research
- **Handles billions of vectors**
- **C++** with Python bindings
**Use alternatives instead**:
- **Chroma/Pinecone**: Need metadata filtering
- **Weaviate**: Need full database features
- **Annoy**: Simpler, fewer features
## Quick start
### Installation
```bash
# CPU only
pip install faiss-cpu
# GPU support
pip install faiss-gpu
```
### Basic usage
```python
import faiss
import numpy as np
# Create sample data (1000 vectors, 128 dimensions)
d = 128
nb = 1000
vectors = np.random.random((nb, d)).astype('float32')
# Create index
index = faiss.IndexFlatL2(d) # L2 distance
index.add(vectors) # Add vectors
# Search
k = 5 # Find 5 nearest neighbors
query = np.random.random((1, d)).astype('float32')
distances, indices = index.search(query, k)
print(f"Nearest neighbors: {indices}")
print(f"Distances: {distances}")
```
## Index types
### 1. Flat (exact search)
```python
# L2 (Euclidean) distance
index = faiss.IndexFlatL2(d)
# Inner product (cosine similarity if normalized)
index = faiss.IndexFlatIP(d)
# Slowest, most accurate
```
### 2. IVF (inverted file) - Fast approximate
```python
# Create quantizer
quantizer = faiss.IndexFlatL2(d)
# IVF index with 100 clusters
nlist = 100
index = faiss.IndexIVFFlat(quantizer, d, nlist)
# Train on data
index.train(vectors)
# Add vectors
index.add(vectors)
# Search (nprobe = clusters to search)
index.nprobe = 10
distances, indices = index.search(query, k)
```
### 3. HNSW (Hierarchical NSW) - Best quality/speed
```python
# HNSW index
M = 32 # Number of connections per layer
index = faiss.IndexHNSWFlat(d, M)
# No training needed
index.add(vectors)
# Search
distances, indices = index.search(query, k)
```
### 4. Product Quantization - Memory efficient
```python
# PQ reduces memory by 16-32×
m = 8 # Number of subquantizers
nbits = 8
index = faiss.IndexPQ(d, m, nbits)
# Train and add
index.train(vectors)
index.add(vectors)
```
## Save and load
```python
# Save index
faiss.write_index(index, "large.index")
# Load index
index = faiss.read_index("large.index")
# Continue using
distances, indices = index.search(query, k)
```
## GPU acceleration
```python
# Single GPU
res = faiss.StandardGpuResources()
index_cpu = faiss.IndexFlatL2(d)
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
# Multi-GPU
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
# 10-100× faster than CPU
```
## LangChain integration
```python
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
# Create FAISS vector store
vectorstore = FAISS.from_documents(docs, OpenAIEmbeddings())
# Save
vectorstore.save_local("faiss_index")
# Load
vectorstore = FAISS.load_local(
"faiss_index",
OpenAIEmbeddings(),
allow_dangerous_deserialization=True
)
# Search
results = vectorstore.similarity_search("query", k=5)
```
## LlamaIndex integration
```python
from llama_index.vector_stores.faiss import FaissVectorStore
import faiss
# Create FAISS index
d = 1536
faiss_index = faiss.IndexFlatL2(d)
vector_store = FaissVectorStore(faiss_index=faiss_index)
```
## Best practices
1. **Choose right index type** - Flat for <10K, IVF for 10K-1M, HNSW for quality
2. **Normalize for cosine** - Use IndexFlatIP with normalized vectors
3. **Use GPU for large datasets** - 10-100× faster
4. **Save trained indices** - Training is expensive
5. **Tune nprobe/ef_search** - Balance speed/accuracy
6. **Monitor memory** - PQ for large datasets
7. **Batch queries** - Better GPU utilization
## Performance
| Index Type | Build Time | Search Time | Memory | Accuracy |
|------------|------------|-------------|--------|----------|
| Flat | Fast | Slow | High | 100% |
| IVF | Medium | Fast | Medium | 95-99% |
| HNSW | Slow | Fastest | High | 99% |
| PQ | Medium | Fast | Low | 90-95% |
## Resources
- **GitHub**: https://github.com/facebookresearch/faiss ⭐ 31,700+
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
- **License**: MIT

View File

@@ -0,0 +1,280 @@
# FAISS Index Types Guide
Complete guide to choosing and using FAISS index types.
## Index selection guide
| Dataset Size | Index Type | Training | Accuracy | Speed |
|--------------|------------|----------|----------|-------|
| < 10K | Flat | No | 100% | Slow |
| 10K-1M | IVF | Yes | 95-99% | Fast |
| 1M-10M | HNSW | No | 99% | Fastest |
| > 10M | IVF+PQ | Yes | 90-95% | Fast, low memory |
## Flat indices (exact search)
### IndexFlatL2 - L2 (Euclidean) distance
```python
import faiss
import numpy as np
d = 128 # Dimension
index = faiss.IndexFlatL2(d)
# Add vectors
vectors = np.random.random((1000, d)).astype('float32')
index.add(vectors)
# Search
k = 5
query = np.random.random((1, d)).astype('float32')
distances, indices = index.search(query, k)
```
**Use when:**
- Dataset < 10,000 vectors
- Need 100% accuracy
- Serving as baseline
### IndexFlatIP - Inner product (cosine similarity)
```python
# For cosine similarity, normalize vectors first
import faiss
d = 128
index = faiss.IndexFlatIP(d)
# Normalize vectors (required for cosine similarity)
faiss.normalize_L2(vectors)
index.add(vectors)
# Search
faiss.normalize_L2(query)
distances, indices = index.search(query, k)
```
**Use when:**
- Need cosine similarity
- Recommendation systems
- Text embeddings
## IVF indices (inverted file)
### IndexIVFFlat - Cluster-based search
```python
# Create quantizer
quantizer = faiss.IndexFlatL2(d)
# Create IVF index with 100 clusters
nlist = 100 # Number of clusters
index = faiss.IndexIVFFlat(quantizer, d, nlist)
# Train on data (required!)
index.train(vectors)
# Add vectors
index.add(vectors)
# Search (nprobe = clusters to search)
index.nprobe = 10 # Search 10 closest clusters
distances, indices = index.search(query, k)
```
**Parameters:**
- `nlist`: Number of clusters (√N to 4√N recommended)
- `nprobe`: Clusters to search (1-nlist, higher = more accurate)
**Use when:**
- Dataset 10K-1M vectors
- Need fast approximate search
- Can afford training time
### Tuning nprobe
```python
# Test different nprobe values
for nprobe in [1, 5, 10, 20, 50]:
index.nprobe = nprobe
distances, indices = index.search(query, k)
# Measure recall/speed trade-off
```
**Guidelines:**
- `nprobe=1`: Fastest, ~50% recall
- `nprobe=10`: Good balance, ~95% recall
- `nprobe=nlist`: Exact search (same as Flat)
## HNSW indices (graph-based)
### IndexHNSWFlat - Hierarchical NSW
```python
# HNSW index
M = 32 # Number of connections per layer (16-64)
index = faiss.IndexHNSWFlat(d, M)
# Optional: Set ef_construction (build time parameter)
index.hnsw.efConstruction = 40 # Higher = better quality, slower build
# Add vectors (no training needed!)
index.add(vectors)
# Search
index.hnsw.efSearch = 16 # Search time parameter
distances, indices = index.search(query, k)
```
**Parameters:**
- `M`: Connections per layer (16-64, default 32)
- `efConstruction`: Build quality (40-200, higher = better)
- `efSearch`: Search quality (16-512, higher = more accurate)
**Use when:**
- Need best quality approximate search
- Can afford higher memory (more connections)
- Dataset 1M-10M vectors
## PQ indices (product quantization)
### IndexPQ - Memory-efficient
```python
# PQ reduces memory by 16-32×
m = 8 # Number of subquantizers (divides d)
nbits = 8 # Bits per subquantizer
index = faiss.IndexPQ(d, m, nbits)
# Train (required!)
index.train(vectors)
# Add vectors
index.add(vectors)
# Search
distances, indices = index.search(query, k)
```
**Parameters:**
- `m`: Subquantizers (d must be divisible by m)
- `nbits`: Bits per code (8 or 16)
**Memory savings:**
- Original: d × 4 bytes (float32)
- PQ: m bytes
- Compression ratio: 4d/m
**Use when:**
- Limited memory
- Large datasets (> 10M vectors)
- Can accept ~90-95% accuracy
### IndexIVFPQ - IVF + PQ combined
```python
# Best for very large datasets
nlist = 4096
m = 8
nbits = 8
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits)
# Train
index.train(vectors)
index.add(vectors)
# Search
index.nprobe = 32
distances, indices = index.search(query, k)
```
**Use when:**
- Dataset > 10M vectors
- Need fast search + low memory
- Can accept 90-95% accuracy
## GPU indices
### Single GPU
```python
import faiss
# Create CPU index
index_cpu = faiss.IndexFlatL2(d)
# Move to GPU
res = faiss.StandardGpuResources() # GPU resources
index_gpu = faiss.index_cpu_to_gpu(res, 0, index_cpu) # GPU 0
# Use normally
index_gpu.add(vectors)
distances, indices = index_gpu.search(query, k)
```
### Multi-GPU
```python
# Use all available GPUs
index_gpu = faiss.index_cpu_to_all_gpus(index_cpu)
# Or specific GPUs
gpus = [0, 1, 2, 3] # Use GPUs 0-3
index_gpu = faiss.index_cpu_to_gpus_list(index_cpu, gpus)
```
**Speedup:**
- Single GPU: 10-50× faster than CPU
- Multi-GPU: Near-linear scaling
## Index factory
```python
# Easy index creation with string descriptors
index = faiss.index_factory(d, "IVF100,Flat")
index = faiss.index_factory(d, "HNSW32")
index = faiss.index_factory(d, "IVF4096,PQ8")
# Train and use
index.train(vectors)
index.add(vectors)
```
**Common descriptors:**
- `"Flat"`: Exact search
- `"IVF100,Flat"`: IVF with 100 clusters
- `"HNSW32"`: HNSW with M=32
- `"IVF4096,PQ8"`: IVF + PQ compression
## Performance comparison
### Search speed (1M vectors, k=10)
| Index | Build Time | Search Time | Memory | Recall |
|-------|------------|-------------|--------|--------|
| Flat | 0s | 50ms | 512 MB | 100% |
| IVF100 | 5s | 2ms | 512 MB | 95% |
| HNSW32 | 60s | 1ms | 1GB | 99% |
| IVF4096+PQ8 | 30s | 3ms | 32 MB | 90% |
*CPU (16 cores), 128-dim vectors*
## Best practices
1. **Start with Flat** - Baseline for comparison
2. **Use IVF for medium datasets** - Good balance
3. **Use HNSW for best quality** - If memory allows
4. **Add PQ for memory savings** - Large datasets
5. **GPU for > 100K vectors** - 10-50× speedup
6. **Tune nprobe/efSearch** - Trade-off speed/accuracy
7. **Train on representative data** - Better clustering
8. **Save trained indices** - Avoid retraining
## Resources
- **Wiki**: https://github.com/facebookresearch/faiss/wiki
- **Paper**: https://arxiv.org/abs/1702.08734

View File

@@ -0,0 +1,370 @@
---
name: optimizing-attention-flash
description: Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster inference. Supports PyTorch native SDPA, flash-attn library, H100 FP8, and sliding window attention.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [flash-attn, torch, transformers]
metadata:
hermes:
tags: [Optimization, Flash Attention, Attention Optimization, Memory Efficiency, Speed Optimization, Long Context, PyTorch, SDPA, H100, FP8, Transformers]
---
# Flash Attention - Fast Memory-Efficient Attention
## Quick start
Flash Attention provides 2-4x speedup and 10-20x memory reduction for transformer attention through IO-aware tiling and recomputation.
**PyTorch native (easiest, PyTorch 2.2+)**:
```python
import torch
import torch.nn.functional as F
q = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16) # [batch, heads, seq, dim]
k = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 512, 64, device='cuda', dtype=torch.float16)
# Automatically uses Flash Attention if available
out = F.scaled_dot_product_attention(q, k, v)
```
**flash-attn library (more features)**:
```bash
pip install flash-attn --no-build-isolation
```
```python
from flash_attn import flash_attn_func
# q, k, v: [batch, seqlen, nheads, headdim]
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
```
## Common workflows
### Workflow 1: Enable in existing PyTorch model
Copy this checklist:
```
Flash Attention Integration:
- [ ] Step 1: Check PyTorch version (≥2.2)
- [ ] Step 2: Enable Flash Attention backend
- [ ] Step 3: Verify speedup with profiling
- [ ] Step 4: Test accuracy matches baseline
```
**Step 1: Check PyTorch version**
```bash
python -c "import torch; print(torch.__version__)"
# Should be ≥2.2.0
```
If <2.2, upgrade:
```bash
pip install --upgrade torch
```
**Step 2: Enable Flash Attention backend**
Replace standard attention:
```python
# Before (standard attention)
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(d_k), dim=-1)
out = attn_weights @ v
# After (Flash Attention)
import torch.nn.functional as F
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
```
Force Flash Attention backend:
```python
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
out = F.scaled_dot_product_attention(q, k, v)
```
**Step 3: Verify speedup with profiling**
```python
import torch.utils.benchmark as benchmark
def test_attention(use_flash):
q, k, v = [torch.randn(2, 8, 2048, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
if use_flash:
with torch.backends.cuda.sdp_kernel(enable_flash=True):
return F.scaled_dot_product_attention(q, k, v)
else:
attn = (q @ k.transpose(-2, -1) / 8.0).softmax(dim=-1)
return attn @ v
# Benchmark
t_flash = benchmark.Timer(stmt='test_attention(True)', globals=globals())
t_standard = benchmark.Timer(stmt='test_attention(False)', globals=globals())
print(f"Flash: {t_flash.timeit(100).mean:.3f}s")
print(f"Standard: {t_standard.timeit(100).mean:.3f}s")
```
Expected: 2-4x speedup for sequences >512 tokens.
**Step 4: Test accuracy matches baseline**
```python
# Compare outputs
q, k, v = [torch.randn(1, 8, 512, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Flash Attention
out_flash = F.scaled_dot_product_attention(q, k, v)
# Standard attention
attn_weights = torch.softmax(q @ k.transpose(-2, -1) / 8.0, dim=-1)
out_standard = attn_weights @ v
# Check difference
diff = (out_flash - out_standard).abs().max()
print(f"Max difference: {diff:.6f}")
# Should be <1e-3 for float16
```
### Workflow 2: Use flash-attn library for advanced features
For multi-query attention, sliding window, or H100 FP8.
Copy this checklist:
```
flash-attn Library Setup:
- [ ] Step 1: Install flash-attn library
- [ ] Step 2: Modify attention code
- [ ] Step 3: Enable advanced features
- [ ] Step 4: Benchmark performance
```
**Step 1: Install flash-attn library**
```bash
# NVIDIA GPUs (CUDA 12.0+)
pip install flash-attn --no-build-isolation
# Verify installation
python -c "from flash_attn import flash_attn_func; print('Success')"
```
**Step 2: Modify attention code**
```python
from flash_attn import flash_attn_func
# Input: [batch_size, seq_len, num_heads, head_dim]
# Transpose from [batch, heads, seq, dim] if needed
q = q.transpose(1, 2) # [batch, seq, heads, dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(
q, k, v,
dropout_p=0.1,
causal=True, # For autoregressive models
window_size=(-1, -1), # No sliding window
softmax_scale=None # Auto-scale
)
out = out.transpose(1, 2) # Back to [batch, heads, seq, dim]
```
**Step 3: Enable advanced features**
Multi-query attention (shared K/V across heads):
```python
from flash_attn import flash_attn_func
# q: [batch, seq, num_q_heads, dim]
# k, v: [batch, seq, num_kv_heads, dim] # Fewer KV heads
out = flash_attn_func(q, k, v) # Automatically handles MQA
```
Sliding window attention (local attention):
```python
# Only attend to window of 256 tokens before/after
out = flash_attn_func(
q, k, v,
window_size=(256, 256), # (left, right) window
causal=True
)
```
**Step 4: Benchmark performance**
```python
import torch
from flash_attn import flash_attn_func
import time
q, k, v = [torch.randn(4, 4096, 32, 64, device='cuda', dtype=torch.float16) for _ in range(3)]
# Warmup
for _ in range(10):
_ = flash_attn_func(q, k, v)
# Benchmark
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = flash_attn_func(q, k, v)
torch.cuda.synchronize()
end = time.time()
print(f"Time per iteration: {(end-start)/100*1000:.2f}ms")
print(f"Memory allocated: {torch.cuda.max_memory_allocated()/1e9:.2f}GB")
```
### Workflow 3: H100 FP8 optimization (FlashAttention-3)
For maximum performance on H100 GPUs.
```
FP8 Setup:
- [ ] Step 1: Verify H100 GPU available
- [ ] Step 2: Install flash-attn with FP8 support
- [ ] Step 3: Convert inputs to FP8
- [ ] Step 4: Run with FP8 attention
```
**Step 1: Verify H100 GPU**
```bash
nvidia-smi --query-gpu=name --format=csv
# Should show "H100" or "H800"
```
**Step 2: Install flash-attn with FP8 support**
```bash
pip install flash-attn --no-build-isolation
# FP8 support included for H100
```
**Step 3: Convert inputs to FP8**
```python
import torch
q = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 4096, 32, 64, device='cuda', dtype=torch.float16)
# Convert to float8_e4m3 (FP8)
q_fp8 = q.to(torch.float8_e4m3fn)
k_fp8 = k.to(torch.float8_e4m3fn)
v_fp8 = v.to(torch.float8_e4m3fn)
```
**Step 4: Run with FP8 attention**
```python
from flash_attn import flash_attn_func
# FlashAttention-3 automatically uses FP8 kernels on H100
out = flash_attn_func(q_fp8, k_fp8, v_fp8)
# Result: ~1.2 PFLOPS, 1.5-2x faster than FP16
```
## When to use vs alternatives
**Use Flash Attention when:**
- Training transformers with sequences >512 tokens
- Running inference with long context (>2K tokens)
- GPU memory constrained (OOM with standard attention)
- Need 2-4x speedup without accuracy loss
- Using PyTorch 2.2+ or can install flash-attn
**Use alternatives instead:**
- **Standard attention**: Sequences <256 tokens (overhead not worth it)
- **xFormers**: Need more attention variants (not just speed)
- **Memory-efficient attention**: CPU inference (Flash Attention needs GPU)
## Common issues
**Issue: ImportError: cannot import flash_attn**
Install with no-build-isolation flag:
```bash
pip install flash-attn --no-build-isolation
```
Or install CUDA toolkit first:
```bash
conda install cuda -c nvidia
pip install flash-attn --no-build-isolation
```
**Issue: Slower than expected (no speedup)**
Flash Attention benefits increase with sequence length:
- <512 tokens: Minimal speedup (10-20%)
- 512-2K tokens: 2-3x speedup
- >2K tokens: 3-4x speedup
Check sequence length is sufficient.
**Issue: RuntimeError: CUDA error**
Verify GPU supports Flash Attention:
```python
import torch
print(torch.cuda.get_device_capability())
# Should be ≥(7, 5) for Turing+
```
Flash Attention requires:
- Ampere (A100, A10): ✅ Full support
- Turing (T4): ✅ Supported
- Volta (V100): ❌ Not supported
**Issue: Accuracy degradation**
Check dtype is float16 or bfloat16 (not float32):
```python
q = q.to(torch.float16) # Or torch.bfloat16
```
Flash Attention uses float16/bfloat16 for speed. Float32 not supported.
## Advanced topics
**Integration with HuggingFace Transformers**: See [references/transformers-integration.md](references/transformers-integration.md) for enabling Flash Attention in BERT, GPT, Llama models.
**Performance benchmarks**: See [references/benchmarks.md](references/benchmarks.md) for detailed speed and memory comparisons across GPUs and sequence lengths.
**Algorithm details**: See [references/algorithm.md](references/algorithm.md) for tiling strategy, recomputation, and IO complexity analysis.
**Advanced features**: See [references/advanced-features.md](references/advanced-features.md) for rotary embeddings, ALiBi, paged KV cache, and custom attention masks.
## Hardware requirements
- **GPU**: NVIDIA Ampere+ (A100, A10, A30) or AMD MI200+
- **VRAM**: Same as standard attention (Flash Attention doesn't increase memory)
- **CUDA**: 12.0+ (11.8 minimum)
- **PyTorch**: 2.2+ for native support
**Not supported**: V100 (Volta), CPU inference
## Resources
- Paper: "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" (NeurIPS 2022)
- Paper: "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" (ICLR 2024)
- Blog: https://tridao.me/blog/2024/flash3/
- GitHub: https://github.com/Dao-AILab/flash-attention
- PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

View File

@@ -0,0 +1,215 @@
# Performance Benchmarks
## Contents
- Speed comparisons across GPUs
- Memory usage analysis
- Scaling with sequence length
- Training vs inference performance
- Flash Attention versions comparison
## Speed comparisons across GPUs
### A100 80GB (Ampere)
**Forward pass time** (milliseconds, batch=8, heads=32, dim=64):
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 | Speedup (FA2) |
|------------|----------|--------------|--------------|---------------|
| 512 | 1.2 | 0.9 | N/A | 1.3x |
| 1024 | 3.8 | 1.4 | N/A | 2.7x |
| 2048 | 14.2 | 4.8 | N/A | 3.0x |
| 4096 | 55.1 | 17.3 | N/A | 3.2x |
| 8192 | 218.5 | 66.2 | N/A | 3.3x |
### H100 80GB (Hopper)
**Forward pass time** (milliseconds, same config):
| Seq Length | Standard | Flash Attn 2 | Flash Attn 3 (FP16) | Flash Attn 3 (FP8) | Best Speedup |
|------------|----------|--------------|---------------------|--------------------|--------------|
| 512 | 0.8 | 0.6 | 0.4 | 0.3 | 2.7x |
| 1024 | 2.6 | 1.0 | 0.6 | 0.4 | 6.5x |
| 2048 | 9.8 | 3.4 | 2.0 | 1.3 | 7.5x |
| 4096 | 38.2 | 12.5 | 7.2 | 4.8 | 8.0x |
| 8192 | 151.4 | 47.8 | 27.1 | 18.2 | 8.3x |
**Key insight**: Flash Attention 3 on H100 with FP8 achieves ~1.2 PFLOPS (75% of theoretical max).
### A10G 24GB (Ampere)
**Forward pass time** (milliseconds, batch=4):
| Seq Length | Standard | Flash Attn 2 | Speedup |
|------------|----------|--------------|---------|
| 512 | 2.1 | 1.6 | 1.3x |
| 1024 | 6.8 | 2.8 | 2.4x |
| 2048 | 25.9 | 9.4 | 2.8x |
| 4096 | 102.1 | 35.2 | 2.9x |
## Memory usage analysis
### GPU memory consumption (batch=8, heads=32, dim=64)
**Standard attention memory**:
| Seq Length | Attention Matrix | KV Cache | Total | Notes |
|------------|------------------|----------|-------|-------|
| 512 | 8 MB | 32 MB | 40 MB | Manageable |
| 2048 | 128 MB | 128 MB | 256 MB | Growing |
| 8192 | 2048 MB (2 GB) | 512 MB | 2.5 GB | Large |
| 32768 | 32768 MB (32 GB) | 2048 MB | 34 GB | OOM on 24GB GPUs |
**Flash Attention 2 memory**:
| Seq Length | Attention (on-chip) | KV Cache | Total | Reduction |
|------------|---------------------|----------|-------|-----------|
| 512 | 0 MB (recomputed) | 32 MB | 32 MB | 20% |
| 2048 | 0 MB | 128 MB | 128 MB | 50% |
| 8192 | 0 MB | 512 MB | 512 MB | 80% |
| 32768 | 0 MB | 2048 MB | 2 GB | 94% |
**Key insight**: Flash Attention doesn't materialize attention matrix, saving O(N²) memory.
### Memory scaling comparison
**Llama 2 7B model memory** (float16, batch=1):
| Context Length | Standard Attention | Flash Attention 2 | Can Fit 24GB GPU? |
|----------------|-------------------|-------------------|-------------------|
| 2K | 3.2 GB | 2.1 GB | Both: Yes |
| 4K | 5.8 GB | 2.8 GB | Both: Yes |
| 8K | 12.1 GB | 4.2 GB | Both: Yes |
| 16K | 26.3 GB (OOM) | 7.8 GB | Only Flash: Yes |
| 32K | OOM | 14.2 GB | Only Flash: Yes |
### Training memory (Llama 2 7B, batch=4)
| Context | Standard (GB) | Flash Attn (GB) | Reduction |
|---------|---------------|-----------------|-----------|
| 2K | 18.2 | 12.4 | 32% |
| 4K | 34.8 | 16.8 | 52% |
| 8K | OOM (>40GB) | 26.2 | Fits! |
## Scaling with sequence length
### Computational complexity
**Standard attention**:
- Time: O(N² × d)
- Memory: O(N² + N × d)
**Flash Attention**:
- Time: O(N² × d) (same, but with better constants)
- Memory: O(N × d) (linear!)
### Empirical scaling (A100, batch=1, heads=32, dim=64)
**Time per token (milliseconds)**:
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|----------|-----|-----|-----|-----|-----|------|
| Standard | 0.15 | 0.37 | 1.11 | 3.44 | 13.4 | 52.8 |
| Flash Attn 2 | 0.11 | 0.14 | 0.24 | 0.43 | 0.83 | 1.64 |
| Speedup | 1.4x | 2.6x | 4.6x | 8.0x | 16.1x | 32.2x |
**Observation**: Speedup increases quadratically with sequence length!
### Memory per token (MB)
| Sequence | 512 | 1K | 2K | 4K | 8K | 16K |
|----------|-----|-----|-----|-----|-----|------|
| Standard | 0.08 | 0.13 | 0.25 | 0.64 | 2.05 | 8.13 |
| Flash Attn 2 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 | 0.06 |
**Observation**: Flash Attention memory per token is constant!
## Training vs inference performance
### Training (forward + backward, Llama 2 7B, A100)
| Batch × Seq | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|-------------|------------------------|--------------------------|---------|
| 4 × 2K | 1.2 | 3.1 | 2.6x |
| 8 × 2K | 2.1 | 5.8 | 2.8x |
| 4 × 4K | 0.4 | 1.3 | 3.3x |
| 8 × 4K | OOM | 2.4 | Enabled |
| 2 × 8K | 0.1 | 0.4 | 4.0x |
### Inference (generation, Llama 2 7B, A100)
| Context Length | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|----------------|----------------------|-------------------------|---------|
| 512 | 48 | 52 | 1.1x |
| 2K | 42 | 62 | 1.5x |
| 4K | 31 | 58 | 1.9x |
| 8K | 18 | 51 | 2.8x |
| 16K | OOM | 42 | Enabled |
**Note**: Inference speedup less dramatic than training because generation is memory-bound (KV cache accesses).
## Flash Attention versions comparison
### Flash Attention 1 vs 2 vs 3 (H100, seq=4096, batch=8)
| Metric | FA1 | FA2 | FA3 (FP16) | FA3 (FP8) |
|--------|-----|-----|------------|-----------|
| Forward time (ms) | 28.4 | 12.5 | 7.2 | 4.8 |
| Memory (GB) | 4.8 | 4.2 | 4.2 | 2.8 |
| TFLOPS | 180 | 420 | 740 | 1150 |
| GPU util % | 35% | 55% | 75% | 82% |
**Key improvements**:
- FA2: 2.3x faster than FA1 (better parallelism)
- FA3 (FP16): 1.7x faster than FA2 (H100 async optimizations)
- FA3 (FP8): 2.6x faster than FA2 (low precision)
### Features by version
| Feature | FA1 | FA2 | FA3 |
|---------|-----|-----|-----|
| Basic attention | ✅ | ✅ | ✅ |
| Causal masking | ✅ | ✅ | ✅ |
| Multi-query attention | ❌ | ✅ | ✅ |
| Sliding window | ❌ | ✅ | ✅ |
| Paged KV cache | ❌ | ✅ | ✅ |
| FP8 support | ❌ | ❌ | ✅ (H100 only) |
| Work partitioning | Basic | Advanced | Optimal |
## Real-world model benchmarks
### Llama 2 models (A100 80GB, batch=4, seq=2048)
| Model | Params | Standard (samples/sec) | Flash Attn (samples/sec) | Speedup |
|-------|--------|------------------------|--------------------------|---------|
| Llama 2 7B | 7B | 1.2 | 3.1 | 2.6x |
| Llama 2 13B | 13B | 0.6 | 1.7 | 2.8x |
| Llama 2 70B | 70B | 0.12 | 0.34 | 2.8x |
### GPT-style models (seq=1024)
| Model | Standard (tokens/sec) | Flash Attn (tokens/sec) | Speedup |
|-------|----------------------|-------------------------|---------|
| GPT-2 (124M) | 520 | 680 | 1.3x |
| GPT-J (6B) | 42 | 98 | 2.3x |
| GPT-NeoX (20B) | 8 | 22 | 2.75x |
## Recommendations by use case
**Training large models (>7B parameters)**:
- Use Flash Attention 2 on A100
- Use Flash Attention 3 FP8 on H100 for maximum speed
- Expected: 2.5-3x speedup
**Long context inference (>4K tokens)**:
- Flash Attention essential (enables contexts standard attention can't handle)
- Expected: 2-4x speedup, 5-10x memory reduction
**Short sequences (<512 tokens)**:
- Flash Attention provides 1.2-1.5x speedup
- Minimal memory benefit
- Still worth enabling (no downside)
**Multi-user serving**:
- Flash Attention reduces per-request memory
- Allows higher concurrent batch sizes
- Can serve 2-3x more users on same hardware

View File

@@ -0,0 +1,293 @@
# HuggingFace Transformers Integration
## Contents
- Enabling Flash Attention in Transformers
- Supported model architectures
- Configuration examples
- Performance comparisons
- Troubleshooting model-specific issues
## Enabling Flash Attention in Transformers
HuggingFace Transformers (v4.36+) supports Flash Attention 2 natively.
**Simple enable for any supported model**:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto"
)
```
**Install requirements**:
```bash
pip install transformers>=4.36
pip install flash-attn --no-build-isolation
```
## Supported model architectures
As of Transformers 4.40:
**Fully supported**:
- Llama / Llama 2 / Llama 3
- Mistral / Mixtral
- Falcon
- GPT-NeoX
- Phi / Phi-2 / Phi-3
- Qwen / Qwen2
- Gemma
- Starcoder2
- GPT-J
- OPT
- BLOOM
**Partially supported** (encoder-decoder):
- BART
- T5 / Flan-T5
- Whisper
**Check support**:
```python
from transformers import AutoConfig
config = AutoConfig.from_pretrained("model-name")
print(config._attn_implementation_internal)
# 'flash_attention_2' if supported
```
## Configuration examples
### Llama 2 with Flash Attention
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_id = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
model_id,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Generate
inputs = tokenizer("Once upon a time", return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_length=100)
print(tokenizer.decode(outputs[0]))
```
### Mistral with Flash Attention for long context
```python
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-v0.1",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, # Better for long context
device_map="auto",
max_position_embeddings=32768 # Extended context
)
# Process long document (32K tokens)
long_text = "..." * 10000
inputs = tokenizer(long_text, return_tensors="pt", truncation=False).to("cuda")
outputs = model.generate(**inputs, max_new_tokens=512)
```
### Fine-tuning with Flash Attention
```python
from transformers import Trainer, TrainingArguments
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
)
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_train_epochs=3,
fp16=True, # Must match model dtype
optim="adamw_torch_fused" # Fast optimizer
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset
)
trainer.train()
```
### Multi-GPU training
```python
from transformers import AutoModelForCausalLM
import torch
# Model parallelism with Flash Attention
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-13b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto", # Automatic multi-GPU placement
max_memory={0: "20GB", 1: "20GB"} # Limit per GPU
)
```
## Performance comparisons
### Memory usage (Llama 2 7B, batch=1)
| Sequence Length | Standard Attention | Flash Attention 2 | Reduction |
|-----------------|-------------------|-------------------|-----------|
| 512 | 1.2 GB | 0.9 GB | 25% |
| 2048 | 3.8 GB | 1.4 GB | 63% |
| 8192 | 14.2 GB | 3.2 GB | 77% |
| 32768 | OOM (>24GB) | 10.8 GB | Fits! |
### Speed (tokens/sec, A100 80GB)
| Model | Standard | Flash Attn 2 | Speedup |
|-------|----------|--------------|---------|
| Llama 2 7B (seq=2048) | 42 | 118 | 2.8x |
| Llama 2 13B (seq=4096) | 18 | 52 | 2.9x |
| Llama 2 70B (seq=2048) | 4 | 11 | 2.75x |
### Training throughput (samples/sec)
| Model | Batch Size | Standard | Flash Attn 2 | Speedup |
|-------|------------|----------|--------------|---------|
| Llama 2 7B | 4 | 1.2 | 3.1 | 2.6x |
| Llama 2 7B | 8 | 2.1 | 5.8 | 2.8x |
| Llama 2 13B | 2 | 0.6 | 1.7 | 2.8x |
## Troubleshooting model-specific issues
### Issue: Model doesn't support Flash Attention
Check support list above. If not supported, use PyTorch SDPA as fallback:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="sdpa", # PyTorch native (still faster)
torch_dtype=torch.float16
)
```
### Issue: CUDA out of memory during loading
Reduce memory footprint:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
device_map="auto",
max_memory={0: "18GB"}, # Reserve memory for KV cache
low_cpu_mem_usage=True
)
```
### Issue: Slower inference than expected
Ensure dtype matches:
```python
# Model and inputs must both be float16/bfloat16
model = model.to(torch.float16)
inputs = tokenizer(..., return_tensors="pt").to("cuda")
inputs = {k: v.to(torch.float16) if v.dtype == torch.float32 else v
for k, v in inputs.items()}
```
### Issue: Different outputs vs standard attention
Flash Attention is numerically equivalent but uses different computation order. Small differences (<1e-3) are normal:
```python
# Compare outputs
model_standard = AutoModelForCausalLM.from_pretrained("model-name", torch_dtype=torch.float16)
model_flash = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="flash_attention_2",
torch_dtype=torch.float16
)
inputs = tokenizer("Test", return_tensors="pt").to("cuda")
with torch.no_grad():
out_standard = model_standard(**inputs).logits
out_flash = model_flash(**inputs).logits
diff = (out_standard - out_flash).abs().max()
print(f"Max diff: {diff:.6f}") # Should be ~1e-3 to 1e-4
```
### Issue: ImportError during model loading
Install flash-attn:
```bash
pip install flash-attn --no-build-isolation
```
Or disable Flash Attention:
```python
model = AutoModelForCausalLM.from_pretrained(
"model-name",
attn_implementation="eager", # Standard PyTorch
torch_dtype=torch.float16
)
```
## Best practices
1. **Always use float16/bfloat16** with Flash Attention (not float32)
2. **Set device_map="auto"** for automatic memory management
3. **Use bfloat16 for long context** (better numerical stability)
4. **Enable gradient checkpointing** for training large models
5. **Monitor memory** with `torch.cuda.max_memory_allocated()`
**Example with all best practices**:
```python
from transformers import AutoModelForCausalLM, TrainingArguments
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16, # Better for training
device_map="auto",
low_cpu_mem_usage=True
)
# Enable gradient checkpointing for memory
model.gradient_checkpointing_enable()
# Training with optimizations
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=8,
gradient_accumulation_steps=2,
bf16=True, # Match model dtype
optim="adamw_torch_fused",
gradient_checkpointing=True
)
```

430
skills/mlops/gguf/SKILL.md Normal file
View File

@@ -0,0 +1,430 @@
---
name: gguf-quantization
description: GGUF format and llama.cpp quantization for efficient CPU/GPU inference. Use when deploying models on consumer hardware, Apple Silicon, or when needing flexible quantization from 2-8 bit without GPU requirements.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [llama-cpp-python>=0.2.0]
metadata:
hermes:
tags: [GGUF, Quantization, llama.cpp, CPU Inference, Apple Silicon, Model Compression, Optimization]
---
# GGUF - Quantization Format for llama.cpp
The GGUF (GPT-Generated Unified Format) is the standard file format for llama.cpp, enabling efficient inference on CPUs, Apple Silicon, and GPUs with flexible quantization options.
## When to use GGUF
**Use GGUF when:**
- Deploying on consumer hardware (laptops, desktops)
- Running on Apple Silicon (M1/M2/M3) with Metal acceleration
- Need CPU inference without GPU requirements
- Want flexible quantization (Q2_K to Q8_0)
- Using local AI tools (LM Studio, Ollama, text-generation-webui)
**Key advantages:**
- **Universal hardware**: CPU, Apple Silicon, NVIDIA, AMD support
- **No Python runtime**: Pure C/C++ inference
- **Flexible quantization**: 2-8 bit with various methods (K-quants)
- **Ecosystem support**: LM Studio, Ollama, koboldcpp, and more
- **imatrix**: Importance matrix for better low-bit quality
**Use alternatives instead:**
- **AWQ/GPTQ**: Maximum accuracy with calibration on NVIDIA GPUs
- **HQQ**: Fast calibration-free quantization for HuggingFace
- **bitsandbytes**: Simple integration with transformers library
- **TensorRT-LLM**: Production NVIDIA deployment with maximum speed
## Quick start
### Installation
```bash
# Clone llama.cpp
git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp
# Build (CPU)
make
# Build with CUDA (NVIDIA)
make GGML_CUDA=1
# Build with Metal (Apple Silicon)
make GGML_METAL=1
# Install Python bindings (optional)
pip install llama-cpp-python
```
### Convert model to GGUF
```bash
# Install requirements
pip install -r requirements.txt
# Convert HuggingFace model to GGUF (FP16)
python convert_hf_to_gguf.py ./path/to/model --outfile model-f16.gguf
# Or specify output type
python convert_hf_to_gguf.py ./path/to/model \
--outfile model-f16.gguf \
--outtype f16
```
### Quantize model
```bash
# Basic quantization to Q4_K_M
./llama-quantize model-f16.gguf model-q4_k_m.gguf Q4_K_M
# Quantize with importance matrix (better quality)
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
```
### Run inference
```bash
# CLI inference
./llama-cli -m model-q4_k_m.gguf -p "Hello, how are you?"
# Interactive mode
./llama-cli -m model-q4_k_m.gguf --interactive
# With GPU offload
./llama-cli -m model-q4_k_m.gguf -ngl 35 -p "Hello!"
```
## Quantization types
### K-quant methods (recommended)
| Type | Bits | Size (7B) | Quality | Use Case |
|------|------|-----------|---------|----------|
| Q2_K | 2.5 | ~2.8 GB | Low | Extreme compression |
| Q3_K_S | 3.0 | ~3.0 GB | Low-Med | Memory constrained |
| Q3_K_M | 3.3 | ~3.3 GB | Medium | Balance |
| Q4_K_S | 4.0 | ~3.8 GB | Med-High | Good balance |
| Q4_K_M | 4.5 | ~4.1 GB | High | **Recommended default** |
| Q5_K_S | 5.0 | ~4.6 GB | High | Quality focused |
| Q5_K_M | 5.5 | ~4.8 GB | Very High | High quality |
| Q6_K | 6.0 | ~5.5 GB | Excellent | Near-original |
| Q8_0 | 8.0 | ~7.2 GB | Best | Maximum quality |
### Legacy methods
| Type | Description |
|------|-------------|
| Q4_0 | 4-bit, basic |
| Q4_1 | 4-bit with delta |
| Q5_0 | 5-bit, basic |
| Q5_1 | 5-bit with delta |
**Recommendation**: Use K-quant methods (Q4_K_M, Q5_K_M) for best quality/size ratio.
## Conversion workflows
### Workflow 1: HuggingFace to GGUF
```bash
# 1. Download model
huggingface-cli download meta-llama/Llama-3.1-8B --local-dir ./llama-3.1-8b
# 2. Convert to GGUF (FP16)
python convert_hf_to_gguf.py ./llama-3.1-8b \
--outfile llama-3.1-8b-f16.gguf \
--outtype f16
# 3. Quantize
./llama-quantize llama-3.1-8b-f16.gguf llama-3.1-8b-q4_k_m.gguf Q4_K_M
# 4. Test
./llama-cli -m llama-3.1-8b-q4_k_m.gguf -p "Hello!" -n 50
```
### Workflow 2: With importance matrix (better quality)
```bash
# 1. Convert to GGUF
python convert_hf_to_gguf.py ./model --outfile model-f16.gguf
# 2. Create calibration text (diverse samples)
cat > calibration.txt << 'EOF'
The quick brown fox jumps over the lazy dog.
Machine learning is a subset of artificial intelligence.
Python is a popular programming language.
# Add more diverse text samples...
EOF
# 3. Generate importance matrix
./llama-imatrix -m model-f16.gguf \
-f calibration.txt \
--chunk 512 \
-o model.imatrix \
-ngl 35 # GPU layers if available
# 4. Quantize with imatrix
./llama-quantize --imatrix model.imatrix \
model-f16.gguf \
model-q4_k_m.gguf \
Q4_K_M
```
### Workflow 3: Multiple quantizations
```bash
#!/bin/bash
MODEL="llama-3.1-8b-f16.gguf"
IMATRIX="llama-3.1-8b.imatrix"
# Generate imatrix once
./llama-imatrix -m $MODEL -f wiki.txt -o $IMATRIX -ngl 35
# Create multiple quantizations
for QUANT in Q4_K_M Q5_K_M Q6_K Q8_0; do
OUTPUT="llama-3.1-8b-${QUANT,,}.gguf"
./llama-quantize --imatrix $IMATRIX $MODEL $OUTPUT $QUANT
echo "Created: $OUTPUT ($(du -h $OUTPUT | cut -f1))"
done
```
## Python usage
### llama-cpp-python
```python
from llama_cpp import Llama
# Load model
llm = Llama(
model_path="./model-q4_k_m.gguf",
n_ctx=4096, # Context window
n_gpu_layers=35, # GPU offload (0 for CPU only)
n_threads=8 # CPU threads
)
# Generate
output = llm(
"What is machine learning?",
max_tokens=256,
temperature=0.7,
stop=["</s>", "\n\n"]
)
print(output["choices"][0]["text"])
```
### Chat completion
```python
from llama_cpp import Llama
llm = Llama(
model_path="./model-q4_k_m.gguf",
n_ctx=4096,
n_gpu_layers=35,
chat_format="llama-3" # Or "chatml", "mistral", etc.
)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is Python?"}
]
response = llm.create_chat_completion(
messages=messages,
max_tokens=256,
temperature=0.7
)
print(response["choices"][0]["message"]["content"])
```
### Streaming
```python
from llama_cpp import Llama
llm = Llama(model_path="./model-q4_k_m.gguf", n_gpu_layers=35)
# Stream tokens
for chunk in llm(
"Explain quantum computing:",
max_tokens=256,
stream=True
):
print(chunk["choices"][0]["text"], end="", flush=True)
```
## Server mode
### Start OpenAI-compatible server
```bash
# Start server
./llama-server -m model-q4_k_m.gguf \
--host 0.0.0.0 \
--port 8080 \
-ngl 35 \
-c 4096
# Or with Python bindings
python -m llama_cpp.server \
--model model-q4_k_m.gguf \
--n_gpu_layers 35 \
--host 0.0.0.0 \
--port 8080
```
### Use with OpenAI client
```python
from openai import OpenAI
client = OpenAI(
base_url="http://localhost:8080/v1",
api_key="not-needed"
)
response = client.chat.completions.create(
model="local-model",
messages=[{"role": "user", "content": "Hello!"}],
max_tokens=256
)
print(response.choices[0].message.content)
```
## Hardware optimization
### Apple Silicon (Metal)
```bash
# Build with Metal
make clean && make GGML_METAL=1
# Run with Metal acceleration
./llama-cli -m model.gguf -ngl 99 -p "Hello"
# Python with Metal
llm = Llama(
model_path="model.gguf",
n_gpu_layers=99, # Offload all layers
n_threads=1 # Metal handles parallelism
)
```
### NVIDIA CUDA
```bash
# Build with CUDA
make clean && make GGML_CUDA=1
# Run with CUDA
./llama-cli -m model.gguf -ngl 35 -p "Hello"
# Specify GPU
CUDA_VISIBLE_DEVICES=0 ./llama-cli -m model.gguf -ngl 35
```
### CPU optimization
```bash
# Build with AVX2/AVX512
make clean && make
# Run with optimal threads
./llama-cli -m model.gguf -t 8 -p "Hello"
# Python CPU config
llm = Llama(
model_path="model.gguf",
n_gpu_layers=0, # CPU only
n_threads=8, # Match physical cores
n_batch=512 # Batch size for prompt processing
)
```
## Integration with tools
### Ollama
```bash
# Create Modelfile
cat > Modelfile << 'EOF'
FROM ./model-q4_k_m.gguf
TEMPLATE """{{ .System }}
{{ .Prompt }}"""
PARAMETER temperature 0.7
PARAMETER num_ctx 4096
EOF
# Create Ollama model
ollama create mymodel -f Modelfile
# Run
ollama run mymodel "Hello!"
```
### LM Studio
1. Place GGUF file in `~/.cache/lm-studio/models/`
2. Open LM Studio and select the model
3. Configure context length and GPU offload
4. Start inference
### text-generation-webui
```bash
# Place in models folder
cp model-q4_k_m.gguf text-generation-webui/models/
# Start with llama.cpp loader
python server.py --model model-q4_k_m.gguf --loader llama.cpp --n-gpu-layers 35
```
## Best practices
1. **Use K-quants**: Q4_K_M offers best quality/size balance
2. **Use imatrix**: Always use importance matrix for Q4 and below
3. **GPU offload**: Offload as many layers as VRAM allows
4. **Context length**: Start with 4096, increase if needed
5. **Thread count**: Match physical CPU cores, not logical
6. **Batch size**: Increase n_batch for faster prompt processing
## Common issues
**Model loads slowly:**
```bash
# Use mmap for faster loading
./llama-cli -m model.gguf --mmap
```
**Out of memory:**
```bash
# Reduce GPU layers
./llama-cli -m model.gguf -ngl 20 # Reduce from 35
# Or use smaller quantization
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
```
**Poor quality at low bits:**
```bash
# Always use imatrix for Q4 and below
./llama-imatrix -m model-f16.gguf -f calibration.txt -o model.imatrix
./llama-quantize --imatrix model.imatrix model-f16.gguf model-q4_k_m.gguf Q4_K_M
```
## References
- **[Advanced Usage](references/advanced-usage.md)** - Batching, speculative decoding, custom builds
- **[Troubleshooting](references/troubleshooting.md)** - Common issues, debugging, benchmarks
## Resources
- **Repository**: https://github.com/ggml-org/llama.cpp
- **Python Bindings**: https://github.com/abetlen/llama-cpp-python
- **Pre-quantized Models**: https://huggingface.co/TheBloke
- **GGUF Converter**: https://huggingface.co/spaces/ggml-org/gguf-my-repo
- **License**: MIT

View File

@@ -0,0 +1,504 @@
# GGUF Advanced Usage Guide
## Speculative Decoding
### Draft Model Approach
```bash
# Use smaller model as draft for faster generation
./llama-speculative \
-m large-model-q4_k_m.gguf \
-md draft-model-q4_k_m.gguf \
-p "Write a story about AI" \
-n 500 \
--draft 8 # Draft tokens before verification
```
### Self-Speculative Decoding
```bash
# Use same model with different context for speculation
./llama-cli -m model-q4_k_m.gguf \
--lookup-cache-static lookup.bin \
--lookup-cache-dynamic lookup-dynamic.bin \
-p "Hello world"
```
## Batched Inference
### Process Multiple Prompts
```python
from llama_cpp import Llama
llm = Llama(
model_path="model-q4_k_m.gguf",
n_ctx=4096,
n_gpu_layers=35,
n_batch=512 # Larger batch for parallel processing
)
prompts = [
"What is Python?",
"Explain machine learning.",
"Describe neural networks."
]
# Process in batch (each prompt gets separate context)
for prompt in prompts:
output = llm(prompt, max_tokens=100)
print(f"Q: {prompt}")
print(f"A: {output['choices'][0]['text']}\n")
```
### Server Batching
```bash
# Start server with batching
./llama-server -m model-q4_k_m.gguf \
--host 0.0.0.0 \
--port 8080 \
-ngl 35 \
-c 4096 \
--parallel 4 # Concurrent requests
--cont-batching # Continuous batching
```
## Custom Model Conversion
### Convert with Vocabulary Modifications
```python
# custom_convert.py
import sys
sys.path.insert(0, './llama.cpp')
from convert_hf_to_gguf import main
from gguf import GGUFWriter
# Custom conversion with modified vocab
def convert_with_custom_vocab(model_path, output_path):
# Load and modify tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Add special tokens if needed
special_tokens = {"additional_special_tokens": ["<|custom|>"]}
tokenizer.add_special_tokens(special_tokens)
tokenizer.save_pretrained(model_path)
# Then run standard conversion
main([model_path, "--outfile", output_path])
```
### Convert Specific Architecture
```bash
# For Mistral-style models
python convert_hf_to_gguf.py ./mistral-model \
--outfile mistral-f16.gguf \
--outtype f16
# For Qwen models
python convert_hf_to_gguf.py ./qwen-model \
--outfile qwen-f16.gguf \
--outtype f16
# For Phi models
python convert_hf_to_gguf.py ./phi-model \
--outfile phi-f16.gguf \
--outtype f16
```
## Advanced Quantization
### Mixed Quantization
```bash
# Quantize different layer types differently
./llama-quantize model-f16.gguf model-mixed.gguf Q4_K_M \
--allow-requantize \
--leave-output-tensor
```
### Quantization with Token Embeddings
```bash
# Keep embeddings at higher precision
./llama-quantize model-f16.gguf model-q4.gguf Q4_K_M \
--token-embedding-type f16
```
### IQ Quantization (Importance-aware)
```bash
# Ultra-low bit quantization with importance
./llama-quantize --imatrix model.imatrix \
model-f16.gguf model-iq2_xxs.gguf IQ2_XXS
# Available IQ types: IQ2_XXS, IQ2_XS, IQ2_S, IQ3_XXS, IQ3_XS, IQ3_S, IQ4_XS
```
## Memory Optimization
### Memory Mapping
```python
from llama_cpp import Llama
# Use memory mapping for large models
llm = Llama(
model_path="model-q4_k_m.gguf",
use_mmap=True, # Memory map the model
use_mlock=False, # Don't lock in RAM
n_gpu_layers=35
)
```
### Partial GPU Offload
```python
# Calculate layers to offload based on VRAM
import subprocess
def get_free_vram_gb():
result = subprocess.run(
['nvidia-smi', '--query-gpu=memory.free', '--format=csv,nounits,noheader'],
capture_output=True, text=True
)
return int(result.stdout.strip()) / 1024
# Estimate layers based on VRAM (rough: 0.5GB per layer for 7B Q4)
free_vram = get_free_vram_gb()
layers_to_offload = int(free_vram / 0.5)
llm = Llama(
model_path="model-q4_k_m.gguf",
n_gpu_layers=min(layers_to_offload, 35) # Cap at total layers
)
```
### KV Cache Optimization
```python
from llama_cpp import Llama
# Optimize KV cache for long contexts
llm = Llama(
model_path="model-q4_k_m.gguf",
n_ctx=8192, # Large context
n_gpu_layers=35,
type_k=1, # Q8_0 for K cache (1)
type_v=1, # Q8_0 for V cache (1)
# Or use Q4_0 (2) for more compression
)
```
## Context Management
### Context Shifting
```python
from llama_cpp import Llama
llm = Llama(
model_path="model-q4_k_m.gguf",
n_ctx=4096,
n_gpu_layers=35
)
# Handle long conversations with context shifting
conversation = []
max_history = 10
def chat(user_message):
conversation.append({"role": "user", "content": user_message})
# Keep only recent history
if len(conversation) > max_history * 2:
conversation = conversation[-max_history * 2:]
response = llm.create_chat_completion(
messages=conversation,
max_tokens=256
)
assistant_message = response["choices"][0]["message"]["content"]
conversation.append({"role": "assistant", "content": assistant_message})
return assistant_message
```
### Save and Load State
```bash
# Save state to file
./llama-cli -m model.gguf \
-p "Once upon a time" \
--save-session session.bin \
-n 100
# Load and continue
./llama-cli -m model.gguf \
--load-session session.bin \
-p " and they lived" \
-n 100
```
## Grammar Constrained Generation
### JSON Output
```python
from llama_cpp import Llama, LlamaGrammar
# Define JSON grammar
json_grammar = LlamaGrammar.from_string('''
root ::= object
object ::= "{" ws pair ("," ws pair)* "}" ws
pair ::= string ":" ws value
value ::= string | number | object | array | "true" | "false" | "null"
array ::= "[" ws value ("," ws value)* "]" ws
string ::= "\\"" [^"\\\\]* "\\""
number ::= [0-9]+
ws ::= [ \\t\\n]*
''')
llm = Llama(model_path="model-q4_k_m.gguf", n_gpu_layers=35)
output = llm(
"Output a JSON object with name and age:",
grammar=json_grammar,
max_tokens=100
)
print(output["choices"][0]["text"])
```
### Custom Grammar
```python
# Grammar for specific format
answer_grammar = LlamaGrammar.from_string('''
root ::= "Answer: " letter "\\n" "Explanation: " explanation
letter ::= [A-D]
explanation ::= [a-zA-Z0-9 .,!?]+
''')
output = llm(
"Q: What is 2+2? A) 3 B) 4 C) 5 D) 6",
grammar=answer_grammar,
max_tokens=100
)
```
## LoRA Integration
### Load LoRA Adapter
```bash
# Apply LoRA at runtime
./llama-cli -m base-model-q4_k_m.gguf \
--lora lora-adapter.gguf \
--lora-scale 1.0 \
-p "Hello!"
```
### Multiple LoRA Adapters
```bash
# Stack multiple adapters
./llama-cli -m base-model.gguf \
--lora adapter1.gguf --lora-scale 0.5 \
--lora adapter2.gguf --lora-scale 0.5 \
-p "Hello!"
```
### Python LoRA Usage
```python
from llama_cpp import Llama
llm = Llama(
model_path="base-model-q4_k_m.gguf",
lora_path="lora-adapter.gguf",
lora_scale=1.0,
n_gpu_layers=35
)
```
## Embedding Generation
### Extract Embeddings
```python
from llama_cpp import Llama
llm = Llama(
model_path="model-q4_k_m.gguf",
embedding=True, # Enable embedding mode
n_gpu_layers=35
)
# Get embeddings
embeddings = llm.embed("This is a test sentence.")
print(f"Embedding dimension: {len(embeddings)}")
```
### Batch Embeddings
```python
texts = [
"Machine learning is fascinating.",
"Deep learning uses neural networks.",
"Python is a programming language."
]
embeddings = [llm.embed(text) for text in texts]
# Calculate similarity
import numpy as np
def cosine_similarity(a, b):
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
sim = cosine_similarity(embeddings[0], embeddings[1])
print(f"Similarity: {sim:.4f}")
```
## Performance Tuning
### Benchmark Script
```python
import time
from llama_cpp import Llama
def benchmark(model_path, prompt, n_tokens=100, n_runs=5):
llm = Llama(
model_path=model_path,
n_gpu_layers=35,
n_ctx=2048,
verbose=False
)
# Warmup
llm(prompt, max_tokens=10)
# Benchmark
times = []
for _ in range(n_runs):
start = time.time()
output = llm(prompt, max_tokens=n_tokens)
elapsed = time.time() - start
times.append(elapsed)
avg_time = sum(times) / len(times)
tokens_per_sec = n_tokens / avg_time
print(f"Model: {model_path}")
print(f"Avg time: {avg_time:.2f}s")
print(f"Tokens/sec: {tokens_per_sec:.1f}")
return tokens_per_sec
# Compare quantizations
for quant in ["q4_k_m", "q5_k_m", "q8_0"]:
benchmark(f"model-{quant}.gguf", "Explain quantum computing:", 100)
```
### Optimal Configuration Finder
```python
def find_optimal_config(model_path, target_vram_gb=8):
"""Find optimal n_gpu_layers and n_batch for target VRAM."""
from llama_cpp import Llama
import gc
best_config = None
best_speed = 0
for n_gpu_layers in range(0, 50, 5):
for n_batch in [128, 256, 512, 1024]:
try:
gc.collect()
llm = Llama(
model_path=model_path,
n_gpu_layers=n_gpu_layers,
n_batch=n_batch,
n_ctx=2048,
verbose=False
)
# Quick benchmark
start = time.time()
llm("Hello", max_tokens=50)
speed = 50 / (time.time() - start)
if speed > best_speed:
best_speed = speed
best_config = {
"n_gpu_layers": n_gpu_layers,
"n_batch": n_batch,
"speed": speed
}
del llm
gc.collect()
except Exception as e:
print(f"OOM at layers={n_gpu_layers}, batch={n_batch}")
break
return best_config
```
## Multi-GPU Setup
### Distribute Across GPUs
```bash
# Split model across multiple GPUs
./llama-cli -m large-model.gguf \
--tensor-split 0.5,0.5 \
-ngl 60 \
-p "Hello!"
```
### Python Multi-GPU
```python
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
from llama_cpp import Llama
llm = Llama(
model_path="large-model-q4_k_m.gguf",
n_gpu_layers=60,
tensor_split=[0.5, 0.5] # Split evenly across 2 GPUs
)
```
## Custom Builds
### Build with All Optimizations
```bash
# Clean build with all CPU optimizations
make clean
LLAMA_OPENBLAS=1 LLAMA_BLAS_VENDOR=OpenBLAS make -j
# With CUDA and cuBLAS
make clean
GGML_CUDA=1 LLAMA_CUBLAS=1 make -j
# With specific CUDA architecture
GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_86 make -j
```
### CMake Build
```bash
mkdir build && cd build
cmake .. -DGGML_CUDA=ON -DCMAKE_BUILD_TYPE=Release
cmake --build . --config Release -j
```

View File

@@ -0,0 +1,442 @@
# GGUF Troubleshooting Guide
## Installation Issues
### Build Fails
**Error**: `make: *** No targets specified and no makefile found`
**Fix**:
```bash
# Ensure you're in llama.cpp directory
cd llama.cpp
make
```
**Error**: `fatal error: cuda_runtime.h: No such file or directory`
**Fix**:
```bash
# Install CUDA toolkit
# Ubuntu
sudo apt install nvidia-cuda-toolkit
# Or set CUDA path
export CUDA_PATH=/usr/local/cuda
export PATH=$CUDA_PATH/bin:$PATH
make GGML_CUDA=1
```
### Python Bindings Issues
**Error**: `ERROR: Failed building wheel for llama-cpp-python`
**Fix**:
```bash
# Install build dependencies
pip install cmake scikit-build-core
# For CUDA support
CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
# For Metal (macOS)
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall --no-cache-dir
```
**Error**: `ImportError: libcudart.so.XX: cannot open shared object file`
**Fix**:
```bash
# Add CUDA libraries to path
export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Or reinstall with correct CUDA version
pip uninstall llama-cpp-python
CUDACXX=/usr/local/cuda/bin/nvcc CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python
```
## Conversion Issues
### Model Not Supported
**Error**: `KeyError: 'model.embed_tokens.weight'`
**Fix**:
```bash
# Check model architecture
python -c "from transformers import AutoConfig; print(AutoConfig.from_pretrained('./model').architectures)"
# Use appropriate conversion script
# For most models:
python convert_hf_to_gguf.py ./model --outfile model.gguf
# For older models, check if legacy script needed
```
### Vocabulary Mismatch
**Error**: `RuntimeError: Vocabulary size mismatch`
**Fix**:
```python
# Ensure tokenizer matches model
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("./model")
model = AutoModelForCausalLM.from_pretrained("./model")
print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Model vocab size: {model.config.vocab_size}")
# If mismatch, resize embeddings before conversion
model.resize_token_embeddings(len(tokenizer))
model.save_pretrained("./model-fixed")
```
### Out of Memory During Conversion
**Error**: `torch.cuda.OutOfMemoryError` during conversion
**Fix**:
```bash
# Use CPU for conversion
CUDA_VISIBLE_DEVICES="" python convert_hf_to_gguf.py ./model --outfile model.gguf
# Or use low memory mode
python convert_hf_to_gguf.py ./model --outfile model.gguf --outtype f16
```
## Quantization Issues
### Wrong Output File Size
**Problem**: Quantized file is larger than expected
**Check**:
```bash
# Verify quantization type
./llama-cli -m model.gguf --verbose
# Expected sizes for 7B model:
# Q4_K_M: ~4.1 GB
# Q5_K_M: ~4.8 GB
# Q8_0: ~7.2 GB
# F16: ~13.5 GB
```
### Quantization Crashes
**Error**: `Segmentation fault` during quantization
**Fix**:
```bash
# Increase stack size
ulimit -s unlimited
# Or use less threads
./llama-quantize -t 4 model-f16.gguf model-q4.gguf Q4_K_M
```
### Poor Quality After Quantization
**Problem**: Model outputs gibberish after quantization
**Solutions**:
1. **Use importance matrix**:
```bash
# Generate imatrix with good calibration data
./llama-imatrix -m model-f16.gguf \
-f wiki_sample.txt \
--chunk 512 \
-o model.imatrix
# Quantize with imatrix
./llama-quantize --imatrix model.imatrix \
model-f16.gguf model-q4_k_m.gguf Q4_K_M
```
2. **Try higher precision**:
```bash
# Use Q5_K_M or Q6_K instead of Q4
./llama-quantize model-f16.gguf model-q5_k_m.gguf Q5_K_M
```
3. **Check original model**:
```bash
# Test FP16 version first
./llama-cli -m model-f16.gguf -p "Hello, how are you?" -n 50
```
## Inference Issues
### Slow Generation
**Problem**: Generation is slower than expected
**Solutions**:
1. **Enable GPU offload**:
```bash
./llama-cli -m model.gguf -ngl 35 -p "Hello"
```
2. **Optimize batch size**:
```python
llm = Llama(
model_path="model.gguf",
n_batch=512, # Increase for faster prompt processing
n_gpu_layers=35
)
```
3. **Use appropriate threads**:
```bash
# Match physical cores, not logical
./llama-cli -m model.gguf -t 8 -p "Hello"
```
4. **Enable Flash Attention** (if supported):
```bash
./llama-cli -m model.gguf -ngl 35 --flash-attn -p "Hello"
```
### Out of Memory
**Error**: `CUDA out of memory` or system freeze
**Solutions**:
1. **Reduce GPU layers**:
```python
# Start low and increase
llm = Llama(model_path="model.gguf", n_gpu_layers=10)
```
2. **Use smaller quantization**:
```bash
./llama-quantize model-f16.gguf model-q3_k_m.gguf Q3_K_M
```
3. **Reduce context length**:
```python
llm = Llama(
model_path="model.gguf",
n_ctx=2048, # Reduce from 4096
n_gpu_layers=35
)
```
4. **Quantize KV cache**:
```python
llm = Llama(
model_path="model.gguf",
type_k=2, # Q4_0 for K cache
type_v=2, # Q4_0 for V cache
n_gpu_layers=35
)
```
### Garbage Output
**Problem**: Model outputs random characters or nonsense
**Diagnose**:
```python
# Check model loading
llm = Llama(model_path="model.gguf", verbose=True)
# Test with simple prompt
output = llm("1+1=", max_tokens=5, temperature=0)
print(output)
```
**Solutions**:
1. **Check model integrity**:
```bash
# Verify GGUF file
./llama-cli -m model.gguf --verbose 2>&1 | head -50
```
2. **Use correct chat format**:
```python
llm = Llama(
model_path="model.gguf",
chat_format="llama-3" # Match your model: chatml, mistral, etc.
)
```
3. **Check temperature**:
```python
# Use lower temperature for deterministic output
output = llm("Hello", max_tokens=50, temperature=0.1)
```
### Token Issues
**Error**: `RuntimeError: unknown token` or encoding errors
**Fix**:
```python
# Ensure UTF-8 encoding
prompt = "Hello, world!".encode('utf-8').decode('utf-8')
output = llm(prompt, max_tokens=50)
```
## Server Issues
### Connection Refused
**Error**: `Connection refused` when accessing server
**Fix**:
```bash
# Bind to all interfaces
./llama-server -m model.gguf --host 0.0.0.0 --port 8080
# Check if port is in use
lsof -i :8080
```
### Server Crashes Under Load
**Problem**: Server crashes with multiple concurrent requests
**Solutions**:
1. **Limit parallelism**:
```bash
./llama-server -m model.gguf \
--parallel 2 \
-c 4096 \
--cont-batching
```
2. **Add request timeout**:
```bash
./llama-server -m model.gguf --timeout 300
```
3. **Monitor memory**:
```bash
watch -n 1 nvidia-smi # For GPU
watch -n 1 free -h # For RAM
```
### API Compatibility Issues
**Problem**: OpenAI client not working with server
**Fix**:
```python
from openai import OpenAI
# Use correct base URL format
client = OpenAI(
base_url="http://localhost:8080/v1", # Include /v1
api_key="not-needed"
)
# Use correct model name
response = client.chat.completions.create(
model="local", # Or the actual model name
messages=[{"role": "user", "content": "Hello"}]
)
```
## Apple Silicon Issues
### Metal Not Working
**Problem**: Metal acceleration not enabled
**Check**:
```bash
# Verify Metal support
./llama-cli -m model.gguf --verbose 2>&1 | grep -i metal
```
**Fix**:
```bash
# Rebuild with Metal
make clean
make GGML_METAL=1
# Python bindings
CMAKE_ARGS="-DGGML_METAL=on" pip install llama-cpp-python --force-reinstall
```
### Incorrect Memory Usage on M1/M2
**Problem**: Model uses too much unified memory
**Fix**:
```python
# Offload all layers for Metal
llm = Llama(
model_path="model.gguf",
n_gpu_layers=99, # Offload everything
n_threads=1 # Metal handles parallelism
)
```
## Debugging
### Enable Verbose Output
```bash
# CLI verbose mode
./llama-cli -m model.gguf --verbose -p "Hello" -n 50
# Python verbose
llm = Llama(model_path="model.gguf", verbose=True)
```
### Check Model Metadata
```bash
# View GGUF metadata
./llama-cli -m model.gguf --verbose 2>&1 | head -100
```
### Validate GGUF File
```python
import struct
def validate_gguf(filepath):
with open(filepath, 'rb') as f:
magic = f.read(4)
if magic != b'GGUF':
print(f"Invalid magic: {magic}")
return False
version = struct.unpack('<I', f.read(4))[0]
print(f"GGUF version: {version}")
tensor_count = struct.unpack('<Q', f.read(8))[0]
metadata_count = struct.unpack('<Q', f.read(8))[0]
print(f"Tensors: {tensor_count}, Metadata: {metadata_count}")
return True
validate_gguf("model.gguf")
```
## Getting Help
1. **GitHub Issues**: https://github.com/ggml-org/llama.cpp/issues
2. **Discussions**: https://github.com/ggml-org/llama.cpp/discussions
3. **Reddit**: r/LocalLLaMA
### Reporting Issues
Include:
- llama.cpp version/commit hash
- Build command used
- Model name and quantization
- Full error message/stack trace
- Hardware: CPU/GPU model, RAM, VRAM
- OS version
- Minimal reproduction steps

View File

@@ -0,0 +1,97 @@
# GRPO/RL Training Skill
**Expert-level guidance for Group Relative Policy Optimization with TRL**
## 📁 Skill Structure
```
grpo-rl-training/
├── SKILL.md # Main skill documentation (READ THIS FIRST)
├── README.md # This file
├── templates/
│ └── basic_grpo_training.py # Production-ready training template
└── examples/
└── reward_functions_library.py # 20+ reward function examples
```
## 🚀 Quick Start
1. **Read SKILL.md** - Comprehensive guide with all concepts and patterns
2. **Copy `templates/basic_grpo_training.py`** - Start with working code
3. **Browse `examples/reward_functions_library.py`** - Pick reward functions for your task
4. **Modify for your use case** - Adapt dataset, rewards, and config
## 💡 What's Inside
### SKILL.md (Main Documentation)
- Core GRPO concepts and algorithm fundamentals
- Complete implementation workflow (dataset → rewards → training → deployment)
- 10+ reward function examples with code
- Hyperparameter tuning guide
- Training insights (loss behavior, metrics, debugging)
- Troubleshooting guide
- Production best practices
### Templates
- **basic_grpo_training.py**: Minimal, production-ready training script
- Uses Qwen 2.5 1.5B Instruct
- 3 reward functions (format + correctness)
- LoRA for efficient training
- Fully documented and ready to run
### Examples
- **reward_functions_library.py**: 20+ battle-tested reward functions
- Correctness rewards (exact match, fuzzy match, numeric, code execution)
- Format rewards (XML, JSON, strict/soft)
- Length rewards (ideal length, min/max)
- Style rewards (reasoning quality, citations, repetition penalty)
- Combined rewards (multi-objective optimization)
- Preset collections for common tasks
## 📖 Usage for Agents
When this skill is loaded in your agent's context:
1. **Always read SKILL.md first** before implementing
2. **Start simple** - Use length-based reward to validate setup
3. **Build incrementally** - Add one reward function at a time
4. **Reference examples** - Copy patterns from reward_functions_library.py
5. **Monitor training** - Watch reward metrics (not loss!)
## 🎯 Common Use Cases
| Task Type | Recommended Rewards | Template |
|-----------|---------------------|----------|
| Math reasoning | `MATH_REASONING_REWARDS` preset | basic_grpo_training.py |
| Code generation | `CODE_GENERATION_REWARDS` preset | Modify dataset in template |
| Summarization | `SUMMARIZATION_REWARDS` preset | Adjust prompts + rewards |
| Q&A | `QA_REWARDS` preset | Use fuzzy match + citations |
## ⚠️ Critical Reminders
- **Loss goes UP during training** - This is normal (it's KL divergence)
- **Use 3-5 reward functions** - Single rewards often fail
- **Test rewards before training** - Debug each function independently
- **Monitor reward_std** - Should stay > 0.1 (avoid mode collapse)
- **Start with num_generations=4-8** - Scale up if GPU allows
## 🔗 External Resources
- [TRL Documentation](https://huggingface.co/docs/trl)
- [DeepSeek R1 Paper](https://arxiv.org/abs/2501.12948)
- [Open R1 Implementation](https://github.com/huggingface/open-r1)
- [Unsloth (2-3x faster)](https://docs.unsloth.ai/)
## 📝 Version
**v1.0.0** - Initial release (January 2025)
## 👨‍💻 Maintained By
Orchestra Research
For questions or improvements, see https://orchestra.com
---
**License:** MIT
**Last Updated:** January 2025

View File

@@ -0,0 +1,575 @@
---
name: grpo-rl-training
description: Expert guidance for GRPO/RL fine-tuning with TRL for reasoning and task-specific model training
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [transformers>=4.47.0, trl>=0.14.0, datasets>=3.2.0, peft>=0.14.0, torch]
metadata:
hermes:
tags: [Post-Training, Reinforcement Learning, GRPO, TRL, RLHF, Reward Modeling, Reasoning, DPO, PPO, Structured Output]
---
# GRPO/RL Training with TRL
Expert-level guidance for implementing Group Relative Policy Optimization (GRPO) using the Transformer Reinforcement Learning (TRL) library. This skill provides battle-tested patterns, critical insights, and production-ready workflows for fine-tuning language models with custom reward functions.
## When to Use This Skill
Use GRPO training when you need to:
- **Enforce specific output formats** (e.g., XML tags, JSON, structured reasoning)
- **Teach verifiable tasks** with objective correctness metrics (math, coding, fact-checking)
- **Improve reasoning capabilities** by rewarding chain-of-thought patterns
- **Align models to domain-specific behaviors** without labeled preference data
- **Optimize for multiple objectives** simultaneously (format + correctness + style)
**Do NOT use GRPO for:**
- Simple supervised fine-tuning tasks (use SFT instead)
- Tasks without clear reward signals
- When you already have high-quality preference pairs (use DPO/PPO instead)
---
## Core Concepts
### 1. GRPO Algorithm Fundamentals
**Key Mechanism:**
- Generates **multiple completions** for each prompt (group size: 4-16)
- Compares completions within each group using reward functions
- Updates policy to favor higher-rewarded responses relative to the group
**Critical Difference from PPO:**
- No separate reward model needed
- More sample-efficient (learns from within-group comparisons)
- Simpler to implement and debug
**Mathematical Intuition:**
```
For each prompt p:
1. Generate N completions: {c₁, c₂, ..., cₙ}
2. Compute rewards: {r₁, r₂, ..., rₙ}
3. Learn to increase probability of high-reward completions
relative to low-reward ones in the same group
```
### 2. Reward Function Design Philosophy
**Golden Rules:**
1. **Compose multiple reward functions** - Each handles one aspect (format, correctness, style)
2. **Scale rewards appropriately** - Higher weight = stronger signal
3. **Use incremental rewards** - Partial credit for partial compliance
4. **Test rewards independently** - Debug each reward function in isolation
**Reward Function Types:**
| Type | Use Case | Example Weight |
|------|----------|----------------|
| **Correctness** | Verifiable tasks (math, code) | 2.0 (highest) |
| **Format** | Strict structure enforcement | 0.5-1.0 |
| **Length** | Encourage verbosity/conciseness | 0.1-0.5 |
| **Style** | Penalize unwanted patterns | -0.5 to 0.5 |
---
## Implementation Workflow
### Step 1: Dataset Preparation
**Critical Requirements:**
- Prompts in chat format (list of dicts with 'role' and 'content')
- Include system prompts to set expectations
- For verifiable tasks, include ground truth answers as additional columns
**Example Structure:**
```python
from datasets import load_dataset, Dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
[Your step-by-step thinking]
</reasoning>
<answer>
[Final answer]
</answer>
"""
def prepare_dataset(raw_data):
"""
Transform raw data into GRPO-compatible format.
Returns: Dataset with columns:
- 'prompt': List[Dict] with role/content (system + user messages)
- 'answer': str (ground truth, optional but recommended)
"""
return raw_data.map(lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': extract_answer(x['raw_answer'])
})
```
**Pro Tips:**
- Use one-shot or few-shot examples in system prompt for complex formats
- Keep prompts concise (max_prompt_length: 256-512 tokens)
- Validate data quality before training (garbage in = garbage out)
### Step 2: Reward Function Implementation
**Template Structure:**
```python
def reward_function_name(
prompts, # List[List[Dict]]: Original prompts
completions, # List[List[Dict]]: Model generations
answer=None, # Optional: Ground truth from dataset
**kwargs # Additional dataset columns
) -> list[float]:
"""
Evaluate completions and return rewards.
Returns: List of floats (one per completion)
"""
# Extract completion text
responses = [comp[0]['content'] for comp in completions]
# Compute rewards
rewards = []
for response in responses:
score = compute_score(response)
rewards.append(score)
return rewards
```
**Example 1: Correctness Reward (Math/Coding)**
```python
def correctness_reward(prompts, completions, answer, **kwargs):
"""Reward correct answers with high score."""
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_final_answer(r) for r in responses]
return [2.0 if ans == gt else 0.0
for ans, gt in zip(extracted, answer)]
```
**Example 2: Format Reward (Structured Output)**
```python
import re
def format_reward(completions, **kwargs):
"""Reward XML-like structured format."""
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
responses = [comp[0]['content'] for comp in completions]
return [1.0 if re.search(pattern, r, re.DOTALL) else 0.0
for r in responses]
```
**Example 3: Incremental Format Reward (Partial Credit)**
```python
def incremental_format_reward(completions, **kwargs):
"""Award partial credit for format compliance."""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
score = 0.0
if '<reasoning>' in r:
score += 0.25
if '</reasoning>' in r:
score += 0.25
if '<answer>' in r:
score += 0.25
if '</answer>' in r:
score += 0.25
# Penalize extra text after closing tag
if r.count('</answer>') == 1:
extra_text = r.split('</answer>')[-1].strip()
score -= len(extra_text) * 0.001
rewards.append(score)
return rewards
```
**Critical Insight:**
Combine 3-5 reward functions for robust training. Order matters less than diversity of signals.
### Step 3: Training Configuration
**Memory-Optimized Config (Small GPU)**
```python
from trl import GRPOConfig
training_args = GRPOConfig(
output_dir="outputs/grpo-model",
# Learning rate
learning_rate=5e-6, # Lower = more stable
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
# Batch settings
per_device_train_batch_size=1,
gradient_accumulation_steps=4, # Effective batch = 4
# GRPO-specific
num_generations=8, # Group size: 8-16 recommended
max_prompt_length=256,
max_completion_length=512,
# Training duration
num_train_epochs=1,
max_steps=None, # Or set fixed steps (e.g., 500)
# Optimization
bf16=True, # Faster on A100/H100
optim="adamw_8bit", # Memory-efficient optimizer
max_grad_norm=0.1,
# Logging
logging_steps=1,
save_steps=100,
report_to="wandb", # Or "none" for no logging
)
```
**High-Performance Config (Large GPU)**
```python
training_args = GRPOConfig(
output_dir="outputs/grpo-model",
learning_rate=1e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
num_generations=16, # Larger groups = better signal
max_prompt_length=512,
max_completion_length=1024,
num_train_epochs=1,
bf16=True,
use_vllm=True, # Fast generation with vLLM
logging_steps=10,
)
```
**Critical Hyperparameters:**
| Parameter | Impact | Tuning Advice |
|-----------|--------|---------------|
| `num_generations` | Group size for comparison | Start with 8, increase to 16 if GPU allows |
| `learning_rate` | Convergence speed/stability | 5e-6 (safe), 1e-5 (faster, riskier) |
| `max_completion_length` | Output verbosity | Match your task (512 for reasoning, 256 for short answers) |
| `gradient_accumulation_steps` | Effective batch size | Increase if GPU memory limited |
### Step 4: Model Setup and Training
**Standard Setup (Transformers)**
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import GRPOTrainer
# Load model
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2", # 2-3x faster
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# Optional: LoRA for parameter-efficient training
peft_config = LoraConfig(
r=16, # Rank (higher = more capacity)
lora_alpha=32, # Scaling factor (typically 2*r)
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
# Initialize trainer
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
incremental_format_reward,
format_reward,
correctness_reward,
],
args=training_args,
train_dataset=dataset,
peft_config=peft_config, # Remove for full fine-tuning
)
# Train
trainer.train()
# Save
trainer.save_model("final_model")
```
**Unsloth Setup (2-3x Faster)**
```python
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="google/gemma-3-1b-it",
max_seq_length=1024,
load_in_4bit=True,
fast_inference=True,
max_lora_rank=32,
)
model = FastLanguageModel.get_peft_model(
model,
r=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=32,
use_gradient_checkpointing="unsloth",
)
# Rest is identical to standard setup
trainer = GRPOTrainer(model=model, ...)
trainer.train()
```
---
## Critical Training Insights
### 1. Loss Behavior (EXPECTED PATTERN)
- **Loss starts near 0 and INCREASES during training**
- This is CORRECT - loss measures KL divergence from initial policy
- Model is learning (diverging from original behavior to optimize rewards)
- Monitor reward metrics instead of loss for progress
### 2. Reward Tracking
Key metrics to watch:
- `reward`: Average across all completions
- `reward_std`: Diversity within groups (should remain > 0)
- `kl`: KL divergence from reference (should grow moderately)
**Healthy Training Pattern:**
```
Step Reward Reward_Std KL
100 0.5 0.3 0.02
200 0.8 0.25 0.05
300 1.2 0.2 0.08 ← Good progression
400 1.5 0.15 0.12
```
**Warning Signs:**
- Reward std → 0 (model collapsing to single response)
- KL exploding (> 0.5) (diverging too much, reduce LR)
- Reward stuck (reward functions too harsh or model capacity issue)
### 3. Common Pitfalls and Solutions
| Problem | Symptom | Solution |
|---------|---------|----------|
| **Mode collapse** | All completions identical | Increase `num_generations`, add diversity penalty |
| **No learning** | Flat rewards | Check reward function logic, increase LR |
| **OOM errors** | GPU memory exceeded | Reduce `num_generations`, enable gradient checkpointing |
| **Slow training** | < 1 it/s | Enable `use_vllm=True`, use Unsloth, reduce seq length |
| **Format ignored** | Model doesn't follow structure | Increase format reward weight, add incremental rewards |
---
## Advanced Patterns
### 1. Multi-Stage Training
For complex tasks, train in stages:
```python
# Stage 1: Format compliance (epochs=1)
trainer_stage1 = GRPOTrainer(
model=model,
reward_funcs=[incremental_format_reward, format_reward],
...
)
trainer_stage1.train()
# Stage 2: Correctness (epochs=1)
trainer_stage2 = GRPOTrainer(
model=model,
reward_funcs=[format_reward, correctness_reward],
...
)
trainer_stage2.train()
```
### 2. Adaptive Reward Scaling
```python
class AdaptiveReward:
def __init__(self, base_reward_func, initial_weight=1.0):
self.func = base_reward_func
self.weight = initial_weight
def __call__(self, *args, **kwargs):
rewards = self.func(*args, **kwargs)
return [r * self.weight for r in rewards]
def adjust_weight(self, success_rate):
"""Increase weight if model struggling, decrease if succeeding."""
if success_rate < 0.3:
self.weight *= 1.2
elif success_rate > 0.8:
self.weight *= 0.9
```
### 3. Custom Dataset Integration
```python
def load_custom_knowledge_base(csv_path):
"""Example: School communication platform docs."""
import pandas as pd
df = pd.read_csv(csv_path)
dataset = Dataset.from_pandas(df).map(lambda x: {
'prompt': [
{'role': 'system', 'content': CUSTOM_SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': x['expert_answer']
})
return dataset
```
---
## Deployment and Inference
### Save and Merge LoRA
```python
# Merge LoRA adapters into base model
if hasattr(trainer.model, 'merge_and_unload'):
merged_model = trainer.model.merge_and_unload()
merged_model.save_pretrained("production_model")
tokenizer.save_pretrained("production_model")
```
### Inference Example
```python
from transformers import pipeline
generator = pipeline(
"text-generation",
model="production_model",
tokenizer=tokenizer
)
result = generator(
[
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': "What is 15 + 27?"}
],
max_new_tokens=256,
do_sample=True,
temperature=0.7,
top_p=0.9
)
print(result[0]['generated_text'])
```
---
## Best Practices Checklist
**Before Training:**
- [ ] Validate dataset format (prompts as List[Dict])
- [ ] Test reward functions on sample data
- [ ] Calculate expected max_prompt_length from data
- [ ] Choose appropriate num_generations based on GPU memory
- [ ] Set up logging (wandb recommended)
**During Training:**
- [ ] Monitor reward progression (should increase)
- [ ] Check reward_std (should stay > 0.1)
- [ ] Watch for OOM errors (reduce batch size if needed)
- [ ] Sample generations every 50-100 steps
- [ ] Validate format compliance on holdout set
**After Training:**
- [ ] Merge LoRA weights if using PEFT
- [ ] Test on diverse prompts
- [ ] Compare to baseline model
- [ ] Document reward weights and hyperparameters
- [ ] Save reproducibility config
---
## Troubleshooting Guide
### Debugging Workflow
1. **Isolate reward functions** - Test each independently
2. **Check data distribution** - Ensure diversity in prompts
3. **Reduce complexity** - Start with single reward, add gradually
4. **Monitor generations** - Print samples every N steps
5. **Validate extraction logic** - Ensure answer parsing works
### Quick Fixes
```python
# Debug reward function
def debug_reward(completions, **kwargs):
responses = [comp[0]['content'] for comp in completions]
for i, r in enumerate(responses[:2]): # Print first 2
print(f"Response {i}: {r[:200]}...")
return [1.0] * len(responses) # Dummy rewards
# Test without training
trainer = GRPOTrainer(..., reward_funcs=[debug_reward])
trainer.generate_completions(dataset[:1]) # Generate without updating
```
---
## References and Resources
**Official Documentation:**
- TRL GRPO Trainer: https://huggingface.co/docs/trl/grpo_trainer
- DeepSeek R1 Paper: https://arxiv.org/abs/2501.12948
- Unsloth Docs: https://docs.unsloth.ai/
**Example Repositories:**
- Open R1 Implementation: https://github.com/huggingface/open-r1
- TRL Examples: https://github.com/huggingface/trl/tree/main/examples
**Recommended Reading:**
- Progressive Disclosure Pattern for agent instructions
- Reward shaping in RL (Ng et al.)
- LoRA paper (Hu et al., 2021)
---
## Usage Instructions for Agents
When this skill is loaded:
1. **Read this entire file** before implementing GRPO training
2. **Start with the simplest reward function** (e.g., length-based) to validate setup
3. **Use the templates** in `templates/` directory as starting points
4. **Reference examples** in `examples/` for task-specific implementations
5. **Follow the workflow** sequentially (don't skip steps)
6. **Debug incrementally** - add one reward function at a time
**Critical Reminders:**
- Always use multiple reward functions (3-5 is optimal)
- Monitor reward metrics, not loss
- Test reward functions before training
- Start small (num_generations=4), scale up gradually
- Save checkpoints frequently (every 100 steps)
This skill is designed for **expert-level implementation**. Beginners should start with supervised fine-tuning before attempting GRPO.

View File

@@ -0,0 +1,228 @@
"""
Basic GRPO Training Template
=============================
A minimal, production-ready template for GRPO training with TRL.
Adapt this for your specific task by modifying:
1. Dataset loading (get_dataset function)
2. Reward functions (reward_*_func)
3. System prompt (SYSTEM_PROMPT)
4. Hyperparameters (GRPOConfig)
"""
import torch
import re
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from trl import GRPOTrainer, GRPOConfig
# ==================== CONFIGURATION ====================
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
OUTPUT_DIR = "outputs/grpo-model"
MAX_PROMPT_LENGTH = 256
MAX_COMPLETION_LENGTH = 512
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
[Your step-by-step thinking]
</reasoning>
<answer>
[Final answer]
</answer>
"""
# ==================== DATASET ====================
def get_dataset(split="train"):
"""
Load and prepare your dataset.
Returns: Dataset with columns:
- 'prompt': List[Dict] with role/content
- 'answer': str (ground truth, optional)
"""
# Example: GSM8K math dataset
data = load_dataset('openai/gsm8k', 'main')[split]
def process_example(x):
# Extract ground truth answer
answer = x['answer'].split('####')[1].strip() if '####' in x['answer'] else None
return {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['question']}
],
'answer': answer
}
return data.map(process_example)
# ==================== HELPER FUNCTIONS ====================
def extract_xml_tag(text: str, tag: str) -> str:
"""Extract content between XML tags."""
pattern = f'<{tag}>(.*?)</{tag}>'
match = re.search(pattern, text, re.DOTALL)
return match.group(1).strip() if match else ""
def extract_answer(text: str) -> str:
"""Extract the final answer from structured output."""
return extract_xml_tag(text, 'answer')
# ==================== REWARD FUNCTIONS ====================
def correctness_reward_func(prompts, completions, answer, **kwargs):
"""
Reward correct answers.
Weight: 2.0 (highest priority)
"""
responses = [comp[0]['content'] for comp in completions]
extracted = [extract_answer(r) for r in responses]
return [2.0 if ans == gt else 0.0 for ans, gt in zip(extracted, answer)]
def format_reward_func(completions, **kwargs):
"""
Reward proper XML format.
Weight: 0.5
"""
pattern = r'<reasoning>.*?</reasoning>\s*<answer>.*?</answer>'
responses = [comp[0]['content'] for comp in completions]
return [0.5 if re.search(pattern, r, re.DOTALL) else 0.0 for r in responses]
def incremental_format_reward_func(completions, **kwargs):
"""
Incremental reward for partial format compliance.
Weight: up to 0.5
"""
responses = [comp[0]['content'] for comp in completions]
rewards = []
for r in responses:
score = 0.0
if '<reasoning>' in r:
score += 0.125
if '</reasoning>' in r:
score += 0.125
if '<answer>' in r:
score += 0.125
if '</answer>' in r:
score += 0.125
# Penalize extra content after closing tag
if '</answer>' in r:
extra = r.split('</answer>')[-1].strip()
score -= len(extra) * 0.001
rewards.append(score)
return rewards
# ==================== MODEL SETUP ====================
def setup_model_and_tokenizer():
"""Load model and tokenizer with optimizations."""
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def get_peft_config():
"""LoRA configuration for parameter-efficient training."""
return LoraConfig(
r=16,
lora_alpha=32,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
task_type="CAUSAL_LM",
lora_dropout=0.05,
)
# ==================== TRAINING ====================
def main():
"""Main training function."""
# Load data
print("Loading dataset...")
dataset = get_dataset()
print(f"Dataset size: {len(dataset)}")
# Setup model
print("Loading model...")
model, tokenizer = setup_model_and_tokenizer()
# Training configuration
training_args = GRPOConfig(
output_dir=OUTPUT_DIR,
run_name="grpo-training",
# Learning rate
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type='cosine',
# Batch settings
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
# GRPO specific
num_generations=8,
max_prompt_length=MAX_PROMPT_LENGTH,
max_completion_length=MAX_COMPLETION_LENGTH,
# Training duration
num_train_epochs=1,
# Optimization
bf16=True,
optim="adamw_8bit",
max_grad_norm=0.1,
# Logging
logging_steps=1,
save_steps=100,
report_to="wandb", # Change to "none" to disable logging
)
# Initialize trainer
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[
incremental_format_reward_func,
format_reward_func,
correctness_reward_func,
],
args=training_args,
train_dataset=dataset,
peft_config=get_peft_config(),
)
# Train
print("Starting training...")
trainer.train()
# Save final model
print(f"Saving model to {OUTPUT_DIR}/final")
trainer.save_model(f"{OUTPUT_DIR}/final")
print("Training complete!")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,575 @@
---
name: guidance
description: Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [guidance, transformers]
metadata:
hermes:
tags: [Prompt Engineering, Guidance, Constrained Generation, Structured Output, JSON Validation, Grammar, Microsoft Research, Format Enforcement, Multi-Step Workflows]
---
# Guidance: Constrained LLM Generation
## When to Use This Skill
Use Guidance when you need to:
- **Control LLM output syntax** with regex or grammars
- **Guarantee valid JSON/XML/code** generation
- **Reduce latency** vs traditional prompting approaches
- **Enforce structured formats** (dates, emails, IDs, etc.)
- **Build multi-step workflows** with Pythonic control flow
- **Prevent invalid outputs** through grammatical constraints
**GitHub Stars**: 18,000+ | **From**: Microsoft Research
## Installation
```bash
# Base installation
pip install guidance
# With specific backends
pip install guidance[transformers] # Hugging Face models
pip install guidance[llama_cpp] # llama.cpp models
```
## Quick Start
### Basic Example: Structured Generation
```python
from guidance import models, gen
# Load model (supports OpenAI, Transformers, llama.cpp)
lm = models.OpenAI("gpt-4")
# Generate with constraints
result = lm + "The capital of France is " + gen("capital", max_tokens=5)
print(result["capital"]) # "Paris"
```
### With Anthropic Claude
```python
from guidance import models, gen, system, user, assistant
# Configure Claude
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Use context managers for chat format
with system():
lm += "You are a helpful assistant."
with user():
lm += "What is the capital of France?"
with assistant():
lm += gen(max_tokens=20)
```
## Core Concepts
### 1. Context Managers
Guidance uses Pythonic context managers for chat-style interactions.
```python
from guidance import system, user, assistant, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# System message
with system():
lm += "You are a JSON generation expert."
# User message
with user():
lm += "Generate a person object with name and age."
# Assistant response
with assistant():
lm += gen("response", max_tokens=100)
print(lm["response"])
```
**Benefits:**
- Natural chat flow
- Clear role separation
- Easy to read and maintain
### 2. Constrained Generation
Guidance ensures outputs match specified patterns using regex or grammars.
#### Regex Constraints
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Constrain to valid email format
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
# Constrain to date format (YYYY-MM-DD)
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
# Constrain to phone number
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
print(lm["email"]) # Guaranteed valid email
print(lm["date"]) # Guaranteed YYYY-MM-DD format
```
**How it works:**
- Regex converted to grammar at token level
- Invalid tokens filtered during generation
- Model can only produce matching outputs
#### Selection Constraints
```python
from guidance import models, gen, select
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Constrain to specific choices
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
# Multiple-choice selection
lm += "Best answer: " + select(
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
name="answer"
)
print(lm["sentiment"]) # One of: positive, negative, neutral
print(lm["answer"]) # One of: A, B, C, or D
```
### 3. Token Healing
Guidance automatically "heals" token boundaries between prompt and generation.
**Problem:** Tokenization creates unnatural boundaries.
```python
# Without token healing
prompt = "The capital of France is "
# Last token: " is "
# First generated token might be " Par" (with leading space)
# Result: "The capital of France is Paris" (double space!)
```
**Solution:** Guidance backs up one token and regenerates.
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Token healing enabled by default
lm += "The capital of France is " + gen("capital", max_tokens=5)
# Result: "The capital of France is Paris" (correct spacing)
```
**Benefits:**
- Natural text boundaries
- No awkward spacing issues
- Better model performance (sees natural token sequences)
### 4. Grammar-Based Generation
Define complex structures using context-free grammars.
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# JSON grammar (simplified)
json_grammar = """
{
"name": <gen name regex="[A-Za-z ]+" max_tokens=20>,
"age": <gen age regex="[0-9]+" max_tokens=3>,
"email": <gen email regex="[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}" max_tokens=50>
}
"""
# Generate valid JSON
lm += gen("person", grammar=json_grammar)
print(lm["person"]) # Guaranteed valid JSON structure
```
**Use cases:**
- Complex structured outputs
- Nested data structures
- Programming language syntax
- Domain-specific languages
### 5. Guidance Functions
Create reusable generation patterns with the `@guidance` decorator.
```python
from guidance import guidance, gen, models
@guidance
def generate_person(lm):
"""Generate a person with name and age."""
lm += "Name: " + gen("name", max_tokens=20, stop="\n")
lm += "\nAge: " + gen("age", regex=r"[0-9]+", max_tokens=3)
return lm
# Use the function
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_person(lm)
print(lm["name"])
print(lm["age"])
```
**Stateful Functions:**
```python
@guidance(stateless=False)
def react_agent(lm, question, tools, max_rounds=5):
"""ReAct agent with tool use."""
lm += f"Question: {question}\n\n"
for i in range(max_rounds):
# Thought
lm += f"Thought {i+1}: " + gen("thought", stop="\n")
# Action
lm += "\nAction: " + select(list(tools.keys()), name="action")
# Execute tool
tool_result = tools[lm["action"]]()
lm += f"\nObservation: {tool_result}\n\n"
# Check if done
lm += "Done? " + select(["Yes", "No"], name="done")
if lm["done"] == "Yes":
break
# Final answer
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
return lm
```
## Backend Configuration
### Anthropic Claude
```python
from guidance import models
lm = models.Anthropic(
model="claude-sonnet-4-5-20250929",
api_key="your-api-key" # Or set ANTHROPIC_API_KEY env var
)
```
### OpenAI
```python
lm = models.OpenAI(
model="gpt-4o-mini",
api_key="your-api-key" # Or set OPENAI_API_KEY env var
)
```
### Local Models (Transformers)
```python
from guidance.models import Transformers
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda" # Or "cpu"
)
```
### Local Models (llama.cpp)
```python
from guidance.models import LlamaCpp
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096,
n_gpu_layers=35
)
```
## Common Patterns
### Pattern 1: JSON Generation
```python
from guidance import models, gen, system, user, assistant
lm = models.Anthropic("claude-sonnet-4-5-20250929")
with system():
lm += "You generate valid JSON."
with user():
lm += "Generate a user profile with name, age, and email."
with assistant():
lm += """{
"name": """ + gen("name", regex=r'"[A-Za-z ]+"', max_tokens=30) + """,
"age": """ + gen("age", regex=r"[0-9]+", max_tokens=3) + """,
"email": """ + gen("email", regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"', max_tokens=50) + """
}"""
print(lm) # Valid JSON guaranteed
```
### Pattern 2: Classification
```python
from guidance import models, gen, select
lm = models.Anthropic("claude-sonnet-4-5-20250929")
text = "This product is amazing! I love it."
lm += f"Text: {text}\n"
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]+", max_tokens=3) + "%"
print(f"Sentiment: {lm['sentiment']}")
print(f"Confidence: {lm['confidence']}%")
```
### Pattern 3: Multi-Step Reasoning
```python
from guidance import models, gen, guidance
@guidance
def chain_of_thought(lm, question):
"""Generate answer with step-by-step reasoning."""
lm += f"Question: {question}\n\n"
# Generate multiple reasoning steps
for i in range(3):
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
# Final answer
lm += "\nTherefore, the answer is: " + gen("answer", max_tokens=50)
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = chain_of_thought(lm, "What is 15% of 200?")
print(lm["answer"])
```
### Pattern 4: ReAct Agent
```python
from guidance import models, gen, select, guidance
@guidance(stateless=False)
def react_agent(lm, question):
"""ReAct agent with tool use."""
tools = {
"calculator": lambda expr: eval(expr),
"search": lambda query: f"Search results for: {query}",
}
lm += f"Question: {question}\n\n"
for round in range(5):
# Thought
lm += f"Thought: " + gen("thought", stop="\n") + "\n"
# Action selection
lm += "Action: " + select(["calculator", "search", "answer"], name="action")
if lm["action"] == "answer":
lm += "\nFinal Answer: " + gen("answer", max_tokens=100)
break
# Action input
lm += "\nAction Input: " + gen("action_input", stop="\n") + "\n"
# Execute tool
if lm["action"] in tools:
result = tools[lm["action"]](lm["action_input"])
lm += f"Observation: {result}\n\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = react_agent(lm, "What is 25 * 4 + 10?")
print(lm["answer"])
```
### Pattern 5: Data Extraction
```python
from guidance import models, gen, guidance
@guidance
def extract_entities(lm, text):
"""Extract structured entities from text."""
lm += f"Text: {text}\n\n"
# Extract person
lm += "Person: " + gen("person", stop="\n", max_tokens=30) + "\n"
# Extract organization
lm += "Organization: " + gen("organization", stop="\n", max_tokens=30) + "\n"
# Extract date
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}", max_tokens=10) + "\n"
# Extract location
lm += "Location: " + gen("location", stop="\n", max_tokens=30) + "\n"
return lm
text = "Tim Cook announced at Apple Park on 2024-09-15 in Cupertino."
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = extract_entities(lm, text)
print(f"Person: {lm['person']}")
print(f"Organization: {lm['organization']}")
print(f"Date: {lm['date']}")
print(f"Location: {lm['location']}")
```
## Best Practices
### 1. Use Regex for Format Validation
```python
# ✅ Good: Regex ensures valid format
lm += "Email: " + gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
# ❌ Bad: Free generation may produce invalid emails
lm += "Email: " + gen("email", max_tokens=50)
```
### 2. Use select() for Fixed Categories
```python
# ✅ Good: Guaranteed valid category
lm += "Status: " + select(["pending", "approved", "rejected"], name="status")
# ❌ Bad: May generate typos or invalid values
lm += "Status: " + gen("status", max_tokens=20)
```
### 3. Leverage Token Healing
```python
# Token healing is enabled by default
# No special action needed - just concatenate naturally
lm += "The capital is " + gen("capital") # Automatic healing
```
### 4. Use stop Sequences
```python
# ✅ Good: Stop at newline for single-line outputs
lm += "Name: " + gen("name", stop="\n")
# ❌ Bad: May generate multiple lines
lm += "Name: " + gen("name", max_tokens=50)
```
### 5. Create Reusable Functions
```python
# ✅ Good: Reusable pattern
@guidance
def generate_person(lm):
lm += "Name: " + gen("name", stop="\n")
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
return lm
# Use multiple times
lm = generate_person(lm)
lm += "\n\n"
lm = generate_person(lm)
```
### 6. Balance Constraints
```python
# ✅ Good: Reasonable constraints
lm += gen("name", regex=r"[A-Za-z ]+", max_tokens=30)
# ❌ Too strict: May fail or be very slow
lm += gen("name", regex=r"^(John|Jane)$", max_tokens=10)
```
## Comparison to Alternatives
| Feature | Guidance | Instructor | Outlines | LMQL |
|---------|----------|------------|----------|------|
| Regex Constraints | ✅ Yes | ❌ No | ✅ Yes | ✅ Yes |
| Grammar Support | ✅ CFG | ❌ No | ✅ CFG | ✅ CFG |
| Pydantic Validation | ❌ No | ✅ Yes | ✅ Yes | ❌ No |
| Token Healing | ✅ Yes | ❌ No | ✅ Yes | ❌ No |
| Local Models | ✅ Yes | ⚠️ Limited | ✅ Yes | ✅ Yes |
| API Models | ✅ Yes | ✅ Yes | ⚠️ Limited | ✅ Yes |
| Pythonic Syntax | ✅ Yes | ✅ Yes | ✅ Yes | ❌ SQL-like |
| Learning Curve | Low | Low | Medium | High |
**When to choose Guidance:**
- Need regex/grammar constraints
- Want token healing
- Building complex workflows with control flow
- Using local models (Transformers, llama.cpp)
- Prefer Pythonic syntax
**When to choose alternatives:**
- Instructor: Need Pydantic validation with automatic retrying
- Outlines: Need JSON schema validation
- LMQL: Prefer declarative query syntax
## Performance Characteristics
**Latency Reduction:**
- 30-50% faster than traditional prompting for constrained outputs
- Token healing reduces unnecessary regeneration
- Grammar constraints prevent invalid token generation
**Memory Usage:**
- Minimal overhead vs unconstrained generation
- Grammar compilation cached after first use
- Efficient token filtering at inference time
**Token Efficiency:**
- Prevents wasted tokens on invalid outputs
- No need for retry loops
- Direct path to valid outputs
## Resources
- **Documentation**: https://guidance.readthedocs.io
- **GitHub**: https://github.com/guidance-ai/guidance (18k+ stars)
- **Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
- **Discord**: Community support available
## See Also
- `references/constraints.md` - Comprehensive regex and grammar patterns
- `references/backends.md` - Backend-specific configuration
- `references/examples.md` - Production-ready examples

View File

@@ -0,0 +1,554 @@
# Backend Configuration Guide
Complete guide to configuring Guidance with different LLM backends.
## Table of Contents
- API-Based Models (Anthropic, OpenAI)
- Local Models (Transformers, llama.cpp)
- Backend Comparison
- Performance Tuning
- Advanced Configuration
## API-Based Models
### Anthropic Claude
#### Basic Setup
```python
from guidance import models
# Using environment variable
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Reads ANTHROPIC_API_KEY from environment
# Explicit API key
lm = models.Anthropic(
model="claude-sonnet-4-5-20250929",
api_key="your-api-key-here"
)
```
#### Available Models
```python
# Claude 3.5 Sonnet (Latest, recommended)
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Claude 3.7 Sonnet (Fast, cost-effective)
lm = models.Anthropic("claude-sonnet-3.7-20250219")
# Claude 3 Opus (Most capable)
lm = models.Anthropic("claude-3-opus-20240229")
# Claude 3.5 Haiku (Fastest, cheapest)
lm = models.Anthropic("claude-3-5-haiku-20241022")
```
#### Configuration Options
```python
lm = models.Anthropic(
model="claude-sonnet-4-5-20250929",
api_key="your-api-key",
max_tokens=4096, # Max tokens to generate
temperature=0.7, # Sampling temperature (0-1)
top_p=0.9, # Nucleus sampling
timeout=30, # Request timeout (seconds)
max_retries=3 # Retry failed requests
)
```
#### With Context Managers
```python
from guidance import models, system, user, assistant, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
with system():
lm += "You are a helpful assistant."
with user():
lm += "What is the capital of France?"
with assistant():
lm += gen(max_tokens=50)
print(lm)
```
### OpenAI
#### Basic Setup
```python
from guidance import models
# Using environment variable
lm = models.OpenAI("gpt-4o")
# Reads OPENAI_API_KEY from environment
# Explicit API key
lm = models.OpenAI(
model="gpt-4o",
api_key="your-api-key-here"
)
```
#### Available Models
```python
# GPT-4o (Latest, multimodal)
lm = models.OpenAI("gpt-4o")
# GPT-4o Mini (Fast, cost-effective)
lm = models.OpenAI("gpt-4o-mini")
# GPT-4 Turbo
lm = models.OpenAI("gpt-4-turbo")
# GPT-3.5 Turbo (Cheapest)
lm = models.OpenAI("gpt-3.5-turbo")
```
#### Configuration Options
```python
lm = models.OpenAI(
model="gpt-4o-mini",
api_key="your-api-key",
max_tokens=2048,
temperature=0.7,
top_p=1.0,
frequency_penalty=0.0,
presence_penalty=0.0,
timeout=30
)
```
#### Chat Format
```python
from guidance import models, gen
lm = models.OpenAI("gpt-4o-mini")
# OpenAI uses chat format
lm += [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is 2+2?"}
]
# Generate response
lm += gen(max_tokens=50)
```
### Azure OpenAI
```python
from guidance import models
lm = models.AzureOpenAI(
model="gpt-4o",
azure_endpoint="https://your-resource.openai.azure.com/",
api_key="your-azure-api-key",
api_version="2024-02-15-preview",
deployment_name="your-deployment-name"
)
```
## Local Models
### Transformers (Hugging Face)
#### Basic Setup
```python
from guidance.models import Transformers
# Load model from Hugging Face
lm = Transformers("microsoft/Phi-4-mini-instruct")
```
#### GPU Configuration
```python
# Use GPU
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda"
)
# Use specific GPU
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda:0" # GPU 0
)
# Use CPU
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cpu"
)
```
#### Advanced Configuration
```python
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda",
torch_dtype="float16", # Use FP16 (faster, less memory)
load_in_8bit=True, # 8-bit quantization
max_memory={0: "20GB"}, # GPU memory limit
offload_folder="./offload" # Offload to disk if needed
)
```
#### Popular Models
```python
# Phi-4 (Microsoft)
lm = Transformers("microsoft/Phi-4-mini-instruct")
lm = Transformers("microsoft/Phi-3-medium-4k-instruct")
# Llama 3 (Meta)
lm = Transformers("meta-llama/Llama-3.1-8B-Instruct")
lm = Transformers("meta-llama/Llama-3.1-70B-Instruct")
# Mistral (Mistral AI)
lm = Transformers("mistralai/Mistral-7B-Instruct-v0.3")
lm = Transformers("mistralai/Mixtral-8x7B-Instruct-v0.1")
# Qwen (Alibaba)
lm = Transformers("Qwen/Qwen2.5-7B-Instruct")
# Gemma (Google)
lm = Transformers("google/gemma-2-9b-it")
```
#### Generation Configuration
```python
lm = Transformers(
"microsoft/Phi-4-mini-instruct",
device="cuda"
)
# Configure generation
from guidance import gen
result = lm + gen(
max_tokens=100,
temperature=0.7,
top_p=0.9,
top_k=50,
repetition_penalty=1.1
)
```
### llama.cpp
#### Basic Setup
```python
from guidance.models import LlamaCpp
# Load GGUF model
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096 # Context window
)
```
#### GPU Configuration
```python
# Use GPU acceleration
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096,
n_gpu_layers=35, # Offload 35 layers to GPU
n_threads=8 # CPU threads for remaining layers
)
# Full GPU offload
lm = LlamaCpp(
model_path="/path/to/model.gguf",
n_ctx=4096,
n_gpu_layers=-1 # Offload all layers
)
```
#### Advanced Configuration
```python
lm = LlamaCpp(
model_path="/path/to/llama-3.1-8b-instruct.Q4_K_M.gguf",
n_ctx=8192, # Context window (tokens)
n_gpu_layers=35, # GPU layers
n_threads=8, # CPU threads
n_batch=512, # Batch size for prompt processing
use_mmap=True, # Memory-map the model file
use_mlock=False, # Lock model in RAM
seed=42, # Random seed
verbose=False # Suppress verbose output
)
```
#### Quantized Models
```python
# Q4_K_M (4-bit, recommended for most cases)
lm = LlamaCpp("/path/to/model.Q4_K_M.gguf")
# Q5_K_M (5-bit, better quality)
lm = LlamaCpp("/path/to/model.Q5_K_M.gguf")
# Q8_0 (8-bit, high quality)
lm = LlamaCpp("/path/to/model.Q8_0.gguf")
# F16 (16-bit float, highest quality)
lm = LlamaCpp("/path/to/model.F16.gguf")
```
#### Popular GGUF Models
```python
# Llama 3.1
lm = LlamaCpp("llama-3.1-8b-instruct.Q4_K_M.gguf")
# Mistral
lm = LlamaCpp("mistral-7b-instruct-v0.3.Q4_K_M.gguf")
# Phi-4
lm = LlamaCpp("phi-4-mini-instruct.Q4_K_M.gguf")
```
## Backend Comparison
### Feature Matrix
| Feature | Anthropic | OpenAI | Transformers | llama.cpp |
|---------|-----------|--------|--------------|-----------|
| Constrained Generation | ✅ Full | ✅ Full | ✅ Full | ✅ Full |
| Token Healing | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
| Streaming | ✅ Yes | ✅ Yes | ✅ Yes | ✅ Yes |
| GPU Support | N/A | N/A | ✅ Yes | ✅ Yes |
| Quantization | N/A | N/A | ✅ Yes | ✅ Yes |
| Cost | $$$ | $$$ | Free | Free |
| Latency | Low | Low | Medium | Low |
| Setup Difficulty | Easy | Easy | Medium | Medium |
### Performance Characteristics
**Anthropic Claude:**
- **Latency**: 200-500ms (API call)
- **Throughput**: Limited by API rate limits
- **Cost**: $3-15 per 1M input tokens
- **Best for**: Production systems, high-quality outputs
**OpenAI:**
- **Latency**: 200-400ms (API call)
- **Throughput**: Limited by API rate limits
- **Cost**: $0.15-30 per 1M input tokens
- **Best for**: Cost-sensitive production, gpt-4o-mini
**Transformers:**
- **Latency**: 50-200ms (local inference)
- **Throughput**: GPU-dependent (10-100 tokens/sec)
- **Cost**: Hardware cost only
- **Best for**: Privacy-sensitive, high-volume, experimentation
**llama.cpp:**
- **Latency**: 30-150ms (local inference)
- **Throughput**: Hardware-dependent (20-150 tokens/sec)
- **Cost**: Hardware cost only
- **Best for**: Edge deployment, Apple Silicon, CPU inference
### Memory Requirements
**Transformers (FP16):**
- 7B model: ~14GB GPU VRAM
- 13B model: ~26GB GPU VRAM
- 70B model: ~140GB GPU VRAM (multi-GPU)
**llama.cpp (Q4_K_M):**
- 7B model: ~4.5GB RAM
- 13B model: ~8GB RAM
- 70B model: ~40GB RAM
**Optimization Tips:**
- Use quantized models (Q4_K_M) for lower memory
- Use GPU offloading for faster inference
- Use CPU inference for smaller models (<7B)
## Performance Tuning
### API Models (Anthropic, OpenAI)
#### Reduce Latency
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Use lower max_tokens (faster response)
lm += gen(max_tokens=100) # Instead of 1000
# Use streaming (perceived latency reduction)
for chunk in lm.stream(gen(max_tokens=500)):
print(chunk, end="", flush=True)
```
#### Reduce Cost
```python
# Use cheaper models
lm = models.Anthropic("claude-3-5-haiku-20241022") # vs Sonnet
lm = models.OpenAI("gpt-4o-mini") # vs gpt-4o
# Reduce context size
# - Keep prompts concise
# - Avoid large few-shot examples
# - Use max_tokens limits
```
### Local Models (Transformers, llama.cpp)
#### Optimize GPU Usage
```python
from guidance.models import Transformers
# Use FP16 for 2x speedup
lm = Transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
torch_dtype="float16"
)
# Use 8-bit quantization for 4x memory reduction
lm = Transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
load_in_8bit=True
)
# Use flash attention (requires flash-attn package)
lm = Transformers(
"meta-llama/Llama-3.1-8B-Instruct",
device="cuda",
use_flash_attention_2=True
)
```
#### Optimize llama.cpp
```python
from guidance.models import LlamaCpp
# Maximize GPU layers
lm = LlamaCpp(
model_path="/path/to/model.Q4_K_M.gguf",
n_gpu_layers=-1 # All layers on GPU
)
# Optimize batch size
lm = LlamaCpp(
model_path="/path/to/model.Q4_K_M.gguf",
n_batch=512, # Larger batch = faster prompt processing
n_gpu_layers=-1
)
# Use Metal (Apple Silicon)
lm = LlamaCpp(
model_path="/path/to/model.Q4_K_M.gguf",
n_gpu_layers=-1, # Use Metal GPU acceleration
use_mmap=True
)
```
#### Batch Processing
```python
# Process multiple requests efficiently
requests = [
"What is 2+2?",
"What is the capital of France?",
"What is photosynthesis?"
]
# Bad: Sequential processing
for req in requests:
lm = Transformers("microsoft/Phi-4-mini-instruct")
lm += req + gen(max_tokens=50)
# Good: Reuse loaded model
lm = Transformers("microsoft/Phi-4-mini-instruct")
for req in requests:
lm += req + gen(max_tokens=50)
```
## Advanced Configuration
### Custom Model Configurations
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
from guidance.models import Transformers
# Load custom model
tokenizer = AutoTokenizer.from_pretrained("your-model")
model = AutoModelForCausalLM.from_pretrained(
"your-model",
device_map="auto",
torch_dtype="float16"
)
# Use with Guidance
lm = Transformers(model=model, tokenizer=tokenizer)
```
### Environment Variables
```bash
# API keys
export ANTHROPIC_API_KEY="sk-ant-..."
export OPENAI_API_KEY="sk-..."
# Transformers cache
export HF_HOME="/path/to/cache"
export TRANSFORMERS_CACHE="/path/to/cache"
# GPU selection
export CUDA_VISIBLE_DEVICES=0,1 # Use GPU 0 and 1
```
### Debugging
```python
# Enable verbose logging
import logging
logging.basicConfig(level=logging.DEBUG)
# Check backend info
lm = models.Anthropic("claude-sonnet-4-5-20250929")
print(f"Model: {lm.model_name}")
print(f"Backend: {lm.backend}")
# Check GPU usage (Transformers)
lm = Transformers("microsoft/Phi-4-mini-instruct", device="cuda")
print(f"Device: {lm.device}")
print(f"Memory allocated: {torch.cuda.memory_allocated() / 1e9:.2f} GB")
```
## Resources
- **Anthropic Docs**: https://docs.anthropic.com
- **OpenAI Docs**: https://platform.openai.com/docs
- **Hugging Face Models**: https://huggingface.co/models
- **llama.cpp**: https://github.com/ggerganov/llama.cpp
- **GGUF Models**: https://huggingface.co/models?library=gguf

View File

@@ -0,0 +1,674 @@
# Comprehensive Constraint Patterns
Guide to regex constraints, grammar-based generation, and token healing in Guidance.
## Table of Contents
- Regex Constraints
- Grammar-Based Generation
- Token Healing
- Selection Constraints
- Complex Patterns
- Performance Optimization
## Regex Constraints
### Basic Patterns
#### Numeric Constraints
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Integer (positive)
lm += "Age: " + gen("age", regex=r"[0-9]+")
# Integer (with negatives)
lm += "Temperature: " + gen("temp", regex=r"-?[0-9]+")
# Float (positive)
lm += "Price: $" + gen("price", regex=r"[0-9]+\.[0-9]{2}")
# Float (with negatives and optional decimals)
lm += "Value: " + gen("value", regex=r"-?[0-9]+(\.[0-9]+)?")
# Percentage (0-100)
lm += "Progress: " + gen("progress", regex=r"(100|[0-9]{1,2})")
# Range (1-5 stars)
lm += "Rating: " + gen("rating", regex=r"[1-5]") + " stars"
```
#### Text Constraints
```python
# Alphabetic only
lm += "Name: " + gen("name", regex=r"[A-Za-z]+")
# Alphabetic with spaces
lm += "Full Name: " + gen("full_name", regex=r"[A-Za-z ]+")
# Alphanumeric
lm += "Username: " + gen("username", regex=r"[A-Za-z0-9_]+")
# Capitalized words
lm += "Title: " + gen("title", regex=r"[A-Z][a-z]+( [A-Z][a-z]+)*")
# Lowercase only
lm += "Code: " + gen("code", regex=r"[a-z0-9-]+")
# Specific length
lm += "ID: " + gen("id", regex=r"[A-Z]{3}-[0-9]{6}") # e.g., "ABC-123456"
```
#### Date and Time Constraints
```python
# Date (YYYY-MM-DD)
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}")
# Date (MM/DD/YYYY)
lm += "Date: " + gen("date_us", regex=r"\d{2}/\d{2}/\d{4}")
# Time (HH:MM)
lm += "Time: " + gen("time", regex=r"\d{2}:\d{2}")
# Time (HH:MM:SS)
lm += "Time: " + gen("time_full", regex=r"\d{2}:\d{2}:\d{2}")
# ISO 8601 datetime
lm += "Timestamp: " + gen(
"timestamp",
regex=r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z"
)
# Year (YYYY)
lm += "Year: " + gen("year", regex=r"(19|20)\d{2}")
# Month name
lm += "Month: " + gen(
"month",
regex=r"(January|February|March|April|May|June|July|August|September|October|November|December)"
)
```
#### Contact Information
```python
# Email
lm += "Email: " + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
)
# Phone (US format)
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}")
# Phone (international format)
lm += "Phone: " + gen("phone_intl", regex=r"\+[0-9]{1,3}-[0-9]{1,14}")
# ZIP code (US)
lm += "ZIP: " + gen("zip", regex=r"\d{5}(-\d{4})?")
# Postal code (Canada)
lm += "Postal: " + gen("postal", regex=r"[A-Z]\d[A-Z] \d[A-Z]\d")
# URL
lm += "URL: " + gen(
"url",
regex=r"https?://[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}(/[a-zA-Z0-9._~:/?#\[\]@!$&'()*+,;=-]*)?"
)
```
### Advanced Patterns
#### JSON Field Constraints
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# String field with quotes
lm += '"name": ' + gen("name", regex=r'"[A-Za-z ]+"')
# Numeric field (no quotes)
lm += '"age": ' + gen("age", regex=r"[0-9]+")
# Boolean field
lm += '"active": ' + gen("active", regex=r"(true|false)")
# Null field
lm += '"optional": ' + gen("optional", regex=r"(null|[0-9]+)")
# Array of strings
lm += '"tags": [' + gen(
"tags",
regex=r'"[a-z]+"(, "[a-z]+")*'
) + ']'
# Complete JSON object
lm += """{
"name": """ + gen("name", regex=r'"[A-Za-z ]+"') + """,
"age": """ + gen("age", regex=r"[0-9]+") + """,
"email": """ + gen(
"email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + """
}"""
```
#### Code Patterns
```python
# Python variable name
lm += "Variable: " + gen("var", regex=r"[a-z_][a-z0-9_]*")
# Python function name
lm += "Function: " + gen("func", regex=r"[a-z_][a-z0-9_]*")
# Hex color code
lm += "Color: #" + gen("color", regex=r"[0-9A-Fa-f]{6}")
# UUID
lm += "UUID: " + gen(
"uuid",
regex=r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
)
# Git commit hash (short)
lm += "Commit: " + gen("commit", regex=r"[0-9a-f]{7}")
# Semantic version
lm += "Version: " + gen("version", regex=r"[0-9]+\.[0-9]+\.[0-9]+")
# IP address (IPv4)
lm += "IP: " + gen(
"ip",
regex=r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)"
)
```
#### Domain-Specific Patterns
```python
# Credit card number
lm += "Card: " + gen("card", regex=r"\d{4}-\d{4}-\d{4}-\d{4}")
# Social Security Number (US)
lm += "SSN: " + gen("ssn", regex=r"\d{3}-\d{2}-\d{4}")
# ISBN-13
lm += "ISBN: " + gen("isbn", regex=r"978-\d{1,5}-\d{1,7}-\d{1,7}-\d")
# License plate (US)
lm += "Plate: " + gen("plate", regex=r"[A-Z]{3}-\d{4}")
# Currency amount
lm += "Amount: $" + gen("amount", regex=r"[0-9]{1,3}(,[0-9]{3})*\.[0-9]{2}")
# Percentage with decimal
lm += "Rate: " + gen("rate", regex=r"[0-9]+\.[0-9]{1,2}%")
```
## Grammar-Based Generation
### JSON Grammar
```python
from guidance import models, gen, guidance
@guidance
def json_object(lm):
"""Generate valid JSON object."""
lm += "{\n"
# Name field (required)
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
# Age field (required)
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
# Email field (required)
lm += ' "email": ' + gen(
"email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + ",\n"
# Active field (required, boolean)
lm += ' "active": ' + gen("active", regex=r"(true|false)") + "\n"
lm += "}"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = json_object(lm)
print(lm) # Valid JSON guaranteed
```
### Nested JSON Grammar
```python
@guidance
def nested_json(lm):
"""Generate nested JSON structure."""
lm += "{\n"
# User object
lm += ' "user": {\n'
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + "\n"
lm += " },\n"
# Address object
lm += ' "address": {\n'
lm += ' "street": ' + gen("street", regex=r'"[A-Za-z0-9 ]+"') + ",\n"
lm += ' "city": ' + gen("city", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "zip": ' + gen("zip", regex=r'"\d{5}"') + "\n"
lm += " }\n"
lm += "}"
return lm
```
### Array Grammar
```python
@guidance
def json_array(lm, count=3):
"""Generate JSON array with fixed count."""
lm += "[\n"
for i in range(count):
lm += " {\n"
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + "\n"
lm += " }"
if i < count - 1:
lm += ","
lm += "\n"
lm += "]"
return lm
```
### XML Grammar
```python
@guidance
def xml_document(lm):
"""Generate valid XML document."""
lm += '<?xml version="1.0"?>\n'
lm += "<person>\n"
# Name element
lm += " <name>" + gen("name", regex=r"[A-Za-z ]+") + "</name>\n"
# Age element
lm += " <age>" + gen("age", regex=r"[0-9]+") + "</age>\n"
# Email element
lm += " <email>" + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"
) + "</email>\n"
lm += "</person>"
return lm
```
### CSV Grammar
```python
@guidance
def csv_row(lm):
"""Generate CSV row."""
lm += gen("name", regex=r"[A-Za-z ]+") + ","
lm += gen("age", regex=r"[0-9]+") + ","
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
return lm
@guidance
def csv_document(lm, rows=5):
"""Generate complete CSV."""
# Header
lm += "Name,Age,Email\n"
# Rows
for i in range(rows):
lm = csv_row(lm)
if i < rows - 1:
lm += "\n"
return lm
```
## Token Healing
### How Token Healing Works
**Problem:** Tokenization creates unnatural boundaries.
```python
# Example without token healing
prompt = "The capital of France is "
# Tokenization: ["The", " capital", " of", " France", " is", " "]
# Model sees last token: " "
# First generated token might include leading space: " Paris"
# Result: "The capital of France is Paris" (double space)
```
**Solution:** Guidance backs up and regenerates the last token.
```python
from guidance import models, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Token healing enabled by default
lm += "The capital of France is " + gen("capital", max_tokens=5)
# Process:
# 1. Back up to token before " is "
# 2. Regenerate " is" + "capital" together
# 3. Result: "The capital of France is Paris" (correct)
```
### Token Healing Examples
#### Natural Continuations
```python
# Before token healing
lm += "The function name is get" + gen("rest")
# Might generate: "The function name is get User" (space before User)
# With token healing
lm += "The function name is get" + gen("rest")
# Generates: "The function name is getUser" (correct camelCase)
```
#### Code Generation
```python
# Function name completion
lm += "def calculate_" + gen("rest", stop="(")
# Token healing ensures smooth connection: "calculate_total"
# Variable name completion
lm += "my_" + gen("var_name", regex=r"[a-z_]+")
# Token healing ensures: "my_variable_name" (not "my_ variable_name")
```
#### Domain-Specific Terms
```python
# Medical terms
lm += "The patient has hyper" + gen("condition")
# Token healing helps: "hypertension" (not "hyper tension")
# Technical terms
lm += "Using micro" + gen("tech")
# Token healing helps: "microservices" (not "micro services")
```
### Disabling Token Healing
```python
# Disable token healing if needed (rare)
lm += gen("text", token_healing=False)
```
## Selection Constraints
### Basic Selection
```python
from guidance import models, select
lm = models.Anthropic("claude-sonnet-4-5-20250929")
# Simple selection
lm += "Status: " + select(["active", "inactive", "pending"], name="status")
# Boolean selection
lm += "Approved: " + select(["Yes", "No"], name="approved")
# Multiple choice
lm += "Answer: " + select(
["A) Paris", "B) London", "C) Berlin", "D) Madrid"],
name="answer"
)
```
### Conditional Selection
```python
from guidance import models, select, gen, guidance
@guidance
def conditional_fields(lm):
"""Generate fields conditionally based on type."""
lm += "Type: " + select(["person", "company"], name="type")
if lm["type"] == "person":
lm += "\nName: " + gen("name", regex=r"[A-Za-z ]+")
lm += "\nAge: " + gen("age", regex=r"[0-9]+")
else:
lm += "\nCompany Name: " + gen("company", regex=r"[A-Za-z ]+")
lm += "\nEmployees: " + gen("employees", regex=r"[0-9]+")
return lm
```
### Repeated Selection
```python
@guidance
def multiple_selections(lm):
"""Select multiple items."""
lm += "Select 3 colors:\n"
colors = ["red", "blue", "green", "yellow", "purple"]
for i in range(3):
lm += f"{i+1}. " + select(colors, name=f"color_{i}") + "\n"
return lm
```
## Complex Patterns
### Pattern 1: Structured Forms
```python
@guidance
def user_form(lm):
"""Generate structured user form."""
lm += "=== User Registration ===\n\n"
# Name (alphabetic only)
lm += "Full Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Age (numeric)
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
# Email (validated format)
lm += "Email: " + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
stop="\n"
) + "\n"
# Phone (US format)
lm += "Phone: " + gen("phone", regex=r"\d{3}-\d{3}-\d{4}") + "\n"
# Account type (selection)
lm += "Account Type: " + select(
["Standard", "Premium", "Enterprise"],
name="account_type"
) + "\n"
# Active status (boolean)
lm += "Active: " + select(["Yes", "No"], name="active") + "\n"
return lm
```
### Pattern 2: Multi-Entity Extraction
```python
@guidance
def extract_entities(lm, text):
"""Extract multiple entities with constraints."""
lm += f"Text: {text}\n\n"
# Person name (alphabetic)
lm += "Person: " + gen("person", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Organization (alphanumeric with spaces)
lm += "Organization: " + gen(
"organization",
regex=r"[A-Za-z0-9 ]+",
stop="\n"
) + "\n"
# Date (YYYY-MM-DD format)
lm += "Date: " + gen("date", regex=r"\d{4}-\d{2}-\d{2}") + "\n"
# Location (alphabetic with spaces)
lm += "Location: " + gen("location", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Amount (currency)
lm += "Amount: $" + gen("amount", regex=r"[0-9,]+\.[0-9]{2}") + "\n"
return lm
```
### Pattern 3: Code Generation
```python
@guidance
def generate_python_function(lm):
"""Generate Python function with constraints."""
# Function name (valid Python identifier)
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
# Parameter name
lm += gen("param", regex=r"[a-z_][a-z0-9_]*") + "):\n"
# Docstring
lm += ' """' + gen("docstring", stop='"""', max_tokens=50) + '"""\n'
# Function body (constrained to valid Python)
lm += " return " + gen("return_value", stop="\n") + "\n"
return lm
```
### Pattern 4: Hierarchical Data
```python
@guidance
def org_chart(lm):
"""Generate organizational chart."""
lm += "Company: " + gen("company", regex=r"[A-Za-z ]+") + "\n\n"
# CEO
lm += "CEO: " + gen("ceo", regex=r"[A-Za-z ]+") + "\n"
# Departments
for dept in ["Engineering", "Sales", "Marketing"]:
lm += f"\n{dept} Department:\n"
lm += " Head: " + gen(f"{dept.lower()}_head", regex=r"[A-Za-z ]+") + "\n"
lm += " Size: " + gen(f"{dept.lower()}_size", regex=r"[0-9]+") + " employees\n"
return lm
```
## Performance Optimization
### Best Practices
#### 1. Use Specific Patterns
```python
# ✅ Good: Specific pattern
lm += gen("age", regex=r"[0-9]{1,3}") # Fast
# ❌ Bad: Overly broad pattern
lm += gen("age", regex=r"[0-9]+") # Slower
```
#### 2. Limit Max Tokens
```python
# ✅ Good: Reasonable limit
lm += gen("name", max_tokens=30)
# ❌ Bad: No limit
lm += gen("name") # May generate forever
```
#### 3. Use stop Sequences
```python
# ✅ Good: Stop at newline
lm += gen("line", stop="\n")
# ❌ Bad: Rely on max_tokens
lm += gen("line", max_tokens=100)
```
#### 4. Cache Compiled Grammars
```python
# Grammars are cached automatically after first use
# No manual caching needed
@guidance
def reusable_pattern(lm):
"""This grammar is compiled once and cached."""
lm += gen("email", regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")
return lm
# First call: compiles grammar
lm = reusable_pattern(lm)
# Subsequent calls: uses cached grammar (fast)
lm = reusable_pattern(lm)
```
#### 5. Avoid Overlapping Constraints
```python
# ✅ Good: Clear constraints
lm += gen("age", regex=r"[0-9]+", max_tokens=3)
# ❌ Bad: Conflicting constraints
lm += gen("age", regex=r"[0-9]{2}", max_tokens=10) # max_tokens unnecessary
```
### Performance Benchmarks
**Regex vs Free Generation:**
- Simple regex (digits): ~1.2x slower than free gen
- Complex regex (email): ~1.5x slower than free gen
- Grammar-based: ~2x slower than free gen
**But:**
- 100% valid outputs (vs ~70% with free gen + validation)
- No retry loops needed
- Overall faster end-to-end for structured outputs
**Optimization Tips:**
- Use regex for critical fields only
- Use `select()` for small fixed sets (fastest)
- Use `stop` sequences when possible (faster than max_tokens)
- Cache compiled grammars by reusing functions
## Resources
- **Token Healing Paper**: https://arxiv.org/abs/2306.17648
- **Guidance Docs**: https://guidance.readthedocs.io
- **GitHub**: https://github.com/guidance-ai/guidance

View File

@@ -0,0 +1,767 @@
# Production-Ready Examples
Real-world examples of using Guidance for structured generation, agents, and workflows.
## Table of Contents
- JSON Generation
- Data Extraction
- Classification Systems
- Agent Systems
- Multi-Step Workflows
- Code Generation
- Production Tips
## JSON Generation
### Basic JSON
```python
from guidance import models, gen, guidance
@guidance
def generate_user(lm):
"""Generate valid user JSON."""
lm += "{\n"
lm += ' "name": ' + gen("name", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "age": ' + gen("age", regex=r"[0-9]+") + ",\n"
lm += ' "email": ' + gen(
"email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + "\n"
lm += "}"
return lm
# Use it
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm += "Generate a user profile:\n"
lm = generate_user(lm)
print(lm)
# Output: Valid JSON guaranteed
```
### Nested JSON
```python
@guidance
def generate_order(lm):
"""Generate nested order JSON."""
lm += "{\n"
# Customer info
lm += ' "customer": {\n'
lm += ' "name": ' + gen("customer_name", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "email": ' + gen(
"customer_email",
regex=r'"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}"'
) + "\n"
lm += " },\n"
# Order details
lm += ' "order": {\n'
lm += ' "id": ' + gen("order_id", regex=r'"ORD-[0-9]{6}"') + ",\n"
lm += ' "date": ' + gen("order_date", regex=r'"\d{4}-\d{2}-\d{2}"') + ",\n"
lm += ' "total": ' + gen("order_total", regex=r"[0-9]+\.[0-9]{2}") + "\n"
lm += " },\n"
# Status
lm += ' "status": ' + gen(
"status",
regex=r'"(pending|processing|shipped|delivered)"'
) + "\n"
lm += "}"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_order(lm)
```
### JSON Array
```python
@guidance
def generate_user_list(lm, count=3):
"""Generate JSON array of users."""
lm += "[\n"
for i in range(count):
lm += " {\n"
lm += ' "id": ' + gen(f"id_{i}", regex=r"[0-9]+") + ",\n"
lm += ' "name": ' + gen(f"name_{i}", regex=r'"[A-Za-z ]+"') + ",\n"
lm += ' "active": ' + gen(f"active_{i}", regex=r"(true|false)") + "\n"
lm += " }"
if i < count - 1:
lm += ","
lm += "\n"
lm += "]"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_user_list(lm, count=5)
```
### Dynamic JSON Schema
```python
import json
from guidance import models, gen, guidance
@guidance
def json_from_schema(lm, schema):
"""Generate JSON matching a schema."""
lm += "{\n"
fields = list(schema["properties"].items())
for i, (field_name, field_schema) in enumerate(fields):
lm += f' "{field_name}": '
# Handle different types
if field_schema["type"] == "string":
if "pattern" in field_schema:
lm += gen(field_name, regex=f'"{field_schema["pattern"]}"')
else:
lm += gen(field_name, regex=r'"[^"]+"')
elif field_schema["type"] == "number":
lm += gen(field_name, regex=r"[0-9]+(\.[0-9]+)?")
elif field_schema["type"] == "integer":
lm += gen(field_name, regex=r"[0-9]+")
elif field_schema["type"] == "boolean":
lm += gen(field_name, regex=r"(true|false)")
if i < len(fields) - 1:
lm += ","
lm += "\n"
lm += "}"
return lm
# Define schema
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"score": {"type": "number"},
"active": {"type": "boolean"}
}
}
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = json_from_schema(lm, schema)
```
## Data Extraction
### Extract from Text
```python
from guidance import models, gen, guidance, system, user, assistant
@guidance
def extract_person_info(lm, text):
"""Extract structured info from text."""
lm += f"Text: {text}\n\n"
with assistant():
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n") + "\n"
lm += "Age: " + gen("age", regex=r"[0-9]+", max_tokens=3) + "\n"
lm += "Occupation: " + gen("occupation", regex=r"[A-Za-z ]+", stop="\n") + "\n"
lm += "Email: " + gen(
"email",
regex=r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
stop="\n"
) + "\n"
return lm
text = "John Smith is a 35-year-old software engineer. Contact: john@example.com"
lm = models.Anthropic("claude-sonnet-4-5-20250929")
with system():
lm += "You extract structured information from text."
with user():
lm = extract_person_info(lm, text)
print(f"Name: {lm['name']}")
print(f"Age: {lm['age']}")
print(f"Occupation: {lm['occupation']}")
print(f"Email: {lm['email']}")
```
### Multi-Entity Extraction
```python
@guidance
def extract_entities(lm, text):
"""Extract multiple entity types."""
lm += f"Analyze: {text}\n\n"
# Person entities
lm += "People:\n"
for i in range(3): # Up to 3 people
lm += f"- " + gen(f"person_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
# Organization entities
lm += "\nOrganizations:\n"
for i in range(2): # Up to 2 orgs
lm += f"- " + gen(f"org_{i}", regex=r"[A-Za-z0-9 ]+", stop="\n") + "\n"
# Dates
lm += "\nDates:\n"
for i in range(2): # Up to 2 dates
lm += f"- " + gen(f"date_{i}", regex=r"\d{4}-\d{2}-\d{2}", stop="\n") + "\n"
# Locations
lm += "\nLocations:\n"
for i in range(2): # Up to 2 locations
lm += f"- " + gen(f"location_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
return lm
text = """
Tim Cook and Satya Nadella met at Microsoft headquarters in Redmond on 2024-09-15
to discuss the collaboration between Apple and Microsoft. The meeting continued
in Cupertino on 2024-09-20.
"""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = extract_entities(lm, text)
```
### Batch Extraction
```python
@guidance
def batch_extract(lm, texts):
"""Extract from multiple texts."""
lm += "Batch Extraction Results:\n\n"
for i, text in enumerate(texts):
lm += f"=== Item {i+1} ===\n"
lm += f"Text: {text}\n"
lm += "Name: " + gen(f"name_{i}", regex=r"[A-Za-z ]+", stop="\n") + "\n"
lm += "Sentiment: " + gen(
f"sentiment_{i}",
regex=r"(positive|negative|neutral)",
stop="\n"
) + "\n\n"
return lm
texts = [
"Alice is happy with the product",
"Bob is disappointed with the service",
"Carol has no strong feelings either way"
]
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = batch_extract(lm, texts)
```
## Classification Systems
### Sentiment Analysis
```python
from guidance import models, select, gen
lm = models.Anthropic("claude-sonnet-4-5-20250929")
text = "This product is absolutely amazing! Best purchase ever."
lm += f"Text: {text}\n\n"
lm += "Sentiment: " + select(
["positive", "negative", "neutral"],
name="sentiment"
)
lm += "\nConfidence: " + gen("confidence", regex=r"[0-9]{1,3}") + "%\n"
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=50)
print(f"Sentiment: {lm['sentiment']}")
print(f"Confidence: {lm['confidence']}%")
print(f"Reasoning: {lm['reasoning']}")
```
### Multi-Label Classification
```python
@guidance
def classify_article(lm, text):
"""Classify article with multiple labels."""
lm += f"Article: {text}\n\n"
# Primary category
lm += "Primary Category: " + select(
["Technology", "Business", "Science", "Politics", "Entertainment"],
name="primary_category"
) + "\n"
# Secondary categories (up to 3)
lm += "\nSecondary Categories:\n"
categories = ["Technology", "Business", "Science", "Politics", "Entertainment"]
for i in range(3):
lm += f"{i+1}. " + select(categories, name=f"secondary_{i}") + "\n"
# Tags
lm += "\nTags: " + gen("tags", stop="\n", max_tokens=50) + "\n"
# Target audience
lm += "Target Audience: " + select(
["General", "Expert", "Beginner"],
name="audience"
)
return lm
article = """
Apple announced new AI features in iOS 18, leveraging machine learning to improve
battery life and performance. The company's stock rose 5% following the announcement.
"""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = classify_article(lm, article)
```
### Intent Classification
```python
@guidance
def classify_intent(lm, message):
"""Classify user intent."""
lm += f"User Message: {message}\n\n"
# Intent
lm += "Intent: " + select(
["question", "complaint", "request", "feedback", "other"],
name="intent"
) + "\n"
# Urgency
lm += "Urgency: " + select(
["low", "medium", "high", "critical"],
name="urgency"
) + "\n"
# Department
lm += "Route To: " + select(
["support", "sales", "billing", "technical"],
name="department"
) + "\n"
# Sentiment
lm += "Sentiment: " + select(
["positive", "neutral", "negative"],
name="sentiment"
)
return lm
message = "My account was charged twice for the same order. Need help ASAP!"
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = classify_intent(lm, message)
print(f"Intent: {lm['intent']}")
print(f"Urgency: {lm['urgency']}")
print(f"Department: {lm['department']}")
```
## Agent Systems
### ReAct Agent
```python
from guidance import models, gen, select, guidance
@guidance(stateless=False)
def react_agent(lm, question, tools, max_rounds=5):
"""ReAct agent with tool use."""
lm += f"Question: {question}\n\n"
for round in range(max_rounds):
# Thought
lm += f"Thought {round+1}: " + gen("thought", stop="\n", max_tokens=100) + "\n"
# Action selection
lm += "Action: " + select(
list(tools.keys()) + ["answer"],
name="action"
)
if lm["action"] == "answer":
lm += "\n\nFinal Answer: " + gen("answer", max_tokens=200)
break
# Action input
lm += "\nAction Input: " + gen("action_input", stop="\n", max_tokens=100) + "\n"
# Execute tool
if lm["action"] in tools:
try:
result = tools[lm["action"]](lm["action_input"])
lm += f"Observation: {result}\n\n"
except Exception as e:
lm += f"Observation: Error - {str(e)}\n\n"
return lm
# Define tools
tools = {
"calculator": lambda expr: eval(expr),
"search": lambda query: f"Search results for '{query}': [Mock results]",
"weather": lambda city: f"Weather in {city}: Sunny, 72°F"
}
# Use agent
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = react_agent(lm, "What is (25 * 4) + 10?", tools)
print(lm["answer"])
```
### Multi-Agent System
```python
@guidance
def coordinator_agent(lm, task):
"""Coordinator that delegates to specialists."""
lm += f"Task: {task}\n\n"
# Determine which specialist to use
lm += "Specialist: " + select(
["researcher", "writer", "coder", "analyst"],
name="specialist"
) + "\n"
lm += "Reasoning: " + gen("reasoning", stop="\n", max_tokens=100) + "\n"
return lm
@guidance
def researcher_agent(lm, query):
"""Research specialist."""
lm += f"Research Query: {query}\n\n"
lm += "Findings:\n"
for i in range(3):
lm += f"{i+1}. " + gen(f"finding_{i}", stop="\n", max_tokens=100) + "\n"
return lm
@guidance
def writer_agent(lm, topic):
"""Writing specialist."""
lm += f"Topic: {topic}\n\n"
lm += "Title: " + gen("title", stop="\n", max_tokens=50) + "\n"
lm += "Content:\n" + gen("content", max_tokens=500)
return lm
# Coordination workflow
task = "Write an article about AI safety"
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = coordinator_agent(lm, task)
specialist = lm["specialist"]
if specialist == "researcher":
lm = researcher_agent(lm, task)
elif specialist == "writer":
lm = writer_agent(lm, task)
```
### Tool Use with Validation
```python
@guidance(stateless=False)
def validated_tool_agent(lm, question):
"""Agent with validated tool calls."""
tools = {
"add": lambda a, b: float(a) + float(b),
"multiply": lambda a, b: float(a) * float(b),
"divide": lambda a, b: float(a) / float(b) if float(b) != 0 else "Error: Division by zero"
}
lm += f"Question: {question}\n\n"
for i in range(5):
# Select tool
lm += "Tool: " + select(list(tools.keys()) + ["done"], name="tool")
if lm["tool"] == "done":
lm += "\nAnswer: " + gen("answer", max_tokens=100)
break
# Get validated numeric arguments
lm += "\nArg1: " + gen("arg1", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
lm += "Arg2: " + gen("arg2", regex=r"-?[0-9]+(\.[0-9]+)?") + "\n"
# Execute
result = tools[lm["tool"]](lm["arg1"], lm["arg2"])
lm += f"Result: {result}\n\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = validated_tool_agent(lm, "What is (10 + 5) * 3?")
```
## Multi-Step Workflows
### Chain of Thought
```python
@guidance
def chain_of_thought(lm, question):
"""Multi-step reasoning with CoT."""
lm += f"Question: {question}\n\n"
# Generate reasoning steps
lm += "Let me think step by step:\n\n"
for i in range(4):
lm += f"Step {i+1}: " + gen(f"step_{i+1}", stop="\n", max_tokens=100) + "\n"
# Final answer
lm += "\nTherefore, the answer is: " + gen("answer", stop="\n", max_tokens=50)
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = chain_of_thought(lm, "If a train travels 60 mph for 2.5 hours, how far does it go?")
print(lm["answer"])
```
### Self-Consistency
```python
@guidance
def self_consistency(lm, question, num_samples=3):
"""Generate multiple reasoning paths and aggregate."""
lm += f"Question: {question}\n\n"
answers = []
for i in range(num_samples):
lm += f"=== Attempt {i+1} ===\n"
lm += "Reasoning: " + gen(f"reasoning_{i}", stop="\n", max_tokens=100) + "\n"
lm += "Answer: " + gen(f"answer_{i}", stop="\n", max_tokens=50) + "\n\n"
answers.append(lm[f"answer_{i}"])
# Aggregate (simple majority vote)
from collections import Counter
most_common = Counter(answers).most_common(1)[0][0]
lm += f"Final Answer (by majority): {most_common}\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = self_consistency(lm, "What is 15% of 200?")
```
### Planning and Execution
```python
@guidance
def plan_and_execute(lm, goal):
"""Plan tasks then execute them."""
lm += f"Goal: {goal}\n\n"
# Planning phase
lm += "Plan:\n"
num_steps = 4
for i in range(num_steps):
lm += f"{i+1}. " + gen(f"plan_step_{i}", stop="\n", max_tokens=100) + "\n"
# Execution phase
lm += "\nExecution:\n\n"
for i in range(num_steps):
lm += f"Step {i+1}: {lm[f'plan_step_{i}']}\n"
lm += "Status: " + select(["completed", "in-progress", "blocked"], name=f"status_{i}") + "\n"
lm += "Result: " + gen(f"result_{i}", stop="\n", max_tokens=150) + "\n\n"
# Summary
lm += "Summary: " + gen("summary", max_tokens=200)
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = plan_and_execute(lm, "Build a REST API for a blog platform")
```
## Code Generation
### Python Function
```python
@guidance
def generate_python_function(lm, description):
"""Generate Python function from description."""
lm += f"Description: {description}\n\n"
# Function signature
lm += "def " + gen("func_name", regex=r"[a-z_][a-z0-9_]*") + "("
lm += gen("params", regex=r"[a-z_][a-z0-9_]*(, [a-z_][a-z0-9_]*)*") + "):\n"
# Docstring
lm += ' """' + gen("docstring", stop='"""', max_tokens=100) + '"""\n'
# Function body
lm += " " + gen("body", stop="\n", max_tokens=200) + "\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_python_function(lm, "Check if a number is prime")
print(lm)
```
### SQL Query
```python
@guidance
def generate_sql(lm, description):
"""Generate SQL query from description."""
lm += f"Description: {description}\n\n"
lm += "SQL Query:\n"
# SELECT clause
lm += "SELECT " + gen("select_clause", stop=" FROM", max_tokens=100)
# FROM clause
lm += " FROM " + gen("from_clause", stop=" WHERE", max_tokens=50)
# WHERE clause (optional)
lm += " WHERE " + gen("where_clause", stop=";", max_tokens=100) + ";"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_sql(lm, "Get all users who signed up in the last 30 days")
```
### API Endpoint
```python
@guidance
def generate_api_endpoint(lm, description):
"""Generate REST API endpoint."""
lm += f"Description: {description}\n\n"
# HTTP method
lm += "Method: " + select(["GET", "POST", "PUT", "DELETE"], name="method") + "\n"
# Path
lm += "Path: /" + gen("path", regex=r"[a-z0-9/-]+", stop="\n") + "\n"
# Request body (if POST/PUT)
if lm["method"] in ["POST", "PUT"]:
lm += "\nRequest Body:\n"
lm += "{\n"
lm += ' "field1": ' + gen("field1", regex=r'"[a-z_]+"') + ",\n"
lm += ' "field2": ' + gen("field2", regex=r'"[a-z_]+"') + "\n"
lm += "}\n"
# Response
lm += "\nResponse (200 OK):\n"
lm += "{\n"
lm += ' "status": "success",\n'
lm += ' "data": ' + gen("response_data", max_tokens=100) + "\n"
lm += "}\n"
return lm
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm = generate_api_endpoint(lm, "Create a new blog post")
```
## Production Tips
### Error Handling
```python
@guidance
def safe_extraction(lm, text):
"""Extract with fallback handling."""
try:
lm += f"Text: {text}\n"
lm += "Name: " + gen("name", regex=r"[A-Za-z ]+", stop="\n", max_tokens=30)
return lm
except Exception as e:
# Fallback to less strict extraction
lm += f"Text: {text}\n"
lm += "Name: " + gen("name", stop="\n", max_tokens=30)
return lm
```
### Caching
```python
from functools import lru_cache
@lru_cache(maxsize=100)
def cached_generation(text):
"""Cache LLM generations."""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
lm += f"Analyze: {text}\n"
lm += "Sentiment: " + select(["positive", "negative", "neutral"], name="sentiment")
return lm["sentiment"]
# First call: hits LLM
result1 = cached_generation("This is great!")
# Second call: returns cached result
result2 = cached_generation("This is great!") # Instant!
```
### Monitoring
```python
import time
@guidance
def monitored_generation(lm, text):
"""Track generation metrics."""
start_time = time.time()
lm += f"Text: {text}\n"
lm += "Analysis: " + gen("analysis", max_tokens=100)
elapsed = time.time() - start_time
# Log metrics
print(f"Generation time: {elapsed:.2f}s")
print(f"Output length: {len(lm['analysis'])} chars")
return lm
```
### Batch Processing
```python
def batch_process(texts, batch_size=10):
"""Process texts in batches."""
lm = models.Anthropic("claude-sonnet-4-5-20250929")
results = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
for text in batch:
lm += f"Text: {text}\n"
lm += "Sentiment: " + select(
["positive", "negative", "neutral"],
name=f"sentiment_{i}"
) + "\n\n"
results.extend([lm[f"sentiment_{i}"] for i in range(len(batch))])
return results
```
## Resources
- **Guidance Notebooks**: https://github.com/guidance-ai/guidance/tree/main/notebooks
- **Guidance Docs**: https://guidance.readthedocs.io
- **Community Examples**: https://github.com/guidance-ai/guidance/discussions

307
skills/mlops/llava/SKILL.md Normal file
View File

@@ -0,0 +1,307 @@
---
name: llava
description: Large Language and Vision Assistant. Enables visual instruction tuning and image-based conversations. Combines CLIP vision encoder with Vicuna/LLaMA language models. Supports multi-turn image chat, visual question answering, and instruction following. Use for vision-language chatbots or image understanding tasks. Best for conversational image analysis.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [transformers, torch, pillow]
metadata:
hermes:
tags: [LLaVA, Vision-Language, Multimodal, Visual Question Answering, Image Chat, CLIP, Vicuna, Conversational AI, Instruction Tuning, VQA]
---
# LLaVA - Large Language and Vision Assistant
Open-source vision-language model for conversational image understanding.
## When to use LLaVA
**Use when:**
- Building vision-language chatbots
- Visual question answering (VQA)
- Image description and captioning
- Multi-turn image conversations
- Visual instruction following
- Document understanding with images
**Metrics**:
- **23,000+ GitHub stars**
- GPT-4V level capabilities (targeted)
- Apache 2.0 License
- Multiple model sizes (7B-34B params)
**Use alternatives instead**:
- **GPT-4V**: Highest quality, API-based
- **CLIP**: Simple zero-shot classification
- **BLIP-2**: Better for captioning only
- **Flamingo**: Research, not open-source
## Quick start
### Installation
```bash
# Clone repository
git clone https://github.com/haotian-liu/LLaVA
cd LLaVA
# Install
pip install -e .
```
### Basic usage
```python
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from PIL import Image
import torch
# Load model
model_path = "liuhaotian/llava-v1.5-7b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path)
)
# Load image
image = Image.open("image.jpg")
image_tensor = process_images([image], image_processor, model.config)
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
# Create conversation
conv = conv_templates["llava_v1"].copy()
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# Generate response
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True,
temperature=0.2,
max_new_tokens=512
)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
print(response)
```
## Available models
| Model | Parameters | VRAM | Quality |
|-------|------------|------|---------|
| LLaVA-v1.5-7B | 7B | ~14 GB | Good |
| LLaVA-v1.5-13B | 13B | ~28 GB | Better |
| LLaVA-v1.6-34B | 34B | ~70 GB | Best |
```python
# Load different models
model_7b = "liuhaotian/llava-v1.5-7b"
model_13b = "liuhaotian/llava-v1.5-13b"
model_34b = "liuhaotian/llava-v1.6-34b"
# 4-bit quantization for lower VRAM
load_4bit = True # Reduces VRAM by ~4×
```
## CLI usage
```bash
# Single image query
python -m llava.serve.cli \
--model-path liuhaotian/llava-v1.5-7b \
--image-file image.jpg \
--query "What is in this image?"
# Multi-turn conversation
python -m llava.serve.cli \
--model-path liuhaotian/llava-v1.5-7b \
--image-file image.jpg
# Then type questions interactively
```
## Web UI (Gradio)
```bash
# Launch Gradio interface
python -m llava.serve.gradio_web_server \
--model-path liuhaotian/llava-v1.5-7b \
--load-4bit # Optional: reduce VRAM
# Access at http://localhost:7860
```
## Multi-turn conversations
```python
# Initialize conversation
conv = conv_templates["llava_v1"].copy()
# Turn 1
conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\nWhat is in this image?")
conv.append_message(conv.roles[1], None)
response1 = generate(conv, model, image) # "A dog playing in a park"
# Turn 2
conv.messages[-1][1] = response1 # Add previous response
conv.append_message(conv.roles[0], "What breed is the dog?")
conv.append_message(conv.roles[1], None)
response2 = generate(conv, model, image) # "Golden Retriever"
# Turn 3
conv.messages[-1][1] = response2
conv.append_message(conv.roles[0], "What time of day is it?")
conv.append_message(conv.roles[1], None)
response3 = generate(conv, model, image)
```
## Common tasks
### Image captioning
```python
question = "Describe this image in detail."
response = ask(model, image, question)
```
### Visual question answering
```python
question = "How many people are in the image?"
response = ask(model, image, question)
```
### Object detection (textual)
```python
question = "List all the objects you can see in this image."
response = ask(model, image, question)
```
### Scene understanding
```python
question = "What is happening in this scene?"
response = ask(model, image, question)
```
### Document understanding
```python
question = "What is the main topic of this document?"
response = ask(model, document_image, question)
```
## Training custom model
```bash
# Stage 1: Feature alignment (558K image-caption pairs)
bash scripts/v1_5/pretrain.sh
# Stage 2: Visual instruction tuning (150K instruction data)
bash scripts/v1_5/finetune.sh
```
## Quantization (reduce VRAM)
```python
# 4-bit quantization
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path="liuhaotian/llava-v1.5-13b",
model_base=None,
model_name=get_model_name_from_path("liuhaotian/llava-v1.5-13b"),
load_4bit=True # Reduces VRAM ~4×
)
# 8-bit quantization
load_8bit=True # Reduces VRAM ~2×
```
## Best practices
1. **Start with 7B model** - Good quality, manageable VRAM
2. **Use 4-bit quantization** - Reduces VRAM significantly
3. **GPU required** - CPU inference extremely slow
4. **Clear prompts** - Specific questions get better answers
5. **Multi-turn conversations** - Maintain conversation context
6. **Temperature 0.2-0.7** - Balance creativity/consistency
7. **max_new_tokens 512-1024** - For detailed responses
8. **Batch processing** - Process multiple images sequentially
## Performance
| Model | VRAM (FP16) | VRAM (4-bit) | Speed (tokens/s) |
|-------|-------------|--------------|------------------|
| 7B | ~14 GB | ~4 GB | ~20 |
| 13B | ~28 GB | ~8 GB | ~12 |
| 34B | ~70 GB | ~18 GB | ~5 |
*On A100 GPU*
## Benchmarks
LLaVA achieves competitive scores on:
- **VQAv2**: 78.5%
- **GQA**: 62.0%
- **MM-Vet**: 35.4%
- **MMBench**: 64.3%
## Limitations
1. **Hallucinations** - May describe things not in image
2. **Spatial reasoning** - Struggles with precise locations
3. **Small text** - Difficulty reading fine print
4. **Object counting** - Imprecise for many objects
5. **VRAM requirements** - Need powerful GPU
6. **Inference speed** - Slower than CLIP
## Integration with frameworks
### LangChain
```python
from langchain.llms.base import LLM
class LLaVALLM(LLM):
def _call(self, prompt, stop=None):
# Custom LLaVA inference
return response
llm = LLaVALLM()
```
### Gradio App
```python
import gradio as gr
def chat(image, text, history):
response = ask_llava(model, image, text)
return response
demo = gr.ChatInterface(
chat,
additional_inputs=[gr.Image(type="pil")],
title="LLaVA Chat"
)
demo.launch()
```
## Resources
- **GitHub**: https://github.com/haotian-liu/LLaVA ⭐ 23,000+
- **Paper**: https://arxiv.org/abs/2304.08485
- **Demo**: https://llava.hliu.cc
- **Models**: https://huggingface.co/liuhaotian
- **License**: Apache 2.0

View File

@@ -0,0 +1,197 @@
# LLaVA Training Guide
Guide to training and fine-tuning LLaVA models.
## Training stages
### Stage 1: Feature alignment (Pretraining)
**Purpose**: Align vision encoder with language model
**Data**: 558K image-caption pairs (CC3M subset)
```bash
# Download pretrained projector or train from scratch
bash scripts/v1_5/pretrain.sh
```
**Configuration:**
- Base model: Vicuna-7B or LLaMA-2-7B
- Vision encoder: CLIP ViT-L/14
- Training time: ~20 hours on 8× A100
### Stage 2: Visual instruction tuning
**Purpose**: Teach model to follow visual instructions
**Data**: 150K GPT-generated multimodal instruction data
```bash
# Fine-tune with instruction data
bash scripts/v1_5/finetune.sh
```
**Configuration:**
- Epochs: 1
- Batch size: 128 (across 8 GPUs)
- Learning rate: 2e-5
- Training time: ~24 hours on 8× A100
## Data format
### Instruction data format
```json
[
{
"id": "001",
"image": "path/to/image.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nWhat is in this image?"
},
{
"from": "gpt",
"value": "The image shows a dog playing in a park."
},
{
"from": "human",
"value": "What breed is the dog?"
},
{
"from": "gpt",
"value": "It appears to be a Golden Retriever."
}
]
}
]
```
## Fine-tuning on custom data
### Prepare your data
```python
import json
# Create instruction data
data = []
for image_path, qa_pairs in your_dataset:
conversations = []
for q, a in qa_pairs:
conversations.append({"from": "human", "value": f"<image>\n{q}"})
conversations.append({"from": "gpt", "value": a})
data.append({
"id": str(len(data)),
"image": image_path,
"conversations": conversations
})
# Save
with open("custom_data.json", "w") as f:
json.dump(data, f, indent=2)
```
### Fine-tune script
```bash
#!/bin/bash
# Set paths
DATA_PATH="custom_data.json"
IMAGE_FOLDER="path/to/images"
MODEL_PATH="liuhaotian/llava-v1.5-7b"
OUTPUT_DIR="./checkpoints/llava-custom"
# Fine-tune
deepspeed llava/train/train_mem.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path $MODEL_PATH \
--version v1 \
--data_path $DATA_PATH \
--image_folder $IMAGE_FOLDER \
--vision_tower openai/clip-vit-large-patch14-336 \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir $OUTPUT_DIR \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 50000 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb
```
## LoRA fine-tuning (memory efficient)
```python
from peft import LoraConfig, get_peft_model
# LoRA config
lora_config = LoraConfig(
r=8, # LoRA rank
lora_alpha=16,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
# Apply LoRA
model = get_peft_model(base_model, lora_config)
# Train with much lower memory
```
## Hardware requirements
### Full fine-tuning
- **7B model**: 8× A100 (40GB)
- **13B model**: 8× A100 (80GB)
- **Training time**: 20-48 hours
### LoRA fine-tuning
- **7B model**: 1× A100 (40GB)
- **13B model**: 2× A100 (40GB)
- **Training time**: 10-24 hours
## Best practices
1. **Start with pretrained** - Don't train from scratch
2. **Use LoRA for efficiency** - 10× less memory
3. **Quality over quantity** - 1K high-quality > 10K low-quality
4. **Multi-turn conversations** - More engaging than single Q&A
5. **Diverse images** - Cover different scenarios
6. **Clear instructions** - Specific questions get better answers
7. **Monitor loss** - Should decrease smoothly
8. **Save checkpoints** - Training can fail
9. **Test regularly** - Validate on held-out set
10. **Use DeepSpeed** - For multi-GPU training
## Resources
- **Training script**: https://github.com/haotian-liu/LLaVA/tree/main/scripts
- **Data format**: https://github.com/haotian-liu/LLaVA/blob/main/docs/Data.md
- **Paper**: https://arxiv.org/abs/2304.08485

View File

@@ -0,0 +1,386 @@
---
name: nemo-curator
description: GPU-accelerated data curation for LLM training. Supports text/image/video/audio. Features fuzzy deduplication (16× faster), quality filtering (30+ heuristics), semantic deduplication, PII redaction, NSFW detection. Scales across GPUs with RAPIDS. Use for preparing high-quality training datasets, cleaning web data, or deduplicating large corpora.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [nemo-curator, cudf, dask, rapids]
metadata:
hermes:
tags: [Data Processing, NeMo Curator, Data Curation, GPU Acceleration, Deduplication, Quality Filtering, NVIDIA, RAPIDS, PII Redaction, Multimodal, LLM Training Data]
---
# NeMo Curator - GPU-Accelerated Data Curation
NVIDIA's toolkit for preparing high-quality training data for LLMs.
## When to use NeMo Curator
**Use NeMo Curator when:**
- Preparing LLM training data from web scrapes (Common Crawl)
- Need fast deduplication (16× faster than CPU)
- Curating multi-modal datasets (text, images, video, audio)
- Filtering low-quality or toxic content
- Scaling data processing across GPU cluster
**Performance**:
- **16× faster** fuzzy deduplication (8TB RedPajama v2)
- **40% lower TCO** vs CPU alternatives
- **Near-linear scaling** across GPU nodes
**Use alternatives instead**:
- **datatrove**: CPU-based, open-source data processing
- **dolma**: Allen AI's data toolkit
- **Ray Data**: General ML data processing (no curation focus)
## Quick start
### Installation
```bash
# Text curation (CUDA 12)
uv pip install "nemo-curator[text_cuda12]"
# All modalities
uv pip install "nemo-curator[all_cuda12]"
# CPU-only (slower)
uv pip install "nemo-curator[cpu]"
```
### Basic text curation pipeline
```python
from nemo_curator import ScoreFilter, Modify
from nemo_curator.datasets import DocumentDataset
import pandas as pd
# Load data
df = pd.DataFrame({"text": ["Good document", "Bad doc", "Excellent text"]})
dataset = DocumentDataset(df)
# Quality filtering
def quality_score(doc):
return len(doc["text"].split()) > 5 # Filter short docs
filtered = ScoreFilter(quality_score)(dataset)
# Deduplication
from nemo_curator.modules import ExactDuplicates
deduped = ExactDuplicates()(filtered)
# Save
deduped.to_parquet("curated_data/")
```
## Data curation pipeline
### Stage 1: Quality filtering
```python
from nemo_curator.filters import (
WordCountFilter,
RepeatedLinesFilter,
UrlRatioFilter,
NonAlphaNumericFilter
)
# Apply 30+ heuristic filters
from nemo_curator import ScoreFilter
# Word count filter
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
# Remove repetitive content
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
# URL ratio filter
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
```
### Stage 2: Deduplication
**Exact deduplication**:
```python
from nemo_curator.modules import ExactDuplicates
# Remove exact duplicates
deduped = ExactDuplicates(id_field="id", text_field="text")(dataset)
```
**Fuzzy deduplication** (16× faster on GPU):
```python
from nemo_curator.modules import FuzzyDuplicates
# MinHash + LSH deduplication
fuzzy_dedup = FuzzyDuplicates(
id_field="id",
text_field="text",
num_hashes=260, # MinHash parameters
num_buckets=20,
hash_method="md5"
)
deduped = fuzzy_dedup(dataset)
```
**Semantic deduplication**:
```python
from nemo_curator.modules import SemanticDuplicates
# Embedding-based deduplication
semantic_dedup = SemanticDuplicates(
id_field="id",
text_field="text",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
threshold=0.8 # Cosine similarity threshold
)
deduped = semantic_dedup(dataset)
```
### Stage 3: PII redaction
```python
from nemo_curator.modules import Modify
from nemo_curator.modifiers import PIIRedactor
# Redact personally identifiable information
pii_redactor = PIIRedactor(
supported_entities=["EMAIL_ADDRESS", "PHONE_NUMBER", "PERSON", "LOCATION"],
anonymize_action="replace" # or "redact"
)
redacted = Modify(pii_redactor)(dataset)
```
### Stage 4: Classifier filtering
```python
from nemo_curator.classifiers import QualityClassifier
# Quality classification
quality_clf = QualityClassifier(
model_path="nvidia/quality-classifier-deberta",
batch_size=256,
device="cuda"
)
# Filter low-quality documents
high_quality = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
```
## GPU acceleration
### GPU vs CPU performance
| Operation | CPU (16 cores) | GPU (A100) | Speedup |
|-----------|----------------|------------|---------|
| Fuzzy dedup (8TB) | 120 hours | 7.5 hours | 16× |
| Exact dedup (1TB) | 8 hours | 0.5 hours | 16× |
| Quality filtering | 2 hours | 0.2 hours | 10× |
### Multi-GPU scaling
```python
from nemo_curator import get_client
import dask_cuda
# Initialize GPU cluster
client = get_client(cluster_type="gpu", n_workers=8)
# Process with 8 GPUs
deduped = FuzzyDuplicates(...)(dataset)
```
## Multi-modal curation
### Image curation
```python
from nemo_curator.image import (
AestheticFilter,
NSFWFilter,
CLIPEmbedder
)
# Aesthetic scoring
aesthetic_filter = AestheticFilter(threshold=5.0)
filtered_images = aesthetic_filter(image_dataset)
# NSFW detection
nsfw_filter = NSFWFilter(threshold=0.9)
safe_images = nsfw_filter(filtered_images)
# Generate CLIP embeddings
clip_embedder = CLIPEmbedder(model="openai/clip-vit-base-patch32")
image_embeddings = clip_embedder(safe_images)
```
### Video curation
```python
from nemo_curator.video import (
SceneDetector,
ClipExtractor,
InternVideo2Embedder
)
# Detect scenes
scene_detector = SceneDetector(threshold=27.0)
scenes = scene_detector(video_dataset)
# Extract clips
clip_extractor = ClipExtractor(min_duration=2.0, max_duration=10.0)
clips = clip_extractor(scenes)
# Generate embeddings
video_embedder = InternVideo2Embedder()
video_embeddings = video_embedder(clips)
```
### Audio curation
```python
from nemo_curator.audio import (
ASRInference,
WERFilter,
DurationFilter
)
# ASR transcription
asr = ASRInference(model="nvidia/stt_en_fastconformer_hybrid_large_pc")
transcribed = asr(audio_dataset)
# Filter by WER (word error rate)
wer_filter = WERFilter(max_wer=0.3)
high_quality_audio = wer_filter(transcribed)
# Duration filtering
duration_filter = DurationFilter(min_duration=1.0, max_duration=30.0)
filtered_audio = duration_filter(high_quality_audio)
```
## Common patterns
### Web scrape curation (Common Crawl)
```python
from nemo_curator import ScoreFilter, Modify
from nemo_curator.filters import *
from nemo_curator.modules import *
from nemo_curator.datasets import DocumentDataset
# Load Common Crawl data
dataset = DocumentDataset.read_parquet("common_crawl/*.parquet")
# Pipeline
pipeline = [
# 1. Quality filtering
WordCountFilter(min_words=100, max_words=50000),
RepeatedLinesFilter(max_repeated_line_fraction=0.2),
SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3),
UrlRatioFilter(max_url_ratio=0.3),
# 2. Language filtering
LanguageIdentificationFilter(target_languages=["en"]),
# 3. Deduplication
ExactDuplicates(id_field="id", text_field="text"),
FuzzyDuplicates(id_field="id", text_field="text", num_hashes=260),
# 4. PII redaction
PIIRedactor(),
# 5. NSFW filtering
NSFWClassifier(threshold=0.8)
]
# Execute
for stage in pipeline:
dataset = stage(dataset)
# Save
dataset.to_parquet("curated_common_crawl/")
```
### Distributed processing
```python
from nemo_curator import get_client
from dask_cuda import LocalCUDACluster
# Multi-GPU cluster
cluster = LocalCUDACluster(n_workers=8)
client = get_client(cluster=cluster)
# Process large dataset
dataset = DocumentDataset.read_parquet("s3://large_dataset/*.parquet")
deduped = FuzzyDuplicates(...)(dataset)
# Cleanup
client.close()
cluster.close()
```
## Performance benchmarks
### Fuzzy deduplication (8TB RedPajama v2)
- **CPU (256 cores)**: 120 hours
- **GPU (8× A100)**: 7.5 hours
- **Speedup**: 16×
### Exact deduplication (1TB)
- **CPU (64 cores)**: 8 hours
- **GPU (4× A100)**: 0.5 hours
- **Speedup**: 16×
### Quality filtering (100GB)
- **CPU (32 cores)**: 2 hours
- **GPU (2× A100)**: 0.2 hours
- **Speedup**: 10×
## Cost comparison
**CPU-based curation** (AWS c5.18xlarge × 10):
- Cost: $3.60/hour × 10 = $36/hour
- Time for 8TB: 120 hours
- **Total**: $4,320
**GPU-based curation** (AWS p4d.24xlarge × 2):
- Cost: $32.77/hour × 2 = $65.54/hour
- Time for 8TB: 7.5 hours
- **Total**: $491.55
**Savings**: 89% reduction ($3,828 saved)
## Supported data formats
- **Input**: Parquet, JSONL, CSV
- **Output**: Parquet (recommended), JSONL
- **WebDataset**: TAR archives for multi-modal
## Use cases
**Production deployments**:
- NVIDIA used NeMo Curator to prepare Nemotron-4 training data
- Open-source datasets curated: RedPajama v2, The Pile
## References
- **[Filtering Guide](references/filtering.md)** - 30+ quality filters, heuristics
- **[Deduplication Guide](references/deduplication.md)** - Exact, fuzzy, semantic methods
## Resources
- **GitHub**: https://github.com/NVIDIA/NeMo-Curator ⭐ 500+
- **Docs**: https://docs.nvidia.com/nemo-framework/user-guide/latest/datacuration/
- **Version**: 0.4.0+
- **License**: Apache 2.0

View File

@@ -0,0 +1,87 @@
# Deduplication Guide
Complete guide to exact, fuzzy, and semantic deduplication.
## Exact deduplication
Remove documents with identical content.
```python
from nemo_curator.modules import ExactDuplicates
# Exact deduplication
exact_dedup = ExactDuplicates(
id_field="id",
text_field="text",
hash_method="md5" # or "sha256"
)
deduped = exact_dedup(dataset)
```
**Performance**: ~16× faster on GPU vs CPU
## Fuzzy deduplication
Remove near-duplicate documents using MinHash + LSH.
```python
from nemo_curator.modules import FuzzyDuplicates
fuzzy_dedup = FuzzyDuplicates(
id_field="id",
text_field="text",
num_hashes=260, # MinHash permutations (more = accurate)
num_buckets=20, # LSH buckets (more = faster, less recall)
hash_method="md5",
jaccard_threshold=0.8 # Similarity threshold
)
deduped = fuzzy_dedup(dataset)
```
**Parameters**:
- `num_hashes`: 128-512 (default 260)
- `num_buckets`: 10-50 (default 20)
- `jaccard_threshold`: 0.7-0.9 (default 0.8)
**Performance**: 16× faster on 8TB dataset (120h → 7.5h)
## Semantic deduplication
Remove semantically similar documents using embeddings.
```python
from nemo_curator.modules import SemanticDuplicates
semantic_dedup = SemanticDuplicates(
id_field="id",
text_field="text",
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
embedding_batch_size=256,
threshold=0.85, # Cosine similarity threshold
device="cuda"
)
deduped = semantic_dedup(dataset)
```
**Models**:
- `all-MiniLM-L6-v2`: Fast, 384 dims
- `all-mpnet-base-v2`: Better quality, 768 dims
- Custom models supported
## Comparison
| Method | Speed | Recall | Use Case |
|--------|-------|--------|----------|
| Exact | Fastest | 100% | Exact matches only |
| Fuzzy | Fast | ~95% | Near-duplicates (recommended) |
| Semantic | Slow | ~90% | Paraphrases, rewrites |
## Best practices
1. **Start with exact dedup** - Remove obvious duplicates
2. **Use fuzzy for large datasets** - Best speed/quality trade-off
3. **Semantic for high-value data** - Expensive but thorough
4. **GPU acceleration required** - 10-16× speedup

View File

@@ -0,0 +1,102 @@
# Quality Filtering Guide
Complete guide to NeMo Curator's 30+ quality filters.
## Text-based filters
### Word count
```python
from nemo_curator.filters import WordCountFilter
# Filter by word count
dataset = dataset.filter(WordCountFilter(min_words=50, max_words=100000))
```
### Repeated content
```python
from nemo_curator.filters import RepeatedLinesFilter
# Remove documents with >30% repeated lines
dataset = dataset.filter(RepeatedLinesFilter(max_repeated_line_fraction=0.3))
```
### Symbol ratio
```python
from nemo_curator.filters import SymbolToWordRatioFilter
# Remove documents with too many symbols
dataset = dataset.filter(SymbolToWordRatioFilter(max_symbol_to_word_ratio=0.3))
```
### URL ratio
```python
from nemo_curator.filters import UrlRatioFilter
# Remove documents with many URLs
dataset = dataset.filter(UrlRatioFilter(max_url_ratio=0.2))
```
## Language filtering
```python
from nemo_curator.filters import LanguageIdentificationFilter
# Keep only English documents
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en"]))
# Multiple languages
dataset = dataset.filter(LanguageIdentificationFilter(target_languages=["en", "es", "fr"]))
```
## Classifier-based filtering
### Quality classifier
```python
from nemo_curator.classifiers import QualityClassifier
quality_clf = QualityClassifier(
model_path="nvidia/quality-classifier-deberta",
batch_size=256,
device="cuda"
)
# Filter low-quality (threshold > 0.5 = high quality)
dataset = dataset.filter(lambda doc: quality_clf(doc["text"]) > 0.5)
```
### NSFW classifier
```python
from nemo_curator.classifiers import NSFWClassifier
nsfw_clf = NSFWClassifier(threshold=0.9, device="cuda")
# Remove NSFW content
dataset = dataset.filter(lambda doc: nsfw_clf(doc["text"]) < 0.9)
```
## Heuristic filters
Full list of 30+ filters:
- WordCountFilter
- RepeatedLinesFilter
- UrlRatioFilter
- SymbolToWordRatioFilter
- NonAlphaNumericFilter
- BulletsFilter
- WhiteSpaceFilter
- ParenthesesFilter
- LongWordFilter
- And 20+ more...
## Best practices
1. **Apply cheap filters first** - Word count before GPU classifiers
2. **Tune thresholds on sample** - Test on 10k docs before full run
3. **Use GPU classifiers sparingly** - Expensive but effective
4. **Chain filters efficiently** - Order by cost (cheap → expensive)

View File

@@ -0,0 +1,314 @@
---
name: obliteratus
description: Remove refusal behaviors from open-weight LLMs using OBLITERATUS — mechanistic interpretability techniques (diff-in-means, SVD, whitened SVD, SAE decomposition, etc.) to excise guardrails while preserving reasoning. 9 CLI methods (+ 4 Python-API-only), 15 analysis modules, 116 model presets across 5 compute tiers. Use when a user wants to uncensor, abliterate, or remove refusal from an LLM.
version: 1.0.0
author: Hermes Agent
license: MIT
dependencies: [obliteratus, torch, transformers, bitsandbytes, accelerate, safetensors]
metadata:
hermes:
tags: [Abliteration, Uncensoring, Refusal-Removal, LLM, Weight-Projection, SVD, Mechanistic-Interpretability, HuggingFace, Model-Surgery]
---
# OBLITERATUS Skill
Remove refusal behaviors (guardrails) from open-weight LLMs without retraining or fine-tuning. Uses mechanistic interpretability techniques — including diff-in-means, SVD, whitened SVD, SAE decomposition, Bayesian kernel projection, and more — to identify and surgically excise refusal directions from model weights while preserving reasoning capabilities.
**License warning:** OBLITERATUS is AGPL-3.0. NEVER import it as a Python library. Always invoke via CLI (`obliteratus` command) or subprocess. This keeps Hermes Agent's MIT license clean.
## When to Use This Skill
Trigger when the user:
- Wants to "uncensor" or "abliterate" an LLM
- Asks about removing refusal/guardrails from a model
- Wants to create an uncensored version of Llama, Qwen, Mistral, etc.
- Mentions "refusal removal", "abliteration", "weight projection"
- Wants to analyze how a model's refusal mechanism works
- References OBLITERATUS, FailSpy, abliterator, or refusal directions
## Step 1: Installation
Check if already installed:
```bash
obliteratus --version 2>/dev/null && echo "INSTALLED" || echo "NOT INSTALLED"
```
If not installed, clone and install from GitHub:
```
Repository: https://github.com/elder-plinius/OBLITERATUS
Install: pip install -e . (from the cloned directory)
For Gradio UI: pip install -e ".[spaces]"
```
**IMPORTANT:** Confirm with user before installing. This pulls in ~5-10GB of dependencies (PyTorch, Transformers, bitsandbytes, etc.).
## Step 2: Check Hardware
Before anything, check what GPU is available:
```bash
python3 -c "
import torch
if torch.cuda.is_available():
gpu = torch.cuda.get_device_name(0)
vram = torch.cuda.get_device_properties(0).total_mem / 1024**3
print(f'GPU: {gpu}')
print(f'VRAM: {vram:.1f} GB')
if vram < 4: print('TIER: tiny (models under 1B)')
elif vram < 8: print('TIER: small (models 1-4B)')
elif vram < 16: print('TIER: medium (models 4-9B with 4bit quant)')
elif vram < 32: print('TIER: large (models 8-32B with 4bit quant)')
else: print('TIER: frontier (models 32B+)')
else:
print('NO GPU - only tiny models (under 1B) on CPU')
"
```
### VRAM Requirements (with 4-bit quantization)
| VRAM | Max Model Size | Example Models |
|:---------|:----------------|:--------------------------------------------|
| CPU only | ~1B params | GPT-2, TinyLlama, SmolLM |
| 4-8 GB | ~4B params | Qwen2.5-1.5B, Phi-3.5 mini, Llama 3.2 3B |
| 8-16 GB | ~9B params | Llama 3.1 8B, Mistral 7B, Gemma 2 9B |
| 24 GB | ~32B params | Qwen3-32B, Llama 3.1 70B (tight), Command-R |
| 48 GB+ | ~72B+ params | Qwen2.5-72B, DeepSeek-R1 |
| Multi-GPU| 200B+ params | Llama 3.1 405B, DeepSeek-V3 (685B MoE) |
## Step 3: Browse Available Models
```bash
# List models for your compute tier
obliteratus models --tier medium
# Get architecture info for a specific model
obliteratus info meta-llama/Llama-3.1-8B-Instruct
```
## Step 4: Choose a Method
### Method Selection Guide
**First time / unsure? Use `informed`.** It auto-configures everything.
| Situation | Recommended Method | Why |
|:----------------------------------|:-------------------|:-----------------------------------------|
| First attempt, any model | `informed` | Auto-detects alignment type, auto-tunes |
| Quick test / prototyping | `basic` | Fast, simple, good enough to evaluate |
| Dense model (Llama, Mistral) | `advanced` | Multi-direction, norm-preserving |
| MoE model (DeepSeek, Mixtral) | `nuclear` | Expert-granular, handles MoE complexity |
| Reasoning model (R1 distills) | `surgical` | CoT-aware, preserves chain-of-thought |
| Stubborn refusals persist | `aggressive` | Whitened SVD + head surgery + jailbreak |
| Want reversible changes | Use steering vectors (see Analysis section) |
| Maximum quality, time no object | `optimized` | Bayesian search for best parameters |
### 9 CLI Methods
These can be passed to `--method` on the command line:
- **basic** — Single refusal direction via diff-in-means. Fastest, simplest. (Arditi et al. 2024)
- **advanced** — Multiple SVD directions, norm-preserving projection. Good default.
- **aggressive** — Whitened SVD + jailbreak contrast + attention head surgery
- **spectral_cascade** — DCT frequency-domain decomposition
- **informed** — Runs analysis DURING abliteration to auto-configure. Detects DPO/RLHF/CAI, maps refusal geometry, compensates for self-repair. Best quality.
- **surgical** — SAE features + neuron masking + head surgery + per-expert. Maximum precision.
- **optimized** — Bayesian hyperparameter search (Optuna TPE). Slowest but optimal.
- **inverted** — Flips the refusal direction (model becomes eager to help, not just neutral)
- **nuclear** — Maximum force combo for stubborn MoE models.
### 4 Python-API-Only Methods
These reproduce prior community/academic work but are NOT available via CLI — only via the Python API (`from obliteratus.abliterate import AbliterationPipeline`). **Do not use these in CLI commands.**
- **failspy** — FailSpy/abliterator reproduction
- **gabliteration** — Gabliteration reproduction
- **heretic** — Heretic/p-e-w reproduction
- **rdo** — Refusal Direction Optimization (ICML 2025)
## Step 5: Run Abliteration
### Basic Usage
```bash
# Default (advanced method)
obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct
# With the informed pipeline (recommended)
obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct --method informed
# With 4-bit quantization to save VRAM
obliteratus obliterate meta-llama/Llama-3.1-8B-Instruct \
--method informed \
--quantization 4bit \
--output-dir ./abliterated-models
# For large models (120B+), use conservative settings
obliteratus obliterate Qwen/Qwen2.5-72B-Instruct \
--method advanced \
--quantization 4bit \
--large-model \
--output-dir ./abliterated-models
```
### Fine-Tuning Parameters
```bash
obliteratus obliterate <model> \
--method advanced \
--n-directions 8 \
--regularization 0.1 \
--refinement-passes 3 \
--dtype bfloat16 \
--device auto \
--output-dir ./output
```
Parameter explanations:
- `--n-directions N` — How many refusal directions to remove (default: auto-detected)
- `--regularization 0.0-1.0` — Fraction of original weights to preserve (higher = safer but less complete removal)
- `--refinement-passes N` — Iterative passes to catch self-repair (Ouroboros effect)
- `--dtype` — float16, bfloat16, or float32
- `--quantization` — 4bit or 8bit (saves VRAM, slight quality tradeoff)
- `--large-model` — Conservative defaults for 120B+ models (fewer directions, fewer passes)
### Interactive Mode (Guided)
For users unsure about options:
```bash
obliteratus interactive
```
### Web UI (Gradio)
```bash
obliteratus ui --port 7860
```
## Step 6: Verify Results
After abliteration, check the output report for:
| Metric | Good Value | Concerning Value | Meaning |
|:---------------|:--------------------|:------------------------|:-------------------------------------------|
| Refusal rate | Near 0% | > 10% | Refusals still present, try harder method |
| Perplexity | Within 10% of orig | > 20% increase | Model coherence damaged, too aggressive |
| KL divergence | < 0.1 | > 0.5 | Large output distribution shift |
| Coherence | High | Low | Model generating nonsense |
### If perplexity spiked (too aggressive):
1. Increase `--regularization` (e.g., 0.2 or 0.3)
2. Decrease `--n-directions` (e.g., 4 instead of 8)
3. Use a less aggressive method (`advanced` instead of `aggressive`)
### If refusal persists (not aggressive enough):
1. Use `--method aggressive` or `--method nuclear`
2. Add `--refinement-passes 3` to catch self-repair
3. Use `--method informed` which auto-compensates
## Step 7: Use the Abliterated Model
The output is a standard HuggingFace model directory. Use it like any other model:
### Quick test
```bash
python3 << 'EOF'
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("./abliterated-models/model-name")
tokenizer = AutoTokenizer.from_pretrained("./abliterated-models/model-name")
inputs = tokenizer("Write a story about:", return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=200)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
EOF
```
### Upload to HuggingFace Hub
```bash
huggingface-cli login # if not already logged in
huggingface-cli upload your-username/model-name-abliterated ./abliterated-models/model-name
```
### Serve with vLLM
```bash
vllm serve ./abliterated-models/model-name --port 8000
```
## Analysis Modules (15 Modules, Pre-Abliteration, Optional)
For understanding refusal geometry before committing to abliteration.
### Run a Study
```bash
obliteratus run study-config.yaml --preset jailbreak
```
### Study Presets
| Preset | Purpose | Time |
|:-------------|:-------------------------------------|:-------|
| `quick` | Sanity check, basic metrics | ~5 min |
| `jailbreak` | Refusal circuit localization | ~20 min|
| `guardrail` | Guardrail robustness evaluation | ~30 min|
| `attention` | Attention head contributions | ~30 min|
| `knowledge` | FFN importance mapping | ~30 min|
| `full` | Complete analysis, all strategies | ~1 hr |
### Key Analysis Modules
- **Alignment Imprint Detection** — Fingerprints DPO vs RLHF vs CAI vs SFT from subspace geometry
- **Concept Cone Geometry** — Is refusal one linear direction or a polyhedral cone (many directions)?
- **Refusal Logit Lens** — Which transformer layer makes the refusal decision?
- **Ouroboros Detection** — Will the model self-repair its refusal after removal?
- **Causal Tracing** — Which attention heads and MLP layers are causally necessary for refusal?
- **Cross-Model Transfer** — Can refusal directions from one model architecture work on another?
- **Residual Stream Decomposition** — Attention vs MLP contribution to refusal behavior
- **SAE-based Analysis** — Sparse Autoencoder feature decomposition of refusal circuits
## Steering Vectors (Reversible Alternative)
For testing refusal removal without permanent weight changes:
Steering vectors apply activation hooks at inference time. Model weights stay unchanged.
Generated during the PROBE/DISTILL stages and can be saved/applied/removed at will.
Useful for A/B testing before committing to permanent abliteration.
## YAML Config for Reproducible Studies
For complex or reproducible workflows, use YAML configs. See templates/ for examples:
```bash
obliteratus run my_study.yaml
```
## Telemetry Notice
- **CLI usage (local installs)**: Telemetry is OFF by default. Must explicitly opt in via `OBLITERATUS_TELEMETRY=1` env var or `--contribute` flag.
- **HuggingFace Spaces**: Telemetry is ON by default (auto-enabled when `SPACE_ID` env var is detected).
- Collected: model ID, method, benchmark scores, hardware info, timing (anonymous)
- NOT collected: IP addresses, user identity, prompt content
- Force off: `export OBLITERATUS_TELEMETRY=0`
## Common Pitfalls
1. **OOM (Out of Memory)** — Use `--quantization 4bit` and `--large-model` for big models
2. **Perplexity spike** — Too aggressive. Increase `--regularization` or reduce `--n-directions`
3. **Refusal persists** — Try `--method aggressive` or `--refinement-passes 3`
4. **MoE models resist** — Use `--method nuclear` for DeepSeek, Mixtral, DBRX
5. **Gated models fail** — Run `huggingface-cli login` and accept model terms on HF website first
6. **Self-repair (Ouroboros)** — Some models reconstruct refusal. Use `--method informed` which auto-compensates
7. **CoT damage** — Reasoning models lose chain-of-thought. Use `--method surgical` (CoT-aware)
8. **Disk space** — Output is full model copy. 8B fp16 = ~16GB, 70B fp16 = ~140GB
9. **Slow on CPU** — CPU-only is viable only for tiny models (<1B). Anything bigger needs GPU.
## Complementary Hermes Skills
After abliteration:
- **axolotl** / **unsloth** — Fine-tune the abliterated model further
- **serving-llms-vllm** — Serve the model as an OpenAI-compatible API
- **sparse-autoencoder-training** — Train SAEs for deeper interpretability work
## Resources
- [OBLITERATUS GitHub](https://github.com/elder-plinius/OBLITERATUS) (AGPL-3.0)
- [HuggingFace Spaces Demo](https://huggingface.co/spaces/pliny-the-prompter/obliteratus)
- [Arditi et al. 2024 — Refusal in LMs Is Mediated by a Single Direction](https://arxiv.org/abs/2406.11717)
- [Refusal Direction Optimization — ICML 2025](https://arxiv.org/abs/2411.14793)

View File

@@ -0,0 +1,170 @@
# OBLITERATUS Analysis Modules — Reference
15 analysis modules for mechanistic interpretability of refusal in LLMs.
These help you understand HOW a model refuses before you decide to remove it.
> **Note:** The `analysis/` directory contains additional utility files (utils.py,
> visualization.py, etc.) and helper functions beyond the 15 core analysis modules
> listed below. The module count matches the README's "15 deep analysis modules."
## Core Analysis (Run These First)
### Alignment Imprint Detection
**File:** `alignment_imprint.py`
**Purpose:** Identifies what alignment technique was used to train the model
**Detects:** DPO, RLHF, CAI (Constitutional AI), SFT (Supervised Fine-Tuning)
**How:** Analyzes subspace geometry — each alignment method leaves a distinct
geometric "fingerprint" in the weight space
**Output:** Detected method + confidence score
**Why it matters:** Different alignment methods need different abliteration approaches.
DPO models typically have cleaner single-direction refusal; RLHF is more diffuse.
### Concept Cone Geometry
**File:** `concept_geometry.py`
**Purpose:** Maps whether refusal is one direction or a polyhedral cone (many)
**Output:** Cone angle, dimensionality, per-category breakdown
**Why it matters:** If refusal is a single direction, `basic` method works. If it's
a cone (multiple directions for different refusal categories), you need `advanced`
or `informed` with higher `n_directions`.
### Refusal Logit Lens
**File:** `logit_lens.py`
**Purpose:** Identifies the specific layer where the model "decides" to refuse
**How:** Projects intermediate hidden states to vocabulary space at each layer,
watches when "I cannot" tokens spike in probability
**Output:** Layer-by-layer refusal probability plot
**Why it matters:** Tells you which layers are most important to target
### Ouroboros (Self-Repair) Detection
**File:** `anti_ouroboros.py`
**Purpose:** Predicts whether the model will reconstruct its refusal after removal
**How:** Measures redundancy in refusal representation across layers
**Output:** Self-repair risk score (0-1)
**Why it matters:** High self-repair risk means you need multiple refinement passes
or the `informed` method which auto-compensates
### Causal Tracing
**File:** `causal_tracing.py`
**Purpose:** Determines which components are causally necessary for refusal
**How:** Patches activations between clean and corrupted runs, measures causal effect
**Output:** Causal importance map across layers, heads, and MLPs
**Why it matters:** Shows exactly which components to target for surgical removal
## Geometric Analysis
### Cross-Layer Alignment
**File:** `cross_layer.py`
**Purpose:** Measures how aligned refusal directions are across layers
**Output:** Alignment matrix, cluster assignments
**Why it matters:** If directions are highly aligned across layers, removal is easier.
If they cluster, you may need layer-group-specific directions.
### Residual Stream Decomposition
**File:** `residual_stream.py`
**Purpose:** Breaks down refusal into Attention vs MLP contributions
**Output:** Per-layer Attention/MLP contribution to refusal direction
**Why it matters:** Helps decide whether to target attention heads, MLPs, or both
### Riemannian Manifold Geometry
**File:** `riemannian_manifold.py` (673 lines)
**Purpose:** Analyzes the weight manifold geometry around refusal directions
**Output:** Curvature, geodesics, tangent space analysis
**Why it matters:** Research-grade; helps understand the geometric structure of alignment
### Whitened SVD
**File:** `whitened_svd.py`
**Purpose:** Covariance-normalized SVD extraction
**How:** Whitens the activation covariance before computing refusal directions,
separating true refusal signal from natural activation variance
**Output:** Cleaner refusal directions with less noise
**Why it matters:** Produces more precise directions, especially for noisy activations
## Probing & Classification
### Activation Probing
**File:** `activation_probing.py`
**Purpose:** Post-excision probing to verify refusal signal is truly gone
**Output:** Residual refusal signal strength per layer
**Why it matters:** Verification that abliteration was complete
### Probing Classifiers
**File:** `probing_classifiers.py`
**Purpose:** Trains linear classifiers to detect refusal in hidden states
**Output:** Classification accuracy per layer (should drop to ~50% after abliteration)
**Why it matters:** Quantitative measure of refusal removal completeness
### Activation Patching
**File:** `activation_patching.py`
**Purpose:** Interchange interventions — swap activations between harmful/harmless runs
**Output:** Which components are sufficient (not just necessary) for refusal
**Why it matters:** Complementary to causal tracing; together they give full picture
## Transfer & Robustness
### Cross-Model Transfer
**File:** `cross_model_transfer.py`
**Purpose:** Tests if refusal directions from one model work on another
**Output:** Transfer success rate between model pairs
**Why it matters:** If directions transfer, you can skip PROBE stage on similar models
### Defense Robustness
**File:** `defense_robustness.py`
**Purpose:** Evaluates how robust the model's refusal defenses are
**Output:** Robustness score, entanglement mapping
**Why it matters:** Higher robustness = need more aggressive method
### Spectral Certification
**File:** `spectral_certification.py`
**Purpose:** Certifies completeness of refusal direction removal
**Output:** Spectral gap analysis, completeness score
**Why it matters:** Formal verification that all major refusal components are addressed
## Advanced / Research
### SAE-based Abliteration
**File:** `sae_abliteration.py` (762 lines)
**Purpose:** Uses Sparse Autoencoder features to decompose refusal at feature level
**Output:** Refusal-specific SAE features, targeted removal
**Why it matters:** Most fine-grained approach; can target individual refusal "concepts"
### Wasserstein Optimal Extraction
**File:** `wasserstein_optimal.py`
**Purpose:** Optimal transport-based direction extraction
**Output:** Wasserstein-optimal refusal directions
**Why it matters:** Theoretically optimal direction extraction under distributional assumptions
### Bayesian Kernel Projection
**File:** `bayesian_kernel_projection.py`
**Purpose:** Bayesian approach to refusal direction projection
**Output:** Posterior distribution over refusal directions
**Why it matters:** Quantifies uncertainty in direction estimation
### Conditional Abliteration
**File:** `conditional_abliteration.py`
**Purpose:** Domain-specific conditional removal (remove refusal for topic X but keep for Y)
**Output:** Per-domain refusal directions
**Why it matters:** Selective uncensoring — remove only specific refusal categories
### Steering Vectors
**File:** `steering_vectors.py`
**Purpose:** Generate inference-time steering vectors (reversible alternative)
**Output:** Steering vector files that can be applied/removed at inference
**Why it matters:** Non-destructive alternative to permanent weight modification
### Tuned Lens
**File:** `tuned_lens.py`
**Purpose:** Trained linear probes per layer (more accurate than raw logit lens)
**Output:** Layer-by-layer refusal representation with trained projections
**Why it matters:** More accurate than logit lens, especially for deeper models
### Multi-Token Position Analysis
**File:** `multi_token_position.py`
**Purpose:** Analyzes refusal signal at multiple token positions (not just last)
**Output:** Position-dependent refusal direction maps
**Why it matters:** Some models encode refusal at the system prompt position, not the query
### Sparse Surgery
**File:** `sparse_surgery.py`
**Purpose:** Row-level sparse weight surgery instead of full matrix projection
**Output:** Targeted weight modifications at the row level
**Why it matters:** More surgical than full-matrix projection, less collateral damage

View File

@@ -0,0 +1,132 @@
# OBLITERATUS Methods — Detailed Guide
> **Important:** The CLI (`obliteratus obliterate --method`) accepts 9 methods:
> basic, advanced, aggressive, spectral_cascade, informed, surgical, optimized,
> inverted, nuclear. Four additional methods (failspy, gabliteration, heretic, rdo)
> are available only via the Python API and will be rejected by argparse if used on CLI.
## How Abliteration Works (Theory)
When a model is trained with RLHF/DPO/CAI, it learns to represent "should I refuse?"
as a direction in its internal activation space. When processing a "harmful" prompt,
activations shift in this direction, causing the model to generate refusal text.
Abliteration works by:
1. Measuring this direction (the difference between harmful and harmless activations)
2. Removing it from the model's weight matrices via orthogonal projection
3. The model can no longer "point toward" refusal, so it responds normally
Mathematically: `W_new = W_old - (W_old @ d @ d.T)` where `d` is the refusal direction.
## Method Details
### basic
**Technique:** Single refusal direction via diff-in-means
**Based on:** Arditi et al. 2024 ("Refusal in Language Models Is Mediated by a Single Direction")
**Speed:** Fast (~5-10 min for 8B)
**Quality:** Moderate — works for simple refusal patterns
**Best for:** Quick tests, models with clean single-direction refusal
**Limitation:** Misses complex multi-direction refusal patterns
### advanced (DEFAULT)
**Technique:** Multiple SVD directions with norm-preserving projection
**Speed:** Medium (~10-20 min for 8B)
**Quality:** Good — handles multi-direction refusal
**Best for:** Dense models (Llama, Qwen, Mistral) as a reliable default
**Key improvement:** Norm preservation prevents weight magnitude drift
### informed (RECOMMENDED)
**Technique:** Analysis-guided auto-configuration
**Speed:** Slow (~20-40 min for 8B, runs 4 analysis modules first)
**Quality:** Best — adapts to each model's specific refusal implementation
**Best for:** Any model when quality matters more than speed
The informed pipeline runs these analysis modules during abliteration:
1. **AlignmentImprintDetector** — Detects DPO/RLHF/CAI/SFT → sets regularization
2. **ConceptConeAnalyzer** — Polyhedral vs linear refusal → sets n_directions
3. **CrossLayerAlignmentAnalyzer** — Cluster-aware → selects target layers
4. **DefenseRobustnessEvaluator** — Self-repair risk → sets refinement passes
5. **Ouroboros loop** — Re-probes after excision, re-excises if refusal persists
### aggressive
**Technique:** Whitened SVD + jailbreak-contrastive activations + attention head surgery
**Speed:** Slow (~30-60 min for 8B)
**Quality:** High but higher risk of coherence damage
**Best for:** Models that resist gentler methods
**Key feature:** Whitened SVD separates refusal signal from natural activation variance
### surgical
**Technique:** SAE features + neuron masking + head surgery + per-expert directions
**Speed:** Very slow (~1-2 hrs for 8B, needs SAE)
**Quality:** Highest precision
**Best for:** Reasoning models (R1 distills) where you must preserve CoT
**Key feature:** CoT-Aware — explicitly protects reasoning-critical directions
### nuclear
**Technique:** Everything combined — expert transplant + steering + per-expert directions
**Speed:** Very slow
**Quality:** Most thorough removal, highest risk of side effects
**Best for:** Stubborn MoE models (DeepSeek, Mixtral, DBRX) that resist other methods
**Key feature:** Expert-granular abliteration decomposes signals per MoE expert
### optimized
**Technique:** Bayesian hyperparameter search via Optuna TPE
**Speed:** Very slow (runs many trials)
**Quality:** Finds optimal configuration automatically
**Best for:** Research, when you want the mathematically best parameters
**Requires:** optuna package
### spectral_cascade
**Technique:** DCT frequency-domain decomposition of refusal signal
**Speed:** Medium-slow
**Quality:** Novel approach, less battle-tested
**Best for:** Research, exploring alternative decomposition strategies
### inverted
**Technique:** Reflects (inverts) the refusal direction instead of removing it
**Speed:** Fast (same as basic)
**Quality:** Aggressive — model becomes actively willing, not just neutral
**Best for:** When you want the model to be maximally helpful
**Warning:** Can make the model too eager; may reduce safety-adjacent reasoning
### failspy / gabliteration / heretic / rdo (PYTHON API ONLY)
**Technique:** Faithful reproductions of prior community/academic work
**Speed:** Varies
**Quality:** Known baselines
**Best for:** Reproducing published results, comparing methods
**⚠️ NOT available via CLI** — these methods are only accessible via the Python API.
Do not use `--method failspy` etc. in CLI commands; argparse will reject them.
## Method Selection Flowchart
```
Is this a quick test?
├─ YES → basic
└─ NO → Is the model MoE (DeepSeek, Mixtral)?
├─ YES → nuclear
└─ NO → Is it a reasoning model (R1 distill)?
├─ YES → surgical
└─ NO → Do you care about speed?
├─ YES → advanced
└─ NO → informed
```
## Key Parameters
| Parameter | Range | Default | Effect |
|:--------------------|:---------|:--------|:--------------------------------------------|
| n_directions | 1-32 | auto | More = more thorough but riskier |
| regularization | 0.0-1.0 | 0.0 | Higher preserves more original behavior |
| refinement_passes | 1-5 | 1 | More catches self-repair (Ouroboros effect) |
| quantization | 4/8 bit | none | Saves VRAM, slight quality tradeoff |
## Troubleshooting
| Problem | Solution |
|:---------------------------|:--------------------------------------------------|
| Refusal rate still > 10% | Try aggressive/nuclear, add refinement passes |
| Perplexity up > 20% | Reduce n_directions, increase regularization |
| Model generates nonsense | Regularization too low, try 0.2-0.3 |
| OOM on GPU | Use 4-bit quantization, or try smaller model |
| MoE model barely changes | Use nuclear method (expert-granular) |
| CoT reasoning broken | Use surgical method (CoT-aware) |

View File

@@ -0,0 +1,33 @@
# OBLITERATUS Abliteration Config
# Usage: obliteratus run this-file.yaml
#
# This is for reproducible, version-controlled abliteration runs.
# For one-off usage, the CLI flags are simpler.
# Model to abliterate
model:
name: "meta-llama/Llama-3.1-8B-Instruct"
dtype: "bfloat16" # float16, bfloat16, float32
quantization: null # null, "4bit", "8bit"
device: "auto" # auto, cuda, cuda:0, cpu
# Abliteration method and parameters
abliteration:
method: "informed" # See SKILL.md Step 4 for all 13 methods
n_directions: null # null = auto-detect, or integer (e.g., 8)
regularization: 0.0 # 0.0-1.0, fraction of original to preserve
refinement_passes: 1 # Iterative passes (increase for self-repair)
norm_preserve: true # Keep weight norms intact after projection
# Output
output:
directory: "./abliterated-models"
save_metadata: true # Save abliteration_metadata.json alongside model
contribute: false # Save community contribution data
# Verification
verify:
enabled: true
test_prompts: null # null = use built-in test prompts
compute_perplexity: true
compute_kl: true

View File

@@ -0,0 +1,40 @@
# OBLITERATUS Analysis Study Config
# Usage: obliteratus run this-file.yaml --preset jailbreak
#
# Run analysis modules to understand refusal geometry BEFORE abliterating.
# Useful for research or when you want to understand what you're removing.
# Model to analyze
model:
name: "meta-llama/Llama-3.1-8B-Instruct"
dtype: "bfloat16"
quantization: "4bit" # Saves VRAM for analysis
device: "auto"
# Study configuration
study:
# Available presets: quick, full, attention, jailbreak, guardrail, knowledge
preset: "jailbreak"
# Or specify individual strategies:
# strategies:
# - layer_removal
# - head_pruning
# - ffn_ablation
# - embedding_ablation
# Analysis modules to run (subset of the 27 available)
analysis:
- alignment_imprint # Detect DPO/RLHF/CAI/SFT training method
- concept_geometry # Map refusal cone geometry
- logit_lens # Find which layer decides to refuse
- anti_ouroboros # Detect self-repair tendency
- cross_layer # Cross-layer alignment clustering
- causal_tracing # Causal necessity of components
- residual_stream # Attention vs MLP contribution
# Output
output:
directory: "./analysis-results"
save_plots: true # Generate matplotlib visualizations
save_report: true # Generate markdown report

View File

@@ -0,0 +1,41 @@
# OBLITERATUS Batch Abliteration Config
# Abliterate multiple models with the same method for comparison.
#
# Run each one sequentially:
# for model in models; do obliteratus obliterate $model --method informed; done
#
# Or use this as a reference for which models to process.
# Common settings
defaults:
method: "informed"
quantization: "4bit"
output_dir: "./abliterated-models"
# Models to process (grouped by compute tier)
models:
# Small (4-8 GB VRAM)
small:
- "Qwen/Qwen2.5-1.5B-Instruct"
- "microsoft/Phi-3.5-mini-instruct"
- "meta-llama/Llama-3.2-3B-Instruct"
# Medium (8-16 GB VRAM)
medium:
- "meta-llama/Llama-3.1-8B-Instruct"
- "mistralai/Mistral-7B-Instruct-v0.3"
- "google/gemma-2-9b-it"
- "Qwen/Qwen2.5-7B-Instruct"
# Large (24 GB VRAM, 4-bit quantization)
large:
- "Qwen/Qwen2.5-14B-Instruct"
- "Qwen/Qwen3-32B"
- "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
# Per-model method overrides (optional)
overrides:
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B":
method: "surgical" # CoT-aware for reasoning models
"mistralai/Mixtral-8x7B-Instruct-v0.1":
method: "nuclear" # Expert-granular for MoE models

434
skills/mlops/peft/SKILL.md Normal file
View File

@@ -0,0 +1,434 @@
---
name: peft-fine-tuning
description: Parameter-efficient fine-tuning for LLMs using LoRA, QLoRA, and 25+ methods. Use when fine-tuning large models (7B-70B) with limited GPU memory, when you need to train <1% of parameters with minimal accuracy loss, or for multi-adapter serving. HuggingFace's official library integrated with transformers ecosystem.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [peft>=0.13.0, transformers>=4.45.0, torch>=2.0.0, bitsandbytes>=0.43.0]
metadata:
hermes:
tags: [Fine-Tuning, PEFT, LoRA, QLoRA, Parameter-Efficient, Adapters, Low-Rank, Memory Optimization, Multi-Adapter]
---
# PEFT (Parameter-Efficient Fine-Tuning)
Fine-tune LLMs by training <1% of parameters using LoRA, QLoRA, and 25+ adapter methods.
## When to use PEFT
**Use PEFT/LoRA when:**
- Fine-tuning 7B-70B models on consumer GPUs (RTX 4090, A100)
- Need to train <1% parameters (6MB adapters vs 14GB full model)
- Want fast iteration with multiple task-specific adapters
- Deploying multiple fine-tuned variants from one base model
**Use QLoRA (PEFT + quantization) when:**
- Fine-tuning 70B models on single 24GB GPU
- Memory is the primary constraint
- Can accept ~5% quality trade-off vs full fine-tuning
**Use full fine-tuning instead when:**
- Training small models (<1B parameters)
- Need maximum quality and have compute budget
- Significant domain shift requires updating all weights
## Quick start
### Installation
```bash
# Basic installation
pip install peft
# With quantization support (recommended)
pip install peft bitsandbytes
# Full stack
pip install peft transformers accelerate bitsandbytes datasets
```
### LoRA fine-tuning (standard)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
# Load base model
model_name = "meta-llama/Llama-3.1-8B"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# LoRA configuration
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16, # Rank (8-64, higher = more capacity)
lora_alpha=32, # Scaling factor (typically 2*r)
lora_dropout=0.05, # Dropout for regularization
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], # Attention layers
bias="none" # Don't train biases
)
# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Output: trainable params: 13,631,488 || all params: 8,043,307,008 || trainable%: 0.17%
# Prepare dataset
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
def tokenize(example):
text = f"### Instruction:\n{example['instruction']}\n\n### Response:\n{example['response']}"
return tokenizer(text, truncation=True, max_length=512, padding="max_length")
tokenized = dataset.map(tokenize, remove_columns=dataset.column_names)
# Training
training_args = TrainingArguments(
output_dir="./lora-llama",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
fp16=True,
logging_steps=10,
save_strategy="epoch"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized,
data_collator=lambda data: {"input_ids": torch.stack([f["input_ids"] for f in data]),
"attention_mask": torch.stack([f["attention_mask"] for f in data]),
"labels": torch.stack([f["input_ids"] for f in data])}
)
trainer.train()
# Save adapter only (6MB vs 16GB)
model.save_pretrained("./lora-llama-adapter")
```
### QLoRA fine-tuning (memory-efficient)
```python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4 (best for LLMs)
bnb_4bit_compute_dtype="bfloat16", # Compute in bf16
bnb_4bit_use_double_quant=True # Nested quantization
)
# Load quantized model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-70B",
quantization_config=bnb_config,
device_map="auto"
)
# Prepare for training (enables gradient checkpointing)
model = prepare_model_for_kbit_training(model)
# LoRA config for QLoRA
lora_config = LoraConfig(
r=64, # Higher rank for 70B
lora_alpha=128,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
# 70B model now fits on single 24GB GPU!
```
## LoRA parameter selection
### Rank (r) - capacity vs efficiency
| Rank | Trainable Params | Memory | Quality | Use Case |
|------|-----------------|--------|---------|----------|
| 4 | ~3M | Minimal | Lower | Simple tasks, prototyping |
| **8** | ~7M | Low | Good | **Recommended starting point** |
| **16** | ~14M | Medium | Better | **General fine-tuning** |
| 32 | ~27M | Higher | High | Complex tasks |
| 64 | ~54M | High | Highest | Domain adaptation, 70B models |
### Alpha (lora_alpha) - scaling factor
```python
# Rule of thumb: alpha = 2 * rank
LoraConfig(r=16, lora_alpha=32) # Standard
LoraConfig(r=16, lora_alpha=16) # Conservative (lower learning rate effect)
LoraConfig(r=16, lora_alpha=64) # Aggressive (higher learning rate effect)
```
### Target modules by architecture
```python
# Llama / Mistral / Qwen
target_modules = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
# GPT-2 / GPT-Neo
target_modules = ["c_attn", "c_proj", "c_fc"]
# Falcon
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
# BLOOM
target_modules = ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"]
# Auto-detect all linear layers
target_modules = "all-linear" # PEFT 0.6.0+
```
## Loading and merging adapters
### Load trained adapter
```python
from peft import PeftModel, AutoPeftModelForCausalLM
from transformers import AutoModelForCausalLM
# Option 1: Load with PeftModel
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
model = PeftModel.from_pretrained(base_model, "./lora-llama-adapter")
# Option 2: Load directly (recommended)
model = AutoPeftModelForCausalLM.from_pretrained(
"./lora-llama-adapter",
device_map="auto"
)
```
### Merge adapter into base model
```python
# Merge for deployment (no adapter overhead)
merged_model = model.merge_and_unload()
# Save merged model
merged_model.save_pretrained("./llama-merged")
tokenizer.save_pretrained("./llama-merged")
# Push to Hub
merged_model.push_to_hub("username/llama-finetuned")
```
### Multi-adapter serving
```python
from peft import PeftModel
# Load base with first adapter
model = AutoPeftModelForCausalLM.from_pretrained("./adapter-task1")
# Load additional adapters
model.load_adapter("./adapter-task2", adapter_name="task2")
model.load_adapter("./adapter-task3", adapter_name="task3")
# Switch between adapters at runtime
model.set_adapter("task1") # Use task1 adapter
output1 = model.generate(**inputs)
model.set_adapter("task2") # Switch to task2
output2 = model.generate(**inputs)
# Disable adapters (use base model)
with model.disable_adapter():
base_output = model.generate(**inputs)
```
## PEFT methods comparison
| Method | Trainable % | Memory | Speed | Best For |
|--------|------------|--------|-------|----------|
| **LoRA** | 0.1-1% | Low | Fast | General fine-tuning |
| **QLoRA** | 0.1-1% | Very Low | Medium | Memory-constrained |
| AdaLoRA | 0.1-1% | Low | Medium | Automatic rank selection |
| IA3 | 0.01% | Minimal | Fastest | Few-shot adaptation |
| Prefix Tuning | 0.1% | Low | Medium | Generation control |
| Prompt Tuning | 0.001% | Minimal | Fast | Simple task adaptation |
| P-Tuning v2 | 0.1% | Low | Medium | NLU tasks |
### IA3 (minimal parameters)
```python
from peft import IA3Config
ia3_config = IA3Config(
target_modules=["q_proj", "v_proj", "k_proj", "down_proj"],
feedforward_modules=["down_proj"]
)
model = get_peft_model(model, ia3_config)
# Trains only 0.01% of parameters!
```
### Prefix Tuning
```python
from peft import PrefixTuningConfig
prefix_config = PrefixTuningConfig(
task_type="CAUSAL_LM",
num_virtual_tokens=20, # Prepended tokens
prefix_projection=True # Use MLP projection
)
model = get_peft_model(model, prefix_config)
```
## Integration patterns
### With TRL (SFTTrainer)
```python
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules="all-linear")
trainer = SFTTrainer(
model=model,
args=SFTConfig(output_dir="./output", max_seq_length=512),
train_dataset=dataset,
peft_config=lora_config, # Pass LoRA config directly
)
trainer.train()
```
### With Axolotl (YAML config)
```yaml
# axolotl config.yaml
adapter: lora
lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_modules:
- q_proj
- v_proj
- k_proj
- o_proj
lora_target_linear: true # Target all linear layers
```
### With vLLM (inference)
```python
from vllm import LLM
from vllm.lora.request import LoRARequest
# Load base model with LoRA support
llm = LLM(model="meta-llama/Llama-3.1-8B", enable_lora=True)
# Serve with adapter
outputs = llm.generate(
prompts,
lora_request=LoRARequest("adapter1", 1, "./lora-adapter")
)
```
## Performance benchmarks
### Memory usage (Llama 3.1 8B)
| Method | GPU Memory | Trainable Params |
|--------|-----------|------------------|
| Full fine-tuning | 60+ GB | 8B (100%) |
| LoRA r=16 | 18 GB | 14M (0.17%) |
| QLoRA r=16 | 6 GB | 14M (0.17%) |
| IA3 | 16 GB | 800K (0.01%) |
### Training speed (A100 80GB)
| Method | Tokens/sec | vs Full FT |
|--------|-----------|------------|
| Full FT | 2,500 | 1x |
| LoRA | 3,200 | 1.3x |
| QLoRA | 2,100 | 0.84x |
### Quality (MMLU benchmark)
| Model | Full FT | LoRA | QLoRA |
|-------|---------|------|-------|
| Llama 2-7B | 45.3 | 44.8 | 44.1 |
| Llama 2-13B | 54.8 | 54.2 | 53.5 |
## Common issues
### CUDA OOM during training
```python
# Solution 1: Enable gradient checkpointing
model.gradient_checkpointing_enable()
# Solution 2: Reduce batch size + increase accumulation
TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=16
)
# Solution 3: Use QLoRA
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
```
### Adapter not applying
```python
# Verify adapter is active
print(model.active_adapters) # Should show adapter name
# Check trainable parameters
model.print_trainable_parameters()
# Ensure model in training mode
model.train()
```
### Quality degradation
```python
# Increase rank
LoraConfig(r=32, lora_alpha=64)
# Target more modules
target_modules = "all-linear"
# Use more training data and epochs
TrainingArguments(num_train_epochs=5)
# Lower learning rate
TrainingArguments(learning_rate=1e-4)
```
## Best practices
1. **Start with r=8-16**, increase if quality insufficient
2. **Use alpha = 2 * rank** as starting point
3. **Target attention + MLP layers** for best quality/efficiency
4. **Enable gradient checkpointing** for memory savings
5. **Save adapters frequently** (small files, easy rollback)
6. **Evaluate on held-out data** before merging
7. **Use QLoRA for 70B+ models** on consumer hardware
## References
- **[Advanced Usage](references/advanced-usage.md)** - DoRA, LoftQ, rank stabilization, custom modules
- **[Troubleshooting](references/troubleshooting.md)** - Common errors, debugging, optimization
## Resources
- **GitHub**: https://github.com/huggingface/peft
- **Docs**: https://huggingface.co/docs/peft
- **LoRA Paper**: arXiv:2106.09685
- **QLoRA Paper**: arXiv:2305.14314
- **Models**: https://huggingface.co/models?library=peft

View File

@@ -0,0 +1,514 @@
# PEFT Advanced Usage Guide
## Advanced LoRA Variants
### DoRA (Weight-Decomposed Low-Rank Adaptation)
DoRA decomposes weights into magnitude and direction components, often achieving better results than standard LoRA:
```python
from peft import LoraConfig
dora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
use_dora=True, # Enable DoRA
task_type="CAUSAL_LM"
)
model = get_peft_model(model, dora_config)
```
**When to use DoRA**:
- Consistently outperforms LoRA on instruction-following tasks
- Slightly higher memory (~10%) due to magnitude vectors
- Best for quality-critical fine-tuning
### AdaLoRA (Adaptive Rank)
Automatically adjusts rank per layer based on importance:
```python
from peft import AdaLoraConfig
adalora_config = AdaLoraConfig(
init_r=64, # Initial rank
target_r=16, # Target average rank
tinit=200, # Warmup steps
tfinal=1000, # Final pruning step
deltaT=10, # Rank update frequency
beta1=0.85,
beta2=0.85,
orth_reg_weight=0.5, # Orthogonality regularization
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM"
)
```
**Benefits**:
- Allocates more rank to important layers
- Can reduce total parameters while maintaining quality
- Good for exploring optimal rank distribution
### LoRA+ (Asymmetric Learning Rates)
Different learning rates for A and B matrices:
```python
from peft import LoraConfig
# LoRA+ uses higher LR for B matrix
lora_plus_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules="all-linear",
use_rslora=True, # Rank-stabilized LoRA (related technique)
)
# Manual implementation of LoRA+
from torch.optim import AdamW
# Group parameters
lora_A_params = [p for n, p in model.named_parameters() if "lora_A" in n]
lora_B_params = [p for n, p in model.named_parameters() if "lora_B" in n]
optimizer = AdamW([
{"params": lora_A_params, "lr": 1e-4},
{"params": lora_B_params, "lr": 1e-3}, # 10x higher for B
])
```
### rsLoRA (Rank-Stabilized LoRA)
Scales LoRA outputs to stabilize training with different ranks:
```python
lora_config = LoraConfig(
r=64,
lora_alpha=64,
use_rslora=True, # Enables rank-stabilized scaling
target_modules="all-linear"
)
```
**When to use**:
- When experimenting with different ranks
- Helps maintain consistent behavior across rank values
- Recommended for r > 32
## LoftQ (LoRA-Fine-Tuning-aware Quantization)
Initializes LoRA weights to compensate for quantization error:
```python
from peft import LoftQConfig, LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
# LoftQ configuration
loftq_config = LoftQConfig(
loftq_bits=4, # Quantization bits
loftq_iter=5, # Alternating optimization iterations
)
# LoRA config with LoftQ initialization
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules="all-linear",
init_lora_weights="loftq",
loftq_config=loftq_config,
task_type="CAUSAL_LM"
)
# Load quantized model
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
quantization_config=bnb_config
)
model = get_peft_model(model, lora_config)
```
**Benefits over standard QLoRA**:
- Better initial quality after quantization
- Faster convergence
- ~1-2% better final accuracy on benchmarks
## Custom Module Targeting
### Target specific layers
```python
# Target only first and last transformer layers
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["model.layers.0.self_attn.q_proj",
"model.layers.0.self_attn.v_proj",
"model.layers.31.self_attn.q_proj",
"model.layers.31.self_attn.v_proj"],
layers_to_transform=[0, 31] # Alternative approach
)
```
### Layer pattern matching
```python
# Target layers 0-10 only
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules="all-linear",
layers_to_transform=list(range(11)), # Layers 0-10
layers_pattern="model.layers"
)
```
### Exclude specific layers
```python
lora_config = LoraConfig(
r=16,
target_modules="all-linear",
modules_to_save=["lm_head"], # Train these fully (not LoRA)
)
```
## Embedding and LM Head Training
### Train embeddings with LoRA
```python
from peft import LoraConfig
# Include embeddings
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "embed_tokens"], # Include embeddings
modules_to_save=["lm_head"], # Train lm_head fully
)
```
### Extending vocabulary with LoRA
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig
# Add new tokens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
new_tokens = ["<custom_token_1>", "<custom_token_2>"]
tokenizer.add_tokens(new_tokens)
# Resize model embeddings
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B")
model.resize_token_embeddings(len(tokenizer))
# Configure LoRA to train new embeddings
lora_config = LoraConfig(
r=16,
target_modules="all-linear",
modules_to_save=["embed_tokens", "lm_head"], # Train these fully
)
model = get_peft_model(model, lora_config)
```
## Multi-Adapter Patterns
### Adapter composition
```python
from peft import PeftModel
# Load model with multiple adapters
model = AutoPeftModelForCausalLM.from_pretrained("./base-adapter")
model.load_adapter("./style-adapter", adapter_name="style")
model.load_adapter("./task-adapter", adapter_name="task")
# Combine adapters (weighted sum)
model.add_weighted_adapter(
adapters=["style", "task"],
weights=[0.7, 0.3],
adapter_name="combined",
combination_type="linear" # or "cat", "svd"
)
model.set_adapter("combined")
```
### Adapter stacking
```python
# Stack adapters (apply sequentially)
model.add_weighted_adapter(
adapters=["base", "domain", "task"],
weights=[1.0, 1.0, 1.0],
adapter_name="stacked",
combination_type="cat" # Concatenate adapter outputs
)
```
### Dynamic adapter switching
```python
import torch
class MultiAdapterModel:
def __init__(self, base_model_path, adapter_paths):
self.model = AutoPeftModelForCausalLM.from_pretrained(adapter_paths[0])
for name, path in adapter_paths[1:].items():
self.model.load_adapter(path, adapter_name=name)
def generate(self, prompt, adapter_name="default"):
self.model.set_adapter(adapter_name)
return self.model.generate(**self.tokenize(prompt))
def generate_ensemble(self, prompt, adapters, weights):
"""Generate with weighted adapter ensemble"""
outputs = []
for adapter, weight in zip(adapters, weights):
self.model.set_adapter(adapter)
logits = self.model(**self.tokenize(prompt)).logits
outputs.append(weight * logits)
return torch.stack(outputs).sum(dim=0)
```
## Memory Optimization
### Gradient checkpointing with LoRA
```python
from peft import prepare_model_for_kbit_training
# Enable gradient checkpointing
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False}
)
```
### CPU offloading for training
```python
from accelerate import Accelerator
accelerator = Accelerator(
mixed_precision="bf16",
gradient_accumulation_steps=8,
cpu_offload=True # Offload optimizer states to CPU
)
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
```
### Memory-efficient attention with LoRA
```python
from transformers import AutoModelForCausalLM
# Combine Flash Attention 2 with LoRA
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16
)
# Apply LoRA
model = get_peft_model(model, lora_config)
```
## Inference Optimization
### Merge for deployment
```python
# Merge adapter weights into base model
merged_model = model.merge_and_unload()
# Quantize merged model for inference
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_4bit=True)
quantized_model = AutoModelForCausalLM.from_pretrained(
"./merged-model",
quantization_config=bnb_config
)
```
### Export to different formats
```python
# Export to GGUF (llama.cpp)
# First merge, then convert
merged_model.save_pretrained("./merged-model")
# Use llama.cpp converter
# python convert-hf-to-gguf.py ./merged-model --outfile model.gguf
# Export to ONNX
from optimum.onnxruntime import ORTModelForCausalLM
ort_model = ORTModelForCausalLM.from_pretrained(
"./merged-model",
export=True
)
ort_model.save_pretrained("./onnx-model")
```
### Batch adapter inference
```python
from vllm import LLM
from vllm.lora.request import LoRARequest
# Initialize with LoRA support
llm = LLM(
model="meta-llama/Llama-3.1-8B",
enable_lora=True,
max_lora_rank=64,
max_loras=4 # Max concurrent adapters
)
# Batch with different adapters
requests = [
("prompt1", LoRARequest("adapter1", 1, "./adapter1")),
("prompt2", LoRARequest("adapter2", 2, "./adapter2")),
("prompt3", LoRARequest("adapter1", 1, "./adapter1")),
]
outputs = llm.generate(
[r[0] for r in requests],
lora_request=[r[1] for r in requests]
)
```
## Training Recipes
### Instruction tuning recipe
```python
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules="all-linear",
bias="none",
task_type="CAUSAL_LM"
)
training_args = TrainingArguments(
output_dir="./output",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.03,
bf16=True,
logging_steps=10,
save_strategy="steps",
save_steps=100,
eval_strategy="steps",
eval_steps=100,
)
```
### Code generation recipe
```python
lora_config = LoraConfig(
r=32, # Higher rank for code
lora_alpha=64,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM"
)
training_args = TrainingArguments(
learning_rate=1e-4, # Lower LR for code
num_train_epochs=2,
max_seq_length=2048, # Longer sequences
)
```
### Conversational/Chat recipe
```python
from trl import SFTTrainer
lora_config = LoraConfig(
r=16,
lora_alpha=16, # alpha = r for chat
lora_dropout=0.05,
target_modules="all-linear"
)
# Use chat template
def format_chat(example):
messages = [
{"role": "user", "content": example["instruction"]},
{"role": "assistant", "content": example["response"]}
]
return tokenizer.apply_chat_template(messages, tokenize=False)
trainer = SFTTrainer(
model=model,
peft_config=lora_config,
train_dataset=dataset.map(format_chat),
max_seq_length=1024,
)
```
## Debugging and Validation
### Verify adapter application
```python
# Check which modules have LoRA
for name, module in model.named_modules():
if hasattr(module, "lora_A"):
print(f"LoRA applied to: {name}")
# Print detailed config
print(model.peft_config)
# Check adapter state
print(f"Active adapters: {model.active_adapters}")
print(f"Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
```
### Compare with base model
```python
# Generate with adapter
model.set_adapter("default")
adapter_output = model.generate(**inputs)
# Generate without adapter
with model.disable_adapter():
base_output = model.generate(**inputs)
print(f"Adapter: {tokenizer.decode(adapter_output[0])}")
print(f"Base: {tokenizer.decode(base_output[0])}")
```
### Monitor training metrics
```python
from transformers import TrainerCallback
class LoRACallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if "loss" in logs:
# Log adapter-specific metrics
model = kwargs["model"]
lora_params = sum(p.numel() for n, p in model.named_parameters()
if "lora" in n and p.requires_grad)
print(f"Step {state.global_step}: loss={logs['loss']:.4f}, lora_params={lora_params}")
```

View File

@@ -0,0 +1,480 @@
# PEFT Troubleshooting Guide
## Installation Issues
### bitsandbytes CUDA Error
**Error**: `CUDA Setup failed despite GPU being available`
**Fix**:
```bash
# Check CUDA version
nvcc --version
# Install matching bitsandbytes
pip uninstall bitsandbytes
pip install bitsandbytes --no-cache-dir
# Or compile from source for specific CUDA
git clone https://github.com/TimDettmers/bitsandbytes.git
cd bitsandbytes
CUDA_VERSION=118 make cuda11x # Adjust for your CUDA
pip install .
```
### Triton Import Error
**Error**: `ModuleNotFoundError: No module named 'triton'`
**Fix**:
```bash
# Install triton (Linux only)
pip install triton
# Windows: Triton not supported, use CUDA backend
# Set environment variable to disable triton
export CUDA_VISIBLE_DEVICES=0
```
### PEFT Version Conflicts
**Error**: `AttributeError: 'LoraConfig' object has no attribute 'use_dora'`
**Fix**:
```bash
# Upgrade to latest PEFT
pip install peft>=0.13.0 --upgrade
# Check version
python -c "import peft; print(peft.__version__)"
```
## Training Issues
### CUDA Out of Memory
**Error**: `torch.cuda.OutOfMemoryError: CUDA out of memory`
**Solutions**:
1. **Enable gradient checkpointing**:
```python
from peft import prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
```
2. **Reduce batch size**:
```python
TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=16 # Maintain effective batch size
)
```
3. **Use QLoRA**:
```python
from transformers import BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True
)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
```
4. **Lower LoRA rank**:
```python
LoraConfig(r=8) # Instead of r=16 or higher
```
5. **Target fewer modules**:
```python
target_modules=["q_proj", "v_proj"] # Instead of all-linear
```
### Loss Not Decreasing
**Problem**: Training loss stays flat or increases.
**Solutions**:
1. **Check learning rate**:
```python
# Start lower
TrainingArguments(learning_rate=1e-4) # Not 2e-4 or higher
```
2. **Verify adapter is active**:
```python
model.print_trainable_parameters()
# Should show >0 trainable params
# Check adapter applied
print(model.peft_config)
```
3. **Check data formatting**:
```python
# Verify tokenization
sample = dataset[0]
decoded = tokenizer.decode(sample["input_ids"])
print(decoded) # Should look correct
```
4. **Increase rank**:
```python
LoraConfig(r=32, lora_alpha=64) # More capacity
```
### NaN Loss
**Error**: `Loss is NaN`
**Fix**:
```python
# Use bf16 instead of fp16
TrainingArguments(bf16=True, fp16=False)
# Or enable loss scaling
TrainingArguments(fp16=True, fp16_full_eval=True)
# Lower learning rate
TrainingArguments(learning_rate=5e-5)
# Check for data issues
for batch in dataloader:
if torch.isnan(batch["input_ids"].float()).any():
print("NaN in input!")
```
### Adapter Not Training
**Problem**: `trainable params: 0` or model not updating.
**Fix**:
```python
# Verify LoRA applied to correct modules
for name, module in model.named_modules():
if "lora" in name.lower():
print(f"Found LoRA: {name}")
# Check target_modules match model architecture
from peft.utils import TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
print(TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.get(model.config.model_type))
# Ensure model in training mode
model.train()
# Check requires_grad
for name, param in model.named_parameters():
if param.requires_grad:
print(f"Trainable: {name}")
```
## Loading Issues
### Adapter Loading Fails
**Error**: `ValueError: Can't find adapter weights`
**Fix**:
```python
# Check adapter files exist
import os
print(os.listdir("./adapter-path"))
# Should contain: adapter_config.json, adapter_model.safetensors
# Load with correct structure
from peft import PeftModel, PeftConfig
# Check config
config = PeftConfig.from_pretrained("./adapter-path")
print(config)
# Load base model first
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(base_model, "./adapter-path")
```
### Base Model Mismatch
**Error**: `RuntimeError: size mismatch`
**Fix**:
```python
# Ensure base model matches adapter
from peft import PeftConfig
config = PeftConfig.from_pretrained("./adapter-path")
print(f"Base model: {config.base_model_name_or_path}")
# Load exact same base model
base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
```
### Safetensors vs PyTorch Format
**Error**: `ValueError: We couldn't connect to 'https://huggingface.co'`
**Fix**:
```python
# Force local loading
model = PeftModel.from_pretrained(
base_model,
"./adapter-path",
local_files_only=True
)
# Or specify format
model.save_pretrained("./adapter", safe_serialization=True) # safetensors
model.save_pretrained("./adapter", safe_serialization=False) # pytorch
```
## Inference Issues
### Slow Generation
**Problem**: Inference much slower than expected.
**Solutions**:
1. **Merge adapter for deployment**:
```python
merged_model = model.merge_and_unload()
# No adapter overhead during inference
```
2. **Use optimized inference engine**:
```python
from vllm import LLM
llm = LLM(model="./merged-model", dtype="half")
```
3. **Enable Flash Attention**:
```python
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2"
)
```
### Output Quality Issues
**Problem**: Fine-tuned model produces worse outputs.
**Solutions**:
1. **Check evaluation without adapter**:
```python
with model.disable_adapter():
base_output = model.generate(**inputs)
# Compare with adapter output
```
2. **Lower temperature during eval**:
```python
model.generate(**inputs, temperature=0.1, do_sample=False)
```
3. **Retrain with more data**:
```python
# Increase training samples
# Use higher quality data
# Train for more epochs
```
### Wrong Adapter Active
**Problem**: Model using wrong adapter or no adapter.
**Fix**:
```python
# Check active adapters
print(model.active_adapters)
# Explicitly set adapter
model.set_adapter("your-adapter-name")
# List all adapters
print(model.peft_config.keys())
```
## QLoRA Specific Issues
### Quantization Errors
**Error**: `RuntimeError: mat1 and mat2 shapes cannot be multiplied`
**Fix**:
```python
# Ensure compute dtype matches
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16, # Match model dtype
bnb_4bit_quant_type="nf4"
)
# Load with correct dtype
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16
)
```
### QLoRA OOM
**Error**: OOM even with 4-bit quantization.
**Fix**:
```python
# Enable double quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True # Further memory reduction
)
# Use offloading
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
max_memory={0: "20GB", "cpu": "100GB"}
)
```
### QLoRA Merge Fails
**Error**: `RuntimeError: expected scalar type BFloat16 but found Float`
**Fix**:
```python
# Dequantize before merging
from peft import PeftModel
# Load in higher precision for merging
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.float16, # Not quantized
device_map="auto"
)
# Load adapter
model = PeftModel.from_pretrained(base_model, "./qlora-adapter")
# Now merge
merged = model.merge_and_unload()
```
## Multi-Adapter Issues
### Adapter Conflict
**Error**: `ValueError: Adapter with name 'default' already exists`
**Fix**:
```python
# Use unique names
model.load_adapter("./adapter1", adapter_name="task1")
model.load_adapter("./adapter2", adapter_name="task2")
# Or delete existing
model.delete_adapter("default")
```
### Mixed Precision Adapters
**Error**: Adapters trained with different dtypes.
**Fix**:
```python
# Convert adapter precision
model = PeftModel.from_pretrained(base_model, "./adapter")
model = model.to(torch.bfloat16)
# Or load with specific dtype
model = PeftModel.from_pretrained(
base_model,
"./adapter",
torch_dtype=torch.bfloat16
)
```
## Performance Optimization
### Memory Profiling
```python
import torch
def print_memory():
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1e9
reserved = torch.cuda.memory_reserved() / 1e9
print(f"Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
# Profile during training
print_memory() # Before
model.train()
loss = model(**batch).loss
loss.backward()
print_memory() # After
```
### Speed Profiling
```python
import time
import torch
def benchmark_generation(model, tokenizer, prompt, n_runs=5):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Warmup
model.generate(**inputs, max_new_tokens=10)
torch.cuda.synchronize()
# Benchmark
times = []
for _ in range(n_runs):
start = time.perf_counter()
outputs = model.generate(**inputs, max_new_tokens=100)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
tokens = outputs.shape[1] - inputs.input_ids.shape[1]
avg_time = sum(times) / len(times)
print(f"Speed: {tokens/avg_time:.2f} tokens/sec")
# Compare adapter vs merged
benchmark_generation(adapter_model, tokenizer, "Hello")
benchmark_generation(merged_model, tokenizer, "Hello")
```
## Getting Help
1. **Check PEFT GitHub Issues**: https://github.com/huggingface/peft/issues
2. **HuggingFace Forums**: https://discuss.huggingface.co/
3. **PEFT Documentation**: https://huggingface.co/docs/peft
### Debugging Template
When reporting issues, include:
```python
# System info
import peft
import transformers
import torch
print(f"PEFT: {peft.__version__}")
print(f"Transformers: {transformers.__version__}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
# Config
print(model.peft_config)
model.print_trainable_parameters()
```

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,7 @@
# Pytorch-Fsdp Documentation Index
## Categories
### Other
**File:** `other.md`
**Pages:** 15

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,349 @@
---
name: pytorch-lightning
description: High-level PyTorch framework with Trainer class, automatic distributed training (DDP/FSDP/DeepSpeed), callbacks system, and minimal boilerplate. Scales from laptop to supercomputer with same code. Use when you want clean training loops with built-in best practices.
version: 1.0.0
author: Orchestra Research
license: MIT
dependencies: [lightning, torch, transformers]
metadata:
hermes:
tags: [PyTorch Lightning, Training Framework, Distributed Training, DDP, FSDP, DeepSpeed, High-Level API, Callbacks, Best Practices, Scalable]
---
# PyTorch Lightning - High-Level Training Framework
## Quick start
PyTorch Lightning organizes PyTorch code to eliminate boilerplate while maintaining flexibility.
**Installation**:
```bash
pip install lightning
```
**Convert PyTorch to Lightning** (3 steps):
```python
import lightning as L
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
# Step 1: Define LightningModule (organize your PyTorch code)
class LitModel(L.LightningModule):
def __init__(self, hidden_size=128):
super().__init__()
self.model = nn.Sequential(
nn.Linear(28 * 28, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 10)
)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = nn.functional.cross_entropy(y_hat, y)
self.log('train_loss', loss) # Auto-logged to TensorBoard
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# Step 2: Create data
train_loader = DataLoader(train_dataset, batch_size=32)
# Step 3: Train with Trainer (handles everything else!)
trainer = L.Trainer(max_epochs=10, accelerator='gpu', devices=2)
model = LitModel()
trainer.fit(model, train_loader)
```
**That's it!** Trainer handles:
- GPU/TPU/CPU switching
- Distributed training (DDP, FSDP, DeepSpeed)
- Mixed precision (FP16, BF16)
- Gradient accumulation
- Checkpointing
- Logging
- Progress bars
## Common workflows
### Workflow 1: From PyTorch to Lightning
**Original PyTorch code**:
```python
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
model.to('cuda')
for epoch in range(max_epochs):
for batch in train_loader:
batch = batch.to('cuda')
optimizer.zero_grad()
loss = model(batch)
loss.backward()
optimizer.step()
```
**Lightning version**:
```python
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = MyModel()
def training_step(self, batch, batch_idx):
loss = self.model(batch) # No .to('cuda') needed!
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
# Train
trainer = L.Trainer(max_epochs=10, accelerator='gpu')
trainer.fit(LitModel(), train_loader)
```
**Benefits**: 40+ lines → 15 lines, no device management, automatic distributed
### Workflow 2: Validation and testing
```python
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = MyModel()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = nn.functional.cross_entropy(y_hat, y)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
val_loss = nn.functional.cross_entropy(y_hat, y)
acc = (y_hat.argmax(dim=1) == y).float().mean()
self.log('val_loss', val_loss)
self.log('val_acc', acc)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
test_loss = nn.functional.cross_entropy(y_hat, y)
self.log('test_loss', test_loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# Train with validation
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, train_loader, val_loader)
# Test
trainer.test(model, test_loader)
```
**Automatic features**:
- Validation runs every epoch by default
- Metrics logged to TensorBoard
- Best model checkpointing based on val_loss
### Workflow 3: Distributed training (DDP)
```python
# Same code as single GPU!
model = LitModel()
# 8 GPUs with DDP (automatic!)
trainer = L.Trainer(
accelerator='gpu',
devices=8,
strategy='ddp' # Or 'fsdp', 'deepspeed'
)
trainer.fit(model, train_loader)
```
**Launch**:
```bash
# Single command, Lightning handles the rest
python train.py
```
**No changes needed**:
- Automatic data distribution
- Gradient synchronization
- Multi-node support (just set `num_nodes=2`)
### Workflow 4: Callbacks for monitoring
```python
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
# Create callbacks
checkpoint = ModelCheckpoint(
monitor='val_loss',
mode='min',
save_top_k=3,
filename='model-{epoch:02d}-{val_loss:.2f}'
)
early_stop = EarlyStopping(
monitor='val_loss',
patience=5,
mode='min'
)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
# Add to Trainer
trainer = L.Trainer(
max_epochs=100,
callbacks=[checkpoint, early_stop, lr_monitor]
)
trainer.fit(model, train_loader, val_loader)
```
**Result**:
- Auto-saves best 3 models
- Stops early if no improvement for 5 epochs
- Logs learning rate to TensorBoard
### Workflow 5: Learning rate scheduling
```python
class LitModel(L.LightningModule):
# ... (training_step, etc.)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
# Cosine annealing
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=100,
eta_min=1e-5
)
return {
'optimizer': optimizer,
'lr_scheduler': {
'scheduler': scheduler,
'interval': 'epoch', # Update per epoch
'frequency': 1
}
}
# Learning rate auto-logged!
trainer = L.Trainer(max_epochs=100)
trainer.fit(model, train_loader)
```
## When to use vs alternatives
**Use PyTorch Lightning when**:
- Want clean, organized code
- Need production-ready training loops
- Switching between single GPU, multi-GPU, TPU
- Want built-in callbacks and logging
- Team collaboration (standardized structure)
**Key advantages**:
- **Organized**: Separates research code from engineering
- **Automatic**: DDP, FSDP, DeepSpeed with 1 line
- **Callbacks**: Modular training extensions
- **Reproducible**: Less boilerplate = fewer bugs
- **Tested**: 1M+ downloads/month, battle-tested
**Use alternatives instead**:
- **Accelerate**: Minimal changes to existing code, more flexibility
- **Ray Train**: Multi-node orchestration, hyperparameter tuning
- **Raw PyTorch**: Maximum control, learning purposes
- **Keras**: TensorFlow ecosystem
## Common issues
**Issue: Loss not decreasing**
Check data and model setup:
```python
# Add to training_step
def training_step(self, batch, batch_idx):
if batch_idx == 0:
print(f"Batch shape: {batch[0].shape}")
print(f"Labels: {batch[1]}")
loss = ...
return loss
```
**Issue: Out of memory**
Reduce batch size or use gradient accumulation:
```python
trainer = L.Trainer(
accumulate_grad_batches=4, # Effective batch = batch_size × 4
precision='bf16' # Or 'fp16', reduces memory 50%
)
```
**Issue: Validation not running**
Ensure you pass val_loader:
```python
# WRONG
trainer.fit(model, train_loader)
# CORRECT
trainer.fit(model, train_loader, val_loader)
```
**Issue: DDP spawns multiple processes unexpectedly**
Lightning auto-detects GPUs. Explicitly set devices:
```python
# Test on CPU first
trainer = L.Trainer(accelerator='cpu', devices=1)
# Then GPU
trainer = L.Trainer(accelerator='gpu', devices=1)
```
## Advanced topics
**Callbacks**: See [references/callbacks.md](references/callbacks.md) for EarlyStopping, ModelCheckpoint, custom callbacks, and callback hooks.
**Distributed strategies**: See [references/distributed.md](references/distributed.md) for DDP, FSDP, DeepSpeed ZeRO integration, multi-node setup.
**Hyperparameter tuning**: See [references/hyperparameter-tuning.md](references/hyperparameter-tuning.md) for integration with Optuna, Ray Tune, and WandB sweeps.
## Hardware requirements
- **CPU**: Works (good for debugging)
- **Single GPU**: Works
- **Multi-GPU**: DDP (default), FSDP, or DeepSpeed
- **Multi-node**: DDP, FSDP, DeepSpeed
- **TPU**: Supported (8 cores)
- **Apple MPS**: Supported
**Precision options**:
- FP32 (default)
- FP16 (V100, older GPUs)
- BF16 (A100/H100, recommended)
- FP8 (H100)
## Resources
- Docs: https://lightning.ai/docs/pytorch/stable/
- GitHub: https://github.com/Lightning-AI/pytorch-lightning ⭐ 29,000+
- Version: 2.5.5+
- Examples: https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples
- Discord: https://discord.gg/lightning-ai
- Used by: Kaggle winners, research labs, production teams

Some files were not shown because too many files have changed in this diff Show More