Compare commits
161 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8416bc2142 | |||
| 48b5bc6038 | |||
| 4ff73fb32c | |||
| 73a88a02fe | |||
| f9c2565ab4 | |||
| ad5f973a8d | |||
| 0791efe2c3 | |||
| 934fbe3c06 | |||
| 6302e56e7c | |||
| 868b3c07e3 | |||
| 9d6148316c | |||
| 7da0822456 | |||
| d35df0db71 | |||
| 93dc5dee6f | |||
| 2d8fad8230 | |||
| ca2958ff98 | |||
| f60ebc7bf2 | |||
| b072737193 | |||
| 3b509da571 | |||
| 5ddb6a191f | |||
| 1b5fb36c9d | |||
| 942f6eac94 | |||
| 2b3c1d81f0 | |||
| 1f21ef7488 | |||
| b799bca7a3 | |||
| b2b4a9ee7d | |||
| ed805f57ff | |||
| fa6f069577 | |||
| cd2280d1a3 | |||
| 5e5ad634a1 | |||
| 55a27a3fb8 | |||
| 8587cddd6c | |||
| 2bd8e5cb23 | |||
| bfe4baa6ed | |||
| 72a6d7dffe | |||
| afe2f0abe1 | |||
| 09fd007c6e | |||
| 24cf2a7954 | |||
| be3eb62047 | |||
| 9c32fed184 | |||
| 6435d69a6d | |||
| a2276177a3 | |||
| ebd0291ef2 | |||
| 0510ee056d | |||
| 44b572a9e0 | |||
| f9c2ad48c2 | |||
| c275aa4732 | |||
| ff071fc74c | |||
| 8d528e0045 | |||
| fd32e3d6e8 | |||
| 34be3f8be6 | |||
| 3037450c77 | |||
| b7091f93b1 | |||
| ab3cbfc99d | |||
| 26030266d2 | |||
| edda0e324b | |||
| 5407d12bc6 | |||
| 2de42ba690 | |||
| f3301a31d5 | |||
| e6a708aa04 | |||
| e80489135b | |||
| a53db44d40 | |||
| 0698ddb496 | |||
| 0962cbb2e5 | |||
| f69c47d9ae | |||
| 027fc1a85a | |||
| f84230527c | |||
| 0e64a48743 | |||
| ffa8b562e9 | |||
| 56b0104154 | |||
| c0c13e4ed4 | |||
| 89befcaf33 | |||
| 0f1c970179 | |||
| 57d3ac0c0b | |||
| a9f9c60efd | |||
| e109a8b502 | |||
| b81926def6 | |||
| 8cb7864110 | |||
| 7cd9f9ed48 | |||
| 2c2334d4db | |||
| 21ffadc2a6 | |||
| 241f966b1a | |||
| 7d0e4510b8 | |||
| 306e67f32d | |||
| 5c8d7d5d6f | |||
| 0b370f2dd9 | |||
| 887e8a8d84 | |||
| 189214a69d | |||
| cd6d24f111 | |||
| c01cfe4f9a | |||
| fbbe9e6030 | |||
| 43bca6d107 | |||
| 669c60a6bb | |||
| dd39003a9b | |||
| 4bded44b6a | |||
| ec22635b47 | |||
| 29d0541ac9 | |||
| a0f411c87d | |||
| 862d5224dd | |||
| e664bc7632 | |||
| f9052d7ecf | |||
| 7dff34ba4e | |||
| dbc25a386e | |||
| 0ea7d0ec80 | |||
| 1d28b4699b | |||
| e0ca46cd73 | |||
| 5454a55269 | |||
| 40c9a13476 | |||
| bd49bce278 | |||
| 52dd479214 | |||
| c57d5cbdde | |||
| 525caadd8c | |||
| f9fa7421cb | |||
| 342096b4bd | |||
| 55510cbad2 | |||
| 3ab50376b0 | |||
| f8fb61d4ad | |||
| 0d68446323 | |||
| 81dbf4309a | |||
| febfe1c268 | |||
| 2a5f86ed6d | |||
| d3659c8ca0 | |||
| f7f75de7c3 | |||
| f58902818d | |||
| 8da410ed95 | |||
| da44c196b6 | |||
| 36079c6646 | |||
| 135448f513 | |||
| 2e143fd15c | |||
| 0b9526b476 | |||
| f304bc63b8 | |||
| decc7851f2 | |||
| 97108db038 | |||
| 1f1fa71d0c | |||
| 2988334fe5 | |||
| 292d12bed4 | |||
| 509cff6e5c | |||
| 29520df44f | |||
| 9be42e49f9 | |||
| 42cef9c282 | |||
| 3a71099dac | |||
| 356122e990 | |||
| aefcdd6f7f | |||
| 3835a8d5df | |||
| e8188a56c7 | |||
| c42a18e9e5 | |||
| b73d221324 | |||
| cc51ffdb57 | |||
| c8971db435 | |||
| fb48b8f0c5 | |||
| 67600d0a0b | |||
| 5a9ab09bc3 | |||
| 2c06ec5f51 | |||
| fff7203049 | |||
| 5663980015 | |||
| 8304a7716d | |||
| 523d8c38f9 | |||
| e6299960cc | |||
| fb6d41237c | |||
| e183744cb5 | |||
| f4a74d3ac7 |
@@ -0,0 +1,422 @@
|
||||
# Hermes Agent v0.4.0 (v2026.3.23)
|
||||
|
||||
**Release Date:** March 23, 2026
|
||||
|
||||
> The biggest release yet — 300 merged PRs in one week. Streaming output, native browser tools, Skills Hub, plugin system, 7 new messaging platforms, MCP server management, @ context references, prompt caching, API server, and a sweeping reliability overhaul across every subsystem.
|
||||
|
||||
---
|
||||
|
||||
## ✨ Highlights
|
||||
|
||||
- **Streaming CLI output** — Real-time token streaming enabled by default in CLI mode with proper tool progress spinners during streaming ([#2251](https://github.com/NousResearch/hermes-agent/pull/2251), [#2340](https://github.com/NousResearch/hermes-agent/pull/2340), [#2161](https://github.com/NousResearch/hermes-agent/pull/2161))
|
||||
- **Native browser tools** — Full Browserbase-powered browser automation: navigate, click, type, screenshot, scrape — plus an interactive `/browser` CLI command ([#2270](https://github.com/NousResearch/hermes-agent/pull/2270), [#2273](https://github.com/NousResearch/hermes-agent/pull/2273))
|
||||
- **Skills Hub** — Discover, install, and manage skills from curated community taps with `/skills` commands ([#2235](https://github.com/NousResearch/hermes-agent/pull/2235))
|
||||
- **Plugin system** — TUI extension hooks for building custom CLIs on top of Hermes, plus `hermes plugins install/remove/list` commands and slash command registration for plugins ([#2333](https://github.com/NousResearch/hermes-agent/pull/2333), [#2337](https://github.com/NousResearch/hermes-agent/pull/2337), [#2359](https://github.com/NousResearch/hermes-agent/pull/2359))
|
||||
- **7 new messaging platforms** — Signal, DingTalk, SMS (Twilio), Mattermost, Matrix, WhatsApp bridge, and Webhook adapters join Telegram and Discord ([#2206](https://github.com/NousResearch/hermes-agent/pull/2206), [#1685](https://github.com/NousResearch/hermes-agent/pull/1685), [#1688](https://github.com/NousResearch/hermes-agent/pull/1688), [#1683](https://github.com/NousResearch/hermes-agent/pull/1683), [#2168](https://github.com/NousResearch/hermes-agent/pull/2168), [#2166](https://github.com/NousResearch/hermes-agent/pull/2166))
|
||||
- **@ context references** — Claude Code-style `@file` and `@url` context injection with tab completions ([#2343](https://github.com/NousResearch/hermes-agent/pull/2343), [#2482](https://github.com/NousResearch/hermes-agent/pull/2482))
|
||||
- **OpenAI-compatible API server** — Expose Hermes as an API endpoint with `/api/jobs` for cron management ([#1756](https://github.com/NousResearch/hermes-agent/pull/1756), [#2450](https://github.com/NousResearch/hermes-agent/pull/2450))
|
||||
|
||||
---
|
||||
|
||||
## 🏗️ Core Agent & Architecture
|
||||
|
||||
### Provider & Model Support
|
||||
- **GitHub Copilot provider** — Full OAuth auth, API routing, token validation, and documentation. Copilot context now correctly resolves to 400k ([#1924](https://github.com/NousResearch/hermes-agent/pull/1924), [#1896](https://github.com/NousResearch/hermes-agent/pull/1896), [#1879](https://github.com/NousResearch/hermes-agent/pull/1879) by @mchzimm, [#2507](https://github.com/NousResearch/hermes-agent/pull/2507))
|
||||
- **Claude Code OAuth provider** — Anthropic-native API mode with dynamic version detection for OAuth user-agent ([#2199](https://github.com/NousResearch/hermes-agent/pull/2199), [#1663](https://github.com/NousResearch/hermes-agent/pull/1663), [#1670](https://github.com/NousResearch/hermes-agent/pull/1670))
|
||||
- **Alibaba Cloud / DashScope provider** — Full integration with DashScope v1 runtime mode, model dot preservation, and 401 auth fixes ([#1673](https://github.com/NousResearch/hermes-agent/pull/1673), [#2332](https://github.com/NousResearch/hermes-agent/pull/2332), [#2459](https://github.com/NousResearch/hermes-agent/pull/2459))
|
||||
- **Kilo Code provider** — Added as first-class inference provider ([#1666](https://github.com/NousResearch/hermes-agent/pull/1666))
|
||||
- **OpenCode Zen and OpenCode Go providers** — New provider backends with custom endpoint support ([#1650](https://github.com/NousResearch/hermes-agent/pull/1650), [#2393](https://github.com/NousResearch/hermes-agent/pull/2393) by @0xbyt4)
|
||||
- **Multi-provider architecture** — Automatic fallback, OpenRouter routing backend, Mistral native tool calling, Google Gemini integration ([#2090](https://github.com/NousResearch/hermes-agent/pull/2090), [#2100](https://github.com/NousResearch/hermes-agent/pull/2100), [#2098](https://github.com/NousResearch/hermes-agent/pull/2098), [#2094](https://github.com/NousResearch/hermes-agent/pull/2094), [#2092](https://github.com/NousResearch/hermes-agent/pull/2092))
|
||||
- **Eager fallback to backup model** on rate-limit errors ([#1730](https://github.com/NousResearch/hermes-agent/pull/1730))
|
||||
- **Endpoint metadata** for custom model context and pricing; query local servers for actual context window size ([#1906](https://github.com/NousResearch/hermes-agent/pull/1906), [#2091](https://github.com/NousResearch/hermes-agent/pull/2091) by @dusterbloom)
|
||||
- **Context length detection overhaul** — models.dev integration, provider-aware resolution, fuzzy matching for custom endpoints, `/v1/props` for llama.cpp ([#2158](https://github.com/NousResearch/hermes-agent/pull/2158), [#2051](https://github.com/NousResearch/hermes-agent/pull/2051), [#2403](https://github.com/NousResearch/hermes-agent/pull/2403))
|
||||
- **Model catalog updates** — gpt-5.4-mini, gpt-5.4-nano, healer-alpha, haiku-4.5, minimax-m2.7, claude 4.6 at 1M context ([#1913](https://github.com/NousResearch/hermes-agent/pull/1913), [#1915](https://github.com/NousResearch/hermes-agent/pull/1915), [#1900](https://github.com/NousResearch/hermes-agent/pull/1900), [#2155](https://github.com/NousResearch/hermes-agent/pull/2155), [#2474](https://github.com/NousResearch/hermes-agent/pull/2474))
|
||||
- **Custom endpoint improvements** — config.yaml `model.base_url` support, custom endpoints use responses API via `api_mode` override, allow custom/local endpoints without API key, fail fast when explicit provider has no key ([#2330](https://github.com/NousResearch/hermes-agent/pull/2330), [#1651](https://github.com/NousResearch/hermes-agent/pull/1651), [#2556](https://github.com/NousResearch/hermes-agent/pull/2556), [#2445](https://github.com/NousResearch/hermes-agent/pull/2445), [#1994](https://github.com/NousResearch/hermes-agent/pull/1994), [#1998](https://github.com/NousResearch/hermes-agent/pull/1998))
|
||||
- Inject model and provider into system prompt ([#1929](https://github.com/NousResearch/hermes-agent/pull/1929))
|
||||
- Fix: prevent Anthropic token leaking to third-party `anthropic_messages` providers ([#2389](https://github.com/NousResearch/hermes-agent/pull/2389))
|
||||
- Fix: prevent Anthropic fallback from inheriting non-Anthropic `base_url` ([#2388](https://github.com/NousResearch/hermes-agent/pull/2388))
|
||||
- Fix: `auxiliary_is_nous` flag never resets — leaked Nous tags to other providers ([#1713](https://github.com/NousResearch/hermes-agent/pull/1713))
|
||||
- Fix: Anthropic `tool_choice 'none'` still allowed tool calls ([#1714](https://github.com/NousResearch/hermes-agent/pull/1714))
|
||||
- Fix: Mistral parser nested JSON fallback extraction ([#2335](https://github.com/NousResearch/hermes-agent/pull/2335))
|
||||
- Fix: MiniMax 401 auth error resolved by defaulting to `anthropic_messages` ([#2103](https://github.com/NousResearch/hermes-agent/pull/2103))
|
||||
- Fix: case-insensitive model family matching ([#2350](https://github.com/NousResearch/hermes-agent/pull/2350))
|
||||
- Fix: ignore placeholder provider keys in activation checks ([#2358](https://github.com/NousResearch/hermes-agent/pull/2358))
|
||||
- Fix: Copilot models response decoding and provider bootstrap error logging ([#2202](https://github.com/NousResearch/hermes-agent/pull/2202))
|
||||
- Fix: Preserve Ollama model:tag colons in context length detection ([#2149](https://github.com/NousResearch/hermes-agent/pull/2149))
|
||||
|
||||
### Agent Loop & Conversation
|
||||
- **Streaming output** — CLI streaming with proper linebreak handling, iteration boundary prevention, and blank line stacking fixes ([#2251](https://github.com/NousResearch/hermes-agent/pull/2251), [#2340](https://github.com/NousResearch/hermes-agent/pull/2340), [#2258](https://github.com/NousResearch/hermes-agent/pull/2258), [#2413](https://github.com/NousResearch/hermes-agent/pull/2413), [#2473](https://github.com/NousResearch/hermes-agent/pull/2473))
|
||||
- **Context compression overhaul** — Structured summaries, iterative updates, token-budget tail protection, fallback model support ([#2323](https://github.com/NousResearch/hermes-agent/pull/2323), [#2128](https://github.com/NousResearch/hermes-agent/pull/2128), [#2224](https://github.com/NousResearch/hermes-agent/pull/2224), [#1727](https://github.com/NousResearch/hermes-agent/pull/1727))
|
||||
- **Context pressure warnings** for CLI and gateway ([#2159](https://github.com/NousResearch/hermes-agent/pull/2159))
|
||||
- **Prompt caching for gateway** — Cache AIAgent per session, keep assistant turns, fix session restore ([#2282](https://github.com/NousResearch/hermes-agent/pull/2282), [#2284](https://github.com/NousResearch/hermes-agent/pull/2284), [#2361](https://github.com/NousResearch/hermes-agent/pull/2361))
|
||||
- **Show reasoning/thinking blocks** when `show_reasoning` is enabled ([#2118](https://github.com/NousResearch/hermes-agent/pull/2118))
|
||||
- **Subagent delegation** for parallel task execution with thread safety ([#2119](https://github.com/NousResearch/hermes-agent/pull/2119), [#1672](https://github.com/NousResearch/hermes-agent/pull/1672), [#1778](https://github.com/NousResearch/hermes-agent/pull/1778))
|
||||
- **Pre-call sanitization and post-call tool guardrails** ([#1732](https://github.com/NousResearch/hermes-agent/pull/1732))
|
||||
- **Auto-recover** from provider-rejected `tool_choice` by retrying without ([#2174](https://github.com/NousResearch/hermes-agent/pull/2174))
|
||||
- **Rate limit handling** with exponential backoff retry ([#2071](https://github.com/NousResearch/hermes-agent/pull/2071))
|
||||
- Fix: prevent silent tool result loss during context compression ([#1993](https://github.com/NousResearch/hermes-agent/pull/1993))
|
||||
- Fix: handle empty/null function arguments in tool call recovery ([#2163](https://github.com/NousResearch/hermes-agent/pull/2163))
|
||||
- Fix: handle API refusal responses gracefully instead of crashing ([#2156](https://github.com/NousResearch/hermes-agent/pull/2156))
|
||||
- Fix: prevent stuck agent loop on malformed tool calls ([#2114](https://github.com/NousResearch/hermes-agent/pull/2114))
|
||||
- Fix: return JSON parse error to model instead of dispatching with empty args ([#2342](https://github.com/NousResearch/hermes-agent/pull/2342))
|
||||
- Fix: consecutive assistant message merge drops content on mixed types ([#1703](https://github.com/NousResearch/hermes-agent/pull/1703))
|
||||
- Fix: message role alternation violations in JSON recovery and error handler ([#1722](https://github.com/NousResearch/hermes-agent/pull/1722))
|
||||
- Fix: `compression_attempts` resets each iteration — allowed unlimited compressions ([#1723](https://github.com/NousResearch/hermes-agent/pull/1723))
|
||||
- Fix: `length_continue_retries` never resets — later truncations got fewer retries ([#1717](https://github.com/NousResearch/hermes-agent/pull/1717))
|
||||
- Fix: compressor summary role violated consecutive-role constraint ([#1720](https://github.com/NousResearch/hermes-agent/pull/1720), [#1743](https://github.com/NousResearch/hermes-agent/pull/1743))
|
||||
- Fix: correctly handle empty tool results ([#2201](https://github.com/NousResearch/hermes-agent/pull/2201))
|
||||
- Fix: crash on None entry in `tool_calls` list during Anthropic conversion ([#2209](https://github.com/NousResearch/hermes-agent/pull/2209) by @0xbyt4, [#2316](https://github.com/NousResearch/hermes-agent/pull/2316))
|
||||
- Fix: per-thread persistent event loops in worker threads ([#2214](https://github.com/NousResearch/hermes-agent/pull/2214) by @jquesnelle)
|
||||
- Fix: prevent 'event loop already running' when async tools run in parallel ([#2207](https://github.com/NousResearch/hermes-agent/pull/2207))
|
||||
- Fix: strip ANSI escape codes from terminal output before sending to model ([#2115](https://github.com/NousResearch/hermes-agent/pull/2115), [#2585](https://github.com/NousResearch/hermes-agent/pull/2585))
|
||||
- Fix: skip top-level `cache_control` on role:tool for OpenRouter ([#2391](https://github.com/NousResearch/hermes-agent/pull/2391))
|
||||
- Fix: delegate tool — save parent tool names before child construction mutates global ([#2083](https://github.com/NousResearch/hermes-agent/pull/2083) by @ygd58, [#1894](https://github.com/NousResearch/hermes-agent/pull/1894))
|
||||
- Fix: only strip last assistant message if empty string ([#2326](https://github.com/NousResearch/hermes-agent/pull/2326))
|
||||
|
||||
### Session & Memory
|
||||
- **Honcho long-term memory backend** integration ([#2276](https://github.com/NousResearch/hermes-agent/pull/2276))
|
||||
- **Per-session SQLite persistence** for gateway ([#2134](https://github.com/NousResearch/hermes-agent/pull/2134))
|
||||
- **`--resume` flag** for CLI session persistence across restarts + `/resume` and `/sessions` commands ([#2135](https://github.com/NousResearch/hermes-agent/pull/2135), [#2143](https://github.com/NousResearch/hermes-agent/pull/2143))
|
||||
- **Session search** and management slash commands ([#2198](https://github.com/NousResearch/hermes-agent/pull/2198))
|
||||
- **Auto session titles** and `.hermes.md` project config ([#1712](https://github.com/NousResearch/hermes-agent/pull/1712))
|
||||
- **Background memory/skill review** replaces inline nudges ([#2235](https://github.com/NousResearch/hermes-agent/pull/2235))
|
||||
- **SOUL.md** as primary agent identity instead of hardcoded default ([#1922](https://github.com/NousResearch/hermes-agent/pull/1922), [#1927](https://github.com/NousResearch/hermes-agent/pull/1927))
|
||||
- **Priority-based context file selection** + CLAUDE.md support ([#2301](https://github.com/NousResearch/hermes-agent/pull/2301))
|
||||
- Fix: concurrent memory writes silently drop entries — added file locking ([#1726](https://github.com/NousResearch/hermes-agent/pull/1726))
|
||||
- Fix: search all sources by default in `session_search` ([#1892](https://github.com/NousResearch/hermes-agent/pull/1892))
|
||||
- Fix: handle hyphenated FTS5 queries and preserve quoted literals ([#1776](https://github.com/NousResearch/hermes-agent/pull/1776))
|
||||
- Fix: skip corrupt lines in `load_transcript` instead of crashing ([#1744](https://github.com/NousResearch/hermes-agent/pull/1744))
|
||||
- Fix: normalize session keys to prevent case-sensitive duplicates ([#2157](https://github.com/NousResearch/hermes-agent/pull/2157))
|
||||
- Fix: prevent `session_search` crash when no sessions exist ([#2194](https://github.com/NousResearch/hermes-agent/pull/2194))
|
||||
- Fix: reset token counters on new session for accurate usage display ([#2101](https://github.com/NousResearch/hermes-agent/pull/2101) by @InB4DevOps)
|
||||
- Fix: prevent stale memory overwrites by flush agent ([#2687](https://github.com/NousResearch/hermes-agent/pull/2687))
|
||||
- Fix: remove synthetic error message injection, fix session resume after repeated failures ([#2303](https://github.com/NousResearch/hermes-agent/pull/2303))
|
||||
|
||||
---
|
||||
|
||||
## 📱 Messaging Platforms (Gateway)
|
||||
|
||||
### New Platform Adapters
|
||||
- **Signal Messenger** adapter ([#2206](https://github.com/NousResearch/hermes-agent/pull/2206)) with attachment handling fix ([#2400](https://github.com/NousResearch/hermes-agent/pull/2400)), group message filtering ([#2297](https://github.com/NousResearch/hermes-agent/pull/2297)), and Note to Self echo-back protection ([#2156](https://github.com/NousResearch/hermes-agent/pull/2156))
|
||||
- **DingTalk** adapter with gateway wiring and setup docs ([#1685](https://github.com/NousResearch/hermes-agent/pull/1685), [#1690](https://github.com/NousResearch/hermes-agent/pull/1690), [#1692](https://github.com/NousResearch/hermes-agent/pull/1692))
|
||||
- **SMS (Twilio)** adapter ([#1688](https://github.com/NousResearch/hermes-agent/pull/1688))
|
||||
- **Mattermost and Matrix** adapters with @-mention-only filter for Mattermost channels ([#1683](https://github.com/NousResearch/hermes-agent/pull/1683), [#2443](https://github.com/NousResearch/hermes-agent/pull/2443))
|
||||
- **WhatsApp bridge** adapter ([#2168](https://github.com/NousResearch/hermes-agent/pull/2168))
|
||||
- **Webhook** platform adapter for external event triggers ([#2166](https://github.com/NousResearch/hermes-agent/pull/2166))
|
||||
- **OpenAI-compatible API server** platform adapter with `/api/jobs` cron management endpoints ([#1756](https://github.com/NousResearch/hermes-agent/pull/1756), [#2450](https://github.com/NousResearch/hermes-agent/pull/2450), [#2456](https://github.com/NousResearch/hermes-agent/pull/2456))
|
||||
|
||||
### Telegram
|
||||
- Auto-detect HTML tags and use `parse_mode=HTML` in `send_message` ([#1709](https://github.com/NousResearch/hermes-agent/pull/1709))
|
||||
- Telegram group vision support + thread-based sessions ([#2153](https://github.com/NousResearch/hermes-agent/pull/2153))
|
||||
- MarkdownV2 support — strikethrough, spoiler, blockquotes, escape parentheses/braces/backslashes/backticks ([#2199](https://github.com/NousResearch/hermes-agent/pull/2199), [#2200](https://github.com/NousResearch/hermes-agent/pull/2200) by @llbn, [#2386](https://github.com/NousResearch/hermes-agent/pull/2386))
|
||||
- Auto-reconnect polling after network interruption ([#2517](https://github.com/NousResearch/hermes-agent/pull/2517))
|
||||
- Aggregate split text messages before dispatching ([#1674](https://github.com/NousResearch/hermes-agent/pull/1674))
|
||||
- Fix: streaming config bridge, not-modified, flood control ([#1782](https://github.com/NousResearch/hermes-agent/pull/1782), [#1783](https://github.com/NousResearch/hermes-agent/pull/1783))
|
||||
- Fix: edited_message event crashes ([#2074](https://github.com/NousResearch/hermes-agent/pull/2074))
|
||||
- Fix: retry 409 polling conflicts before giving up ([#2312](https://github.com/NousResearch/hermes-agent/pull/2312))
|
||||
- Fix: Telegram topic delivery via `platform:chat_id:thread_id` format ([#2455](https://github.com/NousResearch/hermes-agent/pull/2455))
|
||||
|
||||
### Discord
|
||||
- Document caching and text-file injection ([#2503](https://github.com/NousResearch/hermes-agent/pull/2503))
|
||||
- Persistent typing indicator for DMs ([#2468](https://github.com/NousResearch/hermes-agent/pull/2468))
|
||||
- Discord DM vision support — inline images + attachment analysis ([#2186](https://github.com/NousResearch/hermes-agent/pull/2186))
|
||||
- Persist thread participation across gateway restarts ([#1661](https://github.com/NousResearch/hermes-agent/pull/1661))
|
||||
- Fix: prevent gateway crash on non-ASCII guild names ([#2302](https://github.com/NousResearch/hermes-agent/pull/2302))
|
||||
- Fix: handle thread permission errors gracefully ([#2073](https://github.com/NousResearch/hermes-agent/pull/2073))
|
||||
- Fix: properly route slash event handling in threads ([#2460](https://github.com/NousResearch/hermes-agent/pull/2460))
|
||||
- Fix: remove bugged followup messages + remove `/ask` command ([#1836](https://github.com/NousResearch/hermes-agent/pull/1836))
|
||||
- Fix: handle graceful reconnection on WebSocket errors ([#2127](https://github.com/NousResearch/hermes-agent/pull/2127))
|
||||
- Fix: voice channel TTS not working when streaming enabled ([#2322](https://github.com/NousResearch/hermes-agent/pull/2322))
|
||||
|
||||
### Other Platforms
|
||||
- WhatsApp: outbound `send_message` routing ([#1769](https://github.com/NousResearch/hermes-agent/pull/1769) by @sai-samarth), LID format self-chat support ([#1667](https://github.com/NousResearch/hermes-agent/pull/1667)), `reply_prefix` config bridging fix ([#1923](https://github.com/NousResearch/hermes-agent/pull/1923)), restart on bridge child exit ([#2334](https://github.com/NousResearch/hermes-agent/pull/2334)), image/bridge improvements ([#2181](https://github.com/NousResearch/hermes-agent/pull/2181))
|
||||
- Matrix: duplicate messages and image caching for vision support ([#2520](https://github.com/NousResearch/hermes-agent/pull/2520)), correct `reply_to_message_id` parameter ([#1895](https://github.com/NousResearch/hermes-agent/pull/1895)), bare media types fix ([#1736](https://github.com/NousResearch/hermes-agent/pull/1736))
|
||||
- Mattermost: MIME types for media attachments ([#2329](https://github.com/NousResearch/hermes-agent/pull/2329))
|
||||
|
||||
### Gateway Core
|
||||
- **Multi-platform gateway** support (Discord + Telegram + all adapters) ([#2125](https://github.com/NousResearch/hermes-agent/pull/2125))
|
||||
- **Auto-reconnect** failed platforms with exponential backoff ([#2584](https://github.com/NousResearch/hermes-agent/pull/2584))
|
||||
- **Notify users when session auto-resets** ([#2519](https://github.com/NousResearch/hermes-agent/pull/2519))
|
||||
- **`/queue` command** to queue prompts without interrupting ([#2191](https://github.com/NousResearch/hermes-agent/pull/2191))
|
||||
- **Inject reply-to message context** for out-of-session replies ([#1662](https://github.com/NousResearch/hermes-agent/pull/1662))
|
||||
- **Replace bare text approval** with `/approve` and `/deny` commands ([#2002](https://github.com/NousResearch/hermes-agent/pull/2002))
|
||||
- **Support ignoring unauthorized gateway DMs** ([#1919](https://github.com/NousResearch/hermes-agent/pull/1919))
|
||||
- **Configurable approvals** in gateway + `/cost` command with live pricing ([#2180](https://github.com/NousResearch/hermes-agent/pull/2180))
|
||||
- Fix: prevent duplicate session-key collision in multi-platform gateway ([#2171](https://github.com/NousResearch/hermes-agent/pull/2171))
|
||||
- Fix: `/reset` in thread-mode resets global session instead of thread ([#2254](https://github.com/NousResearch/hermes-agent/pull/2254))
|
||||
- Fix: deliver MEDIA: files after streaming responses ([#2382](https://github.com/NousResearch/hermes-agent/pull/2382))
|
||||
- Fix: cap interrupt recursion depth to prevent resource exhaustion ([#1659](https://github.com/NousResearch/hermes-agent/pull/1659))
|
||||
- Fix: detect stopped processes and release stale locks on `--replace` ([#2406](https://github.com/NousResearch/hermes-agent/pull/2406), [#1908](https://github.com/NousResearch/hermes-agent/pull/1908))
|
||||
- Fix: PID-based wait with force-kill for gateway restart ([#1902](https://github.com/NousResearch/hermes-agent/pull/1902))
|
||||
- Fix: prevent `--replace` mode from killing the caller process ([#2185](https://github.com/NousResearch/hermes-agent/pull/2185))
|
||||
- Fix: `/model` shows active fallback model instead of config default ([#1660](https://github.com/NousResearch/hermes-agent/pull/1660))
|
||||
- Fix: `/title` command fails when session doesn't exist in SQLite yet ([#2379](https://github.com/NousResearch/hermes-agent/pull/2379) by @ten-jampa)
|
||||
- Fix: process `/queue`'d messages after agent completion ([#2469](https://github.com/NousResearch/hermes-agent/pull/2469))
|
||||
- Fix: strip orphaned `tool_results` + let `/reset` bypass running agent ([#2180](https://github.com/NousResearch/hermes-agent/pull/2180))
|
||||
- Fix: prevent agents from starting gateway outside systemd management ([#2617](https://github.com/NousResearch/hermes-agent/pull/2617))
|
||||
- Fix: prevent systemd restart storm on gateway connection failure ([#2327](https://github.com/NousResearch/hermes-agent/pull/2327))
|
||||
- Fix: include resolved node path in systemd unit ([#1767](https://github.com/NousResearch/hermes-agent/pull/1767) by @sai-samarth)
|
||||
- Fix: send error details to user in gateway outer exception handler ([#1966](https://github.com/NousResearch/hermes-agent/pull/1966))
|
||||
- Fix: improve gateway error handling for 429 usage limits and 500 context overflow ([#1839](https://github.com/NousResearch/hermes-agent/pull/1839))
|
||||
- Fix: add all missing platform allowlist env vars to startup warning check ([#2628](https://github.com/NousResearch/hermes-agent/pull/2628))
|
||||
- Fix: show startup banner with all env vars when `verbose_logging=true` ([#2298](https://github.com/NousResearch/hermes-agent/pull/2298))
|
||||
- Fix: webhook platform config loading from config.yaml ([#2328](https://github.com/NousResearch/hermes-agent/pull/2328))
|
||||
- Fix: media-group aggregation on rapid successive photo messages ([#2160](https://github.com/NousResearch/hermes-agent/pull/2160))
|
||||
- Fix: media delivery fails for file paths containing spaces ([#2621](https://github.com/NousResearch/hermes-agent/pull/2621))
|
||||
- Fix: Matrix and Mattermost never report as connected ([#1711](https://github.com/NousResearch/hermes-agent/pull/1711))
|
||||
- Fix: PII redaction config never read — missing yaml import ([#1701](https://github.com/NousResearch/hermes-agent/pull/1701))
|
||||
- Fix: NameError on skill slash commands ([#1697](https://github.com/NousResearch/hermes-agent/pull/1697))
|
||||
- Fix: persist watcher metadata in checkpoint for crash recovery ([#1706](https://github.com/NousResearch/hermes-agent/pull/1706))
|
||||
- Fix: pass `message_thread_id` in `send_image_file`, `send_document`, `send_video` ([#2339](https://github.com/NousResearch/hermes-agent/pull/2339))
|
||||
|
||||
---
|
||||
|
||||
## 🖥️ CLI & User Experience
|
||||
|
||||
### Interactive CLI
|
||||
- **@ context completions** — Claude Code-style `@file`/`@url` references with tab completion ([#2482](https://github.com/NousResearch/hermes-agent/pull/2482), [#2343](https://github.com/NousResearch/hermes-agent/pull/2343))
|
||||
- **Persistent config bar** in prompt with model + provider info + `/statusbar` toggle ([#2240](https://github.com/NousResearch/hermes-agent/pull/2240), [#1917](https://github.com/NousResearch/hermes-agent/pull/1917))
|
||||
- **`/permission` command** for dynamic approval mode switching ([#2207](https://github.com/NousResearch/hermes-agent/pull/2207))
|
||||
- **`/browser` command** for interactive browser sessions ([#2273](https://github.com/NousResearch/hermes-agent/pull/2273))
|
||||
- **`/tools` disable/enable/list** slash commands with session reset ([#1652](https://github.com/NousResearch/hermes-agent/pull/1652))
|
||||
- **`/model` command** for runtime model switching with live API probe for custom endpoints ([#2110](https://github.com/NousResearch/hermes-agent/pull/2110), [#1645](https://github.com/NousResearch/hermes-agent/pull/1645), [#2078](https://github.com/NousResearch/hermes-agent/pull/2078))
|
||||
- **Real-time config reload** — config.yaml changes apply without restart ([#2210](https://github.com/NousResearch/hermes-agent/pull/2210))
|
||||
- **Kitty keyboard protocol** Shift+Enter handling for Ghostty/WezTerm (reverted due to prompt_toolkit crash) ([#2345](https://github.com/NousResearch/hermes-agent/pull/2345), [#2349](https://github.com/NousResearch/hermes-agent/pull/2349))
|
||||
- Fix: prevent 'Press ENTER to continue...' on exit ([#2555](https://github.com/NousResearch/hermes-agent/pull/2555))
|
||||
- Fix: flush stdout during agent loop to prevent macOS display freeze ([#1654](https://github.com/NousResearch/hermes-agent/pull/1654))
|
||||
- Fix: show human-readable error when `hermes setup` hits permissions error ([#2196](https://github.com/NousResearch/hermes-agent/pull/2196))
|
||||
- Fix: `/stop` command crash + UnboundLocalError in streaming media delivery ([#2463](https://github.com/NousResearch/hermes-agent/pull/2463))
|
||||
- Fix: resolve garbled ANSI escape codes in status printouts ([#2448](https://github.com/NousResearch/hermes-agent/pull/2448))
|
||||
- Fix: normalize toolset labels and use skin colors in banner ([#1912](https://github.com/NousResearch/hermes-agent/pull/1912))
|
||||
- Fix: update gold ANSI color to true-color format ([#2246](https://github.com/NousResearch/hermes-agent/pull/2246))
|
||||
- Fix: suppress spinner animation in non-TTY environments ([#2216](https://github.com/NousResearch/hermes-agent/pull/2216))
|
||||
- Fix: display provider and endpoint in API error messages ([#2266](https://github.com/NousResearch/hermes-agent/pull/2266))
|
||||
|
||||
### Setup & Configuration
|
||||
- **YAML-based config** with backward-compatible env var fallback ([#2172](https://github.com/NousResearch/hermes-agent/pull/2172))
|
||||
- **`${ENV_VAR}` substitution** in config.yaml ([#2684](https://github.com/NousResearch/hermes-agent/pull/2684))
|
||||
- **`custom_models.yaml`** for user-managed model additions ([#2214](https://github.com/NousResearch/hermes-agent/pull/2214))
|
||||
- **Merge nested YAML sections** instead of replacing ([#2213](https://github.com/NousResearch/hermes-agent/pull/2213))
|
||||
- Fix: log warning instead of silently swallowing config.yaml errors ([#2683](https://github.com/NousResearch/hermes-agent/pull/2683))
|
||||
- Fix: config.yaml provider key overrides env var silently ([#2272](https://github.com/NousResearch/hermes-agent/pull/2272))
|
||||
- Fix: `hermes update` use `.[all]` extras with fallback ([#1728](https://github.com/NousResearch/hermes-agent/pull/1728))
|
||||
- Fix: `hermes update` prompt before resetting working tree on stash conflicts ([#2390](https://github.com/NousResearch/hermes-agent/pull/2390))
|
||||
- Fix: add zprofile fallback and create zshrc on fresh macOS installs ([#2320](https://github.com/NousResearch/hermes-agent/pull/2320))
|
||||
- Fix: use git pull --rebase in update/install to avoid divergent branch error ([#2274](https://github.com/NousResearch/hermes-agent/pull/2274))
|
||||
- Fix: disabled toolsets re-enable themselves after `hermes tools` ([#2268](https://github.com/NousResearch/hermes-agent/pull/2268))
|
||||
- Fix: platform default toolsets silently override tool deselection ([#2624](https://github.com/NousResearch/hermes-agent/pull/2624))
|
||||
- Fix: honor bare YAML `approvals.mode: off` ([#2620](https://github.com/NousResearch/hermes-agent/pull/2620))
|
||||
- Fix: remove `ANTHROPIC_BASE_URL` env var to avoid collisions ([#1675](https://github.com/NousResearch/hermes-agent/pull/1675))
|
||||
- Fix: don't ask IMAP password if already in keyring or env ([#2212](https://github.com/NousResearch/hermes-agent/pull/2212))
|
||||
- Fix: prevent `/model` crash when provider list is empty ([#2209](https://github.com/NousResearch/hermes-agent/pull/2209))
|
||||
- Fix: OpenCode Zen/Go show OpenRouter models instead of their own ([#2277](https://github.com/NousResearch/hermes-agent/pull/2277))
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Tool System
|
||||
|
||||
### Browser Tools
|
||||
- **Native Hermes browser tools** — navigate, click, type, screenshot, scrape via Browserbase ([#2270](https://github.com/NousResearch/hermes-agent/pull/2270))
|
||||
- Fix: race condition in session creation orphans cloud sessions ([#1721](https://github.com/NousResearch/hermes-agent/pull/1721))
|
||||
- Fix: browser handlers TypeError on unexpected LLM params ([#1735](https://github.com/NousResearch/hermes-agent/pull/1735))
|
||||
- Fix: add `/browser` to COMMAND_REGISTRY for help + autocomplete ([#1814](https://github.com/NousResearch/hermes-agent/pull/1814))
|
||||
|
||||
### MCP (Model Context Protocol)
|
||||
- **MCP server management CLI** + OAuth 2.1 PKCE auth ([#2465](https://github.com/NousResearch/hermes-agent/pull/2465))
|
||||
- **Interactive MCP tool configuration** in `hermes tools` ([#1694](https://github.com/NousResearch/hermes-agent/pull/1694))
|
||||
- **Expose MCP servers as standalone toolsets** ([#1907](https://github.com/NousResearch/hermes-agent/pull/1907))
|
||||
- **Optional FastMCP skill** ([#2113](https://github.com/NousResearch/hermes-agent/pull/2113))
|
||||
- Fix: MCP-OAuth port mismatch, path traversal, and shared handler state ([#2552](https://github.com/NousResearch/hermes-agent/pull/2552))
|
||||
- Fix: preserve MCP tool registrations across session resets ([#2124](https://github.com/NousResearch/hermes-agent/pull/2124))
|
||||
- Fix: concurrent file access crash + duplicate MCP registration ([#2154](https://github.com/NousResearch/hermes-agent/pull/2154))
|
||||
- Fix: normalise MCP schemas + expand session list columns ([#2102](https://github.com/NousResearch/hermes-agent/pull/2102))
|
||||
- Fix: `tool_choice` `mcp_` prefix handling ([#1775](https://github.com/NousResearch/hermes-agent/pull/1775))
|
||||
|
||||
### Web Tools
|
||||
- **Configurable web backend** — Firecrawl/BeautifulSoup/Playwright ([#2256](https://github.com/NousResearch/hermes-agent/pull/2256))
|
||||
- **Parallel** as alternative web search/extract backend ([#1696](https://github.com/NousResearch/hermes-agent/pull/1696))
|
||||
- **Tavily** as web search/extract/crawl backend ([#1731](https://github.com/NousResearch/hermes-agent/pull/1731))
|
||||
- Fix: whitespace-only env vars bypass web backend detection ([#2341](https://github.com/NousResearch/hermes-agent/pull/2341))
|
||||
|
||||
### Other Tools
|
||||
- **Vision analysis tool** for image understanding with configurable timeout ([#2182](https://github.com/NousResearch/hermes-agent/pull/2182), [#2480](https://github.com/NousResearch/hermes-agent/pull/2480))
|
||||
- **Code execution tool** for containerized Python/Node/Bash execution ([#2299](https://github.com/NousResearch/hermes-agent/pull/2299))
|
||||
- **TTS tool** using OpenAI API with `base_url` support ([#2118](https://github.com/NousResearch/hermes-agent/pull/2118), [#2064](https://github.com/NousResearch/hermes-agent/pull/2064) by @hanai)
|
||||
- **STT (speech-to-text) tool** using Whisper API ([#2072](https://github.com/NousResearch/hermes-agent/pull/2072))
|
||||
- **IMAP email** reading and sending tools ([#2173](https://github.com/NousResearch/hermes-agent/pull/2173))
|
||||
- **RL training data generation tool** ([#2225](https://github.com/NousResearch/hermes-agent/pull/2225))
|
||||
- **Route-aware pricing estimates** ([#1695](https://github.com/NousResearch/hermes-agent/pull/1695))
|
||||
- Fix: chunk long messages in `send_message_tool` before platform dispatch ([#1646](https://github.com/NousResearch/hermes-agent/pull/1646))
|
||||
- Fix: make concurrent tool batching path-aware for file mutations ([#1914](https://github.com/NousResearch/hermes-agent/pull/1914))
|
||||
- Fix: tool result truncation on large outputs ([#2088](https://github.com/NousResearch/hermes-agent/pull/2088))
|
||||
- Fix: concurrent file writes safely with atomic operations ([#2086](https://github.com/NousResearch/hermes-agent/pull/2086))
|
||||
- Fix: improve fuzzy matching accuracy for file search + position calculation refactor ([#2096](https://github.com/NousResearch/hermes-agent/pull/2096), [#1681](https://github.com/NousResearch/hermes-agent/pull/1681))
|
||||
- Fix: `search_files` wrong line numbers for multi-line matches ([#2069](https://github.com/NousResearch/hermes-agent/pull/2069))
|
||||
- Fix: include pagination args in repeated search key ([#1824](https://github.com/NousResearch/hermes-agent/pull/1824) by @cutepawss)
|
||||
- Fix: strip ANSI escape codes from write_file and patch content ([#2532](https://github.com/NousResearch/hermes-agent/pull/2532))
|
||||
- Fix: expand tilde (~) in vision_analyze local file paths ([#2585](https://github.com/NousResearch/hermes-agent/pull/2585))
|
||||
- Fix: resource leak and double socket close in `code_execution_tool` ([#2381](https://github.com/NousResearch/hermes-agent/pull/2381))
|
||||
- Fix: resolve vision analysis race condition and path handling ([#2191](https://github.com/NousResearch/hermes-agent/pull/2191))
|
||||
- Fix: DM vision — handle multiple images and base64 fallback ([#2211](https://github.com/NousResearch/hermes-agent/pull/2211))
|
||||
- Fix: `model_supports_images` for custom `base_url` providers returns wrong value ([#2278](https://github.com/NousResearch/hermes-agent/pull/2278))
|
||||
- Fix: add missing 'messaging' toolset — couldn't enable/disable `send_message` ([#1718](https://github.com/NousResearch/hermes-agent/pull/1718))
|
||||
- Fix: prevent unavailable tool names from leaking into model schemas ([#2072](https://github.com/NousResearch/hermes-agent/pull/2072))
|
||||
- Fix: disabled toolsets re-enable themselves after `hermes tools` ([#2268](https://github.com/NousResearch/hermes-agent/pull/2268))
|
||||
- Fix: pass visited set by reference to prevent diamond dependency duplication ([#2311](https://github.com/NousResearch/hermes-agent/pull/2311))
|
||||
- Fix: Daytona sandbox lookup migrated from `find_one` to `get/list` ([#2063](https://github.com/NousResearch/hermes-agent/pull/2063) by @rovle)
|
||||
|
||||
---
|
||||
|
||||
## 🧩 Skills Ecosystem
|
||||
|
||||
### Skills System
|
||||
- **Skills Hub** — discover, install, and manage skills from curated taps ([#2235](https://github.com/NousResearch/hermes-agent/pull/2235))
|
||||
- **Agent-created persistent skills** with caution-level findings allowed, dangerous skills ask instead of block ([#2116](https://github.com/NousResearch/hermes-agent/pull/2116), [#1840](https://github.com/NousResearch/hermes-agent/pull/1840), [#2446](https://github.com/NousResearch/hermes-agent/pull/2446))
|
||||
- **`--yes` flag** to bypass confirmation in `/skills install` and uninstall ([#1647](https://github.com/NousResearch/hermes-agent/pull/1647))
|
||||
- **Disabled skills respected** across banner, system prompt, and slash commands ([#1897](https://github.com/NousResearch/hermes-agent/pull/1897))
|
||||
- Fix: skills custom_tools import crash + sandbox file_tools integration ([#2239](https://github.com/NousResearch/hermes-agent/pull/2239))
|
||||
- Fix: agent-created skills with pip requirements crash on install ([#2145](https://github.com/NousResearch/hermes-agent/pull/2145))
|
||||
- Fix: race condition in `Skills.__init__` when `hub.yaml` missing ([#2242](https://github.com/NousResearch/hermes-agent/pull/2242))
|
||||
- Fix: validate skill metadata before install and block duplicates ([#2241](https://github.com/NousResearch/hermes-agent/pull/2241))
|
||||
- Fix: skills hub inspect/resolve — 4 bugs in inspect, redirects, discovery, tap list ([#2447](https://github.com/NousResearch/hermes-agent/pull/2447))
|
||||
- Fix: agent-created skills keep working after session reset ([#2121](https://github.com/NousResearch/hermes-agent/pull/2121))
|
||||
|
||||
### New Skills
|
||||
- **OCR-and-documents** — PDF/DOCX/XLS/PPTX/image OCR with optional GPU ([#2236](https://github.com/NousResearch/hermes-agent/pull/2236))
|
||||
- **Huggingface-hub** bundled skill ([#1921](https://github.com/NousResearch/hermes-agent/pull/1921))
|
||||
- **Sherlock OSINT** username search skill ([#1671](https://github.com/NousResearch/hermes-agent/pull/1671))
|
||||
- **Inference.sh** skill (terminal-based) ([#1686](https://github.com/NousResearch/hermes-agent/pull/1686))
|
||||
- **Meme-generation** — real image generator with Pillow ([#2344](https://github.com/NousResearch/hermes-agent/pull/2344))
|
||||
- **Bioinformatics** gateway skill — index to 400+ bio skills ([#2387](https://github.com/NousResearch/hermes-agent/pull/2387))
|
||||
- **Base blockchain** optional skill ([#1643](https://github.com/NousResearch/hermes-agent/pull/1643))
|
||||
- **3D-model-viewer** optional skill ([#2226](https://github.com/NousResearch/hermes-agent/pull/2226))
|
||||
- **Hermes-agent-setup** skill ([#1905](https://github.com/NousResearch/hermes-agent/pull/1905))
|
||||
|
||||
---
|
||||
|
||||
## 🔒 Security & Reliability
|
||||
|
||||
### Security Hardening
|
||||
- **SSRF protection** for vision_tools and web_tools (hardened) ([#2679](https://github.com/NousResearch/hermes-agent/pull/2679))
|
||||
- **Shell injection prevention** in `_expand_path` via `~user` path suffix ([#2685](https://github.com/NousResearch/hermes-agent/pull/2685))
|
||||
- **Block untrusted browser-origin** API server access ([#2451](https://github.com/NousResearch/hermes-agent/pull/2451))
|
||||
- **Block sandbox backend creds** from subprocess env ([#1658](https://github.com/NousResearch/hermes-agent/pull/1658))
|
||||
- **Block @ references** from reading secrets outside workspace ([#2601](https://github.com/NousResearch/hermes-agent/pull/2601) by @Gutslabs)
|
||||
- **Require opt-in** for project plugin discovery ([#2215](https://github.com/NousResearch/hermes-agent/pull/2215))
|
||||
- **Malicious code pattern pre-exec scanner** for terminal_tool ([#2245](https://github.com/NousResearch/hermes-agent/pull/2245))
|
||||
- **Harden terminal safety** and sandbox file writes ([#1653](https://github.com/NousResearch/hermes-agent/pull/1653))
|
||||
- **PKCE verifier leak** fix, OAuth refresh Content-Type fix ([#1775](https://github.com/NousResearch/hermes-agent/pull/1775))
|
||||
- **Eliminate SQL string formatting** in `execute()` calls ([#2061](https://github.com/NousResearch/hermes-agent/pull/2061) by @dusterbloom)
|
||||
- **Harden jobs API** — input limits, field whitelist, startup check ([#2456](https://github.com/NousResearch/hermes-agent/pull/2456))
|
||||
- Fix: OAuth flag stale after refresh/fallback ([#1890](https://github.com/NousResearch/hermes-agent/pull/1890))
|
||||
- Fix: auxiliary client skips expired Codex JWT ([#2397](https://github.com/NousResearch/hermes-agent/pull/2397))
|
||||
|
||||
### Reliability
|
||||
- **Concurrent tool safety** — path-aware file mutation batching, thread locks on SessionDB methods, file locking for memory writes ([#1914](https://github.com/NousResearch/hermes-agent/pull/1914), [#1704](https://github.com/NousResearch/hermes-agent/pull/1704), [#1726](https://github.com/NousResearch/hermes-agent/pull/1726))
|
||||
- **Error recovery** — handle OpenRouter errors gracefully, guard print() calls against OSError ([#2112](https://github.com/NousResearch/hermes-agent/pull/2112), [#1668](https://github.com/NousResearch/hermes-agent/pull/1668))
|
||||
- **Redacting formatter** — safely handle non-string inputs, NameError fix when verbose_logging=True ([#2392](https://github.com/NousResearch/hermes-agent/pull/2392), [#1700](https://github.com/NousResearch/hermes-agent/pull/1700))
|
||||
- **ACP** — preserve session provider when switching models, persist sessions to disk, preserve leading whitespace in streaming chunks ([#2380](https://github.com/NousResearch/hermes-agent/pull/2380), [#2071](https://github.com/NousResearch/hermes-agent/pull/2071), [#2192](https://github.com/NousResearch/hermes-agent/pull/2192))
|
||||
- **API server** — persist ResponseStore to SQLite across restarts ([#2472](https://github.com/NousResearch/hermes-agent/pull/2472))
|
||||
- Fix: `fetch_nous_models` called with positional args — always TypeError ([#1699](https://github.com/NousResearch/hermes-agent/pull/1699))
|
||||
- Fix: `make_is_write_denied` robust to Path objects ([#1678](https://github.com/NousResearch/hermes-agent/pull/1678))
|
||||
- Fix: resolve merge conflict markers in cli.py breaking hermes startup ([#2347](https://github.com/NousResearch/hermes-agent/pull/2347))
|
||||
- Fix: `minisweagent_path.py` missing from wheel ([#2098](https://github.com/NousResearch/hermes-agent/pull/2098) by @JiwaniZakir)
|
||||
|
||||
### Cron System
|
||||
- **Cron job scheduling** for gateway ([#2140](https://github.com/NousResearch/hermes-agent/pull/2140))
|
||||
- **`[SILENT]` response** — cron agents can suppress delivery ([#1833](https://github.com/NousResearch/hermes-agent/pull/1833))
|
||||
- **Scale missed-job grace window** with schedule frequency ([#2449](https://github.com/NousResearch/hermes-agent/pull/2449))
|
||||
- **Recover recent one-shot jobs** ([#1918](https://github.com/NousResearch/hermes-agent/pull/1918))
|
||||
- Fix: normalize `repeat<=0` to None — cron jobs deleted after first run when LLM passes -1 ([#2612](https://github.com/NousResearch/hermes-agent/pull/2612) by @Mibayy)
|
||||
- Fix: Matrix added to scheduler delivery platform_map ([#2167](https://github.com/NousResearch/hermes-agent/pull/2167) by @buntingszn)
|
||||
- Fix: naive ISO timestamps stored without timezone — jobs fire at wrong time ([#1729](https://github.com/NousResearch/hermes-agent/pull/1729))
|
||||
- Fix: `get_due_jobs` reads `jobs.json` twice — race condition ([#1716](https://github.com/NousResearch/hermes-agent/pull/1716))
|
||||
- Fix: silent jobs return empty response for delivery skip ([#2442](https://github.com/NousResearch/hermes-agent/pull/2442))
|
||||
- Fix: stop injecting cron outputs into gateway session history ([#2313](https://github.com/NousResearch/hermes-agent/pull/2313))
|
||||
- Fix: close abandoned coroutine when `asyncio.run()` raises RuntimeError ([#2317](https://github.com/NousResearch/hermes-agent/pull/2317))
|
||||
|
||||
---
|
||||
|
||||
## 🐛 Notable Bug Fixes
|
||||
|
||||
- Fix: show full command in dangerous command approval prompt ([#1649](https://github.com/NousResearch/hermes-agent/pull/1649))
|
||||
- Fix: Telegram streaming message length overflow ([#1783](https://github.com/NousResearch/hermes-agent/pull/1783))
|
||||
- Fix: prevent `/model` crash when provider list is empty ([#2209](https://github.com/NousResearch/hermes-agent/pull/2209))
|
||||
- Fix: batch of 5 small contributor fixes — PortAudio, SafeWriter, IMAP, thread lock, prefill ([#2466](https://github.com/NousResearch/hermes-agent/pull/2466))
|
||||
- Fix: `dingtalk-stream` added to optional dependencies ([#2452](https://github.com/NousResearch/hermes-agent/pull/2452))
|
||||
- Fix: remove hardcoded `gemini-3-flash-preview` as default summary model ([#2464](https://github.com/NousResearch/hermes-agent/pull/2464))
|
||||
- Fix: remove post-compression file-read history injection ([#2226](https://github.com/NousResearch/hermes-agent/pull/2226))
|
||||
- Fix: truncated `AUXILIARY_WEB_EXTRACT_API_KEY` env var name ([#2309](https://github.com/NousResearch/hermes-agent/pull/2309))
|
||||
- Fix: update validator does not stop ([#2204](https://github.com/NousResearch/hermes-agent/pull/2204), [#2067](https://github.com/NousResearch/hermes-agent/pull/2067))
|
||||
- Fix: log disk warning check failures at debug level ([#2394](https://github.com/NousResearch/hermes-agent/pull/2394))
|
||||
- Fix: quiet mode with `--resume` now passes conversation_history ([#2357](https://github.com/NousResearch/hermes-agent/pull/2357))
|
||||
- Fix: unify resume logic in batch mode for consistent `--resume` behavior ([#2331](https://github.com/NousResearch/hermes-agent/pull/2331))
|
||||
- Fix: prevent unavailable tool names from leaking into model schemas ([#2072](https://github.com/NousResearch/hermes-agent/pull/2072))
|
||||
- Fix: remove `_is_special_key` hack and fix `/skills` path completion ([#2271](https://github.com/NousResearch/hermes-agent/pull/2271))
|
||||
- Fix: use home-relative state paths if XDG dirs don't exist ([#2325](https://github.com/NousResearch/hermes-agent/pull/2325))
|
||||
- Fix: inject model identity for Alibaba Coding Plan ([#2314](https://github.com/NousResearch/hermes-agent/pull/2314))
|
||||
- Fix: OpenClaw migration warns when API keys are skipped ([#1655](https://github.com/NousResearch/hermes-agent/pull/1655))
|
||||
- Fix: email `send_typing` metadata + ☤ Hermes staff symbol ([#1665](https://github.com/NousResearch/hermes-agent/pull/1665))
|
||||
- Fix: replace production `print()` calls with logger in rl_training_tool ([#2462](https://github.com/NousResearch/hermes-agent/pull/2462))
|
||||
- Fix: restore opencode-go provider config corrupted by secret redaction ([#2393](https://github.com/NousResearch/hermes-agent/pull/2393) by @0xbyt4)
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
- Resolve all consistently failing tests ([#2488](https://github.com/NousResearch/hermes-agent/pull/2488))
|
||||
- Replace `FakePath` with `monkeypatch` for Python 3.12 compat ([#2444](https://github.com/NousResearch/hermes-agent/pull/2444))
|
||||
- Align Hermes setup and full-suite expectations ([#1710](https://github.com/NousResearch/hermes-agent/pull/1710))
|
||||
- Add tests for API server jobs API hardening ([#2456](https://github.com/NousResearch/hermes-agent/pull/2456))
|
||||
|
||||
---
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- Comprehensive documentation update for recent features ([#1693](https://github.com/NousResearch/hermes-agent/pull/1693), [#2183](https://github.com/NousResearch/hermes-agent/pull/2183))
|
||||
- Alibaba Cloud and DingTalk setup guide ([#1687](https://github.com/NousResearch/hermes-agent/pull/1687), [#1692](https://github.com/NousResearch/hermes-agent/pull/1692))
|
||||
- SOUL.md as primary agent identity documentation ([#1927](https://github.com/NousResearch/hermes-agent/pull/1927))
|
||||
- Detailed skills documentation ([#2244](https://github.com/NousResearch/hermes-agent/pull/2244))
|
||||
- Honcho self-hosted / Docker configuration section ([#2475](https://github.com/NousResearch/hermes-agent/pull/2475))
|
||||
- Context length detection references in FAQ and quickstart ([#2179](https://github.com/NousResearch/hermes-agent/pull/2179))
|
||||
- Fix documentation inconsistencies across reference and user guides ([#1995](https://github.com/NousResearch/hermes-agent/pull/1995))
|
||||
- Fix MCP install commands — use uv, not bare pip ([#1909](https://github.com/NousResearch/hermes-agent/pull/1909))
|
||||
- Fix MDX build error in api-server.md ([#1787](https://github.com/NousResearch/hermes-agent/pull/1787))
|
||||
- Replace ASCII diagrams with Mermaid/lists ([#2402](https://github.com/NousResearch/hermes-agent/pull/2402))
|
||||
- Add missing gateway commands and correct examples ([#2329](https://github.com/NousResearch/hermes-agent/pull/2329))
|
||||
- Clarify self-hosted Firecrawl setup ([#1669](https://github.com/NousResearch/hermes-agent/pull/1669))
|
||||
- NeuTTS provider documentation ([#1903](https://github.com/NousResearch/hermes-agent/pull/1903))
|
||||
- Gemini OAuth provider implementation plan ([#2467](https://github.com/NousResearch/hermes-agent/pull/2467))
|
||||
- Discord Server Members Intent marked as required ([#2330](https://github.com/NousResearch/hermes-agent/pull/2330))
|
||||
- Align venv path to match installer (venv/ not .venv/) ([#2114](https://github.com/NousResearch/hermes-agent/pull/2114))
|
||||
- New skills added to hub index ([#2281](https://github.com/NousResearch/hermes-agent/pull/2281))
|
||||
- OCR-and-documents skill — split, merge, search examples ([#2461](https://github.com/NousResearch/hermes-agent/pull/2461))
|
||||
|
||||
---
|
||||
|
||||
## 👥 Contributors
|
||||
|
||||
### Core
|
||||
- **@teknium1** (Teknium) — 280 PRs
|
||||
|
||||
### Community Contributors
|
||||
- **@mchzimm** (to_the_max) — GitHub Copilot provider integration across Hermes ([#1879](https://github.com/NousResearch/hermes-agent/pull/1879))
|
||||
- **@jquesnelle** (Jeffrey Quesnelle) — Per-thread persistent event loops in worker threads ([#2214](https://github.com/NousResearch/hermes-agent/pull/2214))
|
||||
- **@llbn** (lbn) — Telegram MarkdownV2 support: strikethrough, spoiler, blockquotes, and escape fixes ([#2199](https://github.com/NousResearch/hermes-agent/pull/2199), [#2200](https://github.com/NousResearch/hermes-agent/pull/2200))
|
||||
- **@dusterbloom** — SQL injection prevention + local server context window querying ([#2061](https://github.com/NousResearch/hermes-agent/pull/2061), [#2091](https://github.com/NousResearch/hermes-agent/pull/2091))
|
||||
- **@0xbyt4** — Anthropic tool_calls None guard + OpenCode-Go provider config fix ([#2209](https://github.com/NousResearch/hermes-agent/pull/2209), [#2393](https://github.com/NousResearch/hermes-agent/pull/2393))
|
||||
- **@sai-samarth** (Saisamarth) — WhatsApp send_message routing + systemd node path fix ([#1769](https://github.com/NousResearch/hermes-agent/pull/1769), [#1767](https://github.com/NousResearch/hermes-agent/pull/1767))
|
||||
- **@Gutslabs** (Guts) — Block @ references from reading secrets outside workspace ([#2601](https://github.com/NousResearch/hermes-agent/pull/2601))
|
||||
- **@Mibayy** (Mibay) — Cron job repeat normalization fix ([#2612](https://github.com/NousResearch/hermes-agent/pull/2612))
|
||||
- **@ten-jampa** (Tenzin Jampa) — Gateway /title command session fix ([#2379](https://github.com/NousResearch/hermes-agent/pull/2379))
|
||||
- **@cutepawss** (lila) — File tools search pagination fix ([#1824](https://github.com/NousResearch/hermes-agent/pull/1824))
|
||||
- **@hanai** (Hanai) — OpenAI TTS base_url support ([#2064](https://github.com/NousResearch/hermes-agent/pull/2064))
|
||||
- **@rovle** (Lovre Pešut) — Daytona sandbox API migration ([#2063](https://github.com/NousResearch/hermes-agent/pull/2063))
|
||||
- **@buntingszn** (bunting szn) — Matrix cron delivery support ([#2167](https://github.com/NousResearch/hermes-agent/pull/2167))
|
||||
- **@InB4DevOps** — Token counter reset on new session ([#2101](https://github.com/NousResearch/hermes-agent/pull/2101))
|
||||
- **@JiwaniZakir** (Zakir Jiwani) — Missing file in wheel fix ([#2098](https://github.com/NousResearch/hermes-agent/pull/2098))
|
||||
- **@ygd58** (buray) — Delegate tool parent tool names fix ([#2083](https://github.com/NousResearch/hermes-agent/pull/2083))
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: [v2026.3.17...v2026.3.23](https://github.com/NousResearch/hermes-agent/compare/v2026.3.17...v2026.3.23)
|
||||
@@ -383,11 +383,11 @@ class HermesACPAgent(acp.Agent):
|
||||
|
||||
new_model = args.strip()
|
||||
target_provider = None
|
||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||
|
||||
# Auto-detect provider for the requested model
|
||||
try:
|
||||
from hermes_cli.models import parse_model_input, detect_provider_for_model
|
||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
@@ -401,9 +401,10 @@ class HermesACPAgent(acp.Agent):
|
||||
session_id=state.session_id,
|
||||
cwd=state.cwd,
|
||||
model=new_model,
|
||||
requested_provider=target_provider or current_provider,
|
||||
)
|
||||
self.session_manager.save_session(state.session_id)
|
||||
provider_label = target_provider or getattr(state.agent, "provider", "auto")
|
||||
provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider
|
||||
logger.info("Session %s: model switched to %s", state.session_id, new_model)
|
||||
return f"Model switched to: {new_model}\nProvider: {provider_label}"
|
||||
|
||||
@@ -475,10 +476,16 @@ class HermesACPAgent(acp.Agent):
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state:
|
||||
state.model = model_id
|
||||
current_provider = getattr(state.agent, "provider", None)
|
||||
current_base_url = getattr(state.agent, "base_url", None)
|
||||
current_api_mode = getattr(state.agent, "api_mode", None)
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=session_id,
|
||||
cwd=state.cwd,
|
||||
model=model_id,
|
||||
requested_provider=current_provider,
|
||||
base_url=current_base_url,
|
||||
api_mode=current_api_mode,
|
||||
)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||
|
||||
+36
-8
@@ -270,7 +270,17 @@ class SessionManager:
|
||||
|
||||
# Ensure model is a plain string (not a MagicMock or other proxy).
|
||||
model_str = str(state.model) if state.model else None
|
||||
cwd_json = json.dumps({"cwd": state.cwd})
|
||||
session_meta = {"cwd": state.cwd}
|
||||
provider = getattr(state.agent, "provider", None)
|
||||
base_url = getattr(state.agent, "base_url", None)
|
||||
api_mode = getattr(state.agent, "api_mode", None)
|
||||
if isinstance(provider, str) and provider.strip():
|
||||
session_meta["provider"] = provider.strip()
|
||||
if isinstance(base_url, str) and base_url.strip():
|
||||
session_meta["base_url"] = base_url.strip()
|
||||
if isinstance(api_mode, str) and api_mode.strip():
|
||||
session_meta["api_mode"] = api_mode.strip()
|
||||
cwd_json = json.dumps(session_meta)
|
||||
|
||||
try:
|
||||
# Ensure the session record exists.
|
||||
@@ -331,10 +341,18 @@ class SessionManager:
|
||||
|
||||
# Extract cwd from model_config.
|
||||
cwd = "."
|
||||
requested_provider = row.get("billing_provider")
|
||||
restored_base_url = row.get("billing_base_url")
|
||||
restored_api_mode = None
|
||||
mc = row.get("model_config")
|
||||
if mc:
|
||||
try:
|
||||
cwd = json.loads(mc).get("cwd", ".")
|
||||
meta = json.loads(mc)
|
||||
if isinstance(meta, dict):
|
||||
cwd = meta.get("cwd", ".")
|
||||
requested_provider = meta.get("provider") or requested_provider
|
||||
restored_base_url = meta.get("base_url") or restored_base_url
|
||||
restored_api_mode = meta.get("api_mode") or restored_api_mode
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
@@ -348,7 +366,14 @@ class SessionManager:
|
||||
history = []
|
||||
|
||||
try:
|
||||
agent = self._make_agent(session_id=session_id, cwd=cwd, model=model)
|
||||
agent = self._make_agent(
|
||||
session_id=session_id,
|
||||
cwd=cwd,
|
||||
model=model,
|
||||
requested_provider=requested_provider,
|
||||
base_url=restored_base_url,
|
||||
api_mode=restored_api_mode,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to recreate agent for ACP session %s", session_id, exc_info=True)
|
||||
return None
|
||||
@@ -386,6 +411,9 @@ class SessionManager:
|
||||
session_id: str,
|
||||
cwd: str,
|
||||
model: str | None = None,
|
||||
requested_provider: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_mode: str | None = None,
|
||||
):
|
||||
if self._agent_factory is not None:
|
||||
return self._agent_factory()
|
||||
@@ -397,10 +425,10 @@ class SessionManager:
|
||||
config = load_config()
|
||||
model_cfg = config.get("model")
|
||||
default_model = "anthropic/claude-opus-4.6"
|
||||
requested_provider = None
|
||||
config_provider = None
|
||||
if isinstance(model_cfg, dict):
|
||||
default_model = str(model_cfg.get("default") or default_model)
|
||||
requested_provider = model_cfg.get("provider")
|
||||
config_provider = model_cfg.get("provider")
|
||||
elif isinstance(model_cfg, str) and model_cfg.strip():
|
||||
default_model = model_cfg.strip()
|
||||
|
||||
@@ -413,12 +441,12 @@ class SessionManager:
|
||||
}
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested=requested_provider)
|
||||
runtime = resolve_runtime_provider(requested=requested_provider or config_provider)
|
||||
kwargs.update(
|
||||
{
|
||||
"provider": runtime.get("provider"),
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
"base_url": runtime.get("base_url"),
|
||||
"api_mode": api_mode or runtime.get("api_mode"),
|
||||
"base_url": base_url or runtime.get("base_url"),
|
||||
"api_key": runtime.get("api_key"),
|
||||
"command": runtime.get("command"),
|
||||
"args": list(runtime.get("args") or []),
|
||||
|
||||
@@ -656,19 +656,21 @@ def refresh_hermes_oauth_token() -> Optional[str]:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def normalize_model_name(model: str) -> str:
|
||||
def normalize_model_name(model: str, preserve_dots: bool = False) -> str:
|
||||
"""Normalize a model name for the Anthropic API.
|
||||
|
||||
- Strips 'anthropic/' prefix (OpenRouter format, case-insensitive)
|
||||
- Converts dots to hyphens in version numbers (OpenRouter uses dots,
|
||||
Anthropic uses hyphens: claude-opus-4.6 → claude-opus-4-6)
|
||||
Anthropic uses hyphens: claude-opus-4.6 → claude-opus-4-6), unless
|
||||
preserve_dots is True (e.g. for Alibaba/DashScope: qwen3.5-plus).
|
||||
"""
|
||||
lower = model.lower()
|
||||
if lower.startswith("anthropic/"):
|
||||
model = model[len("anthropic/"):]
|
||||
# OpenRouter uses dots for version separators (claude-opus-4.6),
|
||||
# Anthropic uses hyphens (claude-opus-4-6). Convert dots to hyphens.
|
||||
model = model.replace(".", "-")
|
||||
if not preserve_dots:
|
||||
# OpenRouter uses dots for version separators (claude-opus-4.6),
|
||||
# Anthropic uses hyphens (claude-opus-4-6). Convert dots to hyphens.
|
||||
model = model.replace(".", "-")
|
||||
return model
|
||||
|
||||
|
||||
@@ -1006,16 +1008,20 @@ def build_anthropic_kwargs(
|
||||
reasoning_config: Optional[Dict[str, Any]],
|
||||
tool_choice: Optional[str] = None,
|
||||
is_oauth: bool = False,
|
||||
preserve_dots: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs for anthropic.messages.create().
|
||||
|
||||
When *is_oauth* is True, applies Claude Code compatibility transforms:
|
||||
system prompt prefix, tool name prefixing, and prompt sanitization.
|
||||
|
||||
When *preserve_dots* is True, model name dots are not converted to hyphens
|
||||
(for Alibaba/DashScope anthropic-compatible endpoints: qwen3.5-plus).
|
||||
"""
|
||||
system, anthropic_messages = convert_messages_to_anthropic(messages)
|
||||
anthropic_tools = convert_tools_to_anthropic(tools) if tools else []
|
||||
|
||||
model = normalize_model_name(model)
|
||||
model = normalize_model_name(model, preserve_dots=preserve_dots)
|
||||
effective_max_tokens = max_tokens or 16384
|
||||
|
||||
# ── OAuth: Claude Code identity ──────────────────────────────────
|
||||
|
||||
+105
-16
@@ -40,6 +40,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
@@ -325,9 +326,10 @@ class AsyncCodexAuxiliaryClient:
|
||||
class _AnthropicCompletionsAdapter:
|
||||
"""OpenAI-client-compatible adapter for Anthropic Messages API."""
|
||||
|
||||
def __init__(self, real_client: Any, model: str):
|
||||
def __init__(self, real_client: Any, model: str, is_oauth: bool = False):
|
||||
self._client = real_client
|
||||
self._model = model
|
||||
self._is_oauth = is_oauth
|
||||
|
||||
def create(self, **kwargs) -> Any:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs, normalize_anthropic_response
|
||||
@@ -356,6 +358,7 @@ class _AnthropicCompletionsAdapter:
|
||||
max_tokens=max_tokens,
|
||||
reasoning_config=None,
|
||||
tool_choice=normalized_tool_choice,
|
||||
is_oauth=self._is_oauth,
|
||||
)
|
||||
if temperature is not None:
|
||||
anthropic_kwargs["temperature"] = temperature
|
||||
@@ -394,9 +397,9 @@ class _AnthropicChatShim:
|
||||
class AnthropicAuxiliaryClient:
|
||||
"""OpenAI-client-compatible wrapper over a native Anthropic client."""
|
||||
|
||||
def __init__(self, real_client: Any, model: str, api_key: str, base_url: str):
|
||||
def __init__(self, real_client: Any, model: str, api_key: str, base_url: str, is_oauth: bool = False):
|
||||
self._real_client = real_client
|
||||
adapter = _AnthropicCompletionsAdapter(real_client, model)
|
||||
adapter = _AnthropicCompletionsAdapter(real_client, model, is_oauth=is_oauth)
|
||||
self.chat = _AnthropicChatShim(adapter)
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
@@ -463,15 +466,30 @@ def _nous_base_url() -> str:
|
||||
|
||||
|
||||
def _read_codex_access_token() -> Optional[str]:
|
||||
"""Read a valid Codex OAuth access token from Hermes auth store (~/.hermes/auth.json)."""
|
||||
"""Read a valid, non-expired Codex OAuth access token from Hermes auth store."""
|
||||
try:
|
||||
from hermes_cli.auth import _read_codex_tokens
|
||||
data = _read_codex_tokens()
|
||||
tokens = data.get("tokens", {})
|
||||
access_token = tokens.get("access_token")
|
||||
if isinstance(access_token, str) and access_token.strip():
|
||||
return access_token.strip()
|
||||
return None
|
||||
if not isinstance(access_token, str) or not access_token.strip():
|
||||
return None
|
||||
|
||||
# Check JWT expiry — expired tokens block the auto chain and
|
||||
# prevent fallback to working providers (e.g. Anthropic).
|
||||
try:
|
||||
import base64
|
||||
payload = access_token.split(".")[1]
|
||||
payload += "=" * (-len(payload) % 4)
|
||||
claims = json.loads(base64.urlsafe_b64decode(payload))
|
||||
exp = claims.get("exp", 0)
|
||||
if exp and time.time() > exp:
|
||||
logger.debug("Codex access token expired (exp=%s), skipping", exp)
|
||||
return None
|
||||
except Exception:
|
||||
pass # Non-JWT token or decode error — use as-is
|
||||
|
||||
return access_token.strip()
|
||||
except Exception as exc:
|
||||
logger.debug("Could not read Codex auth for auxiliary client: %s", exc)
|
||||
return None
|
||||
@@ -654,23 +672,29 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
||||
if not token:
|
||||
return None, None
|
||||
|
||||
# Allow base URL override from config.yaml model.base_url
|
||||
# Allow base URL override from config.yaml model.base_url, but only
|
||||
# when the configured provider is anthropic — otherwise a non-Anthropic
|
||||
# base_url (e.g. Codex endpoint) would leak into Anthropic requests.
|
||||
base_url = _ANTHROPIC_DEFAULT_BASE_URL
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
model_cfg = cfg.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
if cfg_base_url:
|
||||
base_url = cfg_base_url
|
||||
cfg_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
if cfg_provider == "anthropic":
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
if cfg_base_url:
|
||||
base_url = cfg_base_url
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from agent.anthropic_adapter import _is_oauth_token
|
||||
is_oauth = _is_oauth_token(token)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get("anthropic", "claude-haiku-4-5-20251001")
|
||||
logger.debug("Auxiliary client: Anthropic native (%s) at %s", model, base_url)
|
||||
logger.debug("Auxiliary client: Anthropic native (%s) at %s (oauth=%s)", model, base_url, is_oauth)
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url), model
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model
|
||||
|
||||
|
||||
def _resolve_forced_provider(forced: str) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
@@ -1180,6 +1204,53 @@ _client_cache: Dict[tuple, tuple] = {}
|
||||
_client_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def _force_close_async_httpx(client: Any) -> None:
|
||||
"""Mark the httpx AsyncClient inside an AsyncOpenAI client as closed.
|
||||
|
||||
This prevents ``AsyncHttpxClientWrapper.__del__`` from scheduling
|
||||
``aclose()`` on a (potentially closed) event loop, which causes
|
||||
``RuntimeError: Event loop is closed`` → prompt_toolkit's
|
||||
"Press ENTER to continue..." handler.
|
||||
|
||||
We intentionally do NOT run the full async close path — the
|
||||
connections will be dropped by the OS when the process exits.
|
||||
"""
|
||||
try:
|
||||
from httpx._client import ClientState
|
||||
inner = getattr(client, "_client", None)
|
||||
if inner is not None and not getattr(inner, "is_closed", True):
|
||||
inner._state = ClientState.CLOSED
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def shutdown_cached_clients() -> None:
|
||||
"""Close all cached clients (sync and async) to prevent event-loop errors.
|
||||
|
||||
Call this during CLI shutdown, *before* the event loop is closed, to
|
||||
avoid ``AsyncHttpxClientWrapper.__del__`` raising on a dead loop.
|
||||
"""
|
||||
import inspect
|
||||
|
||||
with _client_cache_lock:
|
||||
for key, entry in list(_client_cache.items()):
|
||||
client = entry[0]
|
||||
if client is None:
|
||||
continue
|
||||
# Mark any async httpx transport as closed first (prevents __del__
|
||||
# from scheduling aclose() on a dead event loop).
|
||||
_force_close_async_httpx(client)
|
||||
# Sync clients: close the httpx connection pool cleanly.
|
||||
# Async clients: skip — we already neutered __del__ above.
|
||||
try:
|
||||
close_fn = getattr(client, "close", None)
|
||||
if close_fn and not inspect.iscoroutinefunction(close_fn):
|
||||
close_fn()
|
||||
except Exception:
|
||||
pass
|
||||
_client_cache.clear()
|
||||
|
||||
|
||||
def _get_cached_client(
|
||||
provider: str,
|
||||
model: str = None,
|
||||
@@ -1198,6 +1269,7 @@ def _get_cached_client(
|
||||
# "Event loop is closed" when httpx tries to clean up its
|
||||
# transport. Discard the stale client and create a fresh one.
|
||||
if cached_loop is not None and cached_loop.is_closed():
|
||||
_force_close_async_httpx(cached_client)
|
||||
del _client_cache[cache_key]
|
||||
else:
|
||||
return cached_client, model or cached_default
|
||||
@@ -1427,8 +1499,18 @@ def call_llm(
|
||||
api_key=resolved_api_key,
|
||||
)
|
||||
if client is None:
|
||||
# Fallback: try openrouter
|
||||
if resolved_provider != "openrouter" and not resolved_base_url:
|
||||
# When the user explicitly chose a non-OpenRouter provider but no
|
||||
# credentials were found, fail fast instead of silently routing
|
||||
# through OpenRouter (which causes confusing 404s).
|
||||
_explicit = (resolved_provider or "").strip().lower()
|
||||
if _explicit and _explicit not in ("auto", "openrouter", "custom"):
|
||||
raise RuntimeError(
|
||||
f"Provider '{_explicit}' is set in config.yaml but no API key "
|
||||
f"was found. Set the {_explicit.upper()}_API_KEY environment "
|
||||
f"variable, or switch to a different provider with `hermes model`."
|
||||
)
|
||||
# For auto/custom, fall back to OpenRouter
|
||||
if not resolved_base_url:
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
@@ -1510,7 +1592,14 @@ async def async_call_llm(
|
||||
api_key=resolved_api_key,
|
||||
)
|
||||
if client is None:
|
||||
if resolved_provider != "openrouter" and not resolved_base_url:
|
||||
_explicit = (resolved_provider or "").strip().lower()
|
||||
if _explicit and _explicit not in ("auto", "openrouter", "custom"):
|
||||
raise RuntimeError(
|
||||
f"Provider '{_explicit}' is set in config.yaml but no API key "
|
||||
f"was found. Set the {_explicit.upper()}_API_KEY environment "
|
||||
f"variable, or switch to a different provider with `hermes model`."
|
||||
)
|
||||
if not resolved_base_url:
|
||||
logger.warning("Provider %s unavailable, falling back to openrouter",
|
||||
resolved_provider)
|
||||
client, final_model = _get_cached_client(
|
||||
|
||||
@@ -93,6 +93,14 @@ class ContextCompressor:
|
||||
)
|
||||
self.threshold_tokens = int(self.context_length * threshold_percent)
|
||||
self.compression_count = 0
|
||||
|
||||
if not quiet_mode:
|
||||
logger.info(
|
||||
"Context compressor initialized: model=%s context_length=%d "
|
||||
"threshold=%d (%.0f%%) provider=%s base_url=%s",
|
||||
model, self.context_length, self.threshold_tokens,
|
||||
threshold_percent * 100, provider or "none", base_url or "none",
|
||||
)
|
||||
self._context_probed = False # True after a step-down from context error
|
||||
|
||||
self.last_prompt_tokens = 0
|
||||
|
||||
@@ -0,0 +1,485 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from agent.model_metadata import estimate_tokens_rough
|
||||
|
||||
REFERENCE_PATTERN = re.compile(
|
||||
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
|
||||
)
|
||||
TRAILING_PUNCTUATION = ",.;!?"
|
||||
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube")
|
||||
_SENSITIVE_HERMES_DIRS = (Path("skills") / ".hub",)
|
||||
_SENSITIVE_HOME_FILES = (
|
||||
Path(".ssh") / "authorized_keys",
|
||||
Path(".ssh") / "id_rsa",
|
||||
Path(".ssh") / "id_ed25519",
|
||||
Path(".ssh") / "config",
|
||||
Path(".bashrc"),
|
||||
Path(".zshrc"),
|
||||
Path(".profile"),
|
||||
Path(".bash_profile"),
|
||||
Path(".zprofile"),
|
||||
Path(".netrc"),
|
||||
Path(".pgpass"),
|
||||
Path(".npmrc"),
|
||||
Path(".pypirc"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ContextReference:
|
||||
raw: str
|
||||
kind: str
|
||||
target: str
|
||||
start: int
|
||||
end: int
|
||||
line_start: int | None = None
|
||||
line_end: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextReferenceResult:
|
||||
message: str
|
||||
original_message: str
|
||||
references: list[ContextReference] = field(default_factory=list)
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
injected_tokens: int = 0
|
||||
expanded: bool = False
|
||||
blocked: bool = False
|
||||
|
||||
|
||||
def parse_context_references(message: str) -> list[ContextReference]:
|
||||
refs: list[ContextReference] = []
|
||||
if not message:
|
||||
return refs
|
||||
|
||||
for match in REFERENCE_PATTERN.finditer(message):
|
||||
simple = match.group("simple")
|
||||
if simple:
|
||||
refs.append(
|
||||
ContextReference(
|
||||
raw=match.group(0),
|
||||
kind=simple,
|
||||
target="",
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
kind = match.group("kind")
|
||||
value = _strip_trailing_punctuation(match.group("value") or "")
|
||||
line_start = None
|
||||
line_end = None
|
||||
target = value
|
||||
|
||||
if kind == "file":
|
||||
range_match = re.match(r"^(?P<path>.+?):(?P<start>\d+)(?:-(?P<end>\d+))?$", value)
|
||||
if range_match:
|
||||
target = range_match.group("path")
|
||||
line_start = int(range_match.group("start"))
|
||||
line_end = int(range_match.group("end") or range_match.group("start"))
|
||||
|
||||
refs.append(
|
||||
ContextReference(
|
||||
raw=match.group(0),
|
||||
kind=kind,
|
||||
target=target,
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
line_start=line_start,
|
||||
line_end=line_end,
|
||||
)
|
||||
)
|
||||
|
||||
return refs
|
||||
|
||||
|
||||
def preprocess_context_references(
|
||||
message: str,
|
||||
*,
|
||||
cwd: str | Path,
|
||||
context_length: int,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: str | Path | None = None,
|
||||
) -> ContextReferenceResult:
|
||||
coro = preprocess_context_references_async(
|
||||
message,
|
||||
cwd=cwd,
|
||||
context_length=context_length,
|
||||
url_fetcher=url_fetcher,
|
||||
allowed_root=allowed_root,
|
||||
)
|
||||
# Safe for both CLI (no loop) and gateway (loop already running).
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
||||
return pool.submit(asyncio.run, coro).result()
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
async def preprocess_context_references_async(
|
||||
message: str,
|
||||
*,
|
||||
cwd: str | Path,
|
||||
context_length: int,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: str | Path | None = None,
|
||||
) -> ContextReferenceResult:
|
||||
refs = parse_context_references(message)
|
||||
if not refs:
|
||||
return ContextReferenceResult(message=message, original_message=message)
|
||||
|
||||
cwd_path = Path(cwd).expanduser().resolve()
|
||||
# Default to the current working directory so @ references cannot escape
|
||||
# the active workspace unless a caller explicitly widens the root.
|
||||
allowed_root_path = (
|
||||
Path(allowed_root).expanduser().resolve() if allowed_root is not None else cwd_path
|
||||
)
|
||||
warnings: list[str] = []
|
||||
blocks: list[str] = []
|
||||
injected_tokens = 0
|
||||
|
||||
for ref in refs:
|
||||
warning, block = await _expand_reference(
|
||||
ref,
|
||||
cwd_path,
|
||||
url_fetcher=url_fetcher,
|
||||
allowed_root=allowed_root_path,
|
||||
)
|
||||
if warning:
|
||||
warnings.append(warning)
|
||||
if block:
|
||||
blocks.append(block)
|
||||
injected_tokens += estimate_tokens_rough(block)
|
||||
|
||||
hard_limit = max(1, int(context_length * 0.50))
|
||||
soft_limit = max(1, int(context_length * 0.25))
|
||||
if injected_tokens > hard_limit:
|
||||
warnings.append(
|
||||
f"@ context injection refused: {injected_tokens} tokens exceeds the 50% hard limit ({hard_limit})."
|
||||
)
|
||||
return ContextReferenceResult(
|
||||
message=message,
|
||||
original_message=message,
|
||||
references=refs,
|
||||
warnings=warnings,
|
||||
injected_tokens=injected_tokens,
|
||||
expanded=False,
|
||||
blocked=True,
|
||||
)
|
||||
|
||||
if injected_tokens > soft_limit:
|
||||
warnings.append(
|
||||
f"@ context injection warning: {injected_tokens} tokens exceeds the 25% soft limit ({soft_limit})."
|
||||
)
|
||||
|
||||
stripped = _remove_reference_tokens(message, refs)
|
||||
final = stripped
|
||||
if warnings:
|
||||
final = f"{final}\n\n--- Context Warnings ---\n" + "\n".join(f"- {warning}" for warning in warnings)
|
||||
if blocks:
|
||||
final = f"{final}\n\n--- Attached Context ---\n\n" + "\n\n".join(blocks)
|
||||
|
||||
return ContextReferenceResult(
|
||||
message=final.strip(),
|
||||
original_message=message,
|
||||
references=refs,
|
||||
warnings=warnings,
|
||||
injected_tokens=injected_tokens,
|
||||
expanded=bool(blocks or warnings),
|
||||
blocked=False,
|
||||
)
|
||||
|
||||
|
||||
async def _expand_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
try:
|
||||
if ref.kind == "file":
|
||||
return _expand_file_reference(ref, cwd, allowed_root=allowed_root)
|
||||
if ref.kind == "folder":
|
||||
return _expand_folder_reference(ref, cwd, allowed_root=allowed_root)
|
||||
if ref.kind == "diff":
|
||||
return _expand_git_reference(ref, cwd, ["diff"], "git diff")
|
||||
if ref.kind == "staged":
|
||||
return _expand_git_reference(ref, cwd, ["diff", "--staged"], "git diff --staged")
|
||||
if ref.kind == "git":
|
||||
count = max(1, min(int(ref.target or "1"), 10))
|
||||
return _expand_git_reference(ref, cwd, ["log", f"-{count}", "-p"], f"git log -{count} -p")
|
||||
if ref.kind == "url":
|
||||
content = await _fetch_url_content(ref.target, url_fetcher=url_fetcher)
|
||||
if not content:
|
||||
return f"{ref.raw}: no content extracted", None
|
||||
return None, f"🌐 {ref.raw} ({estimate_tokens_rough(content)} tokens)\n{content}"
|
||||
except Exception as exc:
|
||||
return f"{ref.raw}: {exc}", None
|
||||
|
||||
return f"{ref.raw}: unsupported reference type", None
|
||||
|
||||
|
||||
def _expand_file_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
|
||||
_ensure_reference_path_allowed(path)
|
||||
if not path.exists():
|
||||
return f"{ref.raw}: file not found", None
|
||||
if not path.is_file():
|
||||
return f"{ref.raw}: path is not a file", None
|
||||
if _is_binary_file(path):
|
||||
return f"{ref.raw}: binary files are not supported", None
|
||||
|
||||
text = path.read_text(encoding="utf-8")
|
||||
if ref.line_start is not None:
|
||||
lines = text.splitlines()
|
||||
start_idx = max(ref.line_start - 1, 0)
|
||||
end_idx = min(ref.line_end or ref.line_start, len(lines))
|
||||
text = "\n".join(lines[start_idx:end_idx])
|
||||
|
||||
lang = _code_fence_language(path)
|
||||
label = ref.raw
|
||||
return None, f"📄 {label} ({estimate_tokens_rough(text)} tokens)\n```{lang}\n{text}\n```"
|
||||
|
||||
|
||||
def _expand_folder_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
*,
|
||||
allowed_root: Path | None = None,
|
||||
) -> tuple[str | None, str | None]:
|
||||
path = _resolve_path(cwd, ref.target, allowed_root=allowed_root)
|
||||
_ensure_reference_path_allowed(path)
|
||||
if not path.exists():
|
||||
return f"{ref.raw}: folder not found", None
|
||||
if not path.is_dir():
|
||||
return f"{ref.raw}: path is not a folder", None
|
||||
|
||||
listing = _build_folder_listing(path, cwd)
|
||||
return None, f"📁 {ref.raw} ({estimate_tokens_rough(listing)} tokens)\n{listing}"
|
||||
|
||||
|
||||
def _expand_git_reference(
|
||||
ref: ContextReference,
|
||||
cwd: Path,
|
||||
args: list[str],
|
||||
label: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
result = subprocess.run(
|
||||
["git", *args],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
stderr = (result.stderr or "").strip() or "git command failed"
|
||||
return f"{ref.raw}: {stderr}", None
|
||||
content = result.stdout.strip()
|
||||
if not content:
|
||||
content = "(no output)"
|
||||
return None, f"🧾 {label} ({estimate_tokens_rough(content)} tokens)\n```diff\n{content}\n```"
|
||||
|
||||
|
||||
async def _fetch_url_content(
|
||||
url: str,
|
||||
*,
|
||||
url_fetcher: Callable[[str], str | Awaitable[str]] | None = None,
|
||||
) -> str:
|
||||
fetcher = url_fetcher or _default_url_fetcher
|
||||
content = fetcher(url)
|
||||
if inspect.isawaitable(content):
|
||||
content = await content
|
||||
return str(content or "").strip()
|
||||
|
||||
|
||||
async def _default_url_fetcher(url: str) -> str:
|
||||
from tools.web_tools import web_extract_tool
|
||||
|
||||
raw = await web_extract_tool([url], format="markdown", use_llm_processing=True)
|
||||
payload = json.loads(raw)
|
||||
docs = payload.get("data", {}).get("documents", [])
|
||||
if not docs:
|
||||
return ""
|
||||
doc = docs[0]
|
||||
return str(doc.get("content") or doc.get("raw_content") or "").strip()
|
||||
|
||||
|
||||
def _resolve_path(cwd: Path, target: str, *, allowed_root: Path | None = None) -> Path:
|
||||
path = Path(os.path.expanduser(target))
|
||||
if not path.is_absolute():
|
||||
path = cwd / path
|
||||
resolved = path.resolve()
|
||||
if allowed_root is not None:
|
||||
try:
|
||||
resolved.relative_to(allowed_root)
|
||||
except ValueError as exc:
|
||||
raise ValueError("path is outside the allowed workspace") from exc
|
||||
return resolved
|
||||
|
||||
|
||||
def _ensure_reference_path_allowed(path: Path) -> None:
|
||||
home = Path(os.path.expanduser("~")).resolve()
|
||||
hermes_home = Path(
|
||||
os.getenv("HERMES_HOME", str(home / ".hermes"))
|
||||
).expanduser().resolve()
|
||||
|
||||
blocked_exact = {home / rel for rel in _SENSITIVE_HOME_FILES}
|
||||
blocked_exact.add(hermes_home / ".env")
|
||||
blocked_dirs = [home / rel for rel in _SENSITIVE_HOME_DIRS]
|
||||
blocked_dirs.extend(hermes_home / rel for rel in _SENSITIVE_HERMES_DIRS)
|
||||
|
||||
if path in blocked_exact:
|
||||
raise ValueError("path is a sensitive credential file and cannot be attached")
|
||||
|
||||
for blocked_dir in blocked_dirs:
|
||||
try:
|
||||
path.relative_to(blocked_dir)
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError("path is a sensitive credential or internal Hermes path and cannot be attached")
|
||||
|
||||
|
||||
def _strip_trailing_punctuation(value: str) -> str:
|
||||
stripped = value.rstrip(TRAILING_PUNCTUATION)
|
||||
while stripped.endswith((")", "]", "}")):
|
||||
closer = stripped[-1]
|
||||
opener = {")": "(", "]": "[", "}": "{"}[closer]
|
||||
if stripped.count(closer) > stripped.count(opener):
|
||||
stripped = stripped[:-1]
|
||||
continue
|
||||
break
|
||||
return stripped
|
||||
|
||||
|
||||
def _remove_reference_tokens(message: str, refs: list[ContextReference]) -> str:
|
||||
pieces: list[str] = []
|
||||
cursor = 0
|
||||
for ref in refs:
|
||||
pieces.append(message[cursor:ref.start])
|
||||
cursor = ref.end
|
||||
pieces.append(message[cursor:])
|
||||
text = "".join(pieces)
|
||||
text = re.sub(r"\s{2,}", " ", text)
|
||||
text = re.sub(r"\s+([,.;:!?])", r"\1", text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
def _is_binary_file(path: Path) -> bool:
|
||||
mime, _ = mimetypes.guess_type(path.name)
|
||||
if mime and not mime.startswith("text/") and not any(
|
||||
path.name.endswith(ext) for ext in (".py", ".md", ".txt", ".json", ".yaml", ".yml", ".toml", ".js", ".ts")
|
||||
):
|
||||
return True
|
||||
chunk = path.read_bytes()[:4096]
|
||||
return b"\x00" in chunk
|
||||
|
||||
|
||||
def _build_folder_listing(path: Path, cwd: Path, limit: int = 200) -> str:
|
||||
lines = [f"{path.relative_to(cwd)}/"]
|
||||
entries = _iter_visible_entries(path, cwd, limit=limit)
|
||||
for entry in entries:
|
||||
rel = entry.relative_to(cwd)
|
||||
indent = " " * max(len(rel.parts) - len(path.relative_to(cwd).parts) - 1, 0)
|
||||
if entry.is_dir():
|
||||
lines.append(f"{indent}- {entry.name}/")
|
||||
else:
|
||||
meta = _file_metadata(entry)
|
||||
lines.append(f"{indent}- {entry.name} ({meta})")
|
||||
if len(entries) >= limit:
|
||||
lines.append("- ...")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _iter_visible_entries(path: Path, cwd: Path, limit: int) -> list[Path]:
|
||||
rg_entries = _rg_files(path, cwd, limit=limit)
|
||||
if rg_entries is not None:
|
||||
output: list[Path] = []
|
||||
seen_dirs: set[Path] = set()
|
||||
for rel in rg_entries:
|
||||
full = cwd / rel
|
||||
for parent in full.parents:
|
||||
if parent == cwd or parent in seen_dirs or path not in {parent, *parent.parents}:
|
||||
continue
|
||||
seen_dirs.add(parent)
|
||||
output.append(parent)
|
||||
output.append(full)
|
||||
return sorted({p for p in output if p.exists()}, key=lambda p: (not p.is_dir(), str(p)))
|
||||
|
||||
output = []
|
||||
for root, dirs, files in os.walk(path):
|
||||
dirs[:] = sorted(d for d in dirs if not d.startswith(".") and d != "__pycache__")
|
||||
files = sorted(f for f in files if not f.startswith("."))
|
||||
root_path = Path(root)
|
||||
for d in dirs:
|
||||
output.append(root_path / d)
|
||||
if len(output) >= limit:
|
||||
return output
|
||||
for f in files:
|
||||
output.append(root_path / f)
|
||||
if len(output) >= limit:
|
||||
return output
|
||||
return output
|
||||
|
||||
|
||||
def _rg_files(path: Path, cwd: Path, limit: int) -> list[Path] | None:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["rg", "--files", str(path.relative_to(cwd))],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
files = [Path(line.strip()) for line in result.stdout.splitlines() if line.strip()]
|
||||
return files[:limit]
|
||||
|
||||
|
||||
def _file_metadata(path: Path) -> str:
|
||||
if _is_binary_file(path):
|
||||
return f"{path.stat().st_size} bytes"
|
||||
try:
|
||||
line_count = path.read_text(encoding="utf-8").count("\n") + 1
|
||||
except Exception:
|
||||
return f"{path.stat().st_size} bytes"
|
||||
return f"{line_count} lines"
|
||||
|
||||
|
||||
def _code_fence_language(path: Path) -> str:
|
||||
mapping = {
|
||||
".py": "python",
|
||||
".js": "javascript",
|
||||
".ts": "typescript",
|
||||
".tsx": "tsx",
|
||||
".jsx": "jsx",
|
||||
".json": "json",
|
||||
".md": "markdown",
|
||||
".sh": "bash",
|
||||
".yml": "yaml",
|
||||
".yaml": "yaml",
|
||||
".toml": "toml",
|
||||
}
|
||||
return mapping.get(path.suffix.lower(), "")
|
||||
+19
-7
@@ -164,6 +164,8 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"openrouter.ai": "openrouter",
|
||||
"inference-api.nousresearch.com": "nous",
|
||||
"api.deepseek.com": "deepseek",
|
||||
"api.githubcopilot.com": "copilot",
|
||||
"models.github.ai": "copilot",
|
||||
}
|
||||
|
||||
|
||||
@@ -260,9 +262,11 @@ def detect_local_server_type(base_url: str) -> Optional[str]:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
# llama.cpp exposes /props
|
||||
# llama.cpp exposes /v1/props (older builds used /props without the /v1 prefix)
|
||||
try:
|
||||
r = client.get(f"{server_url}/props")
|
||||
r = client.get(f"{server_url}/v1/props")
|
||||
if r.status_code != 200:
|
||||
r = client.get(f"{server_url}/props") # fallback for older builds
|
||||
if r.status_code == 200 and "default_generation_settings" in r.text:
|
||||
return "llamacpp"
|
||||
except Exception:
|
||||
@@ -455,8 +459,11 @@ def fetch_endpoint_model_metadata(
|
||||
)
|
||||
if is_llamacpp:
|
||||
try:
|
||||
props_url = candidate.rstrip("/").replace("/v1", "") + "/props"
|
||||
props_resp = requests.get(props_url, headers=headers, timeout=5)
|
||||
# Try /v1/props first (current llama.cpp); fall back to /props for older builds
|
||||
base = candidate.rstrip("/").replace("/v1", "")
|
||||
props_resp = requests.get(base + "/v1/props", headers=headers, timeout=5)
|
||||
if not props_resp.ok:
|
||||
props_resp = requests.get(base + "/props", headers=headers, timeout=5)
|
||||
if props_resp.ok:
|
||||
props = props_resp.json()
|
||||
gen_settings = props.get("default_generation_settings", {})
|
||||
@@ -783,8 +790,12 @@ def get_model_context_length(
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# 2. Active endpoint metadata for explicit custom routes
|
||||
if _is_custom_endpoint(base_url):
|
||||
# 2. Active endpoint metadata for truly custom/unknown endpoints.
|
||||
# Known providers (Copilot, OpenAI, Anthropic, etc.) skip this — their
|
||||
# /models endpoint may report a provider-imposed limit (e.g. Copilot
|
||||
# returns 128k) instead of the model's full context (400k). models.dev
|
||||
# has the correct per-provider values and is checked at step 5+.
|
||||
if _is_custom_endpoint(base_url) and not _is_known_provider_base_url(base_url):
|
||||
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
|
||||
matched = endpoint_metadata.get(model)
|
||||
if not matched:
|
||||
@@ -855,10 +866,11 @@ def get_model_context_length(
|
||||
# Only check `default_model in model` (is the key a substring of the input).
|
||||
# The reverse (`model in default_model`) causes shorter names like
|
||||
# "claude-sonnet-4" to incorrectly match "claude-sonnet-4-6" and return 1M.
|
||||
model_lower = model.lower()
|
||||
for default_model, length in sorted(
|
||||
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
||||
):
|
||||
if default_model in model:
|
||||
if default_model in model_lower:
|
||||
return length
|
||||
|
||||
# 9. Query local server as last resort
|
||||
|
||||
@@ -12,13 +12,14 @@ import copy
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
def _apply_cache_marker(msg: dict, cache_marker: dict, native_anthropic: bool = False) -> None:
|
||||
"""Add cache_control to a single message, handling all format variations."""
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "tool":
|
||||
msg["cache_control"] = cache_marker
|
||||
if native_anthropic:
|
||||
msg["cache_control"] = cache_marker
|
||||
return
|
||||
|
||||
if content is None or content == "":
|
||||
@@ -40,6 +41,7 @@ def _apply_cache_marker(msg: dict, cache_marker: dict) -> None:
|
||||
def apply_anthropic_cache_control(
|
||||
api_messages: List[Dict[str, Any]],
|
||||
cache_ttl: str = "5m",
|
||||
native_anthropic: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Apply system_and_3 caching strategy to messages for Anthropic models.
|
||||
|
||||
@@ -59,12 +61,12 @@ def apply_anthropic_cache_control(
|
||||
breakpoints_used = 0
|
||||
|
||||
if messages[0].get("role") == "system":
|
||||
_apply_cache_marker(messages[0], marker)
|
||||
_apply_cache_marker(messages[0], marker, native_anthropic=native_anthropic)
|
||||
breakpoints_used += 1
|
||||
|
||||
remaining = 4 - breakpoints_used
|
||||
non_sys = [i for i in range(len(messages)) if messages[i].get("role") != "system"]
|
||||
for idx in non_sys[-remaining:]:
|
||||
_apply_cache_marker(messages[idx], marker)
|
||||
_apply_cache_marker(messages[idx], marker, native_anthropic=native_anthropic)
|
||||
|
||||
return messages
|
||||
|
||||
@@ -100,6 +100,10 @@ def redact_sensitive_text(text: str) -> str:
|
||||
Safe to call on any string -- non-matching text passes through unchanged.
|
||||
Disabled when security.redact_secrets is false in config.yaml.
|
||||
"""
|
||||
if text is None:
|
||||
return None
|
||||
if not isinstance(text, str):
|
||||
text = str(text)
|
||||
if not text:
|
||||
return text
|
||||
if os.getenv("HERMES_REDACT_SECRETS", "").lower() in ("0", "false", "no", "off"):
|
||||
|
||||
@@ -165,10 +165,10 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"cwd": ".", # "." is resolved to os.getcwd() at runtime
|
||||
"timeout": 60,
|
||||
"lifetime_seconds": 300,
|
||||
"docker_image": "python:3.11",
|
||||
"docker_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"docker_forward_env": [],
|
||||
"singularity_image": "docker://python:3.11",
|
||||
"modal_image": "python:3.11",
|
||||
"singularity_image": "docker://nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"modal_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"docker_volumes": [], # host:container volume mounts for Docker backend
|
||||
"docker_mount_cwd_to_workspace": False, # explicit opt-in only; default off for sandbox isolation
|
||||
@@ -180,7 +180,7 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"compression": {
|
||||
"enabled": True, # Auto-compress when approaching context limit
|
||||
"threshold": 0.50, # Compress at 50% of model's context limit
|
||||
"summary_model": "google/gemini-3-flash-preview", # Fast/cheap model for summaries
|
||||
"summary_model": "", # Model for summaries (empty = use main model)
|
||||
},
|
||||
"smart_model_routing": {
|
||||
"enabled": False,
|
||||
@@ -301,7 +301,11 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
defaults["agent"]["max_turns"] = file_config["max_turns"]
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load cli-config.yaml: %s", e)
|
||||
|
||||
|
||||
# Expand ${ENV_VAR} references in config values before bridging to env vars.
|
||||
from hermes_cli.config import _expand_env_vars
|
||||
defaults = _expand_env_vars(defaults)
|
||||
|
||||
# Apply terminal config to environment variables (so terminal_tool picks them up)
|
||||
terminal_config = defaults.get("terminal", {})
|
||||
|
||||
@@ -448,7 +452,6 @@ from rich import box as rich_box
|
||||
from rich.console import Console
|
||||
from rich.markup import escape as _escape
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text as _RichText
|
||||
|
||||
import fire
|
||||
@@ -460,12 +463,12 @@ from model_tools import get_tool_definitions, get_toolset_for_tool
|
||||
# Extracted CLI modules (Phase 3)
|
||||
from hermes_cli.banner import (
|
||||
cprint as _cprint, _GOLD, _BOLD, _DIM, _RST,
|
||||
VERSION, RELEASE_DATE, HERMES_AGENT_LOGO, HERMES_CADUCEUS, COMPACT_BANNER,
|
||||
HERMES_AGENT_LOGO, HERMES_CADUCEUS, COMPACT_BANNER,
|
||||
build_welcome_banner,
|
||||
)
|
||||
from hermes_cli.commands import COMMANDS, SlashCommandCompleter, SlashCommandAutoSuggest
|
||||
from hermes_cli import callbacks as _callbacks
|
||||
from toolsets import get_all_toolsets, get_toolset_info, resolve_toolset, validate_toolset
|
||||
from toolsets import get_all_toolsets, get_toolset_info, validate_toolset
|
||||
|
||||
# Cron job system for scheduled tasks (execution is handled by the gateway)
|
||||
from cron import get_job
|
||||
@@ -499,6 +502,14 @@ def _run_cleanup():
|
||||
shutdown_mcp_servers()
|
||||
except Exception:
|
||||
pass
|
||||
# Close cached auxiliary LLM clients (sync + async) so that
|
||||
# AsyncHttpxClientWrapper.__del__ doesn't fire on a closed event loop
|
||||
# and trigger prompt_toolkit's "Press ENTER to continue..." handler.
|
||||
try:
|
||||
from agent.auxiliary_client import shutdown_cached_clients
|
||||
shutdown_cached_clients()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -884,7 +895,6 @@ def _build_compact_banner() -> str:
|
||||
|
||||
from agent.skill_commands import (
|
||||
scan_skill_commands,
|
||||
get_skill_commands,
|
||||
build_skill_invocation_message,
|
||||
build_plan_path,
|
||||
build_preloaded_skills_prompt,
|
||||
@@ -893,6 +903,15 @@ from agent.skill_commands import (
|
||||
_skill_commands = scan_skill_commands()
|
||||
|
||||
|
||||
def _get_plugin_cmd_handler_names() -> set:
|
||||
"""Return plugin command names (without slash prefix) for dispatch matching."""
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_manager
|
||||
return set(get_plugin_manager()._plugin_commands.keys())
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
def _parse_skills_argument(skills: str | list[str] | tuple[str, ...] | None) -> list[str]:
|
||||
"""Normalize a CLI skills flag into a deduplicated list of skill identifiers."""
|
||||
if not skills:
|
||||
@@ -1747,8 +1766,22 @@ class HermesCLI:
|
||||
resolved_acp_command = runtime.get("command")
|
||||
resolved_acp_args = list(runtime.get("args") or [])
|
||||
if not isinstance(api_key, str) or not api_key:
|
||||
self.console.print("[bold red]Provider resolver returned an empty API key.[/]")
|
||||
return False
|
||||
# Custom / local endpoints (llama.cpp, ollama, vLLM, etc.) often
|
||||
# don't require authentication. When a base_url IS configured but
|
||||
# no API key was found, use a placeholder so the OpenAI SDK
|
||||
# doesn't reject the request and local servers just ignore it.
|
||||
_source = runtime.get("source", "")
|
||||
_has_custom_base = isinstance(base_url, str) and base_url and "openrouter.ai" not in base_url
|
||||
if _has_custom_base:
|
||||
api_key = "no-key-required"
|
||||
logger.debug(
|
||||
"No API key for custom endpoint %s (source=%s), "
|
||||
"using placeholder — local servers typically ignore auth",
|
||||
base_url, _source,
|
||||
)
|
||||
else:
|
||||
self.console.print("[bold red]Provider resolver returned an empty API key.[/]")
|
||||
return False
|
||||
if not isinstance(base_url, str) or not base_url:
|
||||
self.console.print("[bold red]Provider resolver returned an empty base URL.[/]")
|
||||
return False
|
||||
@@ -1906,6 +1939,9 @@ class HermesCLI:
|
||||
tool_progress_callback=self._on_tool_progress,
|
||||
stream_delta_callback=self._stream_delta if self.streaming_enabled else None,
|
||||
)
|
||||
# Route agent status output through prompt_toolkit so ANSI escape
|
||||
# sequences aren't garbled by patch_stdout's StdoutProxy (#2262).
|
||||
self.agent._print_fn = _cprint
|
||||
self._active_agent_route_signature = (
|
||||
effective_model,
|
||||
runtime.get("provider"),
|
||||
@@ -1931,13 +1967,6 @@ class HermesCLI:
|
||||
def show_banner(self):
|
||||
"""Display the welcome banner in Claude Code style."""
|
||||
self.console.clear()
|
||||
if self.preloaded_skills and not self._startup_skills_line_shown:
|
||||
skills_label = ", ".join(self.preloaded_skills)
|
||||
self.console.print(
|
||||
f"[bold {_accent_hex()}]Activated skills:[/] {skills_label}"
|
||||
)
|
||||
self.console.print()
|
||||
self._startup_skills_line_shown = True
|
||||
|
||||
# Auto-compact for narrow terminals — the full banner with caduceus
|
||||
# + tool list needs ~80 columns minimum to render without wrapping.
|
||||
@@ -2316,10 +2345,9 @@ class HermesCLI:
|
||||
Inspired by OpenAI Codex's separation of interrupt (stop current turn)
|
||||
from /stop (clean up background processes). See openai/codex#14602.
|
||||
"""
|
||||
from tools.process_registry import get_registry
|
||||
from tools.process_registry import process_registry
|
||||
|
||||
registry = get_registry()
|
||||
processes = registry.list_processes()
|
||||
processes = process_registry.list_sessions()
|
||||
running = [p for p in processes if p.get("status") == "running"]
|
||||
|
||||
if not running:
|
||||
@@ -2327,7 +2355,7 @@ class HermesCLI:
|
||||
return
|
||||
|
||||
print(f" Stopping {len(running)} background process(es)...")
|
||||
killed = registry.kill_all()
|
||||
killed = process_registry.kill_all()
|
||||
print(f" ✅ Stopped {killed} process(es).")
|
||||
|
||||
def _handle_paste_command(self):
|
||||
@@ -3759,6 +3787,18 @@ class HermesCLI:
|
||||
self.console.print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]")
|
||||
else:
|
||||
self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]")
|
||||
# Check for plugin-registered slash commands
|
||||
elif base_cmd.lstrip("/") in _get_plugin_cmd_handler_names():
|
||||
from hermes_cli.plugins import get_plugin_command_handler
|
||||
plugin_handler = get_plugin_command_handler(base_cmd.lstrip("/"))
|
||||
if plugin_handler:
|
||||
user_args = cmd_original[len(base_cmd):].strip()
|
||||
try:
|
||||
result = plugin_handler(user_args)
|
||||
if result:
|
||||
_cprint(str(result))
|
||||
except Exception as e:
|
||||
_cprint(f"\033[1;31mPlugin command error: {e}{_RST}")
|
||||
# Check for skill slash commands (/gif-search, /axolotl, etc.)
|
||||
elif base_cmd in _skill_commands:
|
||||
user_instruction = cmd_original[len(base_cmd):].strip()
|
||||
@@ -4217,13 +4257,18 @@ class HermesCLI:
|
||||
elif not self.show_reasoning:
|
||||
self.agent.reasoning_callback = None
|
||||
|
||||
# Use raw ANSI codes via _cprint so the output is routed through
|
||||
# prompt_toolkit's renderer. self.console.print() with Rich markup
|
||||
# writes directly to stdout which patch_stdout's StdoutProxy mangles
|
||||
# into garbled sequences like '?[33mTool progress: NEW?[0m' (#2262).
|
||||
from hermes_cli.colors import Colors as _Colors
|
||||
labels = {
|
||||
"off": "[dim]Tool progress: OFF[/] — silent mode, just the final response.",
|
||||
"new": "[yellow]Tool progress: NEW[/] — show each new tool (skip repeats).",
|
||||
"all": "[green]Tool progress: ALL[/] — show every tool call.",
|
||||
"verbose": "[bold green]Tool progress: VERBOSE[/] — full args, results, think blocks, and debug logs.",
|
||||
"off": f"{_Colors.DIM}Tool progress: OFF{_Colors.RESET} — silent mode, just the final response.",
|
||||
"new": f"{_Colors.YELLOW}Tool progress: NEW{_Colors.RESET} — show each new tool (skip repeats).",
|
||||
"all": f"{_Colors.GREEN}Tool progress: ALL{_Colors.RESET} — show every tool call.",
|
||||
"verbose": f"{_Colors.BOLD}{_Colors.GREEN}Tool progress: VERBOSE{_Colors.RESET} — full args, results, think blocks, and debug logs.",
|
||||
}
|
||||
self.console.print(labels.get(self.tool_progress_mode, ""))
|
||||
_cprint(labels.get(self.tool_progress_mode, ""))
|
||||
|
||||
def _handle_reasoning_command(self, cmd: str):
|
||||
"""Handle /reasoning — manage effort level and display toggle.
|
||||
@@ -5352,6 +5397,28 @@ class HermesCLI:
|
||||
message if isinstance(message, str) else "", images
|
||||
)
|
||||
|
||||
# Expand @ context references (e.g. @file:main.py, @diff, @folder:src/)
|
||||
if isinstance(message, str) and "@" in message:
|
||||
try:
|
||||
from agent.context_references import preprocess_context_references
|
||||
from agent.model_metadata import get_model_context_length
|
||||
_ctx_len = get_model_context_length(
|
||||
self.model, base_url=self.base_url or "", api_key=self.api_key or "")
|
||||
_ctx_result = preprocess_context_references(
|
||||
message, cwd=os.getcwd(), context_length=_ctx_len)
|
||||
if _ctx_result.expanded or _ctx_result.blocked:
|
||||
if _ctx_result.references:
|
||||
_cprint(
|
||||
f" {_DIM}[@ context: {len(_ctx_result.references)} ref(s), "
|
||||
f"{_ctx_result.injected_tokens} tokens]{_RST}")
|
||||
for w in _ctx_result.warnings:
|
||||
_cprint(f" {_DIM}⚠ {w}{_RST}")
|
||||
if _ctx_result.blocked:
|
||||
return "\n".join(_ctx_result.warnings) or "Context injection refused."
|
||||
message = _ctx_result.message
|
||||
except Exception as e:
|
||||
logging.debug("@ context reference expansion failed: %s", e)
|
||||
|
||||
# Add user message to history
|
||||
self.conversation_history.append({"role": "user", "content": message})
|
||||
|
||||
@@ -5850,12 +5917,14 @@ class HermesCLI:
|
||||
"""Run the interactive CLI loop with persistent input at bottom."""
|
||||
self.show_banner()
|
||||
|
||||
# One-line Honcho session indicator (TTY-only, not captured by agent)
|
||||
# One-line Honcho session indicator (TTY-only, not captured by agent).
|
||||
# Only show when the user explicitly configured Honcho for Hermes
|
||||
# (not auto-enabled from a stray HONCHO_API_KEY env var).
|
||||
try:
|
||||
from honcho_integration.client import HonchoClientConfig
|
||||
from agent.display import honcho_session_line, write_tty
|
||||
hcfg = HonchoClientConfig.from_global_config()
|
||||
if hcfg.enabled and hcfg.api_key:
|
||||
if hcfg.enabled and hcfg.api_key and hcfg.explicitly_configured:
|
||||
sname = hcfg.resolve_session_name(session_id=self.session_id)
|
||||
if sname:
|
||||
write_tty(honcho_session_line(hcfg.workspace_id, sname) + "\n")
|
||||
@@ -5877,6 +5946,12 @@ class HermesCLI:
|
||||
_welcome_text = "Welcome to Hermes Agent! Type your message or /help for commands."
|
||||
_welcome_color = "#FFF8DC"
|
||||
self.console.print(f"[{_welcome_color}]{_welcome_text}[/]")
|
||||
if self.preloaded_skills and not self._startup_skills_line_shown:
|
||||
skills_label = ", ".join(self.preloaded_skills)
|
||||
self.console.print(
|
||||
f"[bold {_accent_hex()}]Activated skills:[/] {skills_label}"
|
||||
)
|
||||
self._startup_skills_line_shown = True
|
||||
self.console.print()
|
||||
|
||||
# State for async operation
|
||||
@@ -7267,7 +7342,10 @@ def main(
|
||||
route_label=turn_route["label"],
|
||||
):
|
||||
cli.agent.quiet_mode = True
|
||||
result = cli.agent.run_conversation(query)
|
||||
result = cli.agent.run_conversation(
|
||||
user_message=query,
|
||||
conversation_history=cli.conversation_history,
|
||||
)
|
||||
response = result.get("final_response", "") if isinstance(result, dict) else str(result)
|
||||
if response:
|
||||
print(response)
|
||||
|
||||
+43
-5
@@ -248,6 +248,38 @@ def _recoverable_oneshot_run_at(
|
||||
return None
|
||||
|
||||
|
||||
def _compute_grace_seconds(schedule: dict) -> int:
|
||||
"""Compute how late a job can be and still catch up instead of fast-forwarding.
|
||||
|
||||
Uses half the schedule period, clamped between 120 seconds and 2 hours.
|
||||
This ensures daily jobs can catch up if missed by up to 2 hours,
|
||||
while frequent jobs (every 5-10 min) still fast-forward quickly.
|
||||
"""
|
||||
MIN_GRACE = 120
|
||||
MAX_GRACE = 7200 # 2 hours
|
||||
|
||||
kind = schedule.get("kind")
|
||||
|
||||
if kind == "interval":
|
||||
period_seconds = schedule.get("minutes", 1) * 60
|
||||
grace = period_seconds // 2
|
||||
return max(MIN_GRACE, min(grace, MAX_GRACE))
|
||||
|
||||
if kind == "cron" and HAS_CRONITER:
|
||||
try:
|
||||
now = _hermes_now()
|
||||
cron = croniter(schedule["expr"], now)
|
||||
first = cron.get_next(datetime)
|
||||
second = cron.get_next(datetime)
|
||||
period_seconds = int((second - first).total_seconds())
|
||||
grace = period_seconds // 2
|
||||
return max(MIN_GRACE, min(grace, MAX_GRACE))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return MIN_GRACE
|
||||
|
||||
|
||||
def compute_next_run(schedule: Dict[str, Any], last_run_at: Optional[str] = None) -> Optional[str]:
|
||||
"""
|
||||
Compute the next run time for a schedule.
|
||||
@@ -351,6 +383,10 @@ def create_job(
|
||||
"""
|
||||
parsed_schedule = parse_schedule(schedule)
|
||||
|
||||
# Normalize repeat: treat 0 or negative values as None (infinite)
|
||||
if repeat is not None and repeat <= 0:
|
||||
repeat = None
|
||||
|
||||
# Auto-set repeat=1 for one-shot schedules if not specified
|
||||
if parsed_schedule["kind"] == "once" and repeat is None:
|
||||
repeat = 1
|
||||
@@ -539,7 +575,7 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
# Check if we've hit the repeat limit
|
||||
times = job["repeat"].get("times")
|
||||
completed = job["repeat"]["completed"]
|
||||
if times is not None and completed >= times:
|
||||
if times is not None and times > 0 and completed >= times:
|
||||
# Remove the job (limit reached)
|
||||
jobs.pop(i)
|
||||
save_jobs(jobs)
|
||||
@@ -610,16 +646,18 @@ def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
# For recurring jobs, check if the scheduled time is stale
|
||||
# (gateway was down and missed the window). Fast-forward to
|
||||
# the next future occurrence instead of firing a stale run.
|
||||
if kind in ("cron", "interval") and (now - next_run_dt).total_seconds() > 120:
|
||||
# More than 2 minutes late — this is a missed run, not a current one.
|
||||
# Recompute next_run_at to the next future occurrence.
|
||||
grace = _compute_grace_seconds(schedule)
|
||||
if kind in ("cron", "interval") and (now - next_run_dt).total_seconds() > grace:
|
||||
# Job is past its catch-up grace window — this is a stale missed run.
|
||||
# Grace scales with schedule period: daily=2h, hourly=30m, 10min=5m.
|
||||
new_next = compute_next_run(schedule, now.isoformat())
|
||||
if new_next:
|
||||
logger.info(
|
||||
"Job '%s' missed its scheduled time (%s). "
|
||||
"Job '%s' missed its scheduled time (%s, grace=%ds). "
|
||||
"Fast-forwarding to next run: %s",
|
||||
job.get("name", job["id"]),
|
||||
next_run,
|
||||
grace,
|
||||
new_next,
|
||||
)
|
||||
# Update the job in storage
|
||||
|
||||
+12
-6
@@ -80,11 +80,16 @@ def _resolve_delivery_target(job: dict) -> Optional[dict]:
|
||||
}
|
||||
|
||||
if ":" in deliver:
|
||||
platform_name, chat_id = deliver.split(":", 1)
|
||||
platform_name, rest = deliver.split(":", 1)
|
||||
# Check for thread_id suffix (e.g. "telegram:-1003724596514:17")
|
||||
if ":" in rest:
|
||||
chat_id, thread_id = rest.split(":", 1)
|
||||
else:
|
||||
chat_id, thread_id = rest, None
|
||||
return {
|
||||
"platform": platform_name,
|
||||
"chat_id": chat_id,
|
||||
"thread_id": None,
|
||||
"thread_id": thread_id,
|
||||
}
|
||||
|
||||
platform_name = deliver
|
||||
@@ -412,9 +417,10 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
result = agent.run_conversation(prompt)
|
||||
|
||||
final_response = result.get("final_response", "")
|
||||
if not final_response:
|
||||
final_response = "(No response generated)"
|
||||
final_response = result.get("final_response", "") or ""
|
||||
# Use a separate variable for log display; keep final_response clean
|
||||
# for delivery logic (empty response = no delivery).
|
||||
logged_response = final_response if final_response else "(No response generated)"
|
||||
|
||||
output = f"""# Cron Job: {job_name}
|
||||
|
||||
@@ -428,7 +434,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
## Response
|
||||
|
||||
{final_response}
|
||||
{logged_response}
|
||||
"""
|
||||
|
||||
logger.info("Job '%s' completed successfully", job_name)
|
||||
|
||||
+72
-61
@@ -346,78 +346,89 @@ class HermesAgentLoop:
|
||||
tool_name, turn + 1,
|
||||
)
|
||||
else:
|
||||
# Parse arguments and dispatch
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(tool_args_raw)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
except json.JSONDecodeError as e:
|
||||
args = None
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Invalid JSON in tool arguments: {e}. Please retry with valid JSON."}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"Invalid JSON: {e}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.warning(
|
||||
"Invalid JSON in tool call arguments for '%s': %s",
|
||||
tool_name, tool_args_raw[:200],
|
||||
)
|
||||
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
# Dispatch tool only if arguments parsed successfully
|
||||
if args is not None:
|
||||
try:
|
||||
if tool_name == "terminal":
|
||||
backend = os.getenv("TERMINAL_ENV", "local")
|
||||
cmd_preview = args.get("command", "")[:80]
|
||||
logger.info(
|
||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
||||
)
|
||||
|
||||
tool_submit_time = _time.monotonic()
|
||||
tool_submit_time = _time.monotonic()
|
||||
|
||||
# Todo tool -- handle locally (needs per-loop TodoStore)
|
||||
if tool_name == "todo":
|
||||
tool_result = _todo_tool(
|
||||
todos=args.get("todos"),
|
||||
merge=args.get("merge", False),
|
||||
store=_todo_store,
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "memory":
|
||||
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "session_search":
|
||||
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that
|
||||
# use asyncio.run() internally (modal, docker, daytona) get
|
||||
# a clean event loop instead of deadlocking.
|
||||
loop = asyncio.get_event_loop()
|
||||
# Capture current tool_name/args for the lambda
|
||||
_tn, _ta, _tid = tool_name, args, self.task_id
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
_tn, _ta, task_id=_tid,
|
||||
user_task=_user_task,
|
||||
),
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
# Todo tool -- handle locally (needs per-loop TodoStore)
|
||||
if tool_name == "todo":
|
||||
tool_result = _todo_tool(
|
||||
todos=args.get("todos"),
|
||||
merge=args.get("merge", False),
|
||||
store=_todo_store,
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "memory":
|
||||
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
elif tool_name == "session_search":
|
||||
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
else:
|
||||
# Run tool calls in a thread pool so backends that
|
||||
# use asyncio.run() internally (modal, docker, daytona) get
|
||||
# a clean event loop instead of deadlocking.
|
||||
loop = asyncio.get_event_loop()
|
||||
# Capture current tool_name/args for the lambda
|
||||
_tn, _ta, _tid = tool_name, args, self.task_id
|
||||
tool_result = await loop.run_in_executor(
|
||||
_tool_executor,
|
||||
lambda: handle_function_call(
|
||||
_tn, _ta, task_id=_tid,
|
||||
user_task=_user_task,
|
||||
),
|
||||
)
|
||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
||||
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
# Log slow tools and thread pool stats for debugging
|
||||
pool_active = _tool_executor._work_queue.qsize()
|
||||
if tool_elapsed > 30:
|
||||
logger.warning(
|
||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
||||
self.task_id[:8], turn + 1, tool_name,
|
||||
tool_elapsed, pool_active,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"{type(e).__name__}: {str(e)}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.error(
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
except Exception as e:
|
||||
tool_result = json.dumps(
|
||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
||||
)
|
||||
tool_errors.append(ToolError(
|
||||
turn=turn + 1, tool_name=tool_name,
|
||||
arguments=tool_args_raw[:200],
|
||||
error=f"{type(e).__name__}: {str(e)}",
|
||||
tool_result=tool_result,
|
||||
))
|
||||
logger.error(
|
||||
"Tool '%s' execution failed on turn %d: %s",
|
||||
tool_name, turn + 1, e,
|
||||
)
|
||||
|
||||
# Also check if the tool returned an error in its JSON result
|
||||
try:
|
||||
|
||||
@@ -10,7 +10,6 @@ The [TOOL_CALLS] token is the bot_token used by Mistral models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -42,9 +41,6 @@ class MistralToolCallParser(ToolCallParser):
|
||||
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
||||
BOT_TOKEN = "[TOOL_CALLS]"
|
||||
|
||||
# Fallback regex for pre-v11 format when JSON parsing fails
|
||||
TOOL_CALL_REGEX = re.compile(r"\[?\s*(\{.*?\})\s*\]?", re.DOTALL)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.BOT_TOKEN not in text:
|
||||
return text, None
|
||||
@@ -71,6 +67,13 @@ class MistralToolCallParser(ToolCallParser):
|
||||
tool_name = raw[:brace_idx].strip()
|
||||
args_str = raw[brace_idx:]
|
||||
|
||||
# Validate and clean the JSON arguments
|
||||
try:
|
||||
parsed_args = json.loads(args_str)
|
||||
args_str = json.dumps(parsed_args, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
pass # Keep raw if parsing fails
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
@@ -100,13 +103,14 @@ class MistralToolCallParser(ToolCallParser):
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback regex extraction
|
||||
match = self.TOOL_CALL_REGEX.findall(first_raw)
|
||||
if match:
|
||||
for raw_json in match:
|
||||
try:
|
||||
tc = json.loads(raw_json)
|
||||
args = tc.get("arguments", {})
|
||||
# Fallback: extract JSON objects using raw_decode
|
||||
decoder = json.JSONDecoder()
|
||||
idx = 0
|
||||
while idx < len(first_raw):
|
||||
try:
|
||||
obj, end_idx = decoder.raw_decode(first_raw, idx)
|
||||
if isinstance(obj, dict) and "name" in obj:
|
||||
args = obj.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
@@ -114,12 +118,13 @@ class MistralToolCallParser(ToolCallParser):
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc["name"], arguments=args
|
||||
name=obj["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
idx = end_idx
|
||||
except json.JSONDecodeError:
|
||||
idx += 1
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
+20
-3
@@ -101,12 +101,16 @@ class SessionResetPolicy:
|
||||
mode: str = "both" # "daily", "idle", "both", or "none"
|
||||
at_hour: int = 4 # Hour for daily reset (0-23, local time)
|
||||
idle_minutes: int = 1440 # Minutes of inactivity before reset (24 hours)
|
||||
notify: bool = True # Send a notification to the user when auto-reset occurs
|
||||
notify_exclude_platforms: tuple = ("api_server", "webhook") # Platforms that don't get reset notifications
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"mode": self.mode,
|
||||
"at_hour": self.at_hour,
|
||||
"idle_minutes": self.idle_minutes,
|
||||
"notify": self.notify,
|
||||
"notify_exclude_platforms": list(self.notify_exclude_platforms),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -115,10 +119,14 @@ class SessionResetPolicy:
|
||||
mode = data.get("mode")
|
||||
at_hour = data.get("at_hour")
|
||||
idle_minutes = data.get("idle_minutes")
|
||||
notify = data.get("notify")
|
||||
exclude = data.get("notify_exclude_platforms")
|
||||
return cls(
|
||||
mode=mode if mode is not None else "both",
|
||||
at_hour=at_hour if at_hour is not None else 4,
|
||||
idle_minutes=idle_minutes if idle_minutes is not None else 1440,
|
||||
notify=notify if notify is not None else True,
|
||||
notify_exclude_platforms=tuple(exclude) if exclude is not None else ("api_server", "webhook"),
|
||||
)
|
||||
|
||||
|
||||
@@ -515,8 +523,13 @@ def load_gateway_config() -> GatewayConfig:
|
||||
os.environ["DISCORD_FREE_RESPONSE_CHANNELS"] = str(frc)
|
||||
if "auto_thread" in discord_cfg and not os.getenv("DISCORD_AUTO_THREAD"):
|
||||
os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to process config.yaml — falling back to .env / gateway.json values. "
|
||||
"Check %s for syntax errors. Error: %s",
|
||||
_home / "config.yaml",
|
||||
e,
|
||||
)
|
||||
|
||||
config = GatewayConfig.from_dict(gw_data)
|
||||
|
||||
@@ -738,6 +751,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
# API Server
|
||||
api_server_enabled = os.getenv("API_SERVER_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
api_server_key = os.getenv("API_SERVER_KEY", "")
|
||||
api_server_cors_origins = os.getenv("API_SERVER_CORS_ORIGINS", "")
|
||||
api_server_port = os.getenv("API_SERVER_PORT")
|
||||
api_server_host = os.getenv("API_SERVER_HOST")
|
||||
if api_server_enabled or api_server_key:
|
||||
@@ -746,6 +760,10 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.API_SERVER].enabled = True
|
||||
if api_server_key:
|
||||
config.platforms[Platform.API_SERVER].extra["key"] = api_server_key
|
||||
if api_server_cors_origins:
|
||||
origins = [origin.strip() for origin in api_server_cors_origins.split(",") if origin.strip()]
|
||||
if origins:
|
||||
config.platforms[Platform.API_SERVER].extra["cors_origins"] = origins
|
||||
if api_server_port:
|
||||
try:
|
||||
config.platforms[Platform.API_SERVER].extra["port"] = int(api_server_port)
|
||||
@@ -786,4 +804,3 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
+395
-27
@@ -18,10 +18,10 @@ Requires:
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sqlite3
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -54,41 +54,109 @@ def check_api_server_requirements() -> bool:
|
||||
|
||||
class ResponseStore:
|
||||
"""
|
||||
In-memory LRU store for Responses API state.
|
||||
SQLite-backed LRU store for Responses API state.
|
||||
|
||||
Each stored response includes the full internal conversation history
|
||||
(with tool calls and results) so it can be reconstructed on subsequent
|
||||
requests via previous_response_id.
|
||||
|
||||
Persists across gateway restarts. Falls back to in-memory SQLite
|
||||
if the on-disk path is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = MAX_STORED_RESPONSES):
|
||||
self._store: collections.OrderedDict[str, Dict[str, Any]] = collections.OrderedDict()
|
||||
def __init__(self, max_size: int = MAX_STORED_RESPONSES, db_path: str = None):
|
||||
self._max_size = max_size
|
||||
if db_path is None:
|
||||
try:
|
||||
from hermes_cli.config import get_hermes_home
|
||||
db_path = str(get_hermes_home() / "response_store.db")
|
||||
except Exception:
|
||||
db_path = ":memory:"
|
||||
try:
|
||||
self._conn = sqlite3.connect(db_path, check_same_thread=False)
|
||||
except Exception:
|
||||
self._conn = sqlite3.connect(":memory:", check_same_thread=False)
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS responses (
|
||||
response_id TEXT PRIMARY KEY,
|
||||
data TEXT NOT NULL,
|
||||
accessed_at REAL NOT NULL
|
||||
)"""
|
||||
)
|
||||
self._conn.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS conversations (
|
||||
name TEXT PRIMARY KEY,
|
||||
response_id TEXT NOT NULL
|
||||
)"""
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get(self, response_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve a stored response by ID (moves to end for LRU)."""
|
||||
if response_id in self._store:
|
||||
self._store.move_to_end(response_id)
|
||||
return self._store[response_id]
|
||||
return None
|
||||
"""Retrieve a stored response by ID (updates access time for LRU)."""
|
||||
row = self._conn.execute(
|
||||
"SELECT data FROM responses WHERE response_id = ?", (response_id,)
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
import time
|
||||
self._conn.execute(
|
||||
"UPDATE responses SET accessed_at = ? WHERE response_id = ?",
|
||||
(time.time(), response_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
return json.loads(row[0])
|
||||
|
||||
def put(self, response_id: str, data: Dict[str, Any]) -> None:
|
||||
"""Store a response, evicting the oldest if at capacity."""
|
||||
if response_id in self._store:
|
||||
self._store.move_to_end(response_id)
|
||||
self._store[response_id] = data
|
||||
while len(self._store) > self._max_size:
|
||||
self._store.popitem(last=False)
|
||||
import time
|
||||
self._conn.execute(
|
||||
"INSERT OR REPLACE INTO responses (response_id, data, accessed_at) VALUES (?, ?, ?)",
|
||||
(response_id, json.dumps(data, default=str), time.time()),
|
||||
)
|
||||
# Evict oldest entries beyond max_size
|
||||
count = self._conn.execute("SELECT COUNT(*) FROM responses").fetchone()[0]
|
||||
if count > self._max_size:
|
||||
self._conn.execute(
|
||||
"DELETE FROM responses WHERE response_id IN "
|
||||
"(SELECT response_id FROM responses ORDER BY accessed_at ASC LIMIT ?)",
|
||||
(count - self._max_size,),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def delete(self, response_id: str) -> bool:
|
||||
"""Remove a response from the store. Returns True if found and deleted."""
|
||||
if response_id in self._store:
|
||||
del self._store[response_id]
|
||||
return True
|
||||
return False
|
||||
cursor = self._conn.execute(
|
||||
"DELETE FROM responses WHERE response_id = ?", (response_id,)
|
||||
)
|
||||
self._conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def get_conversation(self, name: str) -> Optional[str]:
|
||||
"""Get the latest response_id for a conversation name."""
|
||||
row = self._conn.execute(
|
||||
"SELECT response_id FROM conversations WHERE name = ?", (name,)
|
||||
).fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
def set_conversation(self, name: str, response_id: str) -> None:
|
||||
"""Map a conversation name to its latest response_id."""
|
||||
self._conn.execute(
|
||||
"INSERT OR REPLACE INTO conversations (name, response_id) VALUES (?, ?)",
|
||||
(name, response_id),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
try:
|
||||
self._conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._store)
|
||||
row = self._conn.execute("SELECT COUNT(*) FROM responses").fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -96,7 +164,6 @@ class ResponseStore:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CORS_HEADERS = {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "GET, POST, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Authorization, Content-Type",
|
||||
}
|
||||
@@ -105,11 +172,23 @@ _CORS_HEADERS = {
|
||||
if AIOHTTP_AVAILABLE:
|
||||
@web.middleware
|
||||
async def cors_middleware(request, handler):
|
||||
"""Add CORS headers to every response; handle OPTIONS preflight."""
|
||||
"""Add CORS headers for explicitly allowed origins; handle OPTIONS preflight."""
|
||||
adapter = request.app.get("api_server_adapter")
|
||||
origin = request.headers.get("Origin", "")
|
||||
cors_headers = None
|
||||
if adapter is not None:
|
||||
if not adapter._origin_allowed(origin):
|
||||
return web.Response(status=403)
|
||||
cors_headers = adapter._cors_headers_for_origin(origin)
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
return web.Response(status=200, headers=_CORS_HEADERS)
|
||||
if cors_headers is None:
|
||||
return web.Response(status=403)
|
||||
return web.Response(status=200, headers=cors_headers)
|
||||
|
||||
response = await handler(request)
|
||||
response.headers.update(_CORS_HEADERS)
|
||||
if cors_headers is not None:
|
||||
response.headers.update(cors_headers)
|
||||
return response
|
||||
else:
|
||||
cors_middleware = None # type: ignore[assignment]
|
||||
@@ -129,12 +208,56 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
self._host: str = extra.get("host", os.getenv("API_SERVER_HOST", DEFAULT_HOST))
|
||||
self._port: int = int(extra.get("port", os.getenv("API_SERVER_PORT", str(DEFAULT_PORT))))
|
||||
self._api_key: str = extra.get("key", os.getenv("API_SERVER_KEY", ""))
|
||||
self._cors_origins: tuple[str, ...] = self._parse_cors_origins(
|
||||
extra.get("cors_origins", os.getenv("API_SERVER_CORS_ORIGINS", "")),
|
||||
)
|
||||
self._app: Optional["web.Application"] = None
|
||||
self._runner: Optional["web.AppRunner"] = None
|
||||
self._site: Optional["web.TCPSite"] = None
|
||||
self._response_store = ResponseStore()
|
||||
# Conversation name → latest response_id mapping
|
||||
self._conversations: Dict[str, str] = {}
|
||||
|
||||
@staticmethod
|
||||
def _parse_cors_origins(value: Any) -> tuple[str, ...]:
|
||||
"""Normalize configured CORS origins into a stable tuple."""
|
||||
if not value:
|
||||
return ()
|
||||
|
||||
if isinstance(value, str):
|
||||
items = value.split(",")
|
||||
elif isinstance(value, (list, tuple, set)):
|
||||
items = value
|
||||
else:
|
||||
items = [str(value)]
|
||||
|
||||
return tuple(str(item).strip() for item in items if str(item).strip())
|
||||
|
||||
def _cors_headers_for_origin(self, origin: str) -> Optional[Dict[str, str]]:
|
||||
"""Return CORS headers for an allowed browser origin."""
|
||||
if not origin or not self._cors_origins:
|
||||
return None
|
||||
|
||||
if "*" in self._cors_origins:
|
||||
headers = dict(_CORS_HEADERS)
|
||||
headers["Access-Control-Allow-Origin"] = "*"
|
||||
return headers
|
||||
|
||||
if origin not in self._cors_origins:
|
||||
return None
|
||||
|
||||
headers = dict(_CORS_HEADERS)
|
||||
headers["Access-Control-Allow-Origin"] = origin
|
||||
headers["Vary"] = "Origin"
|
||||
return headers
|
||||
|
||||
def _origin_allowed(self, origin: str) -> bool:
|
||||
"""Allow non-browser clients and explicitly configured browser origins."""
|
||||
if not origin:
|
||||
return True
|
||||
|
||||
if not self._cors_origins:
|
||||
return False
|
||||
|
||||
return "*" in self._cors_origins or origin in self._cors_origins
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auth helper
|
||||
@@ -463,7 +586,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
# Resolve conversation name to latest response_id
|
||||
if conversation:
|
||||
previous_response_id = self._conversations.get(conversation)
|
||||
previous_response_id = self._response_store.get_conversation(conversation)
|
||||
# No error if conversation doesn't exist yet — it's a new conversation
|
||||
|
||||
# Normalize input to message list
|
||||
@@ -586,7 +709,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
# Update conversation mapping so the next request with the same
|
||||
# conversation name automatically chains to this response
|
||||
if conversation:
|
||||
self._conversations[conversation] = response_id
|
||||
self._response_store.set_conversation(conversation, response_id)
|
||||
|
||||
return web.json_response(response_data)
|
||||
|
||||
@@ -630,6 +753,241 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"deleted": True,
|
||||
})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cron jobs API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# Check cron module availability once (not per-request)
|
||||
_CRON_AVAILABLE = False
|
||||
try:
|
||||
from cron.jobs import (
|
||||
list_jobs as _cron_list,
|
||||
get_job as _cron_get,
|
||||
create_job as _cron_create,
|
||||
update_job as _cron_update,
|
||||
remove_job as _cron_remove,
|
||||
pause_job as _cron_pause,
|
||||
resume_job as _cron_resume,
|
||||
trigger_job as _cron_trigger,
|
||||
)
|
||||
_CRON_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_JOB_ID_RE = __import__("re").compile(r"[a-f0-9]{12}")
|
||||
# Allowed fields for update — prevents clients injecting arbitrary keys
|
||||
_UPDATE_ALLOWED_FIELDS = {"name", "schedule", "prompt", "deliver", "skills", "skill", "repeat", "enabled"}
|
||||
_MAX_NAME_LENGTH = 200
|
||||
_MAX_PROMPT_LENGTH = 5000
|
||||
|
||||
def _check_jobs_available(self) -> Optional["web.Response"]:
|
||||
"""Return error response if cron module isn't available."""
|
||||
if not self._CRON_AVAILABLE:
|
||||
return web.json_response(
|
||||
{"error": "Cron module not available"}, status=501,
|
||||
)
|
||||
return None
|
||||
|
||||
def _check_job_id(self, request: "web.Request") -> tuple:
|
||||
"""Validate and extract job_id. Returns (job_id, error_response)."""
|
||||
job_id = request.match_info["job_id"]
|
||||
if not self._JOB_ID_RE.fullmatch(job_id):
|
||||
return job_id, web.json_response(
|
||||
{"error": "Invalid job ID format"}, status=400,
|
||||
)
|
||||
return job_id, None
|
||||
|
||||
async def _handle_list_jobs(self, request: "web.Request") -> "web.Response":
|
||||
"""GET /api/jobs — list all cron jobs."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
try:
|
||||
include_disabled = request.query.get("include_disabled", "").lower() in ("true", "1")
|
||||
jobs = self._cron_list(include_disabled=include_disabled)
|
||||
return web.json_response({"jobs": jobs})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_create_job(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /api/jobs — create a new cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
try:
|
||||
body = await request.json()
|
||||
name = (body.get("name") or "").strip()
|
||||
schedule = (body.get("schedule") or "").strip()
|
||||
prompt = body.get("prompt", "")
|
||||
deliver = body.get("deliver", "local")
|
||||
skills = body.get("skills")
|
||||
repeat = body.get("repeat")
|
||||
|
||||
if not name:
|
||||
return web.json_response({"error": "Name is required"}, status=400)
|
||||
if len(name) > self._MAX_NAME_LENGTH:
|
||||
return web.json_response(
|
||||
{"error": f"Name must be ≤ {self._MAX_NAME_LENGTH} characters"}, status=400,
|
||||
)
|
||||
if not schedule:
|
||||
return web.json_response({"error": "Schedule is required"}, status=400)
|
||||
if len(prompt) > self._MAX_PROMPT_LENGTH:
|
||||
return web.json_response(
|
||||
{"error": f"Prompt must be ≤ {self._MAX_PROMPT_LENGTH} characters"}, status=400,
|
||||
)
|
||||
if repeat is not None and (not isinstance(repeat, int) or repeat < 1):
|
||||
return web.json_response({"error": "Repeat must be a positive integer"}, status=400)
|
||||
|
||||
kwargs = {
|
||||
"prompt": prompt,
|
||||
"schedule": schedule,
|
||||
"name": name,
|
||||
"deliver": deliver,
|
||||
}
|
||||
if skills:
|
||||
kwargs["skills"] = skills
|
||||
if repeat is not None:
|
||||
kwargs["repeat"] = repeat
|
||||
|
||||
job = self._cron_create(**kwargs)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_get_job(self, request: "web.Request") -> "web.Response":
|
||||
"""GET /api/jobs/{job_id} — get a single cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
job = self._cron_get(job_id)
|
||||
if not job:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_update_job(self, request: "web.Request") -> "web.Response":
|
||||
"""PATCH /api/jobs/{job_id} — update a cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
body = await request.json()
|
||||
# Whitelist allowed fields to prevent arbitrary key injection
|
||||
sanitized = {k: v for k, v in body.items() if k in self._UPDATE_ALLOWED_FIELDS}
|
||||
if not sanitized:
|
||||
return web.json_response({"error": "No valid fields to update"}, status=400)
|
||||
# Validate lengths if present
|
||||
if "name" in sanitized and len(sanitized["name"]) > self._MAX_NAME_LENGTH:
|
||||
return web.json_response(
|
||||
{"error": f"Name must be ≤ {self._MAX_NAME_LENGTH} characters"}, status=400,
|
||||
)
|
||||
if "prompt" in sanitized and len(sanitized["prompt"]) > self._MAX_PROMPT_LENGTH:
|
||||
return web.json_response(
|
||||
{"error": f"Prompt must be ≤ {self._MAX_PROMPT_LENGTH} characters"}, status=400,
|
||||
)
|
||||
job = self._cron_update(job_id, sanitized)
|
||||
if not job:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_delete_job(self, request: "web.Request") -> "web.Response":
|
||||
"""DELETE /api/jobs/{job_id} — delete a cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
success = self._cron_remove(job_id)
|
||||
if not success:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"ok": True})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_pause_job(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /api/jobs/{job_id}/pause — pause a cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
job = self._cron_pause(job_id)
|
||||
if not job:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_resume_job(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /api/jobs/{job_id}/resume — resume a paused cron job."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
job = self._cron_resume(job_id)
|
||||
if not job:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def _handle_run_job(self, request: "web.Request") -> "web.Response":
|
||||
"""POST /api/jobs/{job_id}/run — trigger immediate execution."""
|
||||
auth_err = self._check_auth(request)
|
||||
if auth_err:
|
||||
return auth_err
|
||||
cron_err = self._check_jobs_available()
|
||||
if cron_err:
|
||||
return cron_err
|
||||
job_id, id_err = self._check_job_id(request)
|
||||
if id_err:
|
||||
return id_err
|
||||
try:
|
||||
job = self._cron_trigger(job_id)
|
||||
if not job:
|
||||
return web.json_response({"error": "Job not found"}, status=404)
|
||||
return web.json_response({"job": job})
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Output extraction helper
|
||||
# ------------------------------------------------------------------
|
||||
@@ -733,12 +1091,22 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
try:
|
||||
self._app = web.Application(middlewares=[cors_middleware])
|
||||
self._app["api_server_adapter"] = self
|
||||
self._app.router.add_get("/health", self._handle_health)
|
||||
self._app.router.add_get("/v1/models", self._handle_models)
|
||||
self._app.router.add_post("/v1/chat/completions", self._handle_chat_completions)
|
||||
self._app.router.add_post("/v1/responses", self._handle_responses)
|
||||
self._app.router.add_get("/v1/responses/{response_id}", self._handle_get_response)
|
||||
self._app.router.add_delete("/v1/responses/{response_id}", self._handle_delete_response)
|
||||
# Cron jobs management API
|
||||
self._app.router.add_get("/api/jobs", self._handle_list_jobs)
|
||||
self._app.router.add_post("/api/jobs", self._handle_create_job)
|
||||
self._app.router.add_get("/api/jobs/{job_id}", self._handle_get_job)
|
||||
self._app.router.add_patch("/api/jobs/{job_id}", self._handle_update_job)
|
||||
self._app.router.add_delete("/api/jobs/{job_id}", self._handle_delete_job)
|
||||
self._app.router.add_post("/api/jobs/{job_id}/pause", self._handle_pause_job)
|
||||
self._app.router.add_post("/api/jobs/{job_id}/resume", self._handle_resume_job)
|
||||
self._app.router.add_post("/api/jobs/{job_id}/run", self._handle_run_job)
|
||||
|
||||
self._runner = web.AppRunner(self._app)
|
||||
await self._runner.setup()
|
||||
|
||||
@@ -504,6 +504,14 @@ class BasePlatformAdapter(ABC):
|
||||
metadata: optional dict with platform-specific context (e.g. thread_id for Slack).
|
||||
"""
|
||||
pass
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop a persistent typing indicator (if the platform uses one).
|
||||
|
||||
Override in subclasses that start background typing loops.
|
||||
Default is a no-op for platforms with one-shot typing indicators.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def send_image(
|
||||
self,
|
||||
@@ -713,7 +721,7 @@ class BasePlatformAdapter(ABC):
|
||||
# Extract MEDIA:<path> tags, allowing optional whitespace after the colon
|
||||
# and quoted/backticked paths for LLM-formatted outputs.
|
||||
media_pattern = re.compile(
|
||||
r'''[`"']?MEDIA:\s*(?P<path>`[^`\n]+`|"[^"\n]+"|'[^'\n]+'|\S+)[`"']?'''
|
||||
r'''[`"']?MEDIA:\s*(?P<path>`[^`\n]+`|"[^"\n]+"|'[^'\n]+'|(?:~/|/)\S+(?:[^\S\n]+\S+)*?\.(?:png|jpe?g|gif|webp|mp4|mov|avi|mkv|webm|ogg|opus|mp3|wav|m4a)(?=[\s`"',;:)\]}]|$)|\S+)[`"']?'''
|
||||
)
|
||||
for match in media_pattern.finditer(content):
|
||||
path = match.group("path").strip()
|
||||
|
||||
+134
-13
@@ -43,6 +43,8 @@ from pathlib import Path as _Path
|
||||
sys.path.insert(0, str(_Path(__file__).resolve().parents[2]))
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
import re
|
||||
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
@@ -50,6 +52,8 @@ from gateway.platforms.base import (
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_audio_from_url,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
||||
|
||||
@@ -439,6 +443,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# in those threads don't require @mention. Persisted to disk so the
|
||||
# set survives gateway restarts.
|
||||
self._bot_participated_threads: set = self._load_participated_threads()
|
||||
# Persistent typing indicator loops per channel (DMs don't reliably
|
||||
# show the standard typing gateway event for bots)
|
||||
self._typing_tasks: Dict[str, asyncio.Task] = {}
|
||||
# Cap to prevent unbounded growth (Discord threads get archived).
|
||||
self._MAX_TRACKED_THREADS = 500
|
||||
|
||||
@@ -524,6 +531,11 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if message.author == self._client.user:
|
||||
return
|
||||
|
||||
# Ignore Discord system messages (thread renames, pins, member joins, etc.)
|
||||
# Allow both default and reply types — replies have a distinct MessageType.
|
||||
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
return
|
||||
|
||||
# Bot message filtering (DISCORD_ALLOW_BOTS):
|
||||
# "none" — ignore all other bots (default)
|
||||
# "mentions" — accept bot messages only when they @mention us
|
||||
@@ -1239,14 +1251,48 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
return await super().send_document(chat_id, file_path, caption, file_name, reply_to, metadata=metadata)
|
||||
|
||||
async def send_typing(self, chat_id: str, metadata=None) -> None:
|
||||
"""Send typing indicator."""
|
||||
if self._client:
|
||||
"""Start a persistent typing indicator for a channel.
|
||||
|
||||
Discord's TYPING_START gateway event is unreliable in DMs for bots.
|
||||
Instead, start a background loop that hits the typing endpoint every
|
||||
8 seconds (typing indicator lasts ~10s). The loop is cancelled when
|
||||
stop_typing() is called (after the response is sent).
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
# Don't start a duplicate loop
|
||||
if chat_id in self._typing_tasks:
|
||||
return
|
||||
|
||||
async def _typing_loop() -> None:
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if channel:
|
||||
await channel.typing()
|
||||
except Exception:
|
||||
pass # Ignore typing indicator failures
|
||||
while True:
|
||||
try:
|
||||
route = discord.http.Route(
|
||||
"POST", "/channels/{channel_id}/typing",
|
||||
channel_id=chat_id,
|
||||
)
|
||||
await self._client.http.request(route)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
logger.debug("Discord typing indicator failed for %s: %s", chat_id, e)
|
||||
return
|
||||
await asyncio.sleep(8)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._typing_tasks[chat_id] = asyncio.create_task(_typing_loop())
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Stop the persistent typing indicator for a channel."""
|
||||
task = self._typing_tasks.pop(chat_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
||||
"""Get information about a Discord channel."""
|
||||
@@ -1500,7 +1546,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Discord slash command interaction."""
|
||||
is_dm = isinstance(interaction.channel, discord.DMChannel)
|
||||
chat_type = "dm" if is_dm else "group"
|
||||
is_thread = isinstance(interaction.channel, discord.Thread)
|
||||
thread_id = None
|
||||
|
||||
if is_dm:
|
||||
chat_type = "dm"
|
||||
elif is_thread:
|
||||
chat_type = "thread"
|
||||
thread_id = str(interaction.channel_id)
|
||||
else:
|
||||
chat_type = "group"
|
||||
|
||||
chat_name = ""
|
||||
if not is_dm and hasattr(interaction.channel, "name"):
|
||||
chat_name = interaction.channel.name
|
||||
@@ -1516,6 +1572,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
chat_type=chat_type,
|
||||
user_id=str(interaction.user.id),
|
||||
user_name=interaction.user.display_name,
|
||||
thread_id=thread_id,
|
||||
chat_topic=chat_topic,
|
||||
)
|
||||
|
||||
@@ -1902,7 +1959,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
elif att.content_type.startswith("audio/"):
|
||||
msg_type = MessageType.AUDIO
|
||||
else:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
doc_ext = ""
|
||||
if att.filename:
|
||||
_, doc_ext = os.path.splitext(att.filename)
|
||||
doc_ext = doc_ext.lower()
|
||||
if doc_ext in SUPPORTED_DOCUMENT_TYPES:
|
||||
msg_type = MessageType.DOCUMENT
|
||||
break
|
||||
|
||||
# When auto-threading kicked in, route responses to the new thread
|
||||
@@ -1939,6 +2001,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# vision tool can access them reliably (Discord CDN URLs can expire).
|
||||
media_urls = []
|
||||
media_types = []
|
||||
pending_text_injection: Optional[str] = None
|
||||
for att in message.attachments:
|
||||
content_type = att.content_type or "unknown"
|
||||
if content_type.startswith("image/"):
|
||||
@@ -1970,12 +2033,70 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
else:
|
||||
# Other attachments: keep the original URL
|
||||
media_urls.append(att.url)
|
||||
media_types.append(content_type)
|
||||
# Document attachments: download, cache, and optionally inject text
|
||||
ext = ""
|
||||
if att.filename:
|
||||
_, ext = os.path.splitext(att.filename)
|
||||
ext = ext.lower()
|
||||
if not ext and content_type:
|
||||
mime_to_ext = {v: k for k, v in SUPPORTED_DOCUMENT_TYPES.items()}
|
||||
ext = mime_to_ext.get(content_type, "")
|
||||
if ext not in SUPPORTED_DOCUMENT_TYPES:
|
||||
logger.warning(
|
||||
"[Discord] Unsupported document type '%s' (%s), skipping",
|
||||
ext or "unknown", content_type,
|
||||
)
|
||||
else:
|
||||
MAX_DOC_BYTES = 20 * 1024 * 1024
|
||||
if att.size and att.size > MAX_DOC_BYTES:
|
||||
logger.warning(
|
||||
"[Discord] Document too large (%s bytes), skipping: %s",
|
||||
att.size, att.filename,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
raw_bytes = await resp.read()
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, att.filename or f"document{ext}"
|
||||
)
|
||||
doc_mime = SUPPORTED_DOCUMENT_TYPES[ext]
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(doc_mime)
|
||||
logger.info("[Discord] Cached user document: %s", cached_path)
|
||||
# Inject text content for .txt/.md files (capped at 100 KB)
|
||||
MAX_TEXT_INJECT_BYTES = 100 * 1024
|
||||
if ext in (".md", ".txt") and len(raw_bytes) <= MAX_TEXT_INJECT_BYTES:
|
||||
try:
|
||||
text_content = raw_bytes.decode("utf-8")
|
||||
display_name = att.filename or f"document{ext}"
|
||||
display_name = re.sub(r'[^\w.\- ]', '_', display_name)
|
||||
injection = f"[Content of {display_name}]:\n{text_content}"
|
||||
if pending_text_injection:
|
||||
pending_text_injection = f"{pending_text_injection}\n\n{injection}"
|
||||
else:
|
||||
pending_text_injection = injection
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Discord] Failed to cache document %s: %s",
|
||||
att.filename, e, exc_info=True,
|
||||
)
|
||||
|
||||
event_text = message.content
|
||||
if pending_text_injection:
|
||||
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection
|
||||
|
||||
event = MessageEvent(
|
||||
text=message.content,
|
||||
text=event_text,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=message,
|
||||
|
||||
@@ -230,7 +230,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
# Mark all existing messages as seen so we only process new ones
|
||||
imap.select("INBOX")
|
||||
status, data = imap.uid("search", None, "ALL")
|
||||
if status == "OK" and data[0]:
|
||||
if status == "OK" and data and data[0]:
|
||||
for uid in data[0].split():
|
||||
self._seen_uids.add(uid)
|
||||
imap.logout()
|
||||
@@ -295,7 +295,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
imap.select("INBOX")
|
||||
|
||||
status, data = imap.uid("search", None, "UNSEEN")
|
||||
if status != "OK" or not data[0]:
|
||||
if status != "OK" or not data or not data[0]:
|
||||
imap.logout()
|
||||
return results
|
||||
|
||||
|
||||
@@ -103,6 +103,23 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
self._dm_rooms: Dict[str, bool] = {}
|
||||
# Set of room IDs we've joined
|
||||
self._joined_rooms: Set[str] = set()
|
||||
# Event deduplication (bounded deque keeps newest entries)
|
||||
from collections import deque
|
||||
self._processed_events: deque = deque(maxlen=1000)
|
||||
self._processed_events_set: set = set()
|
||||
|
||||
def _is_duplicate_event(self, event_id) -> bool:
|
||||
"""Return True if this event was already processed. Tracks the ID otherwise."""
|
||||
if not event_id:
|
||||
return False
|
||||
if event_id in self._processed_events_set:
|
||||
return True
|
||||
if len(self._processed_events) == self._processed_events.maxlen:
|
||||
evicted = self._processed_events[0]
|
||||
self._processed_events_set.discard(evicted)
|
||||
self._processed_events.append(event_id)
|
||||
self._processed_events_set.add(event_id)
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Required overrides
|
||||
@@ -188,7 +205,6 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
# Register event callbacks.
|
||||
client.add_event_callback(self._on_room_message, nio.RoomMessageText)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageMedia)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageImage)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageAudio)
|
||||
client.add_event_callback(self._on_room_message_media, nio.RoomMessageVideo)
|
||||
@@ -559,6 +575,10 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
if event.sender == self._user_id:
|
||||
return
|
||||
|
||||
# Deduplicate by event ID (nio can fire the same event more than once).
|
||||
if self._is_duplicate_event(getattr(event, "event_id", None)):
|
||||
return
|
||||
|
||||
# Startup grace: ignore old messages from initial sync.
|
||||
event_ts = getattr(event, "server_timestamp", 0) / 1000.0
|
||||
if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS:
|
||||
@@ -648,6 +668,10 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
if event.sender == self._user_id:
|
||||
return
|
||||
|
||||
# Deduplicate by event ID.
|
||||
if self._is_duplicate_event(getattr(event, "event_id", None)):
|
||||
return
|
||||
|
||||
# Startup grace.
|
||||
event_ts = getattr(event, "server_timestamp", 0) / 1000.0
|
||||
if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS:
|
||||
@@ -681,6 +705,24 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
elif event_mimetype:
|
||||
media_type = event_mimetype
|
||||
|
||||
# For images, download and cache locally so vision tools can access them.
|
||||
# Matrix MXC URLs require authentication, so direct URL access fails.
|
||||
cached_path = None
|
||||
if msg_type == MessageType.PHOTO and url:
|
||||
try:
|
||||
ext_map = {
|
||||
"image/jpeg": ".jpg", "image/png": ".png",
|
||||
"image/gif": ".gif", "image/webp": ".webp",
|
||||
}
|
||||
ext = ext_map.get(event_mimetype, ".jpg")
|
||||
download_resp = await self._client.download(url)
|
||||
if isinstance(download_resp, nio.DownloadResponse):
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
cached_path = cache_image_from_bytes(download_resp.body, ext=ext)
|
||||
logger.info("[Matrix] Cached user image at %s", cached_path)
|
||||
except Exception as e:
|
||||
logger.warning("[Matrix] Failed to cache image: %s", e)
|
||||
|
||||
is_dm = self._dm_rooms.get(room.room_id, False)
|
||||
if not is_dm and room.member_count == 2:
|
||||
is_dm = True
|
||||
@@ -701,14 +743,18 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
# Use cached local path for images, HTTP URL for other media types
|
||||
media_urls = [cached_path] if cached_path else ([http_url] if http_url else None)
|
||||
media_types = [media_type] if media_urls else None
|
||||
|
||||
msg_event = MessageEvent(
|
||||
text=body,
|
||||
message_type=msg_type,
|
||||
source=source,
|
||||
raw_message=getattr(event, "source", {}),
|
||||
message_id=event.event_id,
|
||||
media_urls=[http_url] if http_url else None,
|
||||
media_types=[media_type] if http_url else None,
|
||||
media_urls=media_urls,
|
||||
media_types=media_types,
|
||||
)
|
||||
|
||||
await self.handle_message(msg_event)
|
||||
|
||||
@@ -580,6 +580,24 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
# For DMs, user_id is sufficient. For channels, check for @mention.
|
||||
message_text = post.get("message", "")
|
||||
|
||||
# Mention-only mode: skip channel messages that don't @mention the bot.
|
||||
# DMs (type "D") are always processed.
|
||||
if channel_type_raw != "D":
|
||||
mention_patterns = [
|
||||
f"@{self._bot_username}",
|
||||
f"@{self._bot_user_id}",
|
||||
]
|
||||
has_mention = any(
|
||||
pattern.lower() in message_text.lower()
|
||||
for pattern in mention_patterns
|
||||
)
|
||||
if not has_mention:
|
||||
logger.debug(
|
||||
"Mattermost: skipping non-DM message without @mention (channel=%s)",
|
||||
channel_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Resolve sender info.
|
||||
sender_id = post.get("user_id", "")
|
||||
sender_name = data.get("sender_name", "").lstrip("@") or sender_id
|
||||
|
||||
@@ -478,7 +478,7 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if any(mt.startswith("audio/") for mt in media_types):
|
||||
msg_type = MessageType.VOICE
|
||||
elif any(mt.startswith("image/") for mt in media_types):
|
||||
msg_type = MessageType.IMAGE
|
||||
msg_type = MessageType.PHOTO
|
||||
|
||||
# Parse timestamp from envelope data (milliseconds since epoch)
|
||||
ts_ms = envelope_data.get("timestamp", 0)
|
||||
@@ -519,6 +519,13 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
if not result:
|
||||
return None, ""
|
||||
|
||||
# Handle dict response (signal-cli returns {"data": "base64..."})
|
||||
if isinstance(result, dict):
|
||||
result = result.get("data")
|
||||
if not result:
|
||||
logger.warning("Signal: attachment response missing 'data' key")
|
||||
return None, ""
|
||||
|
||||
# Result is base64-encoded file content
|
||||
raw_data = base64.b64decode(result)
|
||||
ext = _guess_extension(raw_data)
|
||||
|
||||
@@ -130,6 +130,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self._token_lock_identity: Optional[str] = None
|
||||
self._polling_error_task: Optional[asyncio.Task] = None
|
||||
self._polling_conflict_count: int = 0
|
||||
self._polling_network_error_count: int = 0
|
||||
self._polling_error_callback_ref = None
|
||||
|
||||
@staticmethod
|
||||
@@ -141,6 +142,80 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
or "another bot instance is running" in text
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_network_error(error: Exception) -> bool:
|
||||
"""Return True for transient network errors that warrant a reconnect attempt."""
|
||||
name = error.__class__.__name__.lower()
|
||||
if name in ("networkerror", "timedout", "connectionerror"):
|
||||
return True
|
||||
try:
|
||||
from telegram.error import NetworkError, TimedOut
|
||||
if isinstance(error, (NetworkError, TimedOut)):
|
||||
return True
|
||||
except ImportError:
|
||||
pass
|
||||
return isinstance(error, OSError)
|
||||
|
||||
async def _handle_polling_network_error(self, error: Exception) -> None:
|
||||
"""Reconnect polling after a transient network interruption.
|
||||
|
||||
Triggered by NetworkError/TimedOut in the polling error callback, which
|
||||
happen when the host loses connectivity (Mac sleep, WiFi switch, VPN
|
||||
reconnect, etc.). The gateway process stays alive but the long-poll
|
||||
connection silently dies; without this handler the bot never recovers.
|
||||
|
||||
Strategy: exponential back-off (5s, 10s, 20s, 40s, 60s cap) up to
|
||||
MAX_NETWORK_RETRIES attempts, then mark the adapter retryable-fatal so
|
||||
the supervisor restarts the gateway process.
|
||||
"""
|
||||
if self.has_fatal_error:
|
||||
return
|
||||
|
||||
MAX_NETWORK_RETRIES = 10
|
||||
BASE_DELAY = 5
|
||||
MAX_DELAY = 60
|
||||
|
||||
self._polling_network_error_count += 1
|
||||
attempt = self._polling_network_error_count
|
||||
|
||||
if attempt > MAX_NETWORK_RETRIES:
|
||||
message = (
|
||||
"Telegram polling could not reconnect after %d network error retries. "
|
||||
"Restarting gateway." % MAX_NETWORK_RETRIES
|
||||
)
|
||||
logger.error("[%s] %s Last error: %s", self.name, message, error)
|
||||
self._set_fatal_error("telegram_network_error", message, retryable=True)
|
||||
await self._notify_fatal_error()
|
||||
return
|
||||
|
||||
delay = min(BASE_DELAY * (2 ** (attempt - 1)), MAX_DELAY)
|
||||
logger.warning(
|
||||
"[%s] Telegram network error (attempt %d/%d), reconnecting in %ds. Error: %s",
|
||||
self.name, attempt, MAX_NETWORK_RETRIES, delay, error,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
if self._app and self._app.updater and self._app.updater.running:
|
||||
await self._app.updater.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
await self._app.updater.start_polling(
|
||||
allowed_updates=Update.ALL_TYPES,
|
||||
drop_pending_updates=False,
|
||||
error_callback=self._polling_error_callback_ref,
|
||||
)
|
||||
logger.info(
|
||||
"[%s] Telegram polling resumed after network error (attempt %d)",
|
||||
self.name, attempt,
|
||||
)
|
||||
self._polling_network_error_count = 0
|
||||
except Exception as retry_err:
|
||||
logger.warning("[%s] Telegram polling reconnect failed: %s", self.name, retry_err)
|
||||
# The next network error will trigger another attempt.
|
||||
|
||||
async def _handle_polling_conflict(self, error: Exception) -> None:
|
||||
if self.has_fatal_error and self.fatal_error_code == "telegram_polling_conflict":
|
||||
return
|
||||
@@ -276,12 +351,15 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _polling_error_callback(error: Exception) -> None:
|
||||
if not self._looks_like_polling_conflict(error):
|
||||
logger.error("[%s] Telegram polling error: %s", self.name, error, exc_info=True)
|
||||
return
|
||||
if self._polling_error_task and not self._polling_error_task.done():
|
||||
return
|
||||
self._polling_error_task = loop.create_task(self._handle_polling_conflict(error))
|
||||
if self._looks_like_polling_conflict(error):
|
||||
self._polling_error_task = loop.create_task(self._handle_polling_conflict(error))
|
||||
elif self._looks_like_network_error(error):
|
||||
logger.warning("[%s] Telegram network error, scheduling reconnect: %s", self.name, error)
|
||||
self._polling_error_task = loop.create_task(self._handle_polling_network_error(error))
|
||||
else:
|
||||
logger.error("[%s] Telegram polling error: %s", self.name, error, exc_info=True)
|
||||
|
||||
# Store reference for retry use in _handle_polling_conflict
|
||||
self._polling_error_callback_ref = _polling_error_callback
|
||||
@@ -578,23 +656,26 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
image_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a local image file natively as a Telegram photo."""
|
||||
if not self._bot:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
|
||||
try:
|
||||
import os
|
||||
if not os.path.exists(image_path):
|
||||
return SendResult(success=False, error=f"Image file not found: {image_path}")
|
||||
|
||||
|
||||
_thread = metadata.get("thread_id") if metadata else None
|
||||
with open(image_path, "rb") as image_file:
|
||||
msg = await self._bot.send_photo(
|
||||
chat_id=int(chat_id),
|
||||
photo=image_file,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
message_thread_id=int(_thread) if _thread else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
@@ -613,6 +694,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
caption: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a document/file natively as a Telegram file attachment."""
|
||||
@@ -624,6 +706,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
return SendResult(success=False, error=f"File not found: {file_path}")
|
||||
|
||||
display_name = file_name or os.path.basename(file_path)
|
||||
_thread = metadata.get("thread_id") if metadata else None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
msg = await self._bot.send_document(
|
||||
@@ -632,6 +715,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
filename=display_name,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
message_thread_id=int(_thread) if _thread else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
@@ -644,6 +728,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
video_path: str,
|
||||
caption: Optional[str] = None,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> SendResult:
|
||||
"""Send a video natively as a Telegram video message."""
|
||||
@@ -654,12 +739,14 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
if not os.path.exists(video_path):
|
||||
return SendResult(success=False, error=f"Video file not found: {video_path}")
|
||||
|
||||
_thread = metadata.get("thread_id") if metadata else None
|
||||
with open(video_path, "rb") as f:
|
||||
msg = await self._bot.send_video(
|
||||
chat_id=int(chat_id),
|
||||
video=f,
|
||||
caption=caption[:1024] if caption else None,
|
||||
reply_to_message_id=int(reply_to) if reply_to else None,
|
||||
message_thread_id=int(_thread) if _thread else None,
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
except Exception as e:
|
||||
@@ -926,6 +1013,45 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
for key in reversed(list(placeholders.keys())):
|
||||
text = text.replace(key, placeholders[key])
|
||||
|
||||
# 12) Safety net: escape unescaped ( ) { } that slipped through
|
||||
# placeholder processing. Split the text into code/non-code
|
||||
# segments so we never touch content inside ``` or ` spans.
|
||||
_code_split = re.split(r'(```[\s\S]*?```|`[^`]+`)', text)
|
||||
_safe_parts = []
|
||||
for _idx, _seg in enumerate(_code_split):
|
||||
if _idx % 2 == 1:
|
||||
# Inside code span/block — leave untouched
|
||||
_safe_parts.append(_seg)
|
||||
else:
|
||||
# Outside code — escape bare ( ) { }
|
||||
def _esc_bare(m, _seg=_seg):
|
||||
s = m.start()
|
||||
ch = m.group(0)
|
||||
# Already escaped
|
||||
if s > 0 and _seg[s - 1] == '\\':
|
||||
return ch
|
||||
# ( that opens a MarkdownV2 link [text](url)
|
||||
if ch == '(' and s > 0 and _seg[s - 1] == ']':
|
||||
return ch
|
||||
# ) that closes a link URL
|
||||
if ch == ')':
|
||||
before = _seg[:s]
|
||||
if '](http' in before or '](' in before:
|
||||
# Check depth
|
||||
depth = 0
|
||||
for j in range(s - 1, max(s - 2000, -1), -1):
|
||||
if _seg[j] == '(':
|
||||
depth -= 1
|
||||
if depth < 0:
|
||||
if j > 0 and _seg[j - 1] == ']':
|
||||
return ch
|
||||
break
|
||||
elif _seg[j] == ')':
|
||||
depth += 1
|
||||
return '\\' + ch
|
||||
_safe_parts.append(re.sub(r'[(){}]', _esc_bare, _seg))
|
||||
text = ''.join(_safe_parts)
|
||||
|
||||
return text
|
||||
|
||||
async def _handle_text_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
|
||||
@@ -196,7 +196,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
bridge_status = data.get("status", "unknown")
|
||||
if bridge_status == "connected":
|
||||
print(f"[{self.name}] Using existing bridge (status: {bridge_status})")
|
||||
self._running = True
|
||||
self._mark_connected()
|
||||
self._bridge_process = None # Not managed by us
|
||||
asyncio.create_task(self._poll_messages())
|
||||
return True
|
||||
@@ -306,7 +306,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
# Start message polling task
|
||||
asyncio.create_task(self._poll_messages())
|
||||
|
||||
self._running = True
|
||||
self._mark_connected()
|
||||
print(f"[{self.name}] Bridge started on port {self._bridge_port}")
|
||||
return True
|
||||
|
||||
@@ -324,6 +324,23 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
pass
|
||||
self._bridge_log_fh = None
|
||||
|
||||
async def _check_managed_bridge_exit(self) -> Optional[str]:
|
||||
"""Return a fatal error message if the managed bridge child exited."""
|
||||
if self._bridge_process is None:
|
||||
return None
|
||||
|
||||
returncode = self._bridge_process.poll()
|
||||
if returncode is None:
|
||||
return None
|
||||
|
||||
message = f"WhatsApp bridge process exited unexpectedly (code {returncode})."
|
||||
if not self.has_fatal_error:
|
||||
logger.error("[%s] %s", self.name, message)
|
||||
self._set_fatal_error("whatsapp_bridge_exited", message, retryable=True)
|
||||
self._close_bridge_log()
|
||||
await self._notify_fatal_error()
|
||||
return self.fatal_error_message or message
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Stop the WhatsApp bridge and clean up any orphaned processes."""
|
||||
if self._bridge_process:
|
||||
@@ -352,7 +369,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
# Bridge was not started by us, don't kill it
|
||||
print(f"[{self.name}] Disconnecting (external bridge left running)")
|
||||
|
||||
self._running = False
|
||||
self._mark_disconnected()
|
||||
self._bridge_process = None
|
||||
self._close_bridge_log()
|
||||
print(f"[{self.name}] Disconnected")
|
||||
@@ -367,6 +384,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send a message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
return SendResult(success=False, error=bridge_exit)
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -412,6 +432,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Edit a previously sent message via the WhatsApp bridge."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
return SendResult(success=False, error=bridge_exit)
|
||||
try:
|
||||
import aiohttp
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -443,6 +466,9 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send any media file via bridge /send-media endpoint."""
|
||||
if not self._running:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
return SendResult(success=False, error=bridge_exit)
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
@@ -531,6 +557,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Send typing indicator via bridge."""
|
||||
if not self._running:
|
||||
return
|
||||
if await self._check_managed_bridge_exit():
|
||||
return
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -548,6 +576,8 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
"""Get information about a WhatsApp chat."""
|
||||
if not self._running:
|
||||
return {"name": "Unknown", "type": "dm"}
|
||||
if await self._check_managed_bridge_exit():
|
||||
return {"name": chat_id, "type": "dm"}
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
@@ -578,6 +608,10 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
return
|
||||
|
||||
while self._running:
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
print(f"[{self.name}] {bridge_exit}")
|
||||
break
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(
|
||||
@@ -593,6 +627,10 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
bridge_exit = await self._check_managed_bridge_exit()
|
||||
if bridge_exit:
|
||||
print(f"[{self.name}] {bridge_exit}")
|
||||
break
|
||||
print(f"[{self.name}] Poll error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
@@ -674,4 +712,3 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error building event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
+619
-65
@@ -93,6 +93,9 @@ if _config_path.exists():
|
||||
import yaml as _yaml
|
||||
with open(_config_path, encoding="utf-8") as _f:
|
||||
_cfg = _yaml.safe_load(_f) or {}
|
||||
# Expand ${ENV_VAR} references before bridging to env vars.
|
||||
from hermes_cli.config import _expand_env_vars
|
||||
_cfg = _expand_env_vars(_cfg)
|
||||
# Top-level simple values (fallback only — don't override .env)
|
||||
for _key, _val in _cfg.items():
|
||||
if isinstance(_val, (str, int, float, bool)) and _key not in os.environ:
|
||||
@@ -336,6 +339,7 @@ class GatewayRunner:
|
||||
self._running = False
|
||||
self._shutdown_event = asyncio.Event()
|
||||
self._exit_cleanly = False
|
||||
self._exit_with_failure = False
|
||||
self._exit_reason: Optional[str] = None
|
||||
|
||||
# Track running agents per session for interrupt support
|
||||
@@ -343,6 +347,15 @@ class GatewayRunner:
|
||||
self._running_agents: Dict[str, Any] = {}
|
||||
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
|
||||
|
||||
# Cache AIAgent instances per session to preserve prompt caching.
|
||||
# Without this, a new AIAgent is created per message, rebuilding the
|
||||
# system prompt (including memory) every turn — breaking prefix cache
|
||||
# and costing ~10x more on providers with prompt caching (Anthropic).
|
||||
# Key: session_key, Value: (AIAgent, config_signature_str)
|
||||
import threading as _threading
|
||||
self._agent_cache: Dict[str, tuple] = {}
|
||||
self._agent_cache_lock = _threading.Lock()
|
||||
|
||||
# Track active fallback model/provider when primary is rate-limited.
|
||||
# Set after an agent run where fallback was activated; cleared when
|
||||
# the primary model succeeds again or the user switches via /model.
|
||||
@@ -353,6 +366,10 @@ class GatewayRunner:
|
||||
# Key: session_key, Value: {"command": str, "pattern_key": str, ...}
|
||||
self._pending_approvals: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Track platforms that failed to connect for background reconnection.
|
||||
# Key: Platform enum, Value: {"config": platform_config, "attempts": int, "next_retry": float}
|
||||
self._failed_platforms: Dict[Platform, Dict[str, Any]] = {}
|
||||
|
||||
# Persistent Honcho managers keyed by gateway session key.
|
||||
# This preserves write_frequency="session" semantics across short-lived
|
||||
# per-message AIAgent instances.
|
||||
@@ -511,6 +528,12 @@ class GatewayRunner:
|
||||
Synchronous worker — meant to be called via run_in_executor from
|
||||
an async context so it doesn't block the event loop.
|
||||
"""
|
||||
# Skip cron sessions — they run headless with no meaningful user
|
||||
# conversation to extract memories from.
|
||||
if old_session_id and old_session_id.startswith("cron_"):
|
||||
logger.debug("Skipping memory flush for cron session: %s", old_session_id)
|
||||
return
|
||||
|
||||
try:
|
||||
history = self.session_store.load_transcript(old_session_id)
|
||||
if not history or len(history) < 4:
|
||||
@@ -543,6 +566,23 @@ class GatewayRunner:
|
||||
if m.get("role") in ("user", "assistant") and m.get("content")
|
||||
]
|
||||
|
||||
# Read live memory state from disk so the flush agent can see
|
||||
# what's already saved and avoid overwriting newer entries.
|
||||
_current_memory = ""
|
||||
try:
|
||||
from tools.memory_tool import MEMORY_DIR
|
||||
for fname, label in [
|
||||
("MEMORY.md", "MEMORY (your personal notes)"),
|
||||
("USER.md", "USER PROFILE (who the user is)"),
|
||||
]:
|
||||
fpath = MEMORY_DIR / fname
|
||||
if fpath.exists():
|
||||
content = fpath.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
_current_memory += f"\n\n## Current {label}:\n{content}"
|
||||
except Exception:
|
||||
pass # Non-fatal — flush still works, just without the guard
|
||||
|
||||
# Give the agent a real turn to think about what to save
|
||||
flush_prompt = (
|
||||
"[System: This session is about to be automatically reset due to "
|
||||
@@ -554,6 +594,20 @@ class GatewayRunner:
|
||||
"2. If you discovered a reusable workflow or solved a non-trivial "
|
||||
"problem, consider saving it as a skill.\n"
|
||||
"3. If nothing is worth saving, that's fine — just skip.\n\n"
|
||||
)
|
||||
|
||||
if _current_memory:
|
||||
flush_prompt += (
|
||||
"IMPORTANT — here is the current live state of memory. Other "
|
||||
"sessions, cron jobs, or the user may have updated it since this "
|
||||
"conversation ended. Do NOT overwrite or remove entries unless "
|
||||
"the conversation above reveals something that genuinely "
|
||||
"supersedes them. Only add new information that is not already "
|
||||
"captured below."
|
||||
f"{_current_memory}\n\n"
|
||||
)
|
||||
|
||||
flush_prompt += (
|
||||
"Do NOT respond to the user. Just use the memory and skill_manage "
|
||||
"tools if needed, then stop.]"
|
||||
)
|
||||
@@ -591,6 +645,10 @@ class GatewayRunner:
|
||||
def should_exit_cleanly(self) -> bool:
|
||||
return self._exit_cleanly
|
||||
|
||||
@property
|
||||
def should_exit_with_failure(self) -> bool:
|
||||
return self._exit_with_failure
|
||||
|
||||
@property
|
||||
def exit_reason(self) -> Optional[str]:
|
||||
return self._exit_reason
|
||||
@@ -625,7 +683,11 @@ class GatewayRunner:
|
||||
return resolve_turn_route(user_message, getattr(self, "_smart_model_routing", {}), primary)
|
||||
|
||||
async def _handle_adapter_fatal_error(self, adapter: BasePlatformAdapter) -> None:
|
||||
"""React to a non-retryable adapter failure after startup."""
|
||||
"""React to an adapter failure after startup.
|
||||
|
||||
If the error is retryable (e.g. network blip, DNS failure), queue the
|
||||
platform for background reconnection instead of giving up permanently.
|
||||
"""
|
||||
logger.error(
|
||||
"Fatal %s adapter error (%s): %s",
|
||||
adapter.platform.value,
|
||||
@@ -641,10 +703,33 @@ class GatewayRunner:
|
||||
self.adapters.pop(adapter.platform, None)
|
||||
self.delivery_router.adapters = self.adapters
|
||||
|
||||
if not self.adapters:
|
||||
# Queue retryable failures for background reconnection
|
||||
if adapter.fatal_error_retryable:
|
||||
platform_config = self.config.platforms.get(adapter.platform)
|
||||
if platform_config and adapter.platform not in self._failed_platforms:
|
||||
self._failed_platforms[adapter.platform] = {
|
||||
"config": platform_config,
|
||||
"attempts": 0,
|
||||
"next_retry": time.monotonic() + 30,
|
||||
}
|
||||
logger.info(
|
||||
"%s queued for background reconnection",
|
||||
adapter.platform.value,
|
||||
)
|
||||
|
||||
if not self.adapters and not self._failed_platforms:
|
||||
self._exit_reason = adapter.fatal_error_message or "All messaging adapters disconnected"
|
||||
logger.error("No connected messaging platforms remain. Shutting down gateway cleanly.")
|
||||
if adapter.fatal_error_retryable:
|
||||
self._exit_with_failure = True
|
||||
logger.error("No connected messaging platforms remain. Shutting down gateway for service restart.")
|
||||
else:
|
||||
logger.error("No connected messaging platforms remain. Shutting down gateway cleanly.")
|
||||
await self.stop()
|
||||
elif not self.adapters and self._failed_platforms:
|
||||
logger.warning(
|
||||
"No connected messaging platforms remain, but %d platform(s) queued for reconnection",
|
||||
len(self._failed_platforms),
|
||||
)
|
||||
|
||||
def _request_clean_exit(self, reason: str) -> None:
|
||||
self._exit_cleanly = True
|
||||
@@ -859,7 +944,9 @@ class GatewayRunner:
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")
|
||||
@@ -922,13 +1009,32 @@ class GatewayRunner:
|
||||
target.append(
|
||||
f"{platform.value}: {adapter.fatal_error_message}"
|
||||
)
|
||||
# Queue for reconnection if the error is retryable
|
||||
if adapter.fatal_error_retryable:
|
||||
self._failed_platforms[platform] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() + 30,
|
||||
}
|
||||
else:
|
||||
startup_retryable_errors.append(
|
||||
f"{platform.value}: failed to connect"
|
||||
)
|
||||
# No fatal error info means likely a transient issue — queue for retry
|
||||
self._failed_platforms[platform] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() + 30,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("✗ %s error: %s", platform.value, e)
|
||||
startup_retryable_errors.append(f"{platform.value}: {e}")
|
||||
# Unexpected exceptions are typically transient — queue for retry
|
||||
self._failed_platforms[platform] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() + 30,
|
||||
}
|
||||
|
||||
if connected_count == 0:
|
||||
if startup_nonretryable_errors:
|
||||
@@ -1008,6 +1114,15 @@ class GatewayRunner:
|
||||
# Start background session expiry watcher for proactive memory flushing
|
||||
asyncio.create_task(self._session_expiry_watcher())
|
||||
|
||||
# Start background reconnection watcher for platforms that failed at startup
|
||||
if self._failed_platforms:
|
||||
logger.info(
|
||||
"Starting reconnection watcher for %d failed platform(s): %s",
|
||||
len(self._failed_platforms),
|
||||
", ".join(p.value for p in self._failed_platforms),
|
||||
)
|
||||
asyncio.create_task(self._platform_reconnect_watcher())
|
||||
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
return True
|
||||
@@ -1050,6 +1165,107 @@ class GatewayRunner:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _platform_reconnect_watcher(self) -> None:
|
||||
"""Background task that periodically retries connecting failed platforms.
|
||||
|
||||
Uses exponential backoff: 30s → 60s → 120s → 240s → 300s (cap).
|
||||
Stops retrying a platform after 20 failed attempts or if the error
|
||||
is non-retryable (e.g. bad auth token).
|
||||
"""
|
||||
_MAX_ATTEMPTS = 20
|
||||
_BACKOFF_CAP = 300 # 5 minutes max between retries
|
||||
|
||||
await asyncio.sleep(10) # initial delay — let startup finish
|
||||
while self._running:
|
||||
if not self._failed_platforms:
|
||||
# Nothing to reconnect — sleep and check again
|
||||
for _ in range(30):
|
||||
if not self._running:
|
||||
return
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
now = time.monotonic()
|
||||
for platform in list(self._failed_platforms.keys()):
|
||||
if not self._running:
|
||||
return
|
||||
info = self._failed_platforms[platform]
|
||||
if now < info["next_retry"]:
|
||||
continue # not time yet
|
||||
|
||||
if info["attempts"] >= _MAX_ATTEMPTS:
|
||||
logger.warning(
|
||||
"Giving up reconnecting %s after %d attempts",
|
||||
platform.value, info["attempts"],
|
||||
)
|
||||
del self._failed_platforms[platform]
|
||||
continue
|
||||
|
||||
platform_config = info["config"]
|
||||
attempt = info["attempts"] + 1
|
||||
logger.info(
|
||||
"Reconnecting %s (attempt %d/%d)...",
|
||||
platform.value, attempt, _MAX_ATTEMPTS,
|
||||
)
|
||||
|
||||
try:
|
||||
adapter = self._create_adapter(platform, platform_config)
|
||||
if not adapter:
|
||||
logger.warning(
|
||||
"Reconnect %s: adapter creation returned None, removing from retry queue",
|
||||
platform.value,
|
||||
)
|
||||
del self._failed_platforms[platform]
|
||||
continue
|
||||
|
||||
adapter.set_message_handler(self._handle_message)
|
||||
adapter.set_fatal_error_handler(self._handle_adapter_fatal_error)
|
||||
|
||||
success = await adapter.connect()
|
||||
if success:
|
||||
self.adapters[platform] = adapter
|
||||
self._sync_voice_mode_state_to_adapter(adapter)
|
||||
self.delivery_router.adapters = self.adapters
|
||||
del self._failed_platforms[platform]
|
||||
logger.info("✓ %s reconnected successfully", platform.value)
|
||||
|
||||
# Rebuild channel directory with the new adapter
|
||||
try:
|
||||
from gateway.channel_directory import build_channel_directory
|
||||
build_channel_directory(self.adapters)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# Check if the failure is non-retryable
|
||||
if adapter.has_fatal_error and not adapter.fatal_error_retryable:
|
||||
logger.warning(
|
||||
"Reconnect %s: non-retryable error (%s), removing from retry queue",
|
||||
platform.value, adapter.fatal_error_message,
|
||||
)
|
||||
del self._failed_platforms[platform]
|
||||
else:
|
||||
backoff = min(30 * (2 ** (attempt - 1)), _BACKOFF_CAP)
|
||||
info["attempts"] = attempt
|
||||
info["next_retry"] = time.monotonic() + backoff
|
||||
logger.info(
|
||||
"Reconnect %s failed, next retry in %ds",
|
||||
platform.value, backoff,
|
||||
)
|
||||
except Exception as e:
|
||||
backoff = min(30 * (2 ** (attempt - 1)), _BACKOFF_CAP)
|
||||
info["attempts"] = attempt
|
||||
info["next_retry"] = time.monotonic() + backoff
|
||||
logger.warning(
|
||||
"Reconnect %s error: %s, next retry in %ds",
|
||||
platform.value, e, backoff,
|
||||
)
|
||||
|
||||
# Check every 10 seconds for platforms that need reconnection
|
||||
for _ in range(10):
|
||||
if not self._running:
|
||||
return
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the gateway and disconnect all adapters."""
|
||||
logger.info("Stopping gateway...")
|
||||
@@ -1579,6 +1795,21 @@ class GatewayRunner:
|
||||
else:
|
||||
return f"Quick command '/{command}' has unsupported type (supported: 'exec', 'alias')."
|
||||
|
||||
# Plugin-registered slash commands
|
||||
if command:
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_command_handler
|
||||
plugin_handler = get_plugin_command_handler(command)
|
||||
if plugin_handler:
|
||||
user_args = event.get_command_args().strip()
|
||||
import asyncio as _aio
|
||||
result = plugin_handler(user_args)
|
||||
if _aio.iscoroutine(result):
|
||||
result = await result
|
||||
return str(result) if result else None
|
||||
except Exception as e:
|
||||
logger.debug("Plugin command dispatch failed (non-fatal): %s", e)
|
||||
|
||||
# Skill slash commands: /skill-name loads the skill and sends to agent
|
||||
if command:
|
||||
try:
|
||||
@@ -1661,12 +1892,54 @@ class GatewayRunner:
|
||||
# If the previous session expired and was auto-reset, prepend a notice
|
||||
# so the agent knows this is a fresh conversation (not an intentional /reset).
|
||||
if getattr(session_entry, 'was_auto_reset', False):
|
||||
context_prompt = (
|
||||
"[System note: The user's previous session expired due to inactivity. "
|
||||
"This is a fresh conversation with no prior context.]\n\n"
|
||||
+ context_prompt
|
||||
)
|
||||
reset_reason = getattr(session_entry, 'auto_reset_reason', None) or 'idle'
|
||||
if reset_reason == "daily":
|
||||
context_note = "[System note: The user's session was automatically reset by the daily schedule. This is a fresh conversation with no prior context.]"
|
||||
else:
|
||||
context_note = "[System note: The user's previous session expired due to inactivity. This is a fresh conversation with no prior context.]"
|
||||
context_prompt = context_note + "\n\n" + context_prompt
|
||||
|
||||
# Send a user-facing notification explaining the reset, unless:
|
||||
# - notifications are disabled in config
|
||||
# - the platform is excluded (e.g. api_server, webhook)
|
||||
# - the expired session had no activity (nothing was cleared)
|
||||
try:
|
||||
policy = self.session_store.config.get_reset_policy(
|
||||
platform=source.platform,
|
||||
session_type=getattr(source, 'chat_type', 'dm'),
|
||||
)
|
||||
platform_name = source.platform.value if source.platform else ""
|
||||
had_activity = getattr(session_entry, 'reset_had_activity', False)
|
||||
should_notify = (
|
||||
policy.notify
|
||||
and had_activity
|
||||
and platform_name not in policy.notify_exclude_platforms
|
||||
)
|
||||
if should_notify:
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter:
|
||||
if reset_reason == "daily":
|
||||
reason_text = f"daily schedule at {policy.at_hour}:00"
|
||||
else:
|
||||
hours = policy.idle_minutes // 60
|
||||
mins = policy.idle_minutes % 60
|
||||
duration = f"{hours}h" if not mins else f"{hours}h {mins}m" if hours else f"{mins}m"
|
||||
reason_text = f"inactive for {duration}"
|
||||
notice = (
|
||||
f"◐ Session automatically reset ({reason_text}). "
|
||||
f"Conversation history cleared.\n"
|
||||
f"Use /resume to browse and restore a previous session.\n"
|
||||
f"Adjust reset timing in config.yaml under session_reset."
|
||||
)
|
||||
await adapter.send(
|
||||
source.chat_id, notice,
|
||||
metadata=getattr(event, 'metadata', None),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Auto-reset notification failed (non-fatal): %s", e)
|
||||
|
||||
session_entry.was_auto_reset = False
|
||||
session_entry.auto_reset_reason = None
|
||||
|
||||
# Load conversation history from transcript
|
||||
history = self.session_store.load_transcript(session_entry.session_id)
|
||||
@@ -1682,9 +1955,9 @@ class GatewayRunner:
|
||||
# Token source priority:
|
||||
# 1. Actual API-reported prompt_tokens from the last turn
|
||||
# (stored in session_entry.last_prompt_tokens)
|
||||
# 2. Rough char-based estimate (str(msg)//4) with a 1.4x
|
||||
# safety factor to account for overestimation on tool-heavy
|
||||
# conversations (code/JSON tokenizes at 5-7+ chars/token).
|
||||
# 2. Rough char-based estimate (str(msg)//4). Overestimates
|
||||
# by 30-50% on code/JSON-heavy sessions, but that just
|
||||
# means hygiene fires a bit early — safe and harmless.
|
||||
# -----------------------------------------------------------------
|
||||
if history and len(history) >= 4:
|
||||
from agent.model_metadata import (
|
||||
@@ -1703,6 +1976,10 @@ class GatewayRunner:
|
||||
_hyg_model = "anthropic/claude-sonnet-4.6"
|
||||
_hyg_threshold_pct = 0.85
|
||||
_hyg_compression_enabled = True
|
||||
_hyg_config_context_length = None
|
||||
_hyg_provider = None
|
||||
_hyg_base_url = None
|
||||
_hyg_api_key = None
|
||||
try:
|
||||
_hyg_cfg_path = _hermes_home / "config.yaml"
|
||||
if _hyg_cfg_path.exists():
|
||||
@@ -1716,6 +1993,17 @@ class GatewayRunner:
|
||||
_hyg_model = _model_cfg
|
||||
elif isinstance(_model_cfg, dict):
|
||||
_hyg_model = _model_cfg.get("default", _hyg_model)
|
||||
# Read explicit context_length override from model config
|
||||
# (same as run_agent.py lines 995-1005)
|
||||
_raw_ctx = _model_cfg.get("context_length")
|
||||
if _raw_ctx is not None:
|
||||
try:
|
||||
_hyg_config_context_length = int(_raw_ctx)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
# Read provider for accurate context detection
|
||||
_hyg_provider = _model_cfg.get("provider") or None
|
||||
_hyg_base_url = _model_cfg.get("base_url") or None
|
||||
|
||||
# Read compression settings — only use enabled flag.
|
||||
# The threshold is intentionally separate from the agent's
|
||||
@@ -1725,11 +2013,27 @@ class GatewayRunner:
|
||||
_hyg_compression_enabled = str(
|
||||
_comp_cfg.get("enabled", True)
|
||||
).lower() in ("true", "1", "yes")
|
||||
|
||||
# Resolve provider/base_url from runtime if not in config
|
||||
if not _hyg_provider or not _hyg_base_url:
|
||||
try:
|
||||
_hyg_runtime = _resolve_runtime_agent_kwargs()
|
||||
_hyg_provider = _hyg_provider or _hyg_runtime.get("provider")
|
||||
_hyg_base_url = _hyg_base_url or _hyg_runtime.get("base_url")
|
||||
_hyg_api_key = _hyg_runtime.get("api_key")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if _hyg_compression_enabled:
|
||||
_hyg_context_length = get_model_context_length(_hyg_model)
|
||||
_hyg_context_length = get_model_context_length(
|
||||
_hyg_model,
|
||||
base_url=_hyg_base_url or "",
|
||||
api_key=_hyg_api_key or "",
|
||||
config_context_length=_hyg_config_context_length,
|
||||
provider=_hyg_provider or "",
|
||||
)
|
||||
_compress_token_threshold = int(
|
||||
_hyg_context_length * _hyg_threshold_pct
|
||||
)
|
||||
@@ -1739,20 +2043,20 @@ class GatewayRunner:
|
||||
|
||||
# Prefer actual API-reported tokens from the last turn
|
||||
# (stored in session entry) over the rough char-based estimate.
|
||||
# The rough estimate (str(msg)//4) overestimates by 30-50% on
|
||||
# tool-heavy/code-heavy conversations, causing premature compression.
|
||||
_stored_tokens = session_entry.last_prompt_tokens
|
||||
if _stored_tokens > 0:
|
||||
_approx_tokens = _stored_tokens
|
||||
_token_source = "actual"
|
||||
else:
|
||||
_approx_tokens = estimate_messages_tokens_rough(history)
|
||||
# Apply safety factor only for rough estimates
|
||||
_compress_token_threshold = int(
|
||||
_compress_token_threshold * 1.4
|
||||
)
|
||||
_warn_token_threshold = int(_warn_token_threshold * 1.4)
|
||||
_token_source = "estimated"
|
||||
# Note: rough estimates overestimate by 30-50% for code/JSON-heavy
|
||||
# sessions, but that just means hygiene fires a bit early — which
|
||||
# is safe and harmless. The 85% threshold already provides ample
|
||||
# headroom (agent's own compressor runs at 50%). A previous 1.4x
|
||||
# multiplier tried to compensate by inflating the threshold, but
|
||||
# 85% * 1.4 = 119% of context — which exceeds the model's limit
|
||||
# and prevented hygiene from ever firing for ~200K models (GLM-5).
|
||||
|
||||
_needs_compress = _approx_tokens >= _compress_token_threshold
|
||||
|
||||
@@ -2058,7 +2362,31 @@ class GatewayRunner:
|
||||
"message": message_text[:500],
|
||||
}
|
||||
await self.hooks.emit("agent:start", hook_ctx)
|
||||
|
||||
|
||||
# Expand @ context references (@file:, @folder:, @diff, etc.)
|
||||
if "@" in message_text:
|
||||
try:
|
||||
from agent.context_references import preprocess_context_references_async
|
||||
from agent.model_metadata import get_model_context_length
|
||||
_msg_cwd = os.environ.get("MESSAGING_CWD", os.path.expanduser("~"))
|
||||
_msg_ctx_len = get_model_context_length(
|
||||
self._model, base_url=self._base_url or "")
|
||||
_ctx_result = await preprocess_context_references_async(
|
||||
message_text, cwd=_msg_cwd,
|
||||
context_length=_msg_ctx_len, allowed_root=_msg_cwd)
|
||||
if _ctx_result.blocked:
|
||||
_adapter = self.adapters.get(source.platform)
|
||||
if _adapter:
|
||||
await _adapter.send(
|
||||
source.chat_id,
|
||||
"\n".join(_ctx_result.warnings) or "Context injection refused.",
|
||||
)
|
||||
return
|
||||
if _ctx_result.expanded:
|
||||
message_text = _ctx_result.message
|
||||
except Exception as exc:
|
||||
logger.debug("@ context reference expansion failed: %s", exc)
|
||||
|
||||
# Run the agent
|
||||
agent_result = await self._run_agent(
|
||||
message=message_text,
|
||||
@@ -2068,7 +2396,15 @@ class GatewayRunner:
|
||||
session_id=session_entry.session_id,
|
||||
session_key=session_key
|
||||
)
|
||||
|
||||
|
||||
# Stop persistent typing indicator now that the agent is done
|
||||
try:
|
||||
_typing_adapter = self.adapters.get(source.platform)
|
||||
if _typing_adapter and hasattr(_typing_adapter, "stop_typing"):
|
||||
await _typing_adapter.stop_typing(source.chat_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
response = agent_result.get("final_response") or ""
|
||||
agent_messages = agent_result.get("messages", [])
|
||||
|
||||
@@ -2252,14 +2588,31 @@ class GatewayRunner:
|
||||
if self._should_send_voice_reply(event, response, agent_messages, already_sent=_already_sent):
|
||||
await self._send_voice_reply(event, response)
|
||||
|
||||
# If streaming already delivered the response, return None so
|
||||
# _process_message_background doesn't send it again.
|
||||
# If streaming already delivered the response, extract and
|
||||
# deliver any MEDIA: files before returning None. Streaming
|
||||
# sends raw text chunks that include MEDIA: tags — the normal
|
||||
# post-processing in _process_message_background is skipped
|
||||
# when already_sent is True, so media files would never be
|
||||
# delivered without this.
|
||||
if agent_result.get("already_sent"):
|
||||
if response:
|
||||
_media_adapter = self.adapters.get(source.platform)
|
||||
if _media_adapter:
|
||||
await self._deliver_media_from_response(
|
||||
response, event, _media_adapter,
|
||||
)
|
||||
return None
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Stop typing indicator on error too
|
||||
try:
|
||||
_err_adapter = self.adapters.get(source.platform)
|
||||
if _err_adapter and hasattr(_err_adapter, "stop_typing"):
|
||||
await _err_adapter.stop_typing(source.chat_id)
|
||||
except Exception:
|
||||
pass
|
||||
logger.exception("Agent error in session %s", session_key)
|
||||
error_type = type(e).__name__
|
||||
error_detail = str(e)[:300] if str(e) else "no details available"
|
||||
@@ -2330,6 +2683,7 @@ class GatewayRunner:
|
||||
logger.debug("Gateway memory flush on reset failed: %s", e)
|
||||
|
||||
self._shutdown_gateway_honcho(session_key)
|
||||
self._evict_cached_agent(session_key)
|
||||
|
||||
# Reset the session
|
||||
new_entry = self.session_store.reset_session(session_key)
|
||||
@@ -3162,6 +3516,82 @@ class GatewayRunner:
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def _deliver_media_from_response(
|
||||
self,
|
||||
response: str,
|
||||
event: MessageEvent,
|
||||
adapter,
|
||||
) -> None:
|
||||
"""Extract MEDIA: tags and local file paths from a response and deliver them.
|
||||
|
||||
Called after streaming has already sent the text to the user, so the
|
||||
text itself is already delivered — this only handles file attachments
|
||||
that the normal _process_message_background path would have caught.
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
media_files, _ = adapter.extract_media(response)
|
||||
_, cleaned = adapter.extract_images(response)
|
||||
local_files, _ = adapter.extract_local_files(cleaned)
|
||||
|
||||
_thread_meta = {"thread_id": event.source.thread_id} if event.source.thread_id else None
|
||||
|
||||
_AUDIO_EXTS = {'.ogg', '.opus', '.mp3', '.wav', '.m4a'}
|
||||
_VIDEO_EXTS = {'.mp4', '.mov', '.avi', '.mkv', '.webm', '.3gp'}
|
||||
_IMAGE_EXTS = {'.jpg', '.jpeg', '.png', '.webp', '.gif'}
|
||||
|
||||
for media_path, is_voice in media_files:
|
||||
try:
|
||||
ext = Path(media_path).suffix.lower()
|
||||
if ext in _AUDIO_EXTS:
|
||||
await adapter.send_voice(
|
||||
chat_id=event.source.chat_id,
|
||||
audio_path=media_path,
|
||||
metadata=_thread_meta,
|
||||
)
|
||||
elif ext in _VIDEO_EXTS:
|
||||
await adapter.send_video(
|
||||
chat_id=event.source.chat_id,
|
||||
video_path=media_path,
|
||||
metadata=_thread_meta,
|
||||
)
|
||||
elif ext in _IMAGE_EXTS:
|
||||
await adapter.send_image_file(
|
||||
chat_id=event.source.chat_id,
|
||||
image_path=media_path,
|
||||
metadata=_thread_meta,
|
||||
)
|
||||
else:
|
||||
await adapter.send_document(
|
||||
chat_id=event.source.chat_id,
|
||||
file_path=media_path,
|
||||
metadata=_thread_meta,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Post-stream media delivery failed: %s", adapter.name, e)
|
||||
|
||||
for file_path in local_files:
|
||||
try:
|
||||
ext = Path(file_path).suffix.lower()
|
||||
if ext in _IMAGE_EXTS:
|
||||
await adapter.send_image_file(
|
||||
chat_id=event.source.chat_id,
|
||||
image_path=file_path,
|
||||
metadata=_thread_meta,
|
||||
)
|
||||
else:
|
||||
await adapter.send_document(
|
||||
chat_id=event.source.chat_id,
|
||||
file_path=file_path,
|
||||
metadata=_thread_meta,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Post-stream file delivery failed: %s", adapter.name, e)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Post-stream media extraction failed: %s", e)
|
||||
|
||||
async def _handle_rollback_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /rollback command — list or restore filesystem checkpoints."""
|
||||
from tools.checkpoint_manager import CheckpointManager, format_checkpoint_list
|
||||
@@ -3569,6 +3999,20 @@ class GatewayRunner:
|
||||
if not self._session_db:
|
||||
return "Session database not available."
|
||||
|
||||
# Ensure session exists in SQLite DB (it may only exist in session_store
|
||||
# if this is the first command in a new session)
|
||||
existing_title = self._session_db.get_session_title(session_id)
|
||||
if existing_title is None:
|
||||
# Session doesn't exist in DB yet — create it
|
||||
try:
|
||||
self._session_db.create_session(
|
||||
session_id=session_id,
|
||||
source=source.platform.value if source.platform else "unknown",
|
||||
user_id=source.user_id,
|
||||
)
|
||||
except Exception:
|
||||
pass # Session might already exist, ignore errors
|
||||
|
||||
title_arg = event.get_command_args().strip()
|
||||
if title_arg:
|
||||
# Sanitize the title before setting
|
||||
@@ -4355,6 +4799,45 @@ class GatewayRunner:
|
||||
|
||||
_MAX_INTERRUPT_DEPTH = 3 # Cap recursive interrupt handling (#816)
|
||||
|
||||
@staticmethod
|
||||
def _agent_config_signature(
|
||||
model: str,
|
||||
runtime: dict,
|
||||
enabled_toolsets: list,
|
||||
ephemeral_prompt: str,
|
||||
) -> str:
|
||||
"""Compute a stable string key from agent config values.
|
||||
|
||||
When this signature changes between messages, the cached AIAgent is
|
||||
discarded and rebuilt. When it stays the same, the cached agent is
|
||||
reused — preserving the frozen system prompt and tool schemas for
|
||||
prompt cache hits.
|
||||
"""
|
||||
import hashlib, json as _j
|
||||
blob = _j.dumps(
|
||||
[
|
||||
model,
|
||||
runtime.get("api_key", "")[:8], # first 8 chars only
|
||||
runtime.get("base_url", ""),
|
||||
runtime.get("provider", ""),
|
||||
runtime.get("api_mode", ""),
|
||||
sorted(enabled_toolsets) if enabled_toolsets else [],
|
||||
# reasoning_config excluded — it's set per-message on the
|
||||
# cached agent and doesn't affect system prompt or tools.
|
||||
ephemeral_prompt or "",
|
||||
],
|
||||
sort_keys=True,
|
||||
default=str,
|
||||
)
|
||||
return hashlib.sha256(blob.encode()).hexdigest()[:16]
|
||||
|
||||
def _evict_cached_agent(self, session_key: str) -> None:
|
||||
"""Remove a cached agent for a session (called on /new, /model, etc)."""
|
||||
_lock = getattr(self, "_agent_cache_lock", None)
|
||||
if _lock:
|
||||
with _lock:
|
||||
self._agent_cache.pop(session_key, None)
|
||||
|
||||
async def _run_agent(
|
||||
self,
|
||||
message: str,
|
||||
@@ -4704,34 +5187,64 @@ class GatewayRunner:
|
||||
logger.debug("Could not set up stream consumer: %s", _sc_err)
|
||||
|
||||
turn_route = self._resolve_turn_agent_config(message, model, runtime_kwargs)
|
||||
agent = AIAgent(
|
||||
model=turn_route["model"],
|
||||
**turn_route["runtime"],
|
||||
max_iterations=max_iterations,
|
||||
quiet_mode=True,
|
||||
verbose_logging=False,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
ephemeral_system_prompt=combined_ephemeral or None,
|
||||
prefill_messages=self._prefill_messages or None,
|
||||
reasoning_config=reasoning_config,
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
provider_require_parameters=pr.get("require_parameters", False),
|
||||
provider_data_collection=pr.get("data_collection"),
|
||||
session_id=session_id,
|
||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||
step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None,
|
||||
stream_delta_callback=_stream_delta_cb,
|
||||
status_callback=_status_callback_sync,
|
||||
platform=platform_key,
|
||||
honcho_session_key=session_key,
|
||||
honcho_manager=honcho_manager,
|
||||
honcho_config=honcho_config,
|
||||
session_db=self._session_db,
|
||||
fallback_model=self._fallback_model,
|
||||
|
||||
# Check agent cache — reuse the AIAgent from the previous message
|
||||
# in this session to preserve the frozen system prompt and tool
|
||||
# schemas for prompt cache hits.
|
||||
_sig = self._agent_config_signature(
|
||||
turn_route["model"],
|
||||
turn_route["runtime"],
|
||||
enabled_toolsets,
|
||||
combined_ephemeral,
|
||||
)
|
||||
agent = None
|
||||
_cache_lock = getattr(self, "_agent_cache_lock", None)
|
||||
_cache = getattr(self, "_agent_cache", None)
|
||||
if _cache_lock and _cache is not None:
|
||||
with _cache_lock:
|
||||
cached = _cache.get(session_key)
|
||||
if cached and cached[1] == _sig:
|
||||
agent = cached[0]
|
||||
logger.debug("Reusing cached agent for session %s", session_key)
|
||||
|
||||
if agent is None:
|
||||
# Config changed or first message — create fresh agent
|
||||
agent = AIAgent(
|
||||
model=turn_route["model"],
|
||||
**turn_route["runtime"],
|
||||
max_iterations=max_iterations,
|
||||
quiet_mode=True,
|
||||
verbose_logging=False,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
ephemeral_system_prompt=combined_ephemeral or None,
|
||||
prefill_messages=self._prefill_messages or None,
|
||||
reasoning_config=reasoning_config,
|
||||
providers_allowed=pr.get("only"),
|
||||
providers_ignored=pr.get("ignore"),
|
||||
providers_order=pr.get("order"),
|
||||
provider_sort=pr.get("sort"),
|
||||
provider_require_parameters=pr.get("require_parameters", False),
|
||||
provider_data_collection=pr.get("data_collection"),
|
||||
session_id=session_id,
|
||||
platform=platform_key,
|
||||
honcho_session_key=session_key,
|
||||
honcho_manager=honcho_manager,
|
||||
honcho_config=honcho_config,
|
||||
session_db=self._session_db,
|
||||
fallback_model=self._fallback_model,
|
||||
)
|
||||
if _cache_lock and _cache is not None:
|
||||
with _cache_lock:
|
||||
_cache[session_key] = (agent, _sig)
|
||||
logger.debug("Created new agent for session %s (sig=%s)", session_key, _sig)
|
||||
|
||||
# Per-message state — callbacks and reasoning config change every
|
||||
# turn and must not be baked into the cached agent constructor.
|
||||
agent.tool_progress_callback = progress_callback if tool_progress_enabled else None
|
||||
agent.step_callback = _step_callback_sync if _hooks_ref.loaded_hooks else None
|
||||
agent.stream_delta_callback = _stream_delta_cb
|
||||
agent.status_callback = _status_callback_sync
|
||||
agent.reasoning_config = reasoning_config
|
||||
|
||||
# Store agent reference for interrupt support
|
||||
agent_holder[0] = agent
|
||||
@@ -4976,27 +5489,39 @@ class GatewayRunner:
|
||||
if _agent.model != _cfg_model:
|
||||
self._effective_model = _agent.model
|
||||
self._effective_provider = getattr(_agent, 'provider', None)
|
||||
# Fallback activated — evict cached agent so the next
|
||||
# message starts fresh and retries the primary model.
|
||||
self._evict_cached_agent(session_key)
|
||||
else:
|
||||
# Primary model worked — clear any stale fallback state
|
||||
self._effective_model = None
|
||||
self._effective_provider = None
|
||||
|
||||
# Check if we were interrupted and have a pending message
|
||||
# Check if we were interrupted OR have a queued message (/queue).
|
||||
result = result_holder[0]
|
||||
adapter = self.adapters.get(source.platform)
|
||||
|
||||
# Get pending message from adapter if interrupted.
|
||||
# Get pending message from adapter.
|
||||
# Use session_key (not source.chat_id) to match adapter's storage keys.
|
||||
pending = None
|
||||
if result and result.get("interrupted") and adapter:
|
||||
pending_event = adapter.get_pending_message(session_key) if session_key else None
|
||||
if pending_event:
|
||||
pending = pending_event.text
|
||||
elif result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
if result and adapter and session_key:
|
||||
if result.get("interrupted"):
|
||||
# Interrupted — consume the interrupt message
|
||||
pending_event = adapter.get_pending_message(session_key)
|
||||
if pending_event:
|
||||
pending = pending_event.text
|
||||
elif result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
else:
|
||||
# Normal completion — check for /queue'd messages that were
|
||||
# stored without triggering an interrupt.
|
||||
pending_event = adapter.get_pending_message(session_key)
|
||||
if pending_event:
|
||||
pending = pending_event.text
|
||||
logger.debug("Processing queued message after agent completion: '%s...'", pending[:40])
|
||||
|
||||
if pending:
|
||||
logger.debug("Processing interrupted message: '%s...'", pending[:40])
|
||||
logger.debug("Processing pending message: '%s...'", pending[:40])
|
||||
|
||||
# Clear the adapter's interrupt event so the next _run_agent call
|
||||
# doesn't immediately re-trigger the interrupt before the new agent
|
||||
@@ -5018,11 +5543,25 @@ class GatewayRunner:
|
||||
adapter.queue_message(session_key, pending)
|
||||
return result_holder[0] or {"final_response": response, "messages": history}
|
||||
|
||||
# Don't send the interrupted response to the user — it's just noise
|
||||
# like "Operation interrupted." They already know they sent a new
|
||||
# message, so go straight to processing it.
|
||||
|
||||
# Now process the pending message with updated history
|
||||
was_interrupted = result.get("interrupted")
|
||||
if not was_interrupted:
|
||||
# Queued message after normal completion — deliver the first
|
||||
# response before processing the queued follow-up.
|
||||
# Skip if streaming already delivered it.
|
||||
_sc = stream_consumer_holder[0]
|
||||
_already_streamed = _sc and getattr(_sc, "already_sent", False)
|
||||
first_response = result.get("final_response", "")
|
||||
if first_response and not _already_streamed:
|
||||
try:
|
||||
await adapter.send(source.chat_id, first_response,
|
||||
metadata=getattr(event, "metadata", None))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to send first response before queued message: %s", e)
|
||||
# else: interrupted — discard the interrupted response ("Operation
|
||||
# interrupted." is just noise; the user already knows they sent a
|
||||
# new message).
|
||||
|
||||
# Process the pending message with updated history
|
||||
updated_history = result.get("messages", history)
|
||||
return await self._run_agent(
|
||||
message=pending,
|
||||
@@ -5180,6 +5719,16 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
except (ProcessLookupError, PermissionError):
|
||||
pass
|
||||
remove_pid_file()
|
||||
# Also release all scoped locks left by the old process.
|
||||
# Stopped (Ctrl+Z) processes don't release locks on exit,
|
||||
# leaving stale lock files that block the new gateway from starting.
|
||||
try:
|
||||
from gateway.status import release_all_scoped_locks
|
||||
_released = release_all_scoped_locks()
|
||||
if _released:
|
||||
logger.info("Released %d stale scoped lock(s) from old gateway.", _released)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
hermes_home = os.getenv("HERMES_HOME", "~/.hermes")
|
||||
logger.error(
|
||||
@@ -5266,6 +5815,11 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
||||
|
||||
# Wait for shutdown
|
||||
await runner.wait_for_shutdown()
|
||||
|
||||
if runner.should_exit_with_failure:
|
||||
if runner.exit_reason:
|
||||
logger.error("Gateway exiting with failure: %s", runner.exit_reason)
|
||||
return False
|
||||
|
||||
# Stop cron ticker cleanly
|
||||
cron_stop.set()
|
||||
|
||||
+20
-7
@@ -355,6 +355,8 @@ class SessionEntry:
|
||||
# Set when a session was created because the previous one expired;
|
||||
# consumed once by the message handler to inject a notice into context
|
||||
was_auto_reset: bool = False
|
||||
auto_reset_reason: Optional[str] = None # "idle" or "daily"
|
||||
reset_had_activity: bool = False # whether the expired session had any messages
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result = {
|
||||
@@ -573,16 +575,19 @@ class SessionStore:
|
||||
|
||||
return False
|
||||
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
|
||||
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> Optional[str]:
|
||||
"""
|
||||
Check if a session should be reset based on policy.
|
||||
|
||||
Returns the reset reason ("idle" or "daily") if a reset is needed,
|
||||
or None if the session is still valid.
|
||||
|
||||
Sessions with active background processes are never reset.
|
||||
"""
|
||||
if self._has_active_processes_fn:
|
||||
session_key = self._generate_session_key(source)
|
||||
if self._has_active_processes_fn(session_key):
|
||||
return False
|
||||
return None
|
||||
|
||||
policy = self.config.get_reset_policy(
|
||||
platform=source.platform,
|
||||
@@ -590,14 +595,14 @@ class SessionStore:
|
||||
)
|
||||
|
||||
if policy.mode == "none":
|
||||
return False
|
||||
return None
|
||||
|
||||
now = datetime.now()
|
||||
|
||||
if policy.mode in ("idle", "both"):
|
||||
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
|
||||
if now > idle_deadline:
|
||||
return True
|
||||
return "idle"
|
||||
|
||||
if policy.mode in ("daily", "both"):
|
||||
today_reset = now.replace(
|
||||
@@ -610,9 +615,9 @@ class SessionStore:
|
||||
today_reset -= timedelta(days=1)
|
||||
|
||||
if entry.updated_at < today_reset:
|
||||
return True
|
||||
return "daily"
|
||||
|
||||
return False
|
||||
return None
|
||||
|
||||
def has_any_sessions(self) -> bool:
|
||||
"""Check if any sessions have ever been created (across all platforms).
|
||||
@@ -654,7 +659,8 @@ class SessionStore:
|
||||
if session_key in self._entries and not force_new:
|
||||
entry = self._entries[session_key]
|
||||
|
||||
if not self._should_reset(entry, source):
|
||||
reset_reason = self._should_reset(entry, source)
|
||||
if not reset_reason:
|
||||
entry.updated_at = now
|
||||
self._save()
|
||||
return entry
|
||||
@@ -663,6 +669,9 @@ class SessionStore:
|
||||
# should have already flushed memories proactively; discard
|
||||
# the marker so it doesn't accumulate.
|
||||
was_auto_reset = True
|
||||
auto_reset_reason = reset_reason
|
||||
# Track whether the expired session had any real conversation
|
||||
reset_had_activity = entry.total_tokens > 0
|
||||
self._pre_flushed_sessions.discard(entry.session_id)
|
||||
if self._db:
|
||||
try:
|
||||
@@ -671,6 +680,8 @@ class SessionStore:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
else:
|
||||
was_auto_reset = False
|
||||
auto_reset_reason = None
|
||||
reset_had_activity = False
|
||||
|
||||
# Create new session
|
||||
session_id = f"{now.strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
|
||||
@@ -685,6 +696,8 @@ class SessionStore:
|
||||
platform=source.platform,
|
||||
chat_type=source.chat_type,
|
||||
was_auto_reset=was_auto_reset,
|
||||
auto_reset_reason=auto_reset_reason,
|
||||
reset_had_activity=reset_had_activity,
|
||||
)
|
||||
|
||||
self._entries[session_key] = entry
|
||||
|
||||
@@ -274,6 +274,21 @@ def acquire_scoped_lock(scope: str, identity: str, metadata: Optional[dict[str,
|
||||
and current_start != existing.get("start_time")
|
||||
):
|
||||
stale = True
|
||||
# Check if process is stopped (Ctrl+Z / SIGTSTP) — stopped
|
||||
# processes still respond to os.kill(pid, 0) but are not
|
||||
# actually running. Treat them as stale so --replace works.
|
||||
if not stale:
|
||||
try:
|
||||
_proc_status = Path(f"/proc/{existing_pid}/status")
|
||||
if _proc_status.exists():
|
||||
for _line in _proc_status.read_text().splitlines():
|
||||
if _line.startswith("State:"):
|
||||
_state = _line.split()[1]
|
||||
if _state in ("T", "t"): # stopped or tracing stop
|
||||
stale = True
|
||||
break
|
||||
except (OSError, PermissionError):
|
||||
pass
|
||||
if stale:
|
||||
try:
|
||||
lock_path.unlink(missing_ok=True)
|
||||
@@ -314,6 +329,25 @@ def release_scoped_lock(scope: str, identity: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def release_all_scoped_locks() -> int:
|
||||
"""Remove all scoped lock files in the lock directory.
|
||||
|
||||
Called during --replace to clean up stale locks left by stopped/killed
|
||||
gateway processes that did not release their locks gracefully.
|
||||
Returns the number of lock files removed.
|
||||
"""
|
||||
lock_dir = _get_lock_dir()
|
||||
removed = 0
|
||||
if lock_dir.exists():
|
||||
for lock_file in lock_dir.glob("*.lock"):
|
||||
try:
|
||||
lock_file.unlink(missing_ok=True)
|
||||
removed += 1
|
||||
except OSError:
|
||||
pass
|
||||
return removed
|
||||
|
||||
|
||||
def get_running_pid() -> Optional[int]:
|
||||
"""Return the PID of a running gateway instance, or ``None``.
|
||||
|
||||
|
||||
@@ -12,4 +12,4 @@ Provides subcommands for:
|
||||
"""
|
||||
|
||||
__version__ = "0.4.0"
|
||||
__release_date__ = "2026.3.18"
|
||||
__release_date__ = "2026.3.23"
|
||||
|
||||
+32
-5
@@ -199,9 +199,9 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
"opencode-go": ProviderConfig(
|
||||
id="opencode-go",
|
||||
name="OpenCode Go",
|
||||
auth_type="***",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://opencode.ai/zen/go/v1",
|
||||
api_key_env_vars=("OPEN...",),
|
||||
api_key_env_vars=("OPENCODE_GO_API_KEY",),
|
||||
base_url_env_var="OPENCODE_GO_BASE_URL",
|
||||
),
|
||||
"kilocode": ProviderConfig(
|
||||
@@ -278,6 +278,33 @@ def _try_gh_cli_token() -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
_PLACEHOLDER_SECRET_VALUES = {
|
||||
"*",
|
||||
"**",
|
||||
"***",
|
||||
"changeme",
|
||||
"your_api_key",
|
||||
"your-api-key",
|
||||
"placeholder",
|
||||
"example",
|
||||
"dummy",
|
||||
"null",
|
||||
"none",
|
||||
}
|
||||
|
||||
|
||||
def has_usable_secret(value: Any, *, min_length: int = 4) -> bool:
|
||||
"""Return True when a configured secret looks usable, not empty/placeholder."""
|
||||
if not isinstance(value, str):
|
||||
return False
|
||||
cleaned = value.strip()
|
||||
if len(cleaned) < min_length:
|
||||
return False
|
||||
if cleaned.lower() in _PLACEHOLDER_SECRET_VALUES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _resolve_api_key_provider_secret(
|
||||
provider_id: str, pconfig: ProviderConfig
|
||||
) -> tuple[str, str]:
|
||||
@@ -297,7 +324,7 @@ def _resolve_api_key_provider_secret(
|
||||
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
val = os.getenv(env_var, "").strip()
|
||||
if val:
|
||||
if has_usable_secret(val):
|
||||
return val, env_var
|
||||
|
||||
return "", ""
|
||||
@@ -688,7 +715,7 @@ def resolve_provider(
|
||||
except Exception as e:
|
||||
logger.debug("Could not detect active auth provider: %s", e)
|
||||
|
||||
if os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY"):
|
||||
if has_usable_secret(os.getenv("OPENAI_API_KEY")) or has_usable_secret(os.getenv("OPENROUTER_API_KEY")):
|
||||
return "openrouter"
|
||||
|
||||
# Auto-detect API-key providers by checking their env vars
|
||||
@@ -701,7 +728,7 @@ def resolve_provider(
|
||||
if pid == "copilot":
|
||||
continue
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if os.getenv(env_var, "").strip():
|
||||
if has_usable_secret(os.getenv(env_var, "")):
|
||||
return pid
|
||||
|
||||
return "openrouter"
|
||||
|
||||
+180
-1
@@ -137,7 +137,7 @@ COMMAND_REGISTRY: list[CommandDef] = [
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Derived lookups -- rebuilt once at import time
|
||||
# Derived lookups -- rebuilt once at import time, refreshed by rebuild_lookups()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _build_command_lookup() -> dict[str, CommandDef]:
|
||||
@@ -161,6 +161,58 @@ def resolve_command(name: str) -> CommandDef | None:
|
||||
return _COMMAND_LOOKUP.get(name.lower().lstrip("/"))
|
||||
|
||||
|
||||
def register_plugin_command(cmd: CommandDef) -> None:
|
||||
"""Append a plugin-defined command to the registry and refresh lookups."""
|
||||
COMMAND_REGISTRY.append(cmd)
|
||||
rebuild_lookups()
|
||||
|
||||
|
||||
def rebuild_lookups() -> None:
|
||||
"""Rebuild all derived lookup dicts from the current COMMAND_REGISTRY.
|
||||
|
||||
Called after plugin commands are registered so they appear in help,
|
||||
autocomplete, gateway dispatch, Telegram menu, and Slack mapping.
|
||||
"""
|
||||
global GATEWAY_KNOWN_COMMANDS
|
||||
|
||||
_COMMAND_LOOKUP.clear()
|
||||
_COMMAND_LOOKUP.update(_build_command_lookup())
|
||||
|
||||
COMMANDS.clear()
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if not cmd.gateway_only:
|
||||
COMMANDS[f"/{cmd.name}"] = _build_description(cmd)
|
||||
for alias in cmd.aliases:
|
||||
COMMANDS[f"/{alias}"] = f"{cmd.description} (alias for /{cmd.name})"
|
||||
|
||||
COMMANDS_BY_CATEGORY.clear()
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if not cmd.gateway_only:
|
||||
cat = COMMANDS_BY_CATEGORY.setdefault(cmd.category, {})
|
||||
cat[f"/{cmd.name}"] = COMMANDS[f"/{cmd.name}"]
|
||||
for alias in cmd.aliases:
|
||||
cat[f"/{alias}"] = COMMANDS[f"/{alias}"]
|
||||
|
||||
SUBCOMMANDS.clear()
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
if cmd.subcommands:
|
||||
SUBCOMMANDS[f"/{cmd.name}"] = list(cmd.subcommands)
|
||||
for cmd in COMMAND_REGISTRY:
|
||||
key = f"/{cmd.name}"
|
||||
if key in SUBCOMMANDS or not cmd.args_hint:
|
||||
continue
|
||||
m = _PIPE_SUBS_RE.search(cmd.args_hint)
|
||||
if m:
|
||||
SUBCOMMANDS[key] = m.group(0).split("|")
|
||||
|
||||
GATEWAY_KNOWN_COMMANDS = frozenset(
|
||||
name
|
||||
for cmd in COMMAND_REGISTRY
|
||||
if not cmd.cli_only
|
||||
for name in (cmd.name, *cmd.aliases)
|
||||
)
|
||||
|
||||
|
||||
def _build_description(cmd: CommandDef) -> str:
|
||||
"""Build a CLI-facing description string including usage hint."""
|
||||
if cmd.args_hint:
|
||||
@@ -397,9 +449,136 @@ class SlashCommandCompleter(Completer):
|
||||
)
|
||||
count += 1
|
||||
|
||||
@staticmethod
|
||||
def _extract_context_word(text: str) -> str | None:
|
||||
"""Extract a bare ``@`` token for context reference completions."""
|
||||
if not text:
|
||||
return None
|
||||
# Walk backwards to find the start of the current word
|
||||
i = len(text) - 1
|
||||
while i >= 0 and text[i] != " ":
|
||||
i -= 1
|
||||
word = text[i + 1:]
|
||||
if not word.startswith("@"):
|
||||
return None
|
||||
return word
|
||||
|
||||
@staticmethod
|
||||
def _context_completions(word: str, limit: int = 30):
|
||||
"""Yield Claude Code-style @ context completions.
|
||||
|
||||
Bare ``@`` or ``@partial`` shows static references and matching
|
||||
files/folders. ``@file:path`` and ``@folder:path`` are handled
|
||||
by the existing path completion path.
|
||||
"""
|
||||
lowered = word.lower()
|
||||
|
||||
# Static context references
|
||||
_STATIC_REFS = (
|
||||
("@diff", "Git working tree diff"),
|
||||
("@staged", "Git staged diff"),
|
||||
("@file:", "Attach a file"),
|
||||
("@folder:", "Attach a folder"),
|
||||
("@git:", "Git log with diffs (e.g. @git:5)"),
|
||||
("@url:", "Fetch web content"),
|
||||
)
|
||||
for candidate, meta in _STATIC_REFS:
|
||||
if candidate.lower().startswith(lowered) and candidate.lower() != lowered:
|
||||
yield Completion(
|
||||
candidate,
|
||||
start_position=-len(word),
|
||||
display=candidate,
|
||||
display_meta=meta,
|
||||
)
|
||||
|
||||
# If the user typed @file: or @folder:, delegate to path completions
|
||||
for prefix in ("@file:", "@folder:"):
|
||||
if word.startswith(prefix):
|
||||
path_part = word[len(prefix):] or "."
|
||||
expanded = os.path.expanduser(path_part)
|
||||
if expanded.endswith("/"):
|
||||
search_dir, match_prefix = expanded, ""
|
||||
else:
|
||||
search_dir = os.path.dirname(expanded) or "."
|
||||
match_prefix = os.path.basename(expanded)
|
||||
|
||||
try:
|
||||
entries = os.listdir(search_dir)
|
||||
except OSError:
|
||||
return
|
||||
|
||||
count = 0
|
||||
prefix_lower = match_prefix.lower()
|
||||
for entry in sorted(entries):
|
||||
if match_prefix and not entry.lower().startswith(prefix_lower):
|
||||
continue
|
||||
if count >= limit:
|
||||
break
|
||||
full_path = os.path.join(search_dir, entry)
|
||||
is_dir = os.path.isdir(full_path)
|
||||
display_path = os.path.relpath(full_path)
|
||||
suffix = "/" if is_dir else ""
|
||||
kind = "folder" if is_dir else "file"
|
||||
meta = "dir" if is_dir else _file_size_label(full_path)
|
||||
completion = f"@{kind}:{display_path}{suffix}"
|
||||
yield Completion(
|
||||
completion,
|
||||
start_position=-len(word),
|
||||
display=entry + suffix,
|
||||
display_meta=meta,
|
||||
)
|
||||
count += 1
|
||||
return
|
||||
|
||||
# Bare @ or @partial — show matching files/folders from cwd
|
||||
query = word[1:] # strip the @
|
||||
if not query:
|
||||
search_dir, match_prefix = ".", ""
|
||||
else:
|
||||
expanded = os.path.expanduser(query)
|
||||
if expanded.endswith("/"):
|
||||
search_dir, match_prefix = expanded, ""
|
||||
else:
|
||||
search_dir = os.path.dirname(expanded) or "."
|
||||
match_prefix = os.path.basename(expanded)
|
||||
|
||||
try:
|
||||
entries = os.listdir(search_dir)
|
||||
except OSError:
|
||||
return
|
||||
|
||||
count = 0
|
||||
prefix_lower = match_prefix.lower()
|
||||
for entry in sorted(entries):
|
||||
if match_prefix and not entry.lower().startswith(prefix_lower):
|
||||
continue
|
||||
if entry.startswith("."):
|
||||
continue # skip hidden files in bare @ mode
|
||||
if count >= limit:
|
||||
break
|
||||
full_path = os.path.join(search_dir, entry)
|
||||
is_dir = os.path.isdir(full_path)
|
||||
display_path = os.path.relpath(full_path)
|
||||
suffix = "/" if is_dir else ""
|
||||
kind = "folder" if is_dir else "file"
|
||||
meta = "dir" if is_dir else _file_size_label(full_path)
|
||||
completion = f"@{kind}:{display_path}{suffix}"
|
||||
yield Completion(
|
||||
completion,
|
||||
start_position=-len(word),
|
||||
display=entry + suffix,
|
||||
display_meta=meta,
|
||||
)
|
||||
count += 1
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
text = document.text_before_cursor
|
||||
if not text.startswith("/"):
|
||||
# Try @ context completion (Claude Code-style)
|
||||
ctx_word = self._extract_context_word(text)
|
||||
if ctx_word is not None:
|
||||
yield from self._context_completions(ctx_word)
|
||||
return
|
||||
# Try file path completion for non-slash input
|
||||
path_word = self._extract_path_word(text)
|
||||
if path_word is not None:
|
||||
|
||||
+28
-6
@@ -159,7 +159,7 @@ DEFAULT_CONFIG = {
|
||||
"compression": {
|
||||
"enabled": True,
|
||||
"threshold": 0.50,
|
||||
"summary_model": "google/gemini-3-flash-preview",
|
||||
"summary_model": "", # empty = use main configured model
|
||||
"summary_provider": "auto",
|
||||
"summary_base_url": None,
|
||||
},
|
||||
@@ -182,6 +182,7 @@ DEFAULT_CONFIG = {
|
||||
"model": "", # e.g. "google/gemini-2.5-flash", "gpt-4o"
|
||||
"base_url": "", # direct OpenAI-compatible endpoint (takes precedence over provider)
|
||||
"api_key": "", # API key for base_url (falls back to OPENAI_API_KEY)
|
||||
"timeout": 30, # seconds — increase for slow local vision models
|
||||
},
|
||||
"web_extract": {
|
||||
"provider": "auto",
|
||||
@@ -1171,6 +1172,26 @@ def _deep_merge(base: dict, override: dict) -> dict:
|
||||
return result
|
||||
|
||||
|
||||
def _expand_env_vars(obj):
|
||||
"""Recursively expand ``${VAR}`` references in config values.
|
||||
|
||||
Only string values are processed; dict keys, numbers, booleans, and
|
||||
None are left untouched. Unresolved references (variable not in
|
||||
``os.environ``) are kept verbatim so callers can detect them.
|
||||
"""
|
||||
if isinstance(obj, str):
|
||||
return re.sub(
|
||||
r"\${([^}]+)}",
|
||||
lambda m: os.environ.get(m.group(1), m.group(0)),
|
||||
obj,
|
||||
)
|
||||
if isinstance(obj, dict):
|
||||
return {k: _expand_env_vars(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_expand_env_vars(item) for item in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def _normalize_max_turns_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Normalize legacy root-level max_turns into agent.max_turns."""
|
||||
config = dict(config)
|
||||
@@ -1212,7 +1233,7 @@ def load_config() -> Dict[str, Any]:
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to load config: {e}")
|
||||
|
||||
return _normalize_max_turns_config(config)
|
||||
return _expand_env_vars(_normalize_max_turns_config(config))
|
||||
|
||||
|
||||
_SECURITY_COMMENT = """
|
||||
@@ -1625,11 +1646,11 @@ def show_config():
|
||||
print(f" Timeout: {terminal.get('timeout', 60)}s")
|
||||
|
||||
if terminal.get('backend') == 'docker':
|
||||
print(f" Docker image: {terminal.get('docker_image', 'python:3.11-slim')}")
|
||||
print(f" Docker image: {terminal.get('docker_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}")
|
||||
elif terminal.get('backend') == 'singularity':
|
||||
print(f" Image: {terminal.get('singularity_image', 'docker://python:3.11')}")
|
||||
print(f" Image: {terminal.get('singularity_image', 'docker://nikolaik/python-nodejs:python3.11-nodejs20')}")
|
||||
elif terminal.get('backend') == 'modal':
|
||||
print(f" Modal image: {terminal.get('modal_image', 'python:3.11')}")
|
||||
print(f" Modal image: {terminal.get('modal_image', 'nikolaik/python-nodejs:python3.11-nodejs20')}")
|
||||
modal_token = get_env_value('MODAL_TOKEN_ID')
|
||||
print(f" Modal token: {'configured' if modal_token else '(not set)'}")
|
||||
elif terminal.get('backend') == 'daytona':
|
||||
@@ -1659,7 +1680,8 @@ def show_config():
|
||||
print(f" Enabled: {'yes' if enabled else 'no'}")
|
||||
if enabled:
|
||||
print(f" Threshold: {compression.get('threshold', 0.50) * 100:.0f}%")
|
||||
print(f" Model: {compression.get('summary_model', 'google/gemini-3-flash-preview')}")
|
||||
_sm = compression.get('summary_model', '') or '(main model)'
|
||||
print(f" Model: {_sm}")
|
||||
comp_provider = compression.get('summary_provider', 'auto')
|
||||
if comp_provider != 'auto':
|
||||
print(f" Provider: {comp_provider}")
|
||||
|
||||
@@ -717,13 +717,14 @@ def run_doctor(args):
|
||||
print(color("◆ Honcho Memory", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
try:
|
||||
from honcho_integration.client import HonchoClientConfig, GLOBAL_CONFIG_PATH
|
||||
from honcho_integration.client import HonchoClientConfig, resolve_config_path
|
||||
hcfg = HonchoClientConfig.from_global_config()
|
||||
_honcho_cfg_path = resolve_config_path()
|
||||
|
||||
if not GLOBAL_CONFIG_PATH.exists():
|
||||
if not _honcho_cfg_path.exists():
|
||||
check_warn("Honcho config not found", f"run: hermes honcho setup")
|
||||
elif not hcfg.enabled:
|
||||
check_info("Honcho disabled (set enabled: true in ~/.honcho/config.json to activate)")
|
||||
check_info(f"Honcho disabled (set enabled: true in {_honcho_cfg_path} to activate)")
|
||||
elif not hcfg.api_key:
|
||||
check_fail("Honcho API key not set", "run: hermes honcho setup")
|
||||
issues.append("No Honcho API key — run 'hermes honcho setup'")
|
||||
|
||||
+125
-5
@@ -2559,14 +2559,55 @@ def _restore_stashed_changes(
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if restore.returncode != 0:
|
||||
print("✗ Update pulled new code, but restoring local changes failed.")
|
||||
|
||||
# Check for unmerged (conflicted) files — can happen even when returncode is 0
|
||||
unmerged = subprocess.run(
|
||||
git_cmd + ["diff", "--name-only", "--diff-filter=U"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
has_conflicts = bool(unmerged.stdout.strip())
|
||||
|
||||
if restore.returncode != 0 or has_conflicts:
|
||||
print("✗ Update pulled new code, but restoring local changes hit conflicts.")
|
||||
if restore.stdout.strip():
|
||||
print(restore.stdout.strip())
|
||||
if restore.stderr.strip():
|
||||
print(restore.stderr.strip())
|
||||
print("Your changes are still preserved in git stash.")
|
||||
print(f"Resolve manually with: git stash apply {stash_ref}")
|
||||
|
||||
# Show which files conflicted
|
||||
conflicted_files = unmerged.stdout.strip()
|
||||
if conflicted_files:
|
||||
print("\nConflicted files:")
|
||||
for f in conflicted_files.splitlines():
|
||||
print(f" • {f}")
|
||||
|
||||
print("\nYour stashed changes are preserved — nothing is lost.")
|
||||
print(f" Stash ref: {stash_ref}")
|
||||
|
||||
# Ask before resetting (if interactive)
|
||||
do_reset = True
|
||||
if prompt_user:
|
||||
print("\nReset working tree to clean state so Hermes can run?")
|
||||
print(" (You can re-apply your changes later with: git stash apply)")
|
||||
print("[Y/n] ", end="", flush=True)
|
||||
response = input().strip().lower()
|
||||
if response not in ("", "y", "yes"):
|
||||
do_reset = False
|
||||
|
||||
if do_reset:
|
||||
subprocess.run(
|
||||
git_cmd + ["reset", "--hard", "HEAD"],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
)
|
||||
print("Working tree reset to clean state.")
|
||||
else:
|
||||
print("Working tree left as-is (may have conflict markers).")
|
||||
print("Resolve conflicts manually, then run: git stash drop")
|
||||
|
||||
print(f"Restore your changes with: git stash apply {stash_ref}")
|
||||
sys.exit(1)
|
||||
|
||||
stash_selector = _resolve_stash_selector(git_cmd, cwd, stash_ref)
|
||||
@@ -2941,7 +2982,7 @@ def _coalesce_session_name_args(argv: list) -> list:
|
||||
_SUBCOMMANDS = {
|
||||
"chat", "model", "gateway", "setup", "whatsapp", "login", "logout",
|
||||
"status", "cron", "doctor", "config", "pairing", "skills", "tools",
|
||||
"sessions", "insights", "version", "update", "uninstall",
|
||||
"mcp", "sessions", "insights", "version", "update", "uninstall",
|
||||
}
|
||||
_SESSION_FLAGS = {"-c", "--continue", "-r", "--resume"}
|
||||
|
||||
@@ -3529,6 +3570,46 @@ For more help on a command:
|
||||
|
||||
skills_parser.set_defaults(func=cmd_skills)
|
||||
|
||||
# =========================================================================
|
||||
# plugins command
|
||||
# =========================================================================
|
||||
plugins_parser = subparsers.add_parser(
|
||||
"plugins",
|
||||
help="Manage plugins — install, update, remove, list",
|
||||
description="Install plugins from Git repositories, update, remove, or list them.",
|
||||
)
|
||||
plugins_subparsers = plugins_parser.add_subparsers(dest="plugins_action")
|
||||
|
||||
plugins_install = plugins_subparsers.add_parser(
|
||||
"install", help="Install a plugin from a Git URL or owner/repo"
|
||||
)
|
||||
plugins_install.add_argument(
|
||||
"identifier",
|
||||
help="Git URL or owner/repo shorthand (e.g. anpicasso/hermes-plugin-chrome-profiles)",
|
||||
)
|
||||
plugins_install.add_argument(
|
||||
"--force", "-f", action="store_true",
|
||||
help="Remove existing plugin and reinstall",
|
||||
)
|
||||
|
||||
plugins_update = plugins_subparsers.add_parser(
|
||||
"update", help="Pull latest changes for an installed plugin"
|
||||
)
|
||||
plugins_update.add_argument("name", help="Plugin name to update")
|
||||
|
||||
plugins_remove = plugins_subparsers.add_parser(
|
||||
"remove", aliases=["rm", "uninstall"], help="Remove an installed plugin"
|
||||
)
|
||||
plugins_remove.add_argument("name", help="Plugin directory name to remove")
|
||||
|
||||
plugins_subparsers.add_parser("list", aliases=["ls"], help="List installed plugins")
|
||||
|
||||
def cmd_plugins(args):
|
||||
from hermes_cli.plugins_cmd import plugins_command
|
||||
plugins_command(args)
|
||||
|
||||
plugins_parser.set_defaults(func=cmd_plugins)
|
||||
|
||||
# =========================================================================
|
||||
# honcho command
|
||||
# =========================================================================
|
||||
@@ -3685,6 +3766,45 @@ For more help on a command:
|
||||
tools_command(args)
|
||||
|
||||
tools_parser.set_defaults(func=cmd_tools)
|
||||
# =========================================================================
|
||||
# mcp command — manage MCP server connections
|
||||
# =========================================================================
|
||||
mcp_parser = subparsers.add_parser(
|
||||
"mcp",
|
||||
help="Manage MCP server connections",
|
||||
description=(
|
||||
"Add, remove, list, test, and configure MCP server connections.\n\n"
|
||||
"MCP servers provide additional tools via the Model Context Protocol.\n"
|
||||
"Use 'hermes mcp add' to connect to a new server with interactive\n"
|
||||
"tool discovery. Run 'hermes mcp' with no subcommand to list servers."
|
||||
),
|
||||
)
|
||||
mcp_sub = mcp_parser.add_subparsers(dest="mcp_action")
|
||||
|
||||
mcp_add_p = mcp_sub.add_parser("add", help="Add an MCP server (discovery-first install)")
|
||||
mcp_add_p.add_argument("name", help="Server name (used as config key)")
|
||||
mcp_add_p.add_argument("--url", help="HTTP/SSE endpoint URL")
|
||||
mcp_add_p.add_argument("--command", help="Stdio command (e.g. npx)")
|
||||
mcp_add_p.add_argument("--args", nargs="*", default=[], help="Arguments for stdio command")
|
||||
mcp_add_p.add_argument("--auth", choices=["oauth", "header"], help="Auth method")
|
||||
|
||||
mcp_rm_p = mcp_sub.add_parser("remove", aliases=["rm"], help="Remove an MCP server")
|
||||
mcp_rm_p.add_argument("name", help="Server name to remove")
|
||||
|
||||
mcp_sub.add_parser("list", aliases=["ls"], help="List configured MCP servers")
|
||||
|
||||
mcp_test_p = mcp_sub.add_parser("test", help="Test MCP server connection")
|
||||
mcp_test_p.add_argument("name", help="Server name to test")
|
||||
|
||||
mcp_cfg_p = mcp_sub.add_parser("configure", aliases=["config"], help="Toggle tool selection")
|
||||
mcp_cfg_p.add_argument("name", help="Server name to configure")
|
||||
|
||||
def cmd_mcp(args):
|
||||
from hermes_cli.mcp_config import mcp_command
|
||||
mcp_command(args)
|
||||
|
||||
mcp_parser.set_defaults(func=cmd_mcp)
|
||||
|
||||
# =========================================================================
|
||||
# sessions command
|
||||
# =========================================================================
|
||||
|
||||
@@ -0,0 +1,635 @@
|
||||
"""
|
||||
MCP Server Management CLI — ``hermes mcp`` subcommand.
|
||||
|
||||
Implements ``hermes mcp add/remove/list/test/configure`` for interactive
|
||||
MCP server lifecycle management (issue #690 Phase 2).
|
||||
|
||||
Relies on tools/mcp_tool.py for connection/discovery and keeps
|
||||
configuration in ~/.hermes/config.yaml under the ``mcp_servers`` key.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import getpass
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from hermes_cli.config import (
|
||||
load_config,
|
||||
save_config,
|
||||
get_env_value,
|
||||
save_env_value,
|
||||
get_hermes_home,
|
||||
)
|
||||
from hermes_cli.colors import Colors, color
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ─── UI Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
def _info(text: str):
|
||||
print(color(f" {text}", Colors.DIM))
|
||||
|
||||
def _success(text: str):
|
||||
print(color(f" ✓ {text}", Colors.GREEN))
|
||||
|
||||
def _warning(text: str):
|
||||
print(color(f" ⚠ {text}", Colors.YELLOW))
|
||||
|
||||
def _error(text: str):
|
||||
print(color(f" ✗ {text}", Colors.RED))
|
||||
|
||||
|
||||
def _confirm(question: str, default: bool = True) -> bool:
|
||||
default_str = "Y/n" if default else "y/N"
|
||||
try:
|
||||
val = input(color(f" {question} [{default_str}]: ", Colors.YELLOW)).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
if not val:
|
||||
return default
|
||||
return val in ("y", "yes")
|
||||
|
||||
|
||||
def _prompt(question: str, *, password: bool = False, default: str = "") -> str:
|
||||
display = f" {question}"
|
||||
if default:
|
||||
display += f" [{default}]"
|
||||
display += ": "
|
||||
try:
|
||||
if password:
|
||||
value = getpass.getpass(color(display, Colors.YELLOW))
|
||||
else:
|
||||
value = input(color(display, Colors.YELLOW))
|
||||
return value.strip() or default
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return default
|
||||
|
||||
|
||||
# ─── Config Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _get_mcp_servers(config: Optional[dict] = None) -> Dict[str, dict]:
|
||||
"""Return the ``mcp_servers`` dict from config, or empty dict."""
|
||||
if config is None:
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers")
|
||||
if not servers or not isinstance(servers, dict):
|
||||
return {}
|
||||
return servers
|
||||
|
||||
|
||||
def _save_mcp_server(name: str, server_config: dict):
|
||||
"""Add or update a server entry in config.yaml."""
|
||||
config = load_config()
|
||||
config.setdefault("mcp_servers", {})[name] = server_config
|
||||
save_config(config)
|
||||
|
||||
|
||||
def _remove_mcp_server(name: str) -> bool:
|
||||
"""Remove a server from config.yaml. Returns True if it existed."""
|
||||
config = load_config()
|
||||
servers = config.get("mcp_servers", {})
|
||||
if name not in servers:
|
||||
return False
|
||||
del servers[name]
|
||||
if not servers:
|
||||
config.pop("mcp_servers", None)
|
||||
save_config(config)
|
||||
return True
|
||||
|
||||
|
||||
def _env_key_for_server(name: str) -> str:
|
||||
"""Convert server name to an env-var key like ``MCP_MYSERVER_API_KEY``."""
|
||||
return f"MCP_{name.upper().replace('-', '_')}_API_KEY"
|
||||
|
||||
|
||||
# ─── Discovery (temporary connect) ───────────────────────────────────────────
|
||||
|
||||
def _probe_single_server(
|
||||
name: str, config: dict, connect_timeout: float = 30
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Temporarily connect to one MCP server, list its tools, disconnect.
|
||||
|
||||
Returns list of ``(tool_name, description)`` tuples.
|
||||
Raises on connection failure.
|
||||
"""
|
||||
from tools.mcp_tool import (
|
||||
_ensure_mcp_loop,
|
||||
_run_on_mcp_loop,
|
||||
_connect_server,
|
||||
_stop_mcp_loop,
|
||||
)
|
||||
|
||||
_ensure_mcp_loop()
|
||||
|
||||
tools_found: List[Tuple[str, str]] = []
|
||||
|
||||
async def _probe():
|
||||
server = await asyncio.wait_for(
|
||||
_connect_server(name, config), timeout=connect_timeout
|
||||
)
|
||||
for t in server._tools:
|
||||
desc = getattr(t, "description", "") or ""
|
||||
# Truncate long descriptions for display
|
||||
if len(desc) > 80:
|
||||
desc = desc[:77] + "..."
|
||||
tools_found.append((t.name, desc))
|
||||
await server.shutdown()
|
||||
|
||||
try:
|
||||
_run_on_mcp_loop(_probe(), timeout=connect_timeout + 10)
|
||||
except BaseException as exc:
|
||||
raise _unwrap_exception_group(exc) from None
|
||||
finally:
|
||||
_stop_mcp_loop()
|
||||
|
||||
return tools_found
|
||||
|
||||
|
||||
def _unwrap_exception_group(exc: BaseException) -> Exception:
|
||||
"""Extract the root-cause exception from anyio TaskGroup wrappers.
|
||||
|
||||
The MCP SDK uses anyio task groups, which wrap errors in
|
||||
``BaseExceptionGroup`` / ``ExceptionGroup``. This makes error
|
||||
messages opaque ("unhandled errors in a TaskGroup"). We unwrap
|
||||
to surface the real cause (e.g. "401 Unauthorized").
|
||||
"""
|
||||
while isinstance(exc, BaseExceptionGroup) and exc.exceptions:
|
||||
exc = exc.exceptions[0]
|
||||
# Return a plain Exception so callers can catch normally
|
||||
if isinstance(exc, Exception):
|
||||
return exc
|
||||
return RuntimeError(str(exc))
|
||||
|
||||
|
||||
# ─── hermes mcp add ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_add(args):
|
||||
"""Add a new MCP server with discovery-first tool selection."""
|
||||
name = args.name
|
||||
url = getattr(args, "url", None)
|
||||
command = getattr(args, "command", None)
|
||||
cmd_args = getattr(args, "args", None) or []
|
||||
auth_type = getattr(args, "auth", None)
|
||||
|
||||
# Validate transport
|
||||
if not url and not command:
|
||||
_error("Must specify --url <endpoint> or --command <cmd>")
|
||||
_info("Examples:")
|
||||
_info(' hermes mcp add ink --url "https://mcp.ml.ink/mcp"')
|
||||
_info(' hermes mcp add github --command npx --args @modelcontextprotocol/server-github')
|
||||
return
|
||||
|
||||
# Check if server already exists
|
||||
existing = _get_mcp_servers()
|
||||
if name in existing:
|
||||
if not _confirm(f"Server '{name}' already exists. Overwrite?", default=False):
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
# Build initial config
|
||||
server_config: Dict[str, Any] = {}
|
||||
if url:
|
||||
server_config["url"] = url
|
||||
else:
|
||||
server_config["command"] = command
|
||||
if cmd_args:
|
||||
server_config["args"] = cmd_args
|
||||
|
||||
# ── Authentication ────────────────────────────────────────────────
|
||||
|
||||
if url and auth_type == "oauth":
|
||||
print()
|
||||
_info(f"Starting OAuth flow for '{name}'...")
|
||||
oauth_ok = False
|
||||
try:
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
oauth_auth = build_oauth_auth(name, url)
|
||||
if oauth_auth:
|
||||
server_config["auth"] = "oauth"
|
||||
_success("OAuth configured (tokens will be acquired on first connection)")
|
||||
oauth_ok=True
|
||||
else:
|
||||
_warning("OAuth setup failed — MCP SDK auth module not available")
|
||||
except Exception as exc:
|
||||
_warning(f"OAuth error: {exc}")
|
||||
|
||||
if not oauth_ok:
|
||||
_info("This server may not support OAuth.")
|
||||
if _confirm("Continue without authentication?", default=True):
|
||||
# Don't store auth: oauth — server doesn't support it
|
||||
pass
|
||||
else:
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
elif url:
|
||||
# Prompt for API key / Bearer token for HTTP servers
|
||||
print()
|
||||
_info(f"Connecting to {url}")
|
||||
needs_auth = _confirm("Does this server require authentication?", default=True)
|
||||
if needs_auth:
|
||||
if auth_type == "header" or not auth_type:
|
||||
env_key = _env_key_for_server(name)
|
||||
existing_key = get_env_value(env_key)
|
||||
if existing_key:
|
||||
_success(f"{env_key}: already configured")
|
||||
api_key = existing_key
|
||||
else:
|
||||
api_key = _prompt("API key / Bearer token", password=True)
|
||||
if api_key:
|
||||
save_env_value(env_key, api_key)
|
||||
_success(f"Saved to ~/.hermes/.env as {env_key}")
|
||||
|
||||
# Set header with env var interpolation
|
||||
if api_key or existing_key:
|
||||
server_config["headers"] = {
|
||||
"Authorization": f"Bearer ${{{env_key}}}"
|
||||
}
|
||||
|
||||
# ── Discovery: connect and list tools ─────────────────────────────
|
||||
|
||||
print()
|
||||
print(color(f" Connecting to '{name}'...", Colors.CYAN))
|
||||
|
||||
try:
|
||||
tools = _probe_single_server(name, server_config)
|
||||
except Exception as exc:
|
||||
_error(f"Failed to connect: {exc}")
|
||||
if _confirm("Save config anyway (you can test later)?", default=False):
|
||||
server_config["enabled"] = False
|
||||
_save_mcp_server(name, server_config)
|
||||
_success(f"Saved '{name}' to config (disabled)")
|
||||
_info("Fix the issue, then: hermes mcp test " + name)
|
||||
return
|
||||
|
||||
if not tools:
|
||||
_warning("Server connected but reported no tools.")
|
||||
if _confirm("Save config anyway?", default=True):
|
||||
_save_mcp_server(name, server_config)
|
||||
_success(f"Saved '{name}' to config")
|
||||
return
|
||||
|
||||
# ── Tool selection ────────────────────────────────────────────────
|
||||
|
||||
print()
|
||||
_success(f"Connected! Found {len(tools)} tool(s) from '{name}':")
|
||||
print()
|
||||
for tool_name, desc in tools:
|
||||
short = desc[:60] + "..." if len(desc) > 60 else desc
|
||||
print(f" {color(tool_name, Colors.GREEN):40s} {short}")
|
||||
print()
|
||||
|
||||
# Ask: enable all, select, or cancel
|
||||
try:
|
||||
choice = input(
|
||||
color(f" Enable all {len(tools)} tools? [Y/n/select]: ", Colors.YELLOW)
|
||||
).strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
if choice in ("n", "no"):
|
||||
_info("Cancelled — server not saved.")
|
||||
return
|
||||
|
||||
if choice in ("s", "select"):
|
||||
# Interactive tool selection
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
labels = [f"{t[0]} — {t[1]}" for t in tools]
|
||||
pre_selected = set(range(len(tools)))
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"Select tools for '{name}'",
|
||||
labels,
|
||||
pre_selected,
|
||||
)
|
||||
|
||||
if not chosen:
|
||||
_info("No tools selected — server not saved.")
|
||||
return
|
||||
|
||||
chosen_names = [tools[i][0] for i in sorted(chosen)]
|
||||
server_config.setdefault("tools", {})["include"] = chosen_names
|
||||
|
||||
tool_count = len(chosen_names)
|
||||
total = len(tools)
|
||||
else:
|
||||
# Enable all (no filter needed — default behaviour)
|
||||
tool_count = len(tools)
|
||||
total = len(tools)
|
||||
|
||||
# ── Save ──────────────────────────────────────────────────────────
|
||||
|
||||
server_config["enabled"] = True
|
||||
_save_mcp_server(name, server_config)
|
||||
|
||||
print()
|
||||
_success(f"Saved '{name}' to ~/.hermes/config.yaml ({tool_count}/{total} tools enabled)")
|
||||
_info("Start a new session to use these tools.")
|
||||
|
||||
|
||||
# ─── hermes mcp remove ───────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_remove(args):
|
||||
"""Remove an MCP server from config."""
|
||||
name = args.name
|
||||
existing = _get_mcp_servers()
|
||||
|
||||
if name not in existing:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
servers = list(existing.keys())
|
||||
if servers:
|
||||
_info(f"Available servers: {', '.join(servers)}")
|
||||
return
|
||||
|
||||
if not _confirm(f"Remove server '{name}'?", default=True):
|
||||
_info("Cancelled.")
|
||||
return
|
||||
|
||||
_remove_mcp_server(name)
|
||||
_success(f"Removed '{name}' from config")
|
||||
|
||||
# Clean up OAuth tokens if they exist
|
||||
try:
|
||||
from tools.mcp_oauth import remove_oauth_tokens
|
||||
remove_oauth_tokens(name)
|
||||
_success("Cleaned up OAuth tokens")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# ─── hermes mcp list ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_list(args=None):
|
||||
"""List all configured MCP servers."""
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if not servers:
|
||||
print()
|
||||
_info("No MCP servers configured.")
|
||||
print()
|
||||
_info("Add one with:")
|
||||
_info(' hermes mcp add <name> --url <endpoint>')
|
||||
_info(' hermes mcp add <name> --command <cmd> --args <args...>')
|
||||
print()
|
||||
return
|
||||
|
||||
print()
|
||||
print(color(" MCP Servers:", Colors.CYAN + Colors.BOLD))
|
||||
print()
|
||||
|
||||
# Table header
|
||||
print(f" {'Name':<16} {'Transport':<30} {'Tools':<12} {'Status':<10}")
|
||||
print(f" {'─' * 16} {'─' * 30} {'─' * 12} {'─' * 10}")
|
||||
|
||||
for name, cfg in servers.items():
|
||||
# Transport info
|
||||
if "url" in cfg:
|
||||
url = cfg["url"]
|
||||
# Truncate long URLs
|
||||
if len(url) > 28:
|
||||
url = url[:25] + "..."
|
||||
transport = url
|
||||
elif "command" in cfg:
|
||||
cmd = cfg["command"]
|
||||
cmd_args = cfg.get("args", [])
|
||||
if isinstance(cmd_args, list) and cmd_args:
|
||||
transport = f"{cmd} {' '.join(str(a) for a in cmd_args[:2])}"
|
||||
else:
|
||||
transport = cmd
|
||||
if len(transport) > 28:
|
||||
transport = transport[:25] + "..."
|
||||
else:
|
||||
transport = "?"
|
||||
|
||||
# Tool count
|
||||
tools_cfg = cfg.get("tools", {})
|
||||
if isinstance(tools_cfg, dict):
|
||||
include = tools_cfg.get("include")
|
||||
exclude = tools_cfg.get("exclude")
|
||||
if include and isinstance(include, list):
|
||||
tools_str = f"{len(include)} selected"
|
||||
elif exclude and isinstance(exclude, list):
|
||||
tools_str = f"-{len(exclude)} excluded"
|
||||
else:
|
||||
tools_str = "all"
|
||||
else:
|
||||
tools_str = "all"
|
||||
|
||||
# Enabled status
|
||||
enabled = cfg.get("enabled", True)
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.lower() in ("true", "1", "yes")
|
||||
status = color("✓ enabled", Colors.GREEN) if enabled else color("✗ disabled", Colors.DIM)
|
||||
|
||||
print(f" {name:<16} {transport:<30} {tools_str:<12} {status}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
# ─── hermes mcp test ──────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_test(args):
|
||||
"""Test connection to an MCP server."""
|
||||
name = args.name
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
available = list(servers.keys())
|
||||
if available:
|
||||
_info(f"Available: {', '.join(available)}")
|
||||
return
|
||||
|
||||
cfg = servers[name]
|
||||
print()
|
||||
print(color(f" Testing '{name}'...", Colors.CYAN))
|
||||
|
||||
# Show transport info
|
||||
if "url" in cfg:
|
||||
_info(f"Transport: HTTP → {cfg['url']}")
|
||||
else:
|
||||
cmd = cfg.get("command", "?")
|
||||
_info(f"Transport: stdio → {cmd}")
|
||||
|
||||
# Show auth info (masked)
|
||||
auth_type = cfg.get("auth", "")
|
||||
headers = cfg.get("headers", {})
|
||||
if auth_type == "oauth":
|
||||
_info("Auth: OAuth 2.1 PKCE")
|
||||
elif headers:
|
||||
for k, v in headers.items():
|
||||
if isinstance(v, str) and ("key" in k.lower() or "auth" in k.lower()):
|
||||
# Mask the value
|
||||
resolved = _interpolate_value(v)
|
||||
if len(resolved) > 8:
|
||||
masked = resolved[:4] + "***" + resolved[-4:]
|
||||
else:
|
||||
masked = "***"
|
||||
print(f" {k}: {masked}")
|
||||
else:
|
||||
_info("Auth: none")
|
||||
|
||||
# Attempt connection
|
||||
start = time.monotonic()
|
||||
try:
|
||||
tools = _probe_single_server(name, cfg)
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
except Exception as exc:
|
||||
elapsed_ms = (time.monotonic() - start) * 1000
|
||||
_error(f"Connection failed ({elapsed_ms:.0f}ms): {exc}")
|
||||
return
|
||||
|
||||
_success(f"Connected ({elapsed_ms:.0f}ms)")
|
||||
_success(f"Tools discovered: {len(tools)}")
|
||||
|
||||
if tools:
|
||||
print()
|
||||
for tool_name, desc in tools:
|
||||
short = desc[:55] + "..." if len(desc) > 55 else desc
|
||||
print(f" {color(tool_name, Colors.GREEN):36s} {short}")
|
||||
print()
|
||||
|
||||
|
||||
def _interpolate_value(value: str) -> str:
|
||||
"""Resolve ``${ENV_VAR}`` references in a string."""
|
||||
def _replace(m):
|
||||
return os.getenv(m.group(1), "")
|
||||
return re.sub(r"\$\{(\w+)\}", _replace, value)
|
||||
|
||||
|
||||
# ─── hermes mcp configure ────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_configure(args):
|
||||
"""Reconfigure which tools are enabled for an existing MCP server."""
|
||||
name = args.name
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
available = list(servers.keys())
|
||||
if available:
|
||||
_info(f"Available: {', '.join(available)}")
|
||||
return
|
||||
|
||||
cfg = servers[name]
|
||||
|
||||
# Discover all available tools
|
||||
print()
|
||||
print(color(f" Connecting to '{name}' to discover tools...", Colors.CYAN))
|
||||
|
||||
try:
|
||||
all_tools = _probe_single_server(name, cfg)
|
||||
except Exception as exc:
|
||||
_error(f"Failed to connect: {exc}")
|
||||
return
|
||||
|
||||
if not all_tools:
|
||||
_warning("Server reports no tools.")
|
||||
return
|
||||
|
||||
# Determine which are currently enabled
|
||||
tools_cfg = cfg.get("tools", {})
|
||||
if isinstance(tools_cfg, dict):
|
||||
include = tools_cfg.get("include")
|
||||
exclude = tools_cfg.get("exclude")
|
||||
else:
|
||||
include = None
|
||||
exclude = None
|
||||
|
||||
tool_names = [t[0] for t in all_tools]
|
||||
|
||||
if include and isinstance(include, list):
|
||||
include_set = set(include)
|
||||
pre_selected = {
|
||||
i for i, tn in enumerate(tool_names) if tn in include_set
|
||||
}
|
||||
elif exclude and isinstance(exclude, list):
|
||||
exclude_set = set(exclude)
|
||||
pre_selected = {
|
||||
i for i, tn in enumerate(tool_names) if tn not in exclude_set
|
||||
}
|
||||
else:
|
||||
pre_selected = set(range(len(all_tools)))
|
||||
|
||||
currently = len(pre_selected)
|
||||
total = len(all_tools)
|
||||
_info(f"Currently {currently}/{total} tools enabled for '{name}'.")
|
||||
print()
|
||||
|
||||
# Interactive checklist
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
labels = [f"{t[0]} — {t[1]}" for t in all_tools]
|
||||
|
||||
chosen = curses_checklist(
|
||||
f"Select tools for '{name}'",
|
||||
labels,
|
||||
pre_selected,
|
||||
)
|
||||
|
||||
if chosen == pre_selected:
|
||||
_info("No changes made.")
|
||||
return
|
||||
|
||||
# Update config
|
||||
config = load_config()
|
||||
server_entry = config.get("mcp_servers", {}).get(name, {})
|
||||
|
||||
if len(chosen) == total:
|
||||
# All selected → remove include/exclude (register all)
|
||||
server_entry.pop("tools", None)
|
||||
else:
|
||||
chosen_names = [tool_names[i] for i in sorted(chosen)]
|
||||
server_entry.setdefault("tools", {})
|
||||
server_entry["tools"]["include"] = chosen_names
|
||||
server_entry["tools"].pop("exclude", None)
|
||||
|
||||
config.setdefault("mcp_servers", {})[name] = server_entry
|
||||
save_config(config)
|
||||
|
||||
new_count = len(chosen)
|
||||
_success(f"Updated config: {new_count}/{total} tools enabled")
|
||||
_info("Start a new session for changes to take effect.")
|
||||
|
||||
|
||||
# ─── Dispatcher ───────────────────────────────────────────────────────────────
|
||||
|
||||
def mcp_command(args):
|
||||
"""Main dispatcher for ``hermes mcp`` subcommands."""
|
||||
action = getattr(args, "mcp_action", None)
|
||||
|
||||
handlers = {
|
||||
"add": cmd_mcp_add,
|
||||
"remove": cmd_mcp_remove,
|
||||
"rm": cmd_mcp_remove,
|
||||
"list": cmd_mcp_list,
|
||||
"ls": cmd_mcp_list,
|
||||
"test": cmd_mcp_test,
|
||||
"configure": cmd_mcp_configure,
|
||||
"config": cmd_mcp_configure,
|
||||
}
|
||||
|
||||
handler = handlers.get(action)
|
||||
if handler:
|
||||
handler(args)
|
||||
else:
|
||||
# No subcommand — show list
|
||||
cmd_mcp_list()
|
||||
print(color(" Commands:", Colors.CYAN))
|
||||
_info("hermes mcp add <name> --url <endpoint> Add an MCP server")
|
||||
_info("hermes mcp add <name> --command <cmd> Add a stdio server")
|
||||
_info("hermes mcp remove <name> Remove a server")
|
||||
_info("hermes mcp list List servers")
|
||||
_info("hermes mcp test <name> Test connection")
|
||||
_info("hermes mcp configure <name> Toggle tools")
|
||||
print()
|
||||
+11
-6
@@ -31,19 +31,20 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
("anthropic/claude-haiku-4.5", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
("openai/gpt-5.4-mini", ""),
|
||||
("openrouter/hunter-alpha", "free"),
|
||||
("openrouter/healer-alpha", "free"),
|
||||
("xiaomi/mimo-v2-pro", ""),
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3-pro-preview", ""),
|
||||
("google/gemini-3-flash-preview", ""),
|
||||
("qwen/qwen3.5-plus-02-15", ""),
|
||||
("qwen/qwen3.5-35b-a3b", ""),
|
||||
("stepfun/step-3.5-flash", ""),
|
||||
("minimax/minimax-m2.7", ""),
|
||||
("minimax/minimax-m2.5", ""),
|
||||
("z-ai/glm-5", ""),
|
||||
("z-ai/glm-5-turbo", ""),
|
||||
("moonshotai/kimi-k2.5", ""),
|
||||
("x-ai/grok-4.20-beta", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b", ""),
|
||||
("nvidia/nemotron-3-super-120b-a12b:free", "free"),
|
||||
("arcee-ai/trinity-large-preview:free", "free"),
|
||||
("openai/gpt-5.4-pro", ""),
|
||||
@@ -150,6 +151,7 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"gemini-3.1-pro",
|
||||
"gemini-3-pro",
|
||||
"gemini-3-flash",
|
||||
"minimax-m2.7",
|
||||
"minimax-m2.5",
|
||||
"minimax-m2.5-free",
|
||||
"minimax-m2.1",
|
||||
@@ -300,12 +302,15 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
# Check if this provider has credentials available
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.auth import get_auth_status, has_usable_secret
|
||||
if pid == "custom":
|
||||
has_creds = bool(_get_custom_base_url())
|
||||
custom_base_url = _get_custom_base_url() or os.getenv("OPENAI_BASE_URL", "")
|
||||
has_creds = bool(custom_base_url.strip())
|
||||
elif pid == "openrouter":
|
||||
has_creds = has_usable_secret(os.getenv("OPENROUTER_API_KEY", ""))
|
||||
else:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=pid)
|
||||
has_creds = bool(runtime.get("api_key"))
|
||||
status = get_auth_status(pid)
|
||||
has_creds = bool(status.get("logged_in") or status.get("configured"))
|
||||
except Exception:
|
||||
pass
|
||||
result.append({
|
||||
|
||||
@@ -454,3 +454,48 @@ def invoke_hook(hook_name: str, **kwargs: Any) -> None:
|
||||
def get_plugin_tool_names() -> Set[str]:
|
||||
"""Return the set of tool names registered by plugins."""
|
||||
return get_plugin_manager()._plugin_tool_names
|
||||
|
||||
|
||||
def get_plugin_toolsets() -> List[tuple]:
|
||||
"""Return plugin toolsets as ``(key, label, description)`` tuples.
|
||||
|
||||
Used by the ``hermes tools`` TUI so plugin-provided toolsets appear
|
||||
alongside the built-in ones and can be toggled on/off per platform.
|
||||
"""
|
||||
manager = get_plugin_manager()
|
||||
if not manager._plugin_tool_names:
|
||||
return []
|
||||
|
||||
try:
|
||||
from tools.registry import registry
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# Group plugin tool names by their toolset
|
||||
toolset_tools: Dict[str, List[str]] = {}
|
||||
toolset_plugin: Dict[str, LoadedPlugin] = {}
|
||||
for tool_name in manager._plugin_tool_names:
|
||||
entry = registry._tools.get(tool_name)
|
||||
if not entry:
|
||||
continue
|
||||
ts = entry.toolset
|
||||
toolset_tools.setdefault(ts, []).append(entry.name)
|
||||
|
||||
# Map toolsets back to the plugin that registered them
|
||||
for _name, loaded in manager._plugins.items():
|
||||
for tool_name in loaded.tools_registered:
|
||||
entry = registry._tools.get(tool_name)
|
||||
if entry and entry.toolset in toolset_tools:
|
||||
toolset_plugin.setdefault(entry.toolset, loaded)
|
||||
|
||||
result = []
|
||||
for ts_key in sorted(toolset_tools):
|
||||
plugin = toolset_plugin.get(ts_key)
|
||||
label = f"🔌 {ts_key.replace('_', ' ').title()}"
|
||||
if plugin and plugin.manifest.description:
|
||||
desc = plugin.manifest.description
|
||||
else:
|
||||
desc = ", ".join(sorted(toolset_tools[ts_key]))
|
||||
result.append((ts_key, label, desc))
|
||||
|
||||
return result
|
||||
|
||||
@@ -0,0 +1,446 @@
|
||||
"""``hermes plugins`` CLI subcommand — install, update, remove, and list plugins.
|
||||
|
||||
Plugins are installed from Git repositories into ``~/.hermes/plugins/``.
|
||||
Supports full URLs and ``owner/repo`` shorthand (resolves to GitHub).
|
||||
|
||||
After install, if the plugin ships an ``after-install.md`` file it is
|
||||
rendered with Rich Markdown. Otherwise a default confirmation is shown.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum manifest version this installer understands.
|
||||
# Plugins may declare ``manifest_version: 1`` in plugin.yaml;
|
||||
# future breaking changes to the manifest schema bump this.
|
||||
_SUPPORTED_MANIFEST_VERSION = 1
|
||||
|
||||
|
||||
def _plugins_dir() -> Path:
|
||||
"""Return the user plugins directory, creating it if needed."""
|
||||
hermes_home = os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes"))
|
||||
plugins = Path(hermes_home) / "plugins"
|
||||
plugins.mkdir(parents=True, exist_ok=True)
|
||||
return plugins
|
||||
|
||||
|
||||
def _sanitize_plugin_name(name: str, plugins_dir: Path) -> Path:
|
||||
"""Validate a plugin name and return the safe target path inside *plugins_dir*.
|
||||
|
||||
Raises ``ValueError`` if the name contains path-traversal sequences or would
|
||||
resolve outside the plugins directory.
|
||||
"""
|
||||
if not name:
|
||||
raise ValueError("Plugin name must not be empty.")
|
||||
|
||||
# Reject obvious traversal characters
|
||||
for bad in ("/", "\\", ".."):
|
||||
if bad in name:
|
||||
raise ValueError(f"Invalid plugin name '{name}': must not contain '{bad}'.")
|
||||
|
||||
target = (plugins_dir / name).resolve()
|
||||
plugins_resolved = plugins_dir.resolve()
|
||||
|
||||
if (
|
||||
not str(target).startswith(str(plugins_resolved) + os.sep)
|
||||
and target != plugins_resolved
|
||||
):
|
||||
raise ValueError(
|
||||
f"Invalid plugin name '{name}': resolves outside the plugins directory."
|
||||
)
|
||||
|
||||
return target
|
||||
|
||||
|
||||
def _resolve_git_url(identifier: str) -> str:
|
||||
"""Turn an identifier into a cloneable Git URL.
|
||||
|
||||
Accepted formats:
|
||||
- Full URL: https://github.com/owner/repo.git
|
||||
- Full URL: git@github.com:owner/repo.git
|
||||
- Full URL: ssh://git@github.com/owner/repo.git
|
||||
- Shorthand: owner/repo → https://github.com/owner/repo.git
|
||||
|
||||
NOTE: ``http://`` and ``file://`` schemes are accepted but will trigger a
|
||||
security warning at install time.
|
||||
"""
|
||||
# Already a URL
|
||||
if identifier.startswith(("https://", "http://", "git@", "ssh://", "file://")):
|
||||
return identifier
|
||||
|
||||
# owner/repo shorthand
|
||||
parts = identifier.strip("/").split("/")
|
||||
if len(parts) == 2:
|
||||
owner, repo = parts
|
||||
return f"https://github.com/{owner}/{repo}.git"
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid plugin identifier: '{identifier}'. "
|
||||
"Use a Git URL or owner/repo shorthand."
|
||||
)
|
||||
|
||||
|
||||
def _repo_name_from_url(url: str) -> str:
|
||||
"""Extract the repo name from a Git URL for the plugin directory name."""
|
||||
# Strip trailing .git and slashes
|
||||
name = url.rstrip("/")
|
||||
if name.endswith(".git"):
|
||||
name = name[:-4]
|
||||
# Get last path component
|
||||
name = name.rsplit("/", 1)[-1]
|
||||
# Handle ssh-style urls: git@github.com:owner/repo
|
||||
if ":" in name:
|
||||
name = name.rsplit(":", 1)[-1].rsplit("/", 1)[-1]
|
||||
return name
|
||||
|
||||
|
||||
def _read_manifest(plugin_dir: Path) -> dict:
|
||||
"""Read plugin.yaml and return the parsed dict, or empty dict."""
|
||||
manifest_file = plugin_dir / "plugin.yaml"
|
||||
if not manifest_file.exists():
|
||||
return {}
|
||||
try:
|
||||
import yaml
|
||||
|
||||
with open(manifest_file) as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read plugin.yaml in %s: %s", plugin_dir, e)
|
||||
return {}
|
||||
|
||||
|
||||
def _copy_example_files(plugin_dir: Path, console) -> None:
|
||||
"""Copy any .example files to their real names if they don't already exist.
|
||||
|
||||
For example, ``config.yaml.example`` becomes ``config.yaml``.
|
||||
Skips files that already exist to avoid overwriting user config on reinstall.
|
||||
"""
|
||||
for example_file in plugin_dir.glob("*.example"):
|
||||
real_name = example_file.stem # e.g. "config.yaml" from "config.yaml.example"
|
||||
real_path = plugin_dir / real_name
|
||||
if not real_path.exists():
|
||||
try:
|
||||
shutil.copy2(example_file, real_path)
|
||||
console.print(
|
||||
f"[dim] Created {real_name} from {example_file.name}[/dim]"
|
||||
)
|
||||
except OSError as e:
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] Failed to copy {example_file.name}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _display_after_install(plugin_dir: Path, identifier: str) -> None:
|
||||
"""Show after-install.md if it exists, otherwise a default message."""
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
|
||||
console = Console()
|
||||
after_install = plugin_dir / "after-install.md"
|
||||
|
||||
if after_install.exists():
|
||||
content = after_install.read_text(encoding="utf-8")
|
||||
md = Markdown(content)
|
||||
console.print()
|
||||
console.print(Panel(md, border_style="green", expand=False))
|
||||
console.print()
|
||||
else:
|
||||
console.print()
|
||||
console.print(
|
||||
Panel(
|
||||
f"[green bold]Plugin installed:[/] {identifier}\n"
|
||||
f"[dim]Location:[/] {plugin_dir}",
|
||||
border_style="green",
|
||||
title="✓ Installed",
|
||||
expand=False,
|
||||
)
|
||||
)
|
||||
console.print()
|
||||
|
||||
|
||||
def _display_removed(name: str, plugins_dir: Path) -> None:
|
||||
"""Show confirmation after removing a plugin."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
console.print()
|
||||
console.print(f"[red]✗[/red] Plugin [bold]{name}[/bold] removed from {plugins_dir}")
|
||||
console.print()
|
||||
|
||||
|
||||
def _require_installed_plugin(name: str, plugins_dir: Path, console) -> Path:
|
||||
"""Return the plugin path if it exists, or exit with an error listing installed plugins."""
|
||||
target = _sanitize_plugin_name(name, plugins_dir)
|
||||
if not target.exists():
|
||||
installed = ", ".join(d.name for d in plugins_dir.iterdir() if d.is_dir()) or "(none)"
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{name}' not found in {plugins_dir}.\n"
|
||||
f"Installed plugins: {installed}"
|
||||
)
|
||||
sys.exit(1)
|
||||
return target
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def cmd_install(identifier: str, force: bool = False) -> None:
|
||||
"""Install a plugin from a Git URL or owner/repo shorthand."""
|
||||
import tempfile
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
try:
|
||||
git_url = _resolve_git_url(identifier)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Warn about insecure / local URL schemes
|
||||
if git_url.startswith("http://") or git_url.startswith("file://"):
|
||||
console.print(
|
||||
"[yellow]Warning:[/yellow] Using insecure/local URL scheme. "
|
||||
"Consider using https:// or git@ for production installs."
|
||||
)
|
||||
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
# Clone into a temp directory first so we can read plugin.yaml for the name
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
tmp_target = Path(tmp) / "plugin"
|
||||
console.print(f"[dim]Cloning {git_url}...[/dim]")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "clone", "--depth", "1", git_url, str(tmp_target)],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
console.print("[red]Error:[/red] git is not installed or not in PATH.")
|
||||
sys.exit(1)
|
||||
except subprocess.TimeoutExpired:
|
||||
console.print("[red]Error:[/red] Git clone timed out after 60 seconds.")
|
||||
sys.exit(1)
|
||||
|
||||
if result.returncode != 0:
|
||||
console.print(
|
||||
f"[red]Error:[/red] Git clone failed:\n{result.stderr.strip()}"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Read manifest
|
||||
manifest = _read_manifest(tmp_target)
|
||||
plugin_name = manifest.get("name") or _repo_name_from_url(git_url)
|
||||
|
||||
# Sanitize plugin name against path traversal
|
||||
try:
|
||||
target = _sanitize_plugin_name(plugin_name, plugins_dir)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Check manifest_version compatibility
|
||||
mv = manifest.get("manifest_version")
|
||||
if mv is not None:
|
||||
try:
|
||||
mv_int = int(mv)
|
||||
except (ValueError, TypeError):
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{plugin_name}' has invalid "
|
||||
f"manifest_version '{mv}' (expected an integer)."
|
||||
)
|
||||
sys.exit(1)
|
||||
if mv_int > _SUPPORTED_MANIFEST_VERSION:
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{plugin_name}' requires manifest_version "
|
||||
f"{mv}, but this installer only supports up to {_SUPPORTED_MANIFEST_VERSION}.\n"
|
||||
f"Run [bold]hermes update[/bold] to get a newer installer."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if target.exists():
|
||||
if not force:
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{plugin_name}' already exists at {target}.\n"
|
||||
f"Use [bold]--force[/bold] to remove and reinstall, or "
|
||||
f"[bold]hermes plugins update {plugin_name}[/bold] to pull latest."
|
||||
)
|
||||
sys.exit(1)
|
||||
console.print(f"[dim] Removing existing {plugin_name}...[/dim]")
|
||||
shutil.rmtree(target)
|
||||
|
||||
# Move from temp to final location
|
||||
shutil.move(str(tmp_target), str(target))
|
||||
|
||||
# Validate it looks like a plugin
|
||||
if not (target / "plugin.yaml").exists() and not (target / "__init__.py").exists():
|
||||
console.print(
|
||||
f"[yellow]Warning:[/yellow] {plugin_name} doesn't contain plugin.yaml "
|
||||
f"or __init__.py. It may not be a valid Hermes plugin."
|
||||
)
|
||||
|
||||
# Copy .example files to their real names (e.g. config.yaml.example → config.yaml)
|
||||
_copy_example_files(target, console)
|
||||
|
||||
_display_after_install(target, identifier)
|
||||
|
||||
console.print("[dim]Restart the gateway for the plugin to take effect:[/dim]")
|
||||
console.print("[dim] hermes gateway restart[/dim]")
|
||||
console.print()
|
||||
|
||||
|
||||
def cmd_update(name: str) -> None:
|
||||
"""Update an installed plugin by pulling latest from its git remote."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
try:
|
||||
target = _require_installed_plugin(name, plugins_dir, console)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
if not (target / ".git").exists():
|
||||
console.print(
|
||||
f"[red]Error:[/red] Plugin '{name}' was not installed from git "
|
||||
f"(no .git directory). Cannot update."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
console.print(f"[dim]Updating {name}...[/dim]")
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "pull", "--ff-only"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=60,
|
||||
cwd=str(target),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
console.print("[red]Error:[/red] git is not installed or not in PATH.")
|
||||
sys.exit(1)
|
||||
except subprocess.TimeoutExpired:
|
||||
console.print("[red]Error:[/red] Git pull timed out after 60 seconds.")
|
||||
sys.exit(1)
|
||||
|
||||
if result.returncode != 0:
|
||||
console.print(f"[red]Error:[/red] Git pull failed:\n{result.stderr.strip()}")
|
||||
sys.exit(1)
|
||||
|
||||
# Copy any new .example files
|
||||
_copy_example_files(target, console)
|
||||
|
||||
output = result.stdout.strip()
|
||||
if "Already up to date" in output:
|
||||
console.print(
|
||||
f"[green]✓[/green] Plugin [bold]{name}[/bold] is already up to date."
|
||||
)
|
||||
else:
|
||||
console.print(f"[green]✓[/green] Plugin [bold]{name}[/bold] updated.")
|
||||
console.print(f"[dim]{output}[/dim]")
|
||||
|
||||
|
||||
def cmd_remove(name: str) -> None:
|
||||
"""Remove an installed plugin by name."""
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
try:
|
||||
target = _require_installed_plugin(name, plugins_dir, console)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]Error:[/red] {e}")
|
||||
sys.exit(1)
|
||||
|
||||
shutil.rmtree(target)
|
||||
_display_removed(name, plugins_dir)
|
||||
|
||||
|
||||
def cmd_list() -> None:
|
||||
"""List installed plugins."""
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None
|
||||
|
||||
console = Console()
|
||||
plugins_dir = _plugins_dir()
|
||||
|
||||
dirs = sorted(d for d in plugins_dir.iterdir() if d.is_dir())
|
||||
if not dirs:
|
||||
console.print("[dim]No plugins installed.[/dim]")
|
||||
console.print(f"[dim]Install with:[/dim] hermes plugins install owner/repo")
|
||||
return
|
||||
|
||||
table = Table(title="Installed Plugins", show_lines=False)
|
||||
table.add_column("Name", style="bold")
|
||||
table.add_column("Version", style="dim")
|
||||
table.add_column("Description")
|
||||
table.add_column("Source", style="dim")
|
||||
|
||||
for d in dirs:
|
||||
manifest_file = d / "plugin.yaml"
|
||||
name = d.name
|
||||
version = ""
|
||||
description = ""
|
||||
source = "local"
|
||||
|
||||
if manifest_file.exists() and yaml:
|
||||
try:
|
||||
with open(manifest_file) as f:
|
||||
manifest = yaml.safe_load(f) or {}
|
||||
name = manifest.get("name", d.name)
|
||||
version = manifest.get("version", "")
|
||||
description = manifest.get("description", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check if it's a git repo (installed via hermes plugins install)
|
||||
if (d / ".git").exists():
|
||||
source = "git"
|
||||
|
||||
table.add_row(name, str(version), description, source)
|
||||
|
||||
console.print()
|
||||
console.print(table)
|
||||
console.print()
|
||||
|
||||
|
||||
def plugins_command(args) -> None:
|
||||
"""Dispatch hermes plugins subcommands."""
|
||||
action = getattr(args, "plugins_action", None)
|
||||
|
||||
if action == "install":
|
||||
cmd_install(args.identifier, force=getattr(args, "force", False))
|
||||
elif action == "update":
|
||||
cmd_update(args.name)
|
||||
elif action in ("remove", "rm", "uninstall"):
|
||||
cmd_remove(args.name)
|
||||
elif action in ("list", "ls") or action is None:
|
||||
cmd_list()
|
||||
else:
|
||||
from rich.console import Console
|
||||
|
||||
Console().print(f"[red]Unknown plugins action: {action}[/red]")
|
||||
sys.exit(1)
|
||||
@@ -15,6 +15,7 @@ from hermes_cli.auth import (
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_api_key_provider_credentials,
|
||||
resolve_external_process_provider_credentials,
|
||||
has_usable_secret,
|
||||
)
|
||||
from hermes_cli.config import load_config
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
@@ -188,12 +189,13 @@ def _resolve_named_custom_runtime(
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
api_key = (
|
||||
(explicit_api_key or "").strip()
|
||||
or custom_provider.get("api_key", "")
|
||||
or os.getenv("OPENAI_API_KEY", "").strip()
|
||||
or os.getenv("OPENROUTER_API_KEY", "").strip()
|
||||
)
|
||||
api_key_candidates = [
|
||||
(explicit_api_key or "").strip(),
|
||||
str(custom_provider.get("api_key", "") or "").strip(),
|
||||
os.getenv("OPENAI_API_KEY", "").strip(),
|
||||
os.getenv("OPENROUTER_API_KEY", "").strip(),
|
||||
]
|
||||
api_key = next((candidate for candidate in api_key_candidates if has_usable_secret(candidate)), "")
|
||||
|
||||
return {
|
||||
"provider": "openrouter",
|
||||
@@ -257,21 +259,23 @@ def _resolve_openrouter_runtime(
|
||||
# provider (issues #420, #560).
|
||||
_is_openrouter_url = "openrouter.ai" in base_url
|
||||
if _is_openrouter_url:
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
api_key_candidates = [
|
||||
explicit_api_key,
|
||||
os.getenv("OPENROUTER_API_KEY"),
|
||||
os.getenv("OPENAI_API_KEY"),
|
||||
]
|
||||
else:
|
||||
# Custom endpoint: use api_key from config when using config base_url (#1760).
|
||||
api_key = (
|
||||
explicit_api_key
|
||||
or (cfg_api_key if use_config_base_url else "")
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
or os.getenv("OPENROUTER_API_KEY")
|
||||
or ""
|
||||
)
|
||||
api_key_candidates = [
|
||||
explicit_api_key,
|
||||
(cfg_api_key if use_config_base_url else ""),
|
||||
os.getenv("OPENAI_API_KEY"),
|
||||
os.getenv("OPENROUTER_API_KEY"),
|
||||
]
|
||||
api_key = next(
|
||||
(str(candidate or "").strip() for candidate in api_key_candidates if has_usable_secret(candidate)),
|
||||
"",
|
||||
)
|
||||
|
||||
source = "explicit" if (explicit_api_key or explicit_base_url) else "env/config"
|
||||
|
||||
@@ -359,9 +363,14 @@ def resolve_runtime_provider(
|
||||
"No Anthropic credentials found. Set ANTHROPIC_TOKEN or ANTHROPIC_API_KEY, "
|
||||
"run 'claude setup-token', or authenticate with 'claude /login'."
|
||||
)
|
||||
# Allow base URL override from config.yaml model.base_url
|
||||
# Allow base URL override from config.yaml model.base_url, but only
|
||||
# when the configured provider is anthropic — otherwise a non-Anthropic
|
||||
# base_url (e.g. Codex endpoint) would leak into Anthropic requests.
|
||||
model_cfg = _get_model_config()
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
cfg_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
cfg_base_url = ""
|
||||
if cfg_provider == "anthropic":
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
base_url = cfg_base_url or "https://api.anthropic.com"
|
||||
return {
|
||||
"provider": "anthropic",
|
||||
@@ -372,19 +381,6 @@ def resolve_runtime_provider(
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
# Alibaba Cloud / DashScope (Anthropic-compatible endpoint)
|
||||
if provider == "alibaba":
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
base_url = creds.get("base_url", "").rstrip("/") or "https://dashscope-intl.aliyuncs.com/apps/anthropic"
|
||||
return {
|
||||
"provider": "alibaba",
|
||||
"api_mode": "anthropic_messages",
|
||||
"base_url": base_url,
|
||||
"api_key": creds.get("api_key", ""),
|
||||
"source": creds.get("source", "env"),
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
|
||||
# API-key providers (z.ai/GLM, Kimi, MiniMax, MiniMax-CN)
|
||||
pconfig = PROVIDER_REGISTRY.get(provider)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
|
||||
+11
-11
@@ -4,9 +4,9 @@ Interactive setup wizard for Hermes Agent.
|
||||
Modular wizard with independently-runnable sections:
|
||||
1. Model & Provider — choose your AI provider and model
|
||||
2. Terminal Backend — where your agent runs commands
|
||||
3. Messaging Platforms — connect Telegram, Discord, etc.
|
||||
4. Tools — configure TTS, web search, image generation, etc.
|
||||
5. Agent Settings — iterations, compression, session reset
|
||||
3. Agent Settings — iterations, compression, session reset
|
||||
4. Messaging Platforms — connect Telegram, Discord, etc.
|
||||
5. Tools — configure TTS, web search, image generation, etc.
|
||||
|
||||
Config files are stored in ~/.hermes/ for easy access.
|
||||
"""
|
||||
@@ -2037,7 +2037,7 @@ def setup_terminal_backend(config: dict):
|
||||
|
||||
# Docker image
|
||||
current_image = config.get("terminal", {}).get(
|
||||
"docker_image", "python:3.11-slim"
|
||||
"docker_image", "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
)
|
||||
image = prompt(" Docker image", current_image)
|
||||
config["terminal"]["docker_image"] = image
|
||||
@@ -2059,7 +2059,7 @@ def setup_terminal_backend(config: dict):
|
||||
print_info(f"Found: {sing_bin}")
|
||||
|
||||
current_image = config.get("terminal", {}).get(
|
||||
"singularity_image", "docker://python:3.11-slim"
|
||||
"singularity_image", "docker://nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
)
|
||||
image = prompt(" Container image", current_image)
|
||||
config["terminal"]["singularity_image"] = image
|
||||
@@ -2261,7 +2261,7 @@ def setup_agent_settings(config: dict):
|
||||
)
|
||||
print_info("Maximum tool-calling iterations per conversation.")
|
||||
print_info("Higher = more complex tasks, but costs more tokens.")
|
||||
print_info("Recommended: 30-60 for most tasks, 100+ for open exploration.")
|
||||
print_info("Default is 90, which works for most tasks. Use 150+ for open exploration.")
|
||||
|
||||
max_iter_str = prompt("Max iterations", current_max)
|
||||
try:
|
||||
@@ -2303,7 +2303,7 @@ def setup_agent_settings(config: dict):
|
||||
|
||||
config.setdefault("compression", {})["enabled"] = True
|
||||
|
||||
current_threshold = config.get("compression", {}).get("threshold", 0.85)
|
||||
current_threshold = config.get("compression", {}).get("threshold", 0.50)
|
||||
threshold_str = prompt("Compression threshold (0.5-0.95)", str(current_threshold))
|
||||
try:
|
||||
threshold = float(threshold_str)
|
||||
@@ -2313,7 +2313,7 @@ def setup_agent_settings(config: dict):
|
||||
pass
|
||||
|
||||
print_success(
|
||||
f"Context compression threshold set to {config['compression'].get('threshold', 0.85)}"
|
||||
f"Context compression threshold set to {config['compression'].get('threshold', 0.50)}"
|
||||
)
|
||||
|
||||
# ── Session Reset Policy ──
|
||||
@@ -3248,9 +3248,9 @@ def run_setup_wizard(args):
|
||||
print_info("We'll walk you through:")
|
||||
print_info(" 1. Model & Provider — choose your AI provider and model")
|
||||
print_info(" 2. Terminal Backend — where your agent runs commands")
|
||||
print_info(" 3. Messaging Platforms — connect Telegram, Discord, etc.")
|
||||
print_info(" 4. Tools — configure TTS, web search, image generation, etc.")
|
||||
print_info(" 5. Agent Settings — iterations, compression, session reset")
|
||||
print_info(" 3. Agent Settings — iterations, compression, session reset")
|
||||
print_info(" 4. Messaging Platforms — connect Telegram, Discord, etc.")
|
||||
print_info(" 5. Tools — configure TTS, web search, image generation, etc.")
|
||||
print()
|
||||
print_info("Press Enter to begin, or Ctrl+C to exit.")
|
||||
try:
|
||||
|
||||
@@ -455,6 +455,8 @@ def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
|
||||
if bundle and "SKILL.md" in bundle.files:
|
||||
content = bundle.files["SKILL.md"]
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8", errors="replace")
|
||||
# Show first 50 lines as preview
|
||||
lines = content.split("\n")
|
||||
preview = "\n".join(lines[:50])
|
||||
@@ -640,7 +642,8 @@ def do_tap(action: str, repo: str = "", console: Optional[Console] = None) -> No
|
||||
table.add_column("Repo", style="bold cyan")
|
||||
table.add_column("Path", style="dim")
|
||||
for t in taps:
|
||||
table.add_row(t["repo"], t.get("path", "skills/"))
|
||||
label = t.get("repo") or t.get("name") or t.get("path", "unknown")
|
||||
table.add_row(label, t.get("path", "skills/"))
|
||||
c.print(table)
|
||||
c.print()
|
||||
|
||||
|
||||
+108
-49
@@ -101,6 +101,30 @@ CONFIGURABLE_TOOLSETS = [
|
||||
# but the setup checklist won't pre-select them for first-time users.
|
||||
_DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "rl"}
|
||||
|
||||
|
||||
def _get_effective_configurable_toolsets():
|
||||
"""Return CONFIGURABLE_TOOLSETS + any plugin-provided toolsets.
|
||||
|
||||
Plugin toolsets are appended at the end so they appear after the
|
||||
built-in toolsets in the TUI checklist.
|
||||
"""
|
||||
result = list(CONFIGURABLE_TOOLSETS)
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_toolsets
|
||||
result.extend(get_plugin_toolsets())
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _get_plugin_toolset_keys() -> set:
|
||||
"""Return the set of toolset keys provided by plugins."""
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_toolsets
|
||||
return {ts_key for ts_key, _, _ in get_plugin_toolsets()}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
# Platform display config
|
||||
PLATFORMS = {
|
||||
"cli": {"label": "🖥️ CLI", "default_toolset": "hermes-cli"},
|
||||
@@ -377,19 +401,36 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
has_explicit_config = any(ts in configurable_keys for ts in toolset_names)
|
||||
|
||||
if has_explicit_config:
|
||||
return {ts for ts in toolset_names if ts in configurable_keys}
|
||||
enabled_toolsets = {ts for ts in toolset_names if ts in configurable_keys}
|
||||
else:
|
||||
# No explicit config — fall back to resolving composite toolset names
|
||||
# (e.g. "hermes-cli") to individual tool names and reverse-mapping.
|
||||
all_tool_names = set()
|
||||
for ts_name in toolset_names:
|
||||
all_tool_names.update(resolve_toolset(ts_name))
|
||||
|
||||
# No explicit config — fall back to resolving composite toolset names
|
||||
# (e.g. "hermes-cli") to individual tool names and reverse-mapping.
|
||||
all_tool_names = set()
|
||||
for ts_name in toolset_names:
|
||||
all_tool_names.update(resolve_toolset(ts_name))
|
||||
enabled_toolsets = set()
|
||||
for ts_key, _, _ in CONFIGURABLE_TOOLSETS:
|
||||
ts_tools = set(resolve_toolset(ts_key))
|
||||
if ts_tools and ts_tools.issubset(all_tool_names):
|
||||
enabled_toolsets.add(ts_key)
|
||||
|
||||
enabled_toolsets = set()
|
||||
for ts_key, _, _ in CONFIGURABLE_TOOLSETS:
|
||||
ts_tools = set(resolve_toolset(ts_key))
|
||||
if ts_tools and ts_tools.issubset(all_tool_names):
|
||||
enabled_toolsets.add(ts_key)
|
||||
# Plugin toolsets: enabled by default unless explicitly disabled.
|
||||
# A plugin toolset is "known" for a platform once `hermes tools`
|
||||
# has been saved for that platform (tracked via known_plugin_toolsets).
|
||||
# Unknown plugins default to enabled; known-but-absent = disabled.
|
||||
plugin_ts_keys = _get_plugin_toolset_keys()
|
||||
if plugin_ts_keys:
|
||||
known_map = config.get("known_plugin_toolsets", {})
|
||||
known_for_platform = set(known_map.get(platform, []))
|
||||
for pts in plugin_ts_keys:
|
||||
if pts in toolset_names:
|
||||
# Explicitly listed in config — enabled
|
||||
enabled_toolsets.add(pts)
|
||||
elif pts not in known_for_platform:
|
||||
# New plugin not yet seen by hermes tools — default enabled
|
||||
enabled_toolsets.add(pts)
|
||||
# else: known but not in config = user disabled it
|
||||
|
||||
return enabled_toolsets
|
||||
|
||||
@@ -397,41 +438,42 @@ def _get_platform_tools(config: dict, platform: str) -> Set[str]:
|
||||
def _save_platform_tools(config: dict, platform: str, enabled_toolset_keys: Set[str]):
|
||||
"""Save the selected toolset keys for a platform to config.
|
||||
|
||||
Preserves any non-configurable, non-composite entries (like MCP server
|
||||
names) that were already in the config for this platform.
|
||||
|
||||
Composite platform toolsets (hermes-cli, hermes-telegram, etc.) are
|
||||
dropped once the user has explicitly configured individual toolsets —
|
||||
keeping them would override the user's selections because they include
|
||||
all tools via _HERMES_CORE_TOOLS.
|
||||
Preserves any non-configurable toolset entries (like MCP server names)
|
||||
that were already in the config for this platform.
|
||||
"""
|
||||
from toolsets import TOOLSETS
|
||||
|
||||
config.setdefault("platform_toolsets", {})
|
||||
|
||||
# Keys the user can toggle in the checklist UI
|
||||
# Get the set of all configurable toolset keys (built-in + plugin)
|
||||
configurable_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
plugin_keys = _get_plugin_toolset_keys()
|
||||
configurable_keys |= plugin_keys
|
||||
|
||||
# Keys that are known composite/individual toolsets in toolsets.py
|
||||
# (hermes-cli, hermes-telegram, homeassistant, web, terminal, etc.)
|
||||
known_toolset_keys = set(TOOLSETS.keys())
|
||||
# Also exclude platform default toolsets (hermes-cli, hermes-telegram, etc.)
|
||||
# These are "super" toolsets that resolve to ALL tools, so preserving them
|
||||
# would silently override the user's unchecked selections on the next read.
|
||||
platform_default_keys = {p["default_toolset"] for p in PLATFORMS.values()}
|
||||
|
||||
# Get existing toolsets for this platform
|
||||
existing_toolsets = config.get("platform_toolsets", {}).get(platform, [])
|
||||
if not isinstance(existing_toolsets, list):
|
||||
existing_toolsets = []
|
||||
|
||||
# Preserve entries that are neither configurable toolsets nor known
|
||||
# composite toolsets — this keeps MCP server names and other custom
|
||||
# entries while dropping composites like "hermes-cli" that would
|
||||
# silently re-enable everything the user just disabled.
|
||||
# Preserve any entries that are NOT configurable toolsets and NOT platform
|
||||
# defaults (i.e. only MCP server names should be preserved)
|
||||
preserved_entries = {
|
||||
entry for entry in existing_toolsets
|
||||
if entry not in configurable_keys and entry not in known_toolset_keys
|
||||
if entry not in configurable_keys and entry not in platform_default_keys
|
||||
}
|
||||
|
||||
# Merge preserved entries with new enabled toolsets
|
||||
config["platform_toolsets"][platform] = sorted(enabled_toolset_keys | preserved_entries)
|
||||
|
||||
# Track which plugin toolsets are "known" for this platform so we can
|
||||
# distinguish "new plugin, default enabled" from "user disabled it".
|
||||
if plugin_keys:
|
||||
config.setdefault("known_plugin_toolsets", {})
|
||||
config["known_plugin_toolsets"][platform] = sorted(plugin_keys)
|
||||
|
||||
save_config(config)
|
||||
|
||||
|
||||
@@ -549,15 +591,17 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
"""Multi-select checklist of toolsets. Returns set of selected toolset keys."""
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
|
||||
effective = _get_effective_configurable_toolsets()
|
||||
|
||||
labels = []
|
||||
for ts_key, ts_label, ts_desc in CONFIGURABLE_TOOLSETS:
|
||||
for ts_key, ts_label, ts_desc in effective:
|
||||
suffix = ""
|
||||
if not _toolset_has_keys(ts_key) and (TOOL_CATEGORIES.get(ts_key) or TOOLSET_ENV_REQUIREMENTS.get(ts_key)):
|
||||
suffix = " [no API key]"
|
||||
labels.append(f"{ts_label} ({ts_desc}){suffix}")
|
||||
|
||||
pre_selected = {
|
||||
i for i, (ts_key, _, _) in enumerate(CONFIGURABLE_TOOLSETS)
|
||||
i for i, (ts_key, _, _) in enumerate(effective)
|
||||
if ts_key in enabled
|
||||
}
|
||||
|
||||
@@ -567,7 +611,7 @@ def _prompt_toolset_checklist(platform_label: str, enabled: Set[str]) -> Set[str
|
||||
pre_selected,
|
||||
cancel_returns=pre_selected,
|
||||
)
|
||||
return {CONFIGURABLE_TOOLSETS[i][0] for i in chosen}
|
||||
return {effective[i][0] for i in chosen}
|
||||
|
||||
|
||||
# ─── Provider-Aware Configuration ────────────────────────────────────────────
|
||||
@@ -782,7 +826,7 @@ def _configure_simple_requirements(ts_key: str):
|
||||
if not missing:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
ts_label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label} requires configuration:", Colors.YELLOW))
|
||||
|
||||
@@ -801,7 +845,7 @@ def _reconfigure_tool(config: dict):
|
||||
"""Let user reconfigure an existing tool's provider or API key."""
|
||||
# Build list of configurable tools that are currently set up
|
||||
configurable = []
|
||||
for ts_key, ts_label, _ in CONFIGURABLE_TOOLSETS:
|
||||
for ts_key, ts_label, _ in _get_effective_configurable_toolsets():
|
||||
cat = TOOL_CATEGORIES.get(ts_key)
|
||||
reqs = TOOLSET_ENV_REQUIREMENTS.get(ts_key)
|
||||
if cat or reqs:
|
||||
@@ -915,7 +959,7 @@ def _reconfigure_simple_requirements(ts_key: str):
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
ts_label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
ts_label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print()
|
||||
print(color(f" {ts_label}:", Colors.CYAN))
|
||||
|
||||
@@ -954,7 +998,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
|
||||
# Non-interactive summary mode for CLI usage
|
||||
if getattr(args, "summary", False):
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
print(color("⚕ Tool Summary", Colors.CYAN, Colors.BOLD))
|
||||
print()
|
||||
summary = _platform_toolset_summary(config, enabled_platforms)
|
||||
@@ -965,7 +1009,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print(color(f" {pinfo['label']}", Colors.BOLD) + color(f" ({count}/{total})", Colors.DIM))
|
||||
if enabled:
|
||||
for ts_key in sorted(enabled):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print(color(f" ✓ {label}", Colors.GREEN))
|
||||
else:
|
||||
print(color(" (none enabled)", Colors.DIM))
|
||||
@@ -992,11 +1036,11 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
removed = current_enabled - new_enabled
|
||||
if added:
|
||||
for ts in sorted(added):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" + {label}", Colors.GREEN))
|
||||
if removed:
|
||||
for ts in sorted(removed):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
|
||||
# Walk through ALL selected tools that have provider options or
|
||||
@@ -1012,7 +1056,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
print()
|
||||
print(color(f" Configuring {len(to_configure)} tool(s):", Colors.YELLOW))
|
||||
for ts_key in to_configure:
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts_key), ts_key)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts_key), ts_key)
|
||||
print(color(f" • {label}", Colors.DIM))
|
||||
print(color(" You can skip any tool you don't need right now.", Colors.DIM))
|
||||
print()
|
||||
@@ -1034,7 +1078,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
pinfo = PLATFORMS[pkey]
|
||||
current = _get_platform_tools(config, pkey)
|
||||
count = len(current)
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
platform_choices.append(f"Configure {pinfo['label']} ({count}/{total} enabled)")
|
||||
platform_keys.append(pkey)
|
||||
|
||||
@@ -1090,10 +1134,10 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
if added or removed:
|
||||
print(color(f" {pinfo_inner['label']}:", Colors.DIM))
|
||||
for ts in sorted(added):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" + {label}", Colors.GREEN))
|
||||
for ts in sorted(removed):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
# Configure API keys for newly enabled tools
|
||||
for ts_key in sorted(added):
|
||||
@@ -1106,7 +1150,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
# Update choice labels
|
||||
for ci, pk in enumerate(platform_keys):
|
||||
new_count = len(_get_platform_tools(config, pk))
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
platform_choices[ci] = f"Configure {PLATFORMS[pk]['label']} ({new_count}/{total} enabled)"
|
||||
else:
|
||||
print(color(" No changes", Colors.DIM))
|
||||
@@ -1128,11 +1172,11 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
|
||||
if added:
|
||||
for ts in sorted(added):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" + {label}", Colors.GREEN))
|
||||
if removed:
|
||||
for ts in sorted(removed):
|
||||
label = next((l for k, l, _ in CONFIGURABLE_TOOLSETS if k == ts), ts)
|
||||
label = next((l for k, l, _ in _get_effective_configurable_toolsets() if k == ts), ts)
|
||||
print(color(f" - {label}", Colors.RED))
|
||||
|
||||
# Configure newly enabled toolsets that need API keys
|
||||
@@ -1151,7 +1195,7 @@ def tools_command(args=None, first_install: bool = False, config: dict = None):
|
||||
|
||||
# Update the choice label with new count
|
||||
new_count = len(_get_platform_tools(config, pkey))
|
||||
total = len(CONFIGURABLE_TOOLSETS)
|
||||
total = len(_get_effective_configurable_toolsets())
|
||||
platform_choices[idx] = f"Configure {pinfo['label']} ({new_count}/{total} enabled)"
|
||||
|
||||
print()
|
||||
@@ -1331,12 +1375,27 @@ def _apply_mcp_change(config: dict, targets: List[str], action: str) -> Set[str]
|
||||
|
||||
def _print_tools_list(enabled_toolsets: set, mcp_servers: dict, platform: str = "cli"):
|
||||
"""Print a summary of enabled/disabled toolsets and MCP tool filters."""
|
||||
effective = _get_effective_configurable_toolsets()
|
||||
builtin_keys = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
|
||||
print(f"Built-in toolsets ({platform}):")
|
||||
for ts_key, label, _ in CONFIGURABLE_TOOLSETS:
|
||||
for ts_key, label, _ in effective:
|
||||
if ts_key not in builtin_keys:
|
||||
continue
|
||||
status = (color("✓ enabled", Colors.GREEN) if ts_key in enabled_toolsets
|
||||
else color("✗ disabled", Colors.RED))
|
||||
print(f" {status} {ts_key} {color(label, Colors.DIM)}")
|
||||
|
||||
# Plugin toolsets
|
||||
plugin_entries = [(k, l) for k, l, _ in effective if k not in builtin_keys]
|
||||
if plugin_entries:
|
||||
print()
|
||||
print(f"Plugin toolsets ({platform}):")
|
||||
for ts_key, label in plugin_entries:
|
||||
status = (color("✓ enabled", Colors.GREEN) if ts_key in enabled_toolsets
|
||||
else color("✗ disabled", Colors.RED))
|
||||
print(f" {status} {ts_key} {color(label, Colors.DIM)}")
|
||||
|
||||
if mcp_servers:
|
||||
print()
|
||||
print("MCP servers:")
|
||||
@@ -1375,7 +1434,7 @@ def tools_disable_enable_command(args):
|
||||
toolset_targets = [t for t in targets if ":" not in t]
|
||||
mcp_targets = [t for t in targets if ":" in t]
|
||||
|
||||
valid_toolsets = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS}
|
||||
valid_toolsets = {ts_key for ts_key, _, _ in CONFIGURABLE_TOOLSETS} | _get_plugin_toolset_keys()
|
||||
unknown_toolsets = [t for t in toolset_targets if t not in valid_toolsets]
|
||||
if unknown_toolsets:
|
||||
for name in unknown_toolsets:
|
||||
|
||||
+16
-14
@@ -855,23 +855,25 @@ class SessionDB:
|
||||
|
||||
def session_count(self, source: str = None) -> int:
|
||||
"""Count sessions, optionally filtered by source."""
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE source = ?", (source,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions")
|
||||
return cursor.fetchone()[0]
|
||||
with self._lock:
|
||||
if source:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE source = ?", (source,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM sessions")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
def message_count(self, session_id: str = None) -> int:
|
||||
"""Count messages, optionally for a specific session."""
|
||||
if session_id:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM messages")
|
||||
return cursor.fetchone()[0]
|
||||
with self._lock:
|
||||
if session_id:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
else:
|
||||
cursor = self._conn.execute("SELECT COUNT(*) FROM messages")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
# =========================================================================
|
||||
# Export and cleanup
|
||||
|
||||
+31
-16
@@ -10,22 +10,30 @@ import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
GLOBAL_CONFIG_PATH = Path.home() / ".honcho" / "config.json"
|
||||
from honcho_integration.client import resolve_config_path, GLOBAL_CONFIG_PATH
|
||||
|
||||
HOST = "hermes"
|
||||
|
||||
|
||||
def _config_path() -> Path:
|
||||
"""Return the active Honcho config path (instance-local or global)."""
|
||||
return resolve_config_path()
|
||||
|
||||
|
||||
def _read_config() -> dict:
|
||||
if GLOBAL_CONFIG_PATH.exists():
|
||||
path = _config_path()
|
||||
if path.exists():
|
||||
try:
|
||||
return json.loads(GLOBAL_CONFIG_PATH.read_text(encoding="utf-8"))
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _write_config(cfg: dict) -> None:
|
||||
GLOBAL_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
GLOBAL_CONFIG_PATH.write_text(
|
||||
def _write_config(cfg: dict, path: Path | None = None) -> None:
|
||||
path = path or _config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
json.dumps(cfg, indent=2, ensure_ascii=False) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
@@ -87,9 +95,14 @@ def cmd_setup(args) -> None:
|
||||
"""Interactive Honcho setup wizard."""
|
||||
cfg = _read_config()
|
||||
|
||||
active_path = _config_path()
|
||||
print("\nHoncho memory setup\n" + "─" * 40)
|
||||
print(" Honcho gives Hermes persistent cross-session memory.")
|
||||
print(" Config is shared with other hosts at ~/.honcho/config.json\n")
|
||||
if active_path != GLOBAL_CONFIG_PATH:
|
||||
print(f" Instance config: {active_path}")
|
||||
else:
|
||||
print(" Config is shared with other hosts at ~/.honcho/config.json")
|
||||
print()
|
||||
|
||||
if not _ensure_sdk_installed():
|
||||
return
|
||||
@@ -162,10 +175,10 @@ def cmd_setup(args) -> None:
|
||||
hermes_host["recallMode"] = new_recall
|
||||
|
||||
# Session strategy
|
||||
current_strat = hermes_host.get("sessionStrategy") or cfg.get("sessionStrategy", "per-session")
|
||||
current_strat = hermes_host.get("sessionStrategy") or cfg.get("sessionStrategy", "per-directory")
|
||||
print(f"\n Session strategy options:")
|
||||
print(" per-session — new Honcho session each run, named by Hermes session ID (default)")
|
||||
print(" per-directory — one session per working directory")
|
||||
print(" per-directory — one session per working directory (default)")
|
||||
print(" per-session — new Honcho session each run, named by Hermes session ID")
|
||||
print(" per-repo — one session per git repository (uses repo root name)")
|
||||
print(" global — single session across all directories")
|
||||
new_strat = _prompt("Session strategy", default=current_strat)
|
||||
@@ -176,7 +189,7 @@ def cmd_setup(args) -> None:
|
||||
hermes_host.setdefault("saveMessages", True)
|
||||
|
||||
_write_config(cfg)
|
||||
print(f"\n Config written to {GLOBAL_CONFIG_PATH}")
|
||||
print(f"\n Config written to {active_path}")
|
||||
|
||||
# Test connection
|
||||
print(" Testing connection... ", end="", flush=True)
|
||||
@@ -223,8 +236,10 @@ def cmd_status(args) -> None:
|
||||
|
||||
cfg = _read_config()
|
||||
|
||||
active_path = _config_path()
|
||||
|
||||
if not cfg:
|
||||
print(" No Honcho config found at ~/.honcho/config.json")
|
||||
print(f" No Honcho config found at {active_path}")
|
||||
print(" Run 'hermes honcho setup' to configure.\n")
|
||||
return
|
||||
|
||||
@@ -243,7 +258,7 @@ def cmd_status(args) -> None:
|
||||
print(f" API key: {masked}")
|
||||
print(f" Workspace: {hcfg.workspace_id}")
|
||||
print(f" Host: {hcfg.host}")
|
||||
print(f" Config path: {GLOBAL_CONFIG_PATH}")
|
||||
print(f" Config path: {active_path}")
|
||||
print(f" AI peer: {hcfg.ai_peer}")
|
||||
print(f" User peer: {hcfg.peer_name or 'not set'}")
|
||||
print(f" Session key: {hcfg.resolve_session_name()}")
|
||||
@@ -275,7 +290,7 @@ def cmd_sessions(args) -> None:
|
||||
if not sessions:
|
||||
print(" No session mappings configured.\n")
|
||||
print(" Add one with: hermes honcho map <session-name>")
|
||||
print(" Or edit ~/.honcho/config.json directly.\n")
|
||||
print(f" Or edit {_config_path()} directly.\n")
|
||||
return
|
||||
|
||||
cwd = os.getcwd()
|
||||
@@ -361,7 +376,7 @@ def cmd_peer(args) -> None:
|
||||
|
||||
if changed:
|
||||
_write_config(cfg)
|
||||
print(f" Saved to {GLOBAL_CONFIG_PATH}\n")
|
||||
print(f" Saved to {_config_path()}\n")
|
||||
|
||||
|
||||
def cmd_mode(args) -> None:
|
||||
@@ -434,7 +449,7 @@ def cmd_tokens(args) -> None:
|
||||
|
||||
if changed:
|
||||
_write_config(cfg)
|
||||
print(f" Saved to {GLOBAL_CONFIG_PATH}\n")
|
||||
print(f" Saved to {_config_path()}\n")
|
||||
|
||||
|
||||
def cmd_identity(args) -> None:
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
"""Honcho client initialization and configuration.
|
||||
|
||||
Reads the global ~/.honcho/config.json when available, falling back
|
||||
to environment variables.
|
||||
Resolution order for config file:
|
||||
1. $HERMES_HOME/honcho.json (instance-local, enables isolated Hermes instances)
|
||||
2. ~/.honcho/config.json (global, shared across all Honcho-enabled apps)
|
||||
3. Environment variables (HONCHO_API_KEY, HONCHO_ENVIRONMENT)
|
||||
|
||||
Resolution order for host-specific settings:
|
||||
1. Explicit host block fields (always win)
|
||||
@@ -27,6 +29,24 @@ GLOBAL_CONFIG_PATH = Path.home() / ".honcho" / "config.json"
|
||||
HOST = "hermes"
|
||||
|
||||
|
||||
def _get_hermes_home() -> Path:
|
||||
"""Get HERMES_HOME without importing hermes_cli (avoids circular deps)."""
|
||||
return Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
|
||||
|
||||
def resolve_config_path() -> Path:
|
||||
"""Return the active Honcho config path.
|
||||
|
||||
Checks $HERMES_HOME/honcho.json first (instance-local), then falls back
|
||||
to ~/.honcho/config.json (global). Returns the global path if neither
|
||||
exists (for first-time setup writes).
|
||||
"""
|
||||
local_path = _get_hermes_home() / "honcho.json"
|
||||
if local_path.exists():
|
||||
return local_path
|
||||
return GLOBAL_CONFIG_PATH
|
||||
|
||||
|
||||
_RECALL_MODE_ALIASES = {"auto": "hybrid"}
|
||||
_VALID_RECALL_MODES = {"hybrid", "context", "tools"}
|
||||
|
||||
@@ -107,11 +127,15 @@ class HonchoClientConfig:
|
||||
# "tools" — Honcho tools only, no auto-injected context
|
||||
recall_mode: str = "hybrid"
|
||||
# Session resolution
|
||||
session_strategy: str = "per-session"
|
||||
session_strategy: str = "per-directory"
|
||||
session_peer_prefix: bool = False
|
||||
sessions: dict[str, str] = field(default_factory=dict)
|
||||
# Raw global config for anything else consumers need
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
# True when Honcho was explicitly configured for this host (hosts.hermes
|
||||
# block exists or enabled was set explicitly), vs auto-enabled from a
|
||||
# stray HONCHO_API_KEY env var.
|
||||
explicitly_configured: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_env(cls, workspace_id: str = "hermes") -> HonchoClientConfig:
|
||||
@@ -132,11 +156,11 @@ class HonchoClientConfig:
|
||||
host: str = HOST,
|
||||
config_path: Path | None = None,
|
||||
) -> HonchoClientConfig:
|
||||
"""Create config from ~/.honcho/config.json.
|
||||
"""Create config from the resolved Honcho config path.
|
||||
|
||||
Falls back to environment variables if the file doesn't exist.
|
||||
Resolution: $HERMES_HOME/honcho.json -> ~/.honcho/config.json -> env vars.
|
||||
"""
|
||||
path = config_path or GLOBAL_CONFIG_PATH
|
||||
path = config_path or resolve_config_path()
|
||||
if not path.exists():
|
||||
logger.debug("No global Honcho config at %s, falling back to env", path)
|
||||
return cls.from_env()
|
||||
@@ -148,6 +172,9 @@ class HonchoClientConfig:
|
||||
return cls.from_env()
|
||||
|
||||
host_block = (raw.get("hosts") or {}).get(host, {})
|
||||
# A hosts.hermes block or explicit enabled flag means the user
|
||||
# intentionally configured Honcho for this host.
|
||||
_explicitly_configured = bool(host_block) or raw.get("enabled") is True
|
||||
|
||||
# Explicit host block fields win, then flat/global, then defaults
|
||||
workspace = (
|
||||
@@ -209,7 +236,7 @@ class HonchoClientConfig:
|
||||
# sessionStrategy / sessionPeerPrefix: host first, root fallback
|
||||
session_strategy = (
|
||||
host_block.get("sessionStrategy")
|
||||
or raw.get("sessionStrategy", "per-session")
|
||||
or raw.get("sessionStrategy", "per-directory")
|
||||
)
|
||||
host_prefix = host_block.get("sessionPeerPrefix")
|
||||
session_peer_prefix = (
|
||||
@@ -253,6 +280,7 @@ class HonchoClientConfig:
|
||||
session_peer_prefix=session_peer_prefix,
|
||||
sessions=raw.get("sessions", {}),
|
||||
raw=raw,
|
||||
explicitly_configured=_explicitly_configured,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -318,7 +346,7 @@ class HonchoClientConfig:
|
||||
return f"{self.peer_name}-{base}"
|
||||
return base
|
||||
|
||||
# per-directory: one Honcho session per working directory
|
||||
# per-directory: one Honcho session per working directory (default)
|
||||
if self.session_strategy in ("per-directory", "per-session"):
|
||||
base = Path(cwd).name
|
||||
if self.session_peer_prefix and self.peer_name:
|
||||
|
||||
+5
-10
@@ -22,7 +22,6 @@ Public API (signatures preserved from the original 2,400-line version):
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import threading
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
@@ -293,15 +292,11 @@ def get_tool_definitions(
|
||||
for ts_name in get_all_toolsets():
|
||||
tools_to_include.update(resolve_toolset(ts_name))
|
||||
|
||||
# Always include plugin-registered tools — they bypass the toolset filter
|
||||
# because their toolsets are dynamic (created at plugin load time).
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_tool_names
|
||||
plugin_tools = get_plugin_tool_names()
|
||||
if plugin_tools:
|
||||
tools_to_include.update(plugin_tools)
|
||||
except Exception:
|
||||
pass
|
||||
# Plugin-registered tools are now resolved through the normal toolset
|
||||
# path — validate_toolset() / resolve_toolset() / get_all_toolsets()
|
||||
# all check the tool registry for plugin-provided toolsets. No bypass
|
||||
# needed; plugins respect enabled_toolsets / disabled_toolsets like any
|
||||
# other toolset.
|
||||
|
||||
# Ask the registry for schemas (only returns tools whose check_fn passes)
|
||||
filtered_tools = registry.get_definitions(tools_to_include, quiet=quiet_mode)
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
# Meme Generation Examples
|
||||
|
||||
## Example 1: Debugging at 2 AM
|
||||
|
||||
**Topic:** debugging production at 2 AM
|
||||
**Template:** this-is-fine
|
||||
|
||||
```bash
|
||||
python generate_meme.py this-is-fine /tmp/meme.png "PRODUCTION IS DOWN" "This is fine"
|
||||
```
|
||||
|
||||
## Example 2: Developer Priorities
|
||||
|
||||
**Topic:** choosing between writing tests and shipping features
|
||||
**Template:** drake
|
||||
|
||||
```bash
|
||||
python generate_meme.py drake /tmp/meme.png "Writing unit tests" "Shipping straight to prod"
|
||||
```
|
||||
|
||||
## Example 3: Exam Stress
|
||||
|
||||
**Topic:** final exam preparation
|
||||
**Template:** two-buttons
|
||||
|
||||
```bash
|
||||
python generate_meme.py two-buttons /tmp/meme.png "Study everything" "Sleep" "Me at midnight"
|
||||
```
|
||||
|
||||
## Example 4: Escalating Solutions
|
||||
|
||||
**Topic:** fixing a CSS bug
|
||||
**Template:** expanding-brain
|
||||
|
||||
```bash
|
||||
python generate_meme.py expanding-brain /tmp/meme.png "Reading the docs" "Stack Overflow" "!important on everything" "Deleting the stylesheet"
|
||||
```
|
||||
|
||||
## Example 5: Hot Take
|
||||
|
||||
**Topic:** tabs vs spaces
|
||||
**Template:** change-my-mind
|
||||
|
||||
```bash
|
||||
python generate_meme.py change-my-mind /tmp/meme.png "Tabs are just thicc spaces"
|
||||
```
|
||||
@@ -0,0 +1,129 @@
|
||||
---
|
||||
name: meme-generation
|
||||
description: Generate real meme images by picking a template and overlaying text with Pillow. Produces actual .png meme files.
|
||||
version: 2.0.0
|
||||
author: adanaleycio
|
||||
license: MIT
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [creative, memes, humor, images]
|
||||
related_skills: [ascii-art, generative-widgets]
|
||||
category: creative
|
||||
---
|
||||
|
||||
# Meme Generation
|
||||
|
||||
Generate actual meme images from a topic. Picks a template, writes captions, and renders a real .png file with text overlay.
|
||||
|
||||
## When to Use
|
||||
|
||||
- User asks you to make or generate a meme
|
||||
- User wants a meme about a specific topic, situation, or frustration
|
||||
- User says "meme this" or similar
|
||||
|
||||
## Available Templates
|
||||
|
||||
The script supports **any of the ~100 popular imgflip templates** by name or ID, plus 10 curated templates with hand-tuned text positioning.
|
||||
|
||||
### Curated Templates (custom text placement)
|
||||
|
||||
| ID | Name | Fields | Best for |
|
||||
|----|------|--------|----------|
|
||||
| `this-is-fine` | This is Fine | top, bottom | chaos, denial |
|
||||
| `drake` | Drake Hotline Bling | reject, approve | rejecting/preferring |
|
||||
| `distracted-boyfriend` | Distracted Boyfriend | distraction, current, person | temptation, shifting priorities |
|
||||
| `two-buttons` | Two Buttons | left, right, person | impossible choice |
|
||||
| `expanding-brain` | Expanding Brain | 4 levels | escalating irony |
|
||||
| `change-my-mind` | Change My Mind | statement | hot takes |
|
||||
| `woman-yelling-at-cat` | Woman Yelling at Cat | woman, cat | arguments |
|
||||
| `one-does-not-simply` | One Does Not Simply | top, bottom | deceptively hard things |
|
||||
| `grus-plan` | Gru's Plan | step1-3, realization | plans that backfire |
|
||||
| `batman-slapping-robin` | Batman Slapping Robin | robin, batman | shutting down bad ideas |
|
||||
|
||||
### Dynamic Templates (from imgflip API)
|
||||
|
||||
Any template not in the curated list can be used by name or imgflip ID. These get smart default text positioning (top/bottom for 2-field, evenly spaced for 3+). Search with:
|
||||
```bash
|
||||
python "$SKILL_DIR/scripts/generate_meme.py" --search "disaster"
|
||||
```
|
||||
|
||||
## Procedure
|
||||
|
||||
### Mode 1: Classic Template (default)
|
||||
|
||||
1. Read the user's topic and identify the core dynamic (chaos, dilemma, preference, irony, etc.)
|
||||
2. Pick the template that best matches. Use the "Best for" column, or search with `--search`.
|
||||
3. Write short captions for each field (8-12 words max per field, shorter is better).
|
||||
4. Find the skill's script directory:
|
||||
```
|
||||
SKILL_DIR=$(dirname "$(find ~/.hermes/skills -path '*/meme-generation/SKILL.md' 2>/dev/null | head -1)")
|
||||
```
|
||||
5. Run the generator:
|
||||
```bash
|
||||
python "$SKILL_DIR/scripts/generate_meme.py" <template_id> /tmp/meme.png "caption 1" "caption 2" ...
|
||||
```
|
||||
6. Return the image with `MEDIA:/tmp/meme.png`
|
||||
|
||||
### Mode 2: Custom AI Image (when image_generate is available)
|
||||
|
||||
Use this when no classic template fits, or when the user wants something original.
|
||||
|
||||
1. Write the captions first.
|
||||
2. Use `image_generate` to create a scene that matches the meme concept. Do NOT include any text in the image prompt — text will be added by the script. Describe only the visual scene.
|
||||
3. Find the generated image path from the image_generate result URL. Download it to a local path if needed.
|
||||
4. Run the script with `--image` to overlay text, choosing a mode:
|
||||
- **Overlay** (text directly on image, white with black outline):
|
||||
```bash
|
||||
python "$SKILL_DIR/scripts/generate_meme.py" --image /path/to/scene.png /tmp/meme.png "top text" "bottom text"
|
||||
```
|
||||
- **Bars** (black bars above/below with white text — cleaner, always readable):
|
||||
```bash
|
||||
python "$SKILL_DIR/scripts/generate_meme.py" --image /path/to/scene.png --bars /tmp/meme.png "top text" "bottom text"
|
||||
```
|
||||
Use `--bars` when the image is busy/detailed and text would be hard to read on top of it.
|
||||
5. **Verify with vision** (if `vision_analyze` is available): Check the result looks good:
|
||||
```
|
||||
vision_analyze(image_url="/tmp/meme.png", question="Is the text legible and well-positioned? Does the meme work visually?")
|
||||
```
|
||||
If the vision model flags issues (text hard to read, bad placement, etc.), try the other mode (switch between overlay and bars) or regenerate the scene.
|
||||
6. Return the image with `MEDIA:/tmp/meme.png`
|
||||
|
||||
## Examples
|
||||
|
||||
**"debugging production at 2 AM":**
|
||||
```bash
|
||||
python generate_meme.py this-is-fine /tmp/meme.png "SERVERS ARE ON FIRE" "This is fine"
|
||||
```
|
||||
|
||||
**"choosing between sleep and one more episode":**
|
||||
```bash
|
||||
python generate_meme.py drake /tmp/meme.png "Getting 8 hours of sleep" "One more episode at 3 AM"
|
||||
```
|
||||
|
||||
**"the stages of a Monday morning":**
|
||||
```bash
|
||||
python generate_meme.py expanding-brain /tmp/meme.png "Setting an alarm" "Setting 5 alarms" "Sleeping through all alarms" "Working from bed"
|
||||
```
|
||||
|
||||
## Listing Templates
|
||||
|
||||
To see all available templates:
|
||||
```bash
|
||||
python generate_meme.py --list
|
||||
```
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- Keep captions SHORT. Memes with long text look terrible.
|
||||
- Match the number of text arguments to the template's field count.
|
||||
- Pick the template that fits the joke structure, not just the topic.
|
||||
- Do not generate hateful, abusive, or personally targeted content.
|
||||
- The script caches template images in `scripts/.cache/` after first download.
|
||||
|
||||
## Verification
|
||||
|
||||
The output is correct if:
|
||||
- A .png file was created at the output path
|
||||
- Text is legible (white with black outline) on the template
|
||||
- The joke lands — caption matches the template's intended structure
|
||||
- File can be delivered via MEDIA: path
|
||||
@@ -0,0 +1 @@
|
||||
.cache/
|
||||
@@ -0,0 +1,471 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate a meme image by overlaying text on a template.
|
||||
|
||||
Usage:
|
||||
python generate_meme.py <template_id_or_name> <output_path> <text1> [text2] [text3] [text4]
|
||||
|
||||
Example:
|
||||
python generate_meme.py drake /tmp/meme.png "Writing tests" "Shipping to prod and hoping"
|
||||
python generate_meme.py "Disaster Girl" /tmp/meme.png "Top text" "Bottom text"
|
||||
python generate_meme.py --list # show curated templates
|
||||
python generate_meme.py --search "distracted" # search all imgflip templates
|
||||
|
||||
Templates with custom text positioning are in templates.json (10 curated).
|
||||
Any of the ~100 popular imgflip templates can also be used by name or ID —
|
||||
unknown templates get smart default text positioning based on their box_count.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import requests as _requests
|
||||
except ImportError:
|
||||
_requests = None
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
TEMPLATES_FILE = SCRIPT_DIR / "templates.json"
|
||||
CACHE_DIR = SCRIPT_DIR / ".cache"
|
||||
IMGFLIP_API = "https://api.imgflip.com/get_memes"
|
||||
IMGFLIP_CACHE_FILE = CACHE_DIR / "imgflip_memes.json"
|
||||
IMGFLIP_CACHE_MAX_AGE = 86400 # 24 hours
|
||||
|
||||
|
||||
def _fetch_url(url: str, timeout: int = 15) -> bytes:
|
||||
"""Fetch URL content, using requests if available, else urllib."""
|
||||
if _requests is not None:
|
||||
resp = _requests.get(url, timeout=timeout)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
import urllib.request
|
||||
return urllib.request.urlopen(url, timeout=timeout).read()
|
||||
|
||||
|
||||
def load_curated_templates() -> dict:
|
||||
"""Load templates with hand-tuned text field positions."""
|
||||
with open(TEMPLATES_FILE) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _default_fields(box_count: int) -> list:
|
||||
"""Generate sensible default text field positions for unknown templates."""
|
||||
if box_count <= 0:
|
||||
box_count = 2
|
||||
if box_count == 1:
|
||||
return [{"name": "text", "x_pct": 0.5, "y_pct": 0.5, "w_pct": 0.90, "align": "center"}]
|
||||
if box_count == 2:
|
||||
return [
|
||||
{"name": "top", "x_pct": 0.5, "y_pct": 0.08, "w_pct": 0.95, "align": "center"},
|
||||
{"name": "bottom", "x_pct": 0.5, "y_pct": 0.92, "w_pct": 0.95, "align": "center"},
|
||||
]
|
||||
# 3+: evenly space vertically
|
||||
fields = []
|
||||
for i in range(box_count):
|
||||
y = 0.08 + (0.84 * i / (box_count - 1)) if box_count > 1 else 0.5
|
||||
fields.append({
|
||||
"name": f"text{i+1}",
|
||||
"x_pct": 0.5,
|
||||
"y_pct": round(y, 2),
|
||||
"w_pct": 0.90,
|
||||
"align": "center",
|
||||
})
|
||||
return fields
|
||||
|
||||
|
||||
def fetch_imgflip_templates() -> list:
|
||||
"""Fetch popular meme templates from imgflip API. Cached for 24h."""
|
||||
import time
|
||||
|
||||
CACHE_DIR.mkdir(exist_ok=True)
|
||||
# Check cache
|
||||
if IMGFLIP_CACHE_FILE.exists():
|
||||
age = time.time() - IMGFLIP_CACHE_FILE.stat().st_mtime
|
||||
if age < IMGFLIP_CACHE_MAX_AGE:
|
||||
with open(IMGFLIP_CACHE_FILE) as f:
|
||||
return json.load(f)
|
||||
|
||||
try:
|
||||
data = json.loads(_fetch_url(IMGFLIP_API))
|
||||
memes = data.get("data", {}).get("memes", [])
|
||||
with open(IMGFLIP_CACHE_FILE, "w") as f:
|
||||
json.dump(memes, f)
|
||||
return memes
|
||||
except Exception as e:
|
||||
# If fetch fails and we have stale cache, use it
|
||||
if IMGFLIP_CACHE_FILE.exists():
|
||||
with open(IMGFLIP_CACHE_FILE) as f:
|
||||
return json.load(f)
|
||||
print(f"Warning: could not fetch imgflip templates: {e}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
|
||||
def _slugify(name: str) -> str:
|
||||
"""Convert a template name to a slug for matching."""
|
||||
return name.lower().replace(" ", "-").replace("'", "").replace("\"", "")
|
||||
|
||||
|
||||
def resolve_template(identifier: str) -> dict:
|
||||
"""Resolve a template by curated ID, imgflip name, or imgflip ID.
|
||||
|
||||
Returns dict with: name, url, fields, source.
|
||||
"""
|
||||
curated = load_curated_templates()
|
||||
|
||||
# 1. Exact curated ID match
|
||||
if identifier in curated:
|
||||
tmpl = curated[identifier]
|
||||
return {**tmpl, "source": "curated"}
|
||||
|
||||
# 2. Slugified curated match
|
||||
slug = _slugify(identifier)
|
||||
for tid, tmpl in curated.items():
|
||||
if _slugify(tmpl["name"]) == slug or tid == slug:
|
||||
return {**tmpl, "source": "curated"}
|
||||
|
||||
# 3. Search imgflip templates
|
||||
imgflip_memes = fetch_imgflip_templates()
|
||||
slug_lower = slug.lower()
|
||||
id_lower = identifier.strip()
|
||||
|
||||
for meme in imgflip_memes:
|
||||
meme_slug = _slugify(meme["name"])
|
||||
# Check curated first for this imgflip template (custom positioning)
|
||||
for tid, ctmpl in curated.items():
|
||||
if _slugify(ctmpl["name"]) == meme_slug:
|
||||
if meme_slug == slug_lower or meme["id"] == id_lower:
|
||||
return {**ctmpl, "source": "curated"}
|
||||
|
||||
if meme_slug == slug_lower or meme["id"] == id_lower or slug_lower in meme_slug:
|
||||
return {
|
||||
"name": meme["name"],
|
||||
"url": meme["url"],
|
||||
"fields": _default_fields(meme.get("box_count", 2)),
|
||||
"source": "imgflip",
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_template_image(url: str) -> Image.Image:
|
||||
"""Download a template image, caching it locally."""
|
||||
CACHE_DIR.mkdir(exist_ok=True)
|
||||
# Use URL hash as cache key
|
||||
cache_name = url.split("/")[-1]
|
||||
cache_path = CACHE_DIR / cache_name
|
||||
|
||||
# Always cache as PNG to avoid JPEG/RGBA conflicts
|
||||
cache_path = cache_path.with_suffix(".png")
|
||||
|
||||
if cache_path.exists():
|
||||
return Image.open(cache_path).convert("RGBA")
|
||||
|
||||
data = _fetch_url(url)
|
||||
img = Image.open(BytesIO(data)).convert("RGBA")
|
||||
img.save(cache_path, "PNG")
|
||||
return img
|
||||
|
||||
|
||||
def find_font(size: int) -> ImageFont.FreeTypeFont:
|
||||
"""Find a bold font for meme text. Tries Impact, then falls back."""
|
||||
candidates = [
|
||||
"/usr/share/fonts/truetype/msttcorefonts/Impact.ttf",
|
||||
"/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/liberation-sans/LiberationSans-Bold.ttf",
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
|
||||
"/usr/share/fonts/dejavu-sans/DejaVuSans-Bold.ttf",
|
||||
"/System/Library/Fonts/Helvetica.ttc",
|
||||
"/System/Library/Fonts/SFCompact.ttf",
|
||||
]
|
||||
for path in candidates:
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
return ImageFont.truetype(path, size)
|
||||
except (OSError, IOError):
|
||||
continue
|
||||
# Last resort: Pillow default
|
||||
try:
|
||||
return ImageFont.truetype("DejaVuSans-Bold", size)
|
||||
except (OSError, IOError):
|
||||
return ImageFont.load_default()
|
||||
|
||||
|
||||
def _wrap_text(text: str, font: ImageFont.FreeTypeFont, max_width: int) -> str:
|
||||
"""Word-wrap text to fit within max_width pixels. Never breaks mid-word."""
|
||||
words = text.split()
|
||||
if not words:
|
||||
return text
|
||||
lines = []
|
||||
current_line = words[0]
|
||||
for word in words[1:]:
|
||||
test_line = current_line + " " + word
|
||||
if font.getlength(test_line) <= max_width:
|
||||
current_line = test_line
|
||||
else:
|
||||
lines.append(current_line)
|
||||
current_line = word
|
||||
lines.append(current_line)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def draw_outlined_text(
|
||||
draw: ImageDraw.ImageDraw,
|
||||
text: str,
|
||||
x: int,
|
||||
y: int,
|
||||
font_size: int,
|
||||
max_width: int,
|
||||
align: str = "center",
|
||||
):
|
||||
"""Draw white text with black outline, auto-scaled to fit max_width."""
|
||||
# Auto-scale: reduce font size until text fits reasonably
|
||||
size = font_size
|
||||
while size > 12:
|
||||
font = find_font(size)
|
||||
wrapped = _wrap_text(text, font, max_width)
|
||||
bbox = draw.multiline_textbbox((0, 0), wrapped, font=font, align=align)
|
||||
text_w = bbox[2] - bbox[0]
|
||||
line_count = wrapped.count("\n") + 1
|
||||
# Accept if width fits and not too many lines
|
||||
if text_w <= max_width * 1.05 and line_count <= 4:
|
||||
break
|
||||
size -= 2
|
||||
else:
|
||||
font = find_font(size)
|
||||
wrapped = _wrap_text(text, font, max_width)
|
||||
|
||||
# Measure total text block
|
||||
bbox = draw.multiline_textbbox((0, 0), wrapped, font=font, align=align)
|
||||
text_w = bbox[2] - bbox[0]
|
||||
text_h = bbox[3] - bbox[1]
|
||||
|
||||
# Center horizontally at x, vertically at y
|
||||
tx = x - text_w // 2
|
||||
ty = y - text_h // 2
|
||||
|
||||
# Draw outline (black border)
|
||||
outline_range = max(2, font.size // 18)
|
||||
for dx in range(-outline_range, outline_range + 1):
|
||||
for dy in range(-outline_range, outline_range + 1):
|
||||
if dx == 0 and dy == 0:
|
||||
continue
|
||||
draw.multiline_text(
|
||||
(tx + dx, ty + dy), wrapped, font=font, fill="black", align=align
|
||||
)
|
||||
# Draw main text (white)
|
||||
draw.multiline_text((tx, ty), wrapped, font=font, fill="white", align=align)
|
||||
|
||||
|
||||
def _overlay_on_image(img: Image.Image, texts: list, fields: list) -> Image.Image:
|
||||
"""Overlay meme text directly on an image using field positions."""
|
||||
draw = ImageDraw.Draw(img)
|
||||
w, h = img.size
|
||||
base_font_size = max(16, min(w, h) // 12)
|
||||
|
||||
for i, field in enumerate(fields):
|
||||
if i >= len(texts):
|
||||
break
|
||||
text = texts[i].strip()
|
||||
if not text:
|
||||
continue
|
||||
fx = int(field["x_pct"] * w)
|
||||
fy = int(field["y_pct"] * h)
|
||||
fw = int(field["w_pct"] * w)
|
||||
draw_outlined_text(draw, text, fx, fy, base_font_size, fw, field.get("align", "center"))
|
||||
return img
|
||||
|
||||
|
||||
def _add_bars(img: Image.Image, texts: list) -> Image.Image:
|
||||
"""Add black bars with white text above/below the image.
|
||||
|
||||
Distributes texts across bars: first text on top bar, last text on
|
||||
bottom bar, any middle texts overlaid on the image center.
|
||||
"""
|
||||
w, h = img.size
|
||||
bar_font_size = max(20, w // 16)
|
||||
font = find_font(bar_font_size)
|
||||
padding = bar_font_size // 2
|
||||
|
||||
top_text = texts[0].strip() if texts else ""
|
||||
bottom_text = texts[-1].strip() if len(texts) > 1 else ""
|
||||
middle_texts = [t.strip() for t in texts[1:-1]] if len(texts) > 2 else []
|
||||
|
||||
def _measure_bar(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
wrapped = _wrap_text(text, font, int(w * 0.92))
|
||||
bbox = ImageDraw.Draw(Image.new("RGB", (1, 1))).multiline_textbbox(
|
||||
(0, 0), wrapped, font=font, align="center"
|
||||
)
|
||||
return (bbox[3] - bbox[1]) + padding * 2
|
||||
|
||||
top_h = _measure_bar(top_text)
|
||||
bottom_h = _measure_bar(bottom_text)
|
||||
new_h = h + top_h + bottom_h
|
||||
|
||||
canvas = Image.new("RGB", (w, new_h), (0, 0, 0))
|
||||
canvas.paste(img.convert("RGB"), (0, top_h))
|
||||
draw = ImageDraw.Draw(canvas)
|
||||
|
||||
if top_text:
|
||||
wrapped = _wrap_text(top_text, font, int(w * 0.92))
|
||||
bbox = draw.multiline_textbbox((0, 0), wrapped, font=font, align="center")
|
||||
tw = bbox[2] - bbox[0]
|
||||
th = bbox[3] - bbox[1]
|
||||
tx = (w - tw) // 2
|
||||
ty = (top_h - th) // 2
|
||||
draw.multiline_text((tx, ty), wrapped, font=font, fill="white", align="center")
|
||||
|
||||
if bottom_text:
|
||||
wrapped = _wrap_text(bottom_text, font, int(w * 0.92))
|
||||
bbox = draw.multiline_textbbox((0, 0), wrapped, font=font, align="center")
|
||||
tw = bbox[2] - bbox[0]
|
||||
th = bbox[3] - bbox[1]
|
||||
tx = (w - tw) // 2
|
||||
ty = top_h + h + (bottom_h - th) // 2
|
||||
draw.multiline_text((tx, ty), wrapped, font=font, fill="white", align="center")
|
||||
|
||||
# Overlay any middle texts centered on the image
|
||||
if middle_texts:
|
||||
mid_fields = _default_fields(len(middle_texts))
|
||||
# Shift y positions to account for top bar offset
|
||||
for field in mid_fields:
|
||||
field["y_pct"] = (top_h + field["y_pct"] * h) / new_h
|
||||
field["w_pct"] = 0.90
|
||||
_overlay_on_image(canvas, middle_texts, mid_fields)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
def generate_meme(template_id: str, texts: list[str], output_path: str) -> str:
|
||||
"""Generate a meme from a template and save it. Returns the path."""
|
||||
tmpl = resolve_template(template_id)
|
||||
|
||||
if tmpl is None:
|
||||
print(f"Unknown template: {template_id}", file=sys.stderr)
|
||||
print("Use --list to see curated templates or --search to find imgflip templates.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
fields = tmpl["fields"]
|
||||
print(f"Using template: {tmpl['name']} ({tmpl['source']}, {len(fields)} fields)", file=sys.stderr)
|
||||
|
||||
img = get_template_image(tmpl["url"])
|
||||
img = _overlay_on_image(img, texts, fields)
|
||||
|
||||
output = Path(output_path)
|
||||
if output.suffix.lower() in (".jpg", ".jpeg"):
|
||||
img = img.convert("RGB")
|
||||
img.save(str(output), quality=95)
|
||||
return str(output)
|
||||
|
||||
|
||||
def generate_from_image(
|
||||
image_path: str, texts: list[str], output_path: str, use_bars: bool = False
|
||||
) -> str:
|
||||
"""Generate a meme from a custom image (e.g. AI-generated). Returns the path."""
|
||||
img = Image.open(image_path).convert("RGBA")
|
||||
print(f"Custom image: {img.size[0]}x{img.size[1]}, {len(texts)} text(s), mode={'bars' if use_bars else 'overlay'}", file=sys.stderr)
|
||||
|
||||
if use_bars:
|
||||
result = _add_bars(img, texts)
|
||||
else:
|
||||
fields = _default_fields(len(texts))
|
||||
result = _overlay_on_image(img, texts, fields)
|
||||
|
||||
output = Path(output_path)
|
||||
if output.suffix.lower() in (".jpg", ".jpeg"):
|
||||
result = result.convert("RGB")
|
||||
result.save(str(output), quality=95)
|
||||
return str(output)
|
||||
|
||||
|
||||
def list_templates():
|
||||
"""Print curated templates with custom positioning."""
|
||||
templates = load_curated_templates()
|
||||
print(f"{'ID':<25} {'Name':<30} {'Fields':<8} Best for")
|
||||
print("-" * 90)
|
||||
for tid, tmpl in sorted(templates.items()):
|
||||
fields = len(tmpl["fields"])
|
||||
print(f"{tid:<25} {tmpl['name']:<30} {fields:<8} {tmpl['best_for']}")
|
||||
print(f"\n{len(templates)} curated templates with custom text positioning.")
|
||||
print("Use --search to find any of the ~100 popular imgflip templates.")
|
||||
|
||||
|
||||
def search_templates(query: str):
|
||||
"""Search imgflip templates by name."""
|
||||
imgflip_memes = fetch_imgflip_templates()
|
||||
curated = load_curated_templates()
|
||||
curated_slugs = {_slugify(t["name"]) for t in curated.values()}
|
||||
query_lower = query.lower()
|
||||
|
||||
matches = []
|
||||
for meme in imgflip_memes:
|
||||
if query_lower in meme["name"].lower():
|
||||
slug = _slugify(meme["name"])
|
||||
has_custom = "curated" if slug in curated_slugs else "default"
|
||||
matches.append((meme["name"], meme["id"], meme.get("box_count", 2), has_custom))
|
||||
|
||||
if not matches:
|
||||
print(f"No templates found matching '{query}'")
|
||||
return
|
||||
|
||||
print(f"{'Name':<40} {'ID':<12} {'Fields':<8} Positioning")
|
||||
print("-" * 75)
|
||||
for name, mid, boxes, positioning in matches:
|
||||
print(f"{name:<40} {mid:<12} {boxes:<8} {positioning}")
|
||||
print(f"\n{len(matches)} template(s) found. Use the name or ID as the first argument.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: generate_meme.py <template_id_or_name> <output_path> <text1> [text2] ...")
|
||||
print(" generate_meme.py --image <path> [--bars] <output_path> <text1> [text2] ...")
|
||||
print(" generate_meme.py --list # curated templates")
|
||||
print(" generate_meme.py --search <query> # search all imgflip templates")
|
||||
sys.exit(1)
|
||||
|
||||
if sys.argv[1] == "--list":
|
||||
list_templates()
|
||||
sys.exit(0)
|
||||
|
||||
if sys.argv[1] == "--search":
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: generate_meme.py --search <query>")
|
||||
sys.exit(1)
|
||||
search_templates(sys.argv[2])
|
||||
sys.exit(0)
|
||||
|
||||
if sys.argv[1] == "--image":
|
||||
# Custom image mode: --image <path> [--bars] <output> <text1> ...
|
||||
args = sys.argv[2:]
|
||||
if len(args) < 3:
|
||||
print("Usage: generate_meme.py --image <image_path> [--bars] <output_path> <text1> ...")
|
||||
sys.exit(1)
|
||||
image_path = args.pop(0)
|
||||
use_bars = False
|
||||
if args and args[0] == "--bars":
|
||||
use_bars = True
|
||||
args.pop(0)
|
||||
if len(args) < 2:
|
||||
print("Need at least: output_path and one text argument")
|
||||
sys.exit(1)
|
||||
output_path = args.pop(0)
|
||||
result = generate_from_image(image_path, args, output_path, use_bars=use_bars)
|
||||
print(f"Meme saved to: {result}")
|
||||
sys.exit(0)
|
||||
|
||||
if len(sys.argv) < 4:
|
||||
print("Need at least: template_id_or_name, output_path, and one text argument")
|
||||
sys.exit(1)
|
||||
|
||||
template_id = sys.argv[1]
|
||||
output_path = sys.argv[2]
|
||||
texts = sys.argv[3:]
|
||||
|
||||
result = generate_meme(template_id, texts, output_path)
|
||||
print(f"Meme saved to: {result}")
|
||||
@@ -0,0 +1,97 @@
|
||||
{
|
||||
"this-is-fine": {
|
||||
"name": "This is Fine",
|
||||
"url": "https://i.imgflip.com/wxica.jpg",
|
||||
"best_for": "chaos, denial, pretending things are okay",
|
||||
"fields": [
|
||||
{"name": "top", "x_pct": 0.5, "y_pct": 0.08, "w_pct": 0.95, "align": "center"},
|
||||
{"name": "bottom", "x_pct": 0.5, "y_pct": 0.92, "w_pct": 0.95, "align": "center"}
|
||||
]
|
||||
},
|
||||
"drake": {
|
||||
"name": "Drake Hotline Bling",
|
||||
"url": "https://i.imgflip.com/30b1gx.jpg",
|
||||
"best_for": "rejecting one thing, preferring another",
|
||||
"fields": [
|
||||
{"name": "reject", "x_pct": 0.73, "y_pct": 0.25, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "approve", "x_pct": 0.73, "y_pct": 0.75, "w_pct": 0.45, "align": "center"}
|
||||
]
|
||||
},
|
||||
"distracted-boyfriend": {
|
||||
"name": "Distracted Boyfriend",
|
||||
"url": "https://i.imgflip.com/1ur9b0.jpg",
|
||||
"best_for": "distraction, shifting priorities, temptation",
|
||||
"fields": [
|
||||
{"name": "distraction", "x_pct": 0.18, "y_pct": 0.90, "w_pct": 0.30, "align": "center"},
|
||||
{"name": "current", "x_pct": 0.55, "y_pct": 0.90, "w_pct": 0.30, "align": "center"},
|
||||
{"name": "person", "x_pct": 0.82, "y_pct": 0.90, "w_pct": 0.30, "align": "center"}
|
||||
]
|
||||
},
|
||||
"two-buttons": {
|
||||
"name": "Two Buttons",
|
||||
"url": "https://i.imgflip.com/1g8my4.jpg",
|
||||
"best_for": "impossible choice, dilemma between two options",
|
||||
"fields": [
|
||||
{"name": "left_button", "x_pct": 0.30, "y_pct": 0.20, "w_pct": 0.28, "align": "center"},
|
||||
{"name": "right_button", "x_pct": 0.62, "y_pct": 0.12, "w_pct": 0.28, "align": "center"},
|
||||
{"name": "person", "x_pct": 0.5, "y_pct": 0.85, "w_pct": 0.90, "align": "center"}
|
||||
]
|
||||
},
|
||||
"expanding-brain": {
|
||||
"name": "Expanding Brain",
|
||||
"url": "https://i.imgflip.com/1jwhww.jpg",
|
||||
"best_for": "escalating irony, increasingly absurd ideas",
|
||||
"fields": [
|
||||
{"name": "level1", "x_pct": 0.25, "y_pct": 0.12, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "level2", "x_pct": 0.25, "y_pct": 0.38, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "level3", "x_pct": 0.25, "y_pct": 0.63, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "level4", "x_pct": 0.25, "y_pct": 0.88, "w_pct": 0.45, "align": "center"}
|
||||
]
|
||||
},
|
||||
"change-my-mind": {
|
||||
"name": "Change My Mind",
|
||||
"url": "https://i.imgflip.com/24y43o.jpg",
|
||||
"best_for": "strong or ironic opinion, controversial take",
|
||||
"fields": [
|
||||
{"name": "statement", "x_pct": 0.58, "y_pct": 0.78, "w_pct": 0.35, "align": "center"}
|
||||
]
|
||||
},
|
||||
"woman-yelling-at-cat": {
|
||||
"name": "Woman Yelling at Cat",
|
||||
"url": "https://i.imgflip.com/345v97.jpg",
|
||||
"best_for": "argument, blame, misunderstanding",
|
||||
"fields": [
|
||||
{"name": "woman", "x_pct": 0.27, "y_pct": 0.10, "w_pct": 0.50, "align": "center"},
|
||||
{"name": "cat", "x_pct": 0.76, "y_pct": 0.10, "w_pct": 0.44, "align": "center"}
|
||||
]
|
||||
},
|
||||
"one-does-not-simply": {
|
||||
"name": "One Does Not Simply",
|
||||
"url": "https://i.imgflip.com/1bij.jpg",
|
||||
"best_for": "something that sounds easy but is actually hard",
|
||||
"fields": [
|
||||
{"name": "top", "x_pct": 0.5, "y_pct": 0.08, "w_pct": 0.95, "align": "center"},
|
||||
{"name": "bottom", "x_pct": 0.5, "y_pct": 0.92, "w_pct": 0.95, "align": "center"}
|
||||
]
|
||||
},
|
||||
"grus-plan": {
|
||||
"name": "Gru's Plan",
|
||||
"url": "https://i.imgflip.com/26jxvs.jpg",
|
||||
"best_for": "a plan that backfires, unexpected consequence",
|
||||
"fields": [
|
||||
{"name": "step1", "x_pct": 0.5, "y_pct": 0.05, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "step2", "x_pct": 0.5, "y_pct": 0.30, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "step3", "x_pct": 0.5, "y_pct": 0.55, "w_pct": 0.45, "align": "center"},
|
||||
{"name": "realization", "x_pct": 0.5, "y_pct": 0.80, "w_pct": 0.45, "align": "center"}
|
||||
]
|
||||
},
|
||||
"batman-slapping-robin": {
|
||||
"name": "Batman Slapping Robin",
|
||||
"url": "https://i.imgflip.com/9ehk.jpg",
|
||||
"best_for": "shutting down a bad idea, correcting someone",
|
||||
"fields": [
|
||||
{"name": "robin", "x_pct": 0.28, "y_pct": 0.08, "w_pct": 0.50, "align": "center"},
|
||||
{"name": "batman", "x_pct": 0.72, "y_pct": 0.08, "w_pct": 0.50, "align": "center"}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,235 @@
|
||||
---
|
||||
name: bioinformatics
|
||||
description: Gateway to 400+ bioinformatics skills from bioSkills and ClawBio. Covers genomics, transcriptomics, single-cell, variant calling, pharmacogenomics, metagenomics, structural biology, and more. Fetches domain-specific reference material on demand.
|
||||
version: 1.0.0
|
||||
platforms: [linux, macos]
|
||||
metadata:
|
||||
hermes:
|
||||
tags: [bioinformatics, genomics, sequencing, biology, research, science]
|
||||
category: research
|
||||
---
|
||||
|
||||
# Bioinformatics Skills Gateway
|
||||
|
||||
Use when asked about bioinformatics, genomics, sequencing, variant calling, gene expression, single-cell analysis, protein structure, pharmacogenomics, metagenomics, phylogenetics, or any computational biology task.
|
||||
|
||||
This skill is a gateway to two open-source bioinformatics skill libraries. Instead of bundling hundreds of domain-specific skills, it indexes them and fetches what you need on demand.
|
||||
|
||||
## Sources
|
||||
|
||||
◆ **bioSkills** — 385 reference skills (code patterns, parameter guides, decision trees)
|
||||
Repo: https://github.com/GPTomics/bioSkills
|
||||
Format: SKILL.md per topic with code examples. Python/R/CLI.
|
||||
|
||||
◆ **ClawBio** — 33 runnable pipeline skills (executable scripts, reproducibility bundles)
|
||||
Repo: https://github.com/ClawBio/ClawBio
|
||||
Format: Python scripts with demos. Each analysis exports report.md + commands.sh + environment.yml.
|
||||
|
||||
## How to fetch and use a skill
|
||||
|
||||
1. Identify the domain and skill name from the index below.
|
||||
2. Clone the relevant repo (shallow clone to save time):
|
||||
```bash
|
||||
# bioSkills (reference material)
|
||||
git clone --depth 1 https://github.com/GPTomics/bioSkills.git /tmp/bioSkills
|
||||
|
||||
# ClawBio (runnable pipelines)
|
||||
git clone --depth 1 https://github.com/ClawBio/ClawBio.git /tmp/ClawBio
|
||||
```
|
||||
3. Read the specific skill:
|
||||
```bash
|
||||
# bioSkills — each skill is at: <category>/<skill-name>/SKILL.md
|
||||
cat /tmp/bioSkills/variant-calling/gatk-variant-calling/SKILL.md
|
||||
|
||||
# ClawBio — each skill is at: skills/<skill-name>/
|
||||
cat /tmp/ClawBio/skills/pharmgx-reporter/README.md
|
||||
```
|
||||
4. Follow the fetched skill as reference material. These are NOT Hermes-format skills — treat them as expert domain guides. They contain correct parameters, proper tool flags, and validated pipelines.
|
||||
|
||||
## Skill Index by Domain
|
||||
|
||||
### Sequence Fundamentals
|
||||
bioSkills:
|
||||
sequence-io/ — read-sequences, write-sequences, format-conversion, batch-processing, compressed-files, fastq-quality, filter-sequences, paired-end-fastq, sequence-statistics
|
||||
sequence-manipulation/ — seq-objects, reverse-complement, transcription-translation, motif-search, codon-usage, sequence-properties, sequence-slicing
|
||||
ClawBio:
|
||||
seq-wrangler — Sequence QC, alignment, and BAM processing (wraps FastQC, BWA, SAMtools)
|
||||
|
||||
### Read QC & Alignment
|
||||
bioSkills:
|
||||
read-qc/ — quality-reports, fastp-workflow, adapter-trimming, quality-filtering, umi-processing, contamination-screening, rnaseq-qc
|
||||
read-alignment/ — bwa-alignment, star-alignment, hisat2-alignment, bowtie2-alignment
|
||||
alignment-files/ — sam-bam-basics, alignment-sorting, alignment-filtering, bam-statistics, duplicate-handling, pileup-generation
|
||||
|
||||
### Variant Calling & Annotation
|
||||
bioSkills:
|
||||
variant-calling/ — gatk-variant-calling, deepvariant, variant-calling (bcftools), joint-calling, structural-variant-calling, filtering-best-practices, variant-annotation, variant-normalization, vcf-basics, vcf-manipulation, vcf-statistics, consensus-sequences, clinical-interpretation
|
||||
ClawBio:
|
||||
vcf-annotator — VEP + ClinVar + gnomAD annotation with ancestry-aware context
|
||||
variant-annotation — Variant annotation pipeline
|
||||
|
||||
### Differential Expression (Bulk RNA-seq)
|
||||
bioSkills:
|
||||
differential-expression/ — deseq2-basics, edger-basics, batch-correction, de-results, de-visualization, timeseries-de
|
||||
rna-quantification/ — alignment-free-quant (Salmon/kallisto), featurecounts-counting, tximport-workflow, count-matrix-qc
|
||||
expression-matrix/ — counts-ingest, gene-id-mapping, metadata-joins, sparse-handling
|
||||
ClawBio:
|
||||
rnaseq-de — Full DE pipeline with QC, normalization, and visualization
|
||||
diff-visualizer — Rich visualization and reporting for DE results
|
||||
|
||||
### Single-Cell RNA-seq
|
||||
bioSkills:
|
||||
single-cell/ — preprocessing, clustering, batch-integration, cell-annotation, cell-communication, doublet-detection, markers-annotation, trajectory-inference, multimodal-integration, perturb-seq, scatac-analysis, lineage-tracing, metabolite-communication, data-io
|
||||
ClawBio:
|
||||
scrna-orchestrator — Full Scanpy pipeline (QC, clustering, markers, annotation)
|
||||
scrna-embedding — scVI-based latent embedding and batch integration
|
||||
|
||||
### Spatial Transcriptomics
|
||||
bioSkills:
|
||||
spatial-transcriptomics/ — spatial-data-io, spatial-preprocessing, spatial-domains, spatial-deconvolution, spatial-communication, spatial-neighbors, spatial-statistics, spatial-visualization, spatial-multiomics, spatial-proteomics, image-analysis
|
||||
|
||||
### Epigenomics
|
||||
bioSkills:
|
||||
chip-seq/ — peak-calling, differential-binding, motif-analysis, peak-annotation, chipseq-qc, chipseq-visualization, super-enhancers
|
||||
atac-seq/ — atac-peak-calling, atac-qc, differential-accessibility, footprinting, motif-deviation, nucleosome-positioning
|
||||
methylation-analysis/ — bismark-alignment, methylation-calling, dmr-detection, methylkit-analysis
|
||||
hi-c-analysis/ — hic-data-io, tad-detection, loop-calling, compartment-analysis, contact-pairs, matrix-operations, hic-visualization, hic-differential
|
||||
ClawBio:
|
||||
methylation-clock — Epigenetic age estimation
|
||||
|
||||
### Pharmacogenomics & Clinical
|
||||
bioSkills:
|
||||
clinical-databases/ — clinvar-lookup, gnomad-frequencies, dbsnp-queries, pharmacogenomics, polygenic-risk, hla-typing, variant-prioritization, somatic-signatures, tumor-mutational-burden, myvariant-queries
|
||||
ClawBio:
|
||||
pharmgx-reporter — PGx report from 23andMe/AncestryDNA (12 genes, 31 SNPs, 51 drugs)
|
||||
drug-photo — Photo of medication → personalized PGx dosage card (via vision)
|
||||
clinpgx — ClinPGx API for gene-drug data and CPIC guidelines
|
||||
gwas-lookup — Federated variant lookup across 9 genomic databases
|
||||
gwas-prs — Polygenic risk scores from consumer genetic data
|
||||
nutrigx_advisor — Personalized nutrition from consumer genetic data
|
||||
|
||||
### Population Genetics & GWAS
|
||||
bioSkills:
|
||||
population-genetics/ — association-testing (PLINK GWAS), plink-basics, population-structure, linkage-disequilibrium, scikit-allel-analysis, selection-statistics
|
||||
causal-genomics/ — mendelian-randomization, fine-mapping, colocalization-analysis, mediation-analysis, pleiotropy-detection
|
||||
phasing-imputation/ — haplotype-phasing, genotype-imputation, imputation-qc, reference-panels
|
||||
ClawBio:
|
||||
claw-ancestry-pca — Ancestry PCA against SGDP reference panel
|
||||
|
||||
### Metagenomics & Microbiome
|
||||
bioSkills:
|
||||
metagenomics/ — kraken-classification, metaphlan-profiling, abundance-estimation, functional-profiling, amr-detection, strain-tracking, metagenome-visualization
|
||||
microbiome/ — amplicon-processing, diversity-analysis, differential-abundance, taxonomy-assignment, functional-prediction, qiime2-workflow
|
||||
ClawBio:
|
||||
claw-metagenomics — Shotgun metagenomics profiling (taxonomy, resistome, functional pathways)
|
||||
|
||||
### Genome Assembly & Annotation
|
||||
bioSkills:
|
||||
genome-assembly/ — hifi-assembly, long-read-assembly, short-read-assembly, metagenome-assembly, assembly-polishing, assembly-qc, scaffolding, contamination-detection
|
||||
genome-annotation/ — eukaryotic-gene-prediction, prokaryotic-annotation, functional-annotation, ncrna-annotation, repeat-annotation, annotation-transfer
|
||||
long-read-sequencing/ — basecalling, long-read-alignment, long-read-qc, clair3-variants, structural-variants, medaka-polishing, nanopore-methylation, isoseq-analysis
|
||||
|
||||
### Structural Biology & Chemoinformatics
|
||||
bioSkills:
|
||||
structural-biology/ — alphafold-predictions, modern-structure-prediction, structure-io, structure-navigation, structure-modification, geometric-analysis
|
||||
chemoinformatics/ — molecular-io, molecular-descriptors, similarity-searching, substructure-search, virtual-screening, admet-prediction, reaction-enumeration
|
||||
ClawBio:
|
||||
struct-predictor — Local AlphaFold/Boltz/Chai structure prediction with comparison
|
||||
|
||||
### Proteomics
|
||||
bioSkills:
|
||||
proteomics/ — data-import, peptide-identification, protein-inference, quantification, differential-abundance, dia-analysis, ptm-analysis, proteomics-qc, spectral-libraries
|
||||
ClawBio:
|
||||
proteomics-de — Proteomics differential expression
|
||||
|
||||
### Pathway Analysis & Gene Networks
|
||||
bioSkills:
|
||||
pathway-analysis/ — go-enrichment, gsea, kegg-pathways, reactome-pathways, wikipathways, enrichment-visualization
|
||||
gene-regulatory-networks/ — scenic-regulons, coexpression-networks, differential-networks, multiomics-grn, perturbation-simulation
|
||||
|
||||
### Immunoinformatics
|
||||
bioSkills:
|
||||
immunoinformatics/ — mhc-binding-prediction, epitope-prediction, neoantigen-prediction, immunogenicity-scoring, tcr-epitope-binding
|
||||
tcr-bcr-analysis/ — mixcr-analysis, scirpy-analysis, immcantation-analysis, repertoire-visualization, vdjtools-analysis
|
||||
|
||||
### CRISPR & Genome Engineering
|
||||
bioSkills:
|
||||
crispr-screens/ — mageck-analysis, jacks-analysis, hit-calling, screen-qc, library-design, crispresso-editing, base-editing-analysis, batch-correction
|
||||
genome-engineering/ — grna-design, off-target-prediction, hdr-template-design, base-editing-design, prime-editing-design
|
||||
|
||||
### Workflow Management
|
||||
bioSkills:
|
||||
workflow-management/ — snakemake-workflows, nextflow-pipelines, cwl-workflows, wdl-workflows
|
||||
ClawBio:
|
||||
repro-enforcer — Export any analysis as reproducibility bundle (Conda env + Singularity + checksums)
|
||||
galaxy-bridge — Access 8,000+ Galaxy tools from usegalaxy.org
|
||||
|
||||
### Specialized Domains
|
||||
bioSkills:
|
||||
alternative-splicing/ — splicing-quantification, differential-splicing, isoform-switching, sashimi-plots, single-cell-splicing, splicing-qc
|
||||
ecological-genomics/ — edna-metabarcoding, landscape-genomics, conservation-genetics, biodiversity-metrics, community-ecology, species-delimitation
|
||||
epidemiological-genomics/ — pathogen-typing, variant-surveillance, phylodynamics, transmission-inference, amr-surveillance
|
||||
liquid-biopsy/ — cfdna-preprocessing, ctdna-mutation-detection, fragment-analysis, tumor-fraction-estimation, methylation-based-detection, longitudinal-monitoring
|
||||
epitranscriptomics/ — m6a-peak-calling, m6a-differential, m6anet-analysis, merip-preprocessing, modification-visualization
|
||||
metabolomics/ — xcms-preprocessing, metabolite-annotation, normalization-qc, statistical-analysis, pathway-mapping, lipidomics, targeted-analysis, msdial-preprocessing
|
||||
flow-cytometry/ — fcs-handling, gating-analysis, compensation-transformation, clustering-phenotyping, differential-analysis, cytometry-qc, doublet-detection, bead-normalization
|
||||
systems-biology/ — flux-balance-analysis, metabolic-reconstruction, gene-essentiality, context-specific-models, model-curation
|
||||
rna-structure/ — secondary-structure-prediction, ncrna-search, structure-probing
|
||||
|
||||
### Data Visualization & Reporting
|
||||
bioSkills:
|
||||
data-visualization/ — ggplot2-fundamentals, heatmaps-clustering, volcano-customization, circos-plots, genome-browser-tracks, interactive-visualization, multipanel-figures, network-visualization, upset-plots, color-palettes, specialized-omics-plots, genome-tracks
|
||||
reporting/ — rmarkdown-reports, quarto-reports, jupyter-reports, automated-qc-reports, figure-export
|
||||
ClawBio:
|
||||
profile-report — Analysis profile reporting
|
||||
data-extractor — Extract numerical data from scientific figure images (via vision)
|
||||
lit-synthesizer — PubMed/bioRxiv search, summarization, citation graphs
|
||||
pubmed-summariser — Gene/disease PubMed search with structured briefing
|
||||
|
||||
### Database Access
|
||||
bioSkills:
|
||||
database-access/ — entrez-search, entrez-fetch, entrez-link, blast-searches, local-blast, sra-data, geo-data, uniprot-access, batch-downloads, interaction-databases, sequence-similarity
|
||||
ClawBio:
|
||||
ukb-navigator — Semantic search across 12,000+ UK Biobank fields
|
||||
clinical-trial-finder — Clinical trial discovery
|
||||
|
||||
### Experimental Design
|
||||
bioSkills:
|
||||
experimental-design/ — power-analysis, sample-size, batch-design, multiple-testing
|
||||
|
||||
### Machine Learning for Omics
|
||||
bioSkills:
|
||||
machine-learning/ — omics-classifiers, biomarker-discovery, survival-analysis, model-validation, prediction-explanation, atlas-mapping
|
||||
ClawBio:
|
||||
claw-semantic-sim — Semantic similarity index for disease literature (PubMedBERT)
|
||||
omics-target-evidence-mapper — Aggregate target-level evidence across omics sources
|
||||
|
||||
## Environment Setup
|
||||
|
||||
These skills assume a bioinformatics workstation. Common dependencies:
|
||||
|
||||
```bash
|
||||
# Python
|
||||
pip install biopython pysam cyvcf2 pybedtools pyBigWig scikit-allel anndata scanpy mygene
|
||||
|
||||
# R/Bioconductor
|
||||
Rscript -e 'BiocManager::install(c("DESeq2","edgeR","Seurat","clusterProfiler","methylKit"))'
|
||||
|
||||
# CLI tools (Ubuntu/Debian)
|
||||
sudo apt install samtools bcftools ncbi-blast+ minimap2 bedtools
|
||||
|
||||
# CLI tools (macOS)
|
||||
brew install samtools bcftools blast minimap2 bedtools
|
||||
|
||||
# Or via Conda (recommended for reproducibility)
|
||||
conda install -c bioconda samtools bcftools blast minimap2 bedtools fastp kraken2
|
||||
```
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- The fetched skills are NOT in Hermes SKILL.md format. They use their own structure (bioSkills: code pattern cookbooks; ClawBio: README + Python scripts). Read them as expert reference material.
|
||||
- bioSkills are reference guides — they show correct parameters and code patterns but aren't executable pipelines.
|
||||
- ClawBio skills are executable — many have `--demo` flags and can be run directly.
|
||||
- Both repos assume bioinformatics tools are installed. Check prerequisites before running pipelines.
|
||||
- For ClawBio, run `pip install -r requirements.txt` in the cloned repo first.
|
||||
- Genomic data files can be very large. Be mindful of disk space when downloading reference genomes, SRA datasets, or building indices.
|
||||
@@ -0,0 +1,80 @@
|
||||
# Gemini OAuth Provider — Implementation Plan
|
||||
|
||||
## Goal
|
||||
Add a first-class `gemini` provider that authenticates via Google OAuth, using the standard Gemini API (not Cloud Code Assist). Users who have a Google AI subscription or Gemini API access can authenticate through the browser without needing to manually copy API keys.
|
||||
|
||||
## Architecture Decision
|
||||
- **Path A (chosen):** Standard Gemini API at `generativelanguage.googleapis.com/v1beta/openai/`
|
||||
- **NOT Path B:** Cloud Code Assist (`cloudcode-pa.googleapis.com`) — rate-limited free tier, internal API, account ban risk
|
||||
- Standard `chat_completions` api_mode via OpenAI SDK — no new api_mode needed
|
||||
- Our own OAuth credentials — NOT sharing tokens with Gemini CLI
|
||||
|
||||
## OAuth Flow
|
||||
- **Type:** Authorization Code + PKCE (S256) — same pattern as clawdbot/pi-mono
|
||||
- **Auth URL:** `https://accounts.google.com/o/oauth2/v2/auth`
|
||||
- **Token URL:** `https://oauth2.googleapis.com/token`
|
||||
- **Redirect:** `http://localhost:8085/oauth2callback` (localhost callback server)
|
||||
- **Fallback:** Manual URL paste for remote/WSL/headless environments
|
||||
- **Scopes:** `https://www.googleapis.com/auth/cloud-platform`, `https://www.googleapis.com/auth/userinfo.email`
|
||||
- **PKCE:** S256 code challenge, 32-byte random verifier
|
||||
|
||||
## Client ID
|
||||
- Need to register a "Desktop app" OAuth client on a Nous Research GCP project
|
||||
- Ship client_id + client_secret in code (Google considers installed app secrets non-confidential)
|
||||
- Alternatively: accept user-provided client_id via env vars as override
|
||||
|
||||
## Token Lifecycle
|
||||
- Store at `~/.hermes/gemini_oauth.json` (NOT sharing with `~/.gemini/oauth_creds.json`)
|
||||
- Fields: `client_id`, `client_secret`, `refresh_token`, `access_token`, `expires_at`, `email`
|
||||
- File permissions: 0o600
|
||||
- Before each API call: check expiry, refresh if within 5 min of expiration
|
||||
- Refresh: POST to token URL with `grant_type=refresh_token`
|
||||
- File locking for concurrent access (multiple agent sessions)
|
||||
|
||||
## API Integration
|
||||
- Base URL: `https://generativelanguage.googleapis.com/v1beta/openai/`
|
||||
- Auth: `Authorization: Bearer <access_token>` (passed as `api_key` to OpenAI SDK)
|
||||
- api_mode: `chat_completions` (standard)
|
||||
- Models: gemini-2.5-pro, gemini-2.5-flash, gemini-2.0-flash, etc.
|
||||
|
||||
## Files to Create/Modify
|
||||
|
||||
### New files
|
||||
1. `agent/google_oauth.py` — OAuth flow (PKCE, localhost server, token exchange, refresh)
|
||||
- `start_oauth_flow()` — opens browser, starts callback server
|
||||
- `exchange_code()` — code → tokens
|
||||
- `refresh_access_token()` — refresh flow
|
||||
- `load_credentials()` / `save_credentials()` — file I/O with locking
|
||||
- `get_valid_access_token()` — check expiry, refresh if needed
|
||||
- ~200 lines
|
||||
|
||||
### Existing files to modify
|
||||
2. `hermes_cli/auth.py` — Add ProviderConfig for "gemini" with auth_type="oauth_google"
|
||||
3. `hermes_cli/models.py` — Add Gemini model catalog
|
||||
4. `hermes_cli/runtime_provider.py` — Add gemini branch (read OAuth token, build OpenAI client)
|
||||
5. `hermes_cli/main.py` — Add `_model_flow_gemini()`, add to provider choices
|
||||
6. `hermes_cli/setup.py` — Add gemini auth flow (trigger browser OAuth)
|
||||
7. `run_agent.py` — Token refresh before API calls (like Copilot pattern)
|
||||
8. `agent/auxiliary_client.py` — Add gemini to aux resolution chain
|
||||
9. `agent/model_metadata.py` — Add Gemini model context lengths
|
||||
|
||||
### Tests
|
||||
10. `tests/agent/test_google_oauth.py` — OAuth flow unit tests
|
||||
11. `tests/test_api_key_providers.py` — Add gemini provider test
|
||||
|
||||
### Docs
|
||||
12. `website/docs/getting-started/quickstart.md` — Add gemini to provider table
|
||||
13. `website/docs/user-guide/configuration.md` — Gemini setup section
|
||||
14. `website/docs/reference/environment-variables.md` — New env vars
|
||||
|
||||
## Estimated scope
|
||||
~400 lines new code, ~150 lines modifications, ~100 lines tests, ~50 lines docs = ~700 lines total
|
||||
|
||||
## Prerequisites
|
||||
- Nous Research GCP project with Desktop OAuth client registered
|
||||
- OR: accept user-provided client_id via HERMES_GEMINI_CLIENT_ID env var
|
||||
|
||||
## Reference implementations
|
||||
- clawdbot: `extensions/google/oauth.flow.ts` (PKCE + localhost server)
|
||||
- pi-mono: `packages/ai/src/utils/oauth/google-gemini-cli.ts` (same flow)
|
||||
- hermes-agent Copilot OAuth: `hermes_cli/main.py` `_copilot_device_flow()` (different flow type but same lifecycle pattern)
|
||||
@@ -60,6 +60,7 @@ mcp = ["mcp>=1.2.0"]
|
||||
homeassistant = ["aiohttp>=3.9.0"]
|
||||
sms = ["aiohttp>=3.9.0"]
|
||||
acp = ["agent-client-protocol>=0.8.1,<1.0"]
|
||||
dingtalk = ["dingtalk-stream>=0.1.0"]
|
||||
rl = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git",
|
||||
"tinker @ git+https://github.com/thinking-machines-lab/tinker.git",
|
||||
@@ -84,6 +85,7 @@ all = [
|
||||
"hermes-agent[sms]",
|
||||
"hermes-agent[acp]",
|
||||
"hermes-agent[voice]",
|
||||
"hermes-agent[dingtalk]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
+96
-19
@@ -70,7 +70,7 @@ from tools.browser_tool import cleanup_browser
|
||||
|
||||
import requests
|
||||
|
||||
from hermes_constants import OPENROUTER_BASE_URL, OPENROUTER_MODELS_URL
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
# Agent internals extracted to agent/ package for modularity
|
||||
from agent.prompt_builder import (
|
||||
@@ -78,7 +78,7 @@ from agent.prompt_builder import (
|
||||
MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE,
|
||||
)
|
||||
from agent.model_metadata import (
|
||||
fetch_model_metadata, get_model_context_length,
|
||||
fetch_model_metadata,
|
||||
estimate_tokens_rough, estimate_messages_tokens_rough,
|
||||
get_next_probe_tier, parse_context_limit_from_error,
|
||||
save_context_length,
|
||||
@@ -108,7 +108,7 @@ HONCHO_TOOL_NAMES = {
|
||||
|
||||
|
||||
class _SafeWriter:
|
||||
"""Transparent stdio wrapper that catches OSError from broken pipes.
|
||||
"""Transparent stdio wrapper that catches OSError/ValueError from broken pipes.
|
||||
|
||||
When hermes-agent runs as a systemd service, Docker container, or headless
|
||||
daemon, the stdout/stderr pipe can become unavailable (idle timeout, buffer
|
||||
@@ -117,8 +117,13 @@ class _SafeWriter:
|
||||
run_conversation() — especially via double-fault when an except handler
|
||||
also tries to print.
|
||||
|
||||
Additionally, when subagents run in ThreadPoolExecutor threads, the shared
|
||||
stdout handle can close between thread teardown and cleanup, raising
|
||||
``ValueError: I/O operation on closed file`` instead of OSError.
|
||||
|
||||
This wrapper delegates all writes to the underlying stream and silently
|
||||
catches OSError. It is transparent when the wrapped stream is healthy.
|
||||
catches both OSError and ValueError. It is transparent when the wrapped
|
||||
stream is healthy.
|
||||
"""
|
||||
|
||||
__slots__ = ("_inner",)
|
||||
@@ -129,13 +134,13 @@ class _SafeWriter:
|
||||
def write(self, data):
|
||||
try:
|
||||
return self._inner.write(data)
|
||||
except OSError:
|
||||
except (OSError, ValueError):
|
||||
return len(data) if isinstance(data, str) else 0
|
||||
|
||||
def flush(self):
|
||||
try:
|
||||
self._inner.flush()
|
||||
except OSError:
|
||||
except (OSError, ValueError):
|
||||
pass
|
||||
|
||||
def fileno(self):
|
||||
@@ -144,7 +149,7 @@ class _SafeWriter:
|
||||
def isatty(self):
|
||||
try:
|
||||
return self._inner.isatty()
|
||||
except OSError:
|
||||
except (OSError, ValueError):
|
||||
return False
|
||||
|
||||
def __getattr__(self, name):
|
||||
@@ -473,6 +478,11 @@ class AIAgent:
|
||||
self.quiet_mode = quiet_mode
|
||||
self.ephemeral_system_prompt = ephemeral_system_prompt
|
||||
self.platform = platform # "cli", "telegram", "discord", "whatsapp", etc.
|
||||
# Pluggable print function — CLI replaces this with _cprint so that
|
||||
# raw ANSI status lines are routed through prompt_toolkit's renderer
|
||||
# instead of going directly to stdout where patch_stdout's StdoutProxy
|
||||
# would mangle the escape sequences. None = use builtins.print.
|
||||
self._print_fn = None
|
||||
self.skip_context_files = skip_context_files
|
||||
self.pass_session_id = pass_session_id
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
@@ -660,6 +670,9 @@ class AIAgent:
|
||||
# Internal stream callback (set during streaming TTS).
|
||||
# Initialized here so _vprint can reference it before run_conversation.
|
||||
self._stream_callback = None
|
||||
# Deferred paragraph break flag — set after tool iterations so a
|
||||
# single "\n\n" is prepended to the next real text delta.
|
||||
self._stream_needs_break = False
|
||||
|
||||
# Optional current-turn user-message override used when the API-facing
|
||||
# user message intentionally differs from the persisted transcript
|
||||
@@ -681,7 +694,11 @@ class AIAgent:
|
||||
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token
|
||||
effective_key = api_key or resolve_anthropic_token() or ""
|
||||
# Only fall back to ANTHROPIC_TOKEN when the provider is actually Anthropic.
|
||||
# Other anthropic_messages providers (MiniMax, Alibaba, etc.) must use their own API key.
|
||||
# Falling back would send Anthropic credentials to third-party endpoints (Fixes #1739, #minimax-401).
|
||||
_is_native_anthropic = self.provider == "anthropic"
|
||||
effective_key = (api_key or resolve_anthropic_token() or "") if _is_native_anthropic else (api_key or "")
|
||||
self.api_key = effective_key
|
||||
self._anthropic_api_key = effective_key
|
||||
self._anthropic_base_url = base_url
|
||||
@@ -732,6 +749,16 @@ class AIAgent:
|
||||
if hasattr(_routed_client, '_default_headers') and _routed_client._default_headers:
|
||||
client_kwargs["default_headers"] = dict(_routed_client._default_headers)
|
||||
else:
|
||||
# When the user explicitly chose a non-OpenRouter provider
|
||||
# but no credentials were found, fail fast with a clear
|
||||
# message instead of silently routing through OpenRouter.
|
||||
_explicit = (self.provider or "").strip().lower()
|
||||
if _explicit and _explicit not in ("auto", "openrouter", "custom"):
|
||||
raise RuntimeError(
|
||||
f"Provider '{_explicit}' is set in config.yaml but no API key "
|
||||
f"was found. Set the {_explicit.upper()}_API_KEY environment "
|
||||
f"variable, or switch to a different provider with `hermes model`."
|
||||
)
|
||||
# Final fallback: try raw OpenRouter key
|
||||
client_kwargs = {
|
||||
"api_key": os.getenv("OPENROUTER_API_KEY", ""),
|
||||
@@ -901,7 +928,7 @@ class AIAgent:
|
||||
pass # Memory is optional -- don't break agent init
|
||||
|
||||
# Honcho AI-native memory (cross-session user modeling)
|
||||
# Reads ~/.honcho/config.json as the single source of truth.
|
||||
# Reads $HERMES_HOME/honcho.json (instance) or ~/.honcho/config.json (global).
|
||||
self._honcho = None # HonchoSessionManager | None
|
||||
self._honcho_session_key = honcho_session_key
|
||||
self._honcho_config = None # HonchoClientConfig | None
|
||||
@@ -1097,16 +1124,21 @@ class AIAgent:
|
||||
self.context_compressor.compression_count = 0
|
||||
self.context_compressor._context_probed = False
|
||||
|
||||
@staticmethod
|
||||
def _safe_print(*args, **kwargs):
|
||||
def _safe_print(self, *args, **kwargs):
|
||||
"""Print that silently handles broken pipes / closed stdout.
|
||||
|
||||
In headless environments (systemd, Docker, nohup) stdout may become
|
||||
unavailable mid-session. A raw ``print()`` raises ``OSError`` which
|
||||
can crash cron jobs and lose completed work.
|
||||
|
||||
Internally routes through ``self._print_fn`` (default: builtin
|
||||
``print``) so callers such as the CLI can inject a renderer that
|
||||
handles ANSI escape sequences properly (e.g. prompt_toolkit's
|
||||
``print_formatted_text(ANSI(...))``) without touching this method.
|
||||
"""
|
||||
try:
|
||||
print(*args, **kwargs)
|
||||
fn = self._print_fn or print
|
||||
fn(*args, **kwargs)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@@ -1373,9 +1405,11 @@ class AIAgent:
|
||||
|
||||
def _run_review():
|
||||
import contextlib, os as _os
|
||||
review_agent = None
|
||||
try:
|
||||
with open(_os.devnull, "w") as _devnull, \
|
||||
contextlib.redirect_stdout(_devnull):
|
||||
contextlib.redirect_stdout(_devnull), \
|
||||
contextlib.redirect_stderr(_devnull):
|
||||
review_agent = AIAgent(
|
||||
model=self.model,
|
||||
max_iterations=8,
|
||||
@@ -1428,6 +1462,20 @@ class AIAgent:
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Background memory/skill review failed: %s", e)
|
||||
finally:
|
||||
# Explicitly close the OpenAI/httpx client so GC doesn't
|
||||
# try to clean it up on a dead asyncio event loop (which
|
||||
# produces "Event loop is closed" errors in the terminal).
|
||||
if review_agent is not None:
|
||||
client = getattr(review_agent, "client", None)
|
||||
if client is not None:
|
||||
try:
|
||||
review_agent._close_openai_client(
|
||||
client, reason="bg_review_done", shared=True
|
||||
)
|
||||
review_agent.client = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
t = threading.Thread(target=_run_review, daemon=True, name="bg-review")
|
||||
t.start()
|
||||
@@ -2333,7 +2381,7 @@ class AIAgent:
|
||||
# Alibaba Coding Plan API always returns "glm-4.7" as model name regardless
|
||||
# of the requested model. Inject explicit model identity into the system prompt
|
||||
# so the agent can correctly report which model it is (workaround for API bug).
|
||||
if self.provider in ("alibaba-coding-plan", "alibaba-coding-plan-anthropic"):
|
||||
if self.provider == "alibaba":
|
||||
_model_short = self.model.split("/")[-1] if "/" in self.model else self.model
|
||||
prompt_parts.append(
|
||||
f"You are powered by the model named {_model_short}. "
|
||||
@@ -2414,7 +2462,6 @@ class AIAgent:
|
||||
"Pre-call sanitizer: added %d stub tool result(s)",
|
||||
len(missing_results),
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
@@ -3337,6 +3384,10 @@ class AIAgent:
|
||||
def _try_refresh_anthropic_client_credentials(self) -> bool:
|
||||
if self.api_mode != "anthropic_messages" or not hasattr(self, "_anthropic_api_key"):
|
||||
return False
|
||||
# Only refresh credentials for the native Anthropic provider.
|
||||
# Other anthropic_messages providers (MiniMax, Alibaba, etc.) use their own keys.
|
||||
if self.provider != "anthropic":
|
||||
return False
|
||||
|
||||
try:
|
||||
from agent.anthropic_adapter import resolve_anthropic_token, build_anthropic_client
|
||||
@@ -3439,6 +3490,13 @@ class AIAgent:
|
||||
|
||||
def _fire_stream_delta(self, text: str) -> None:
|
||||
"""Fire all registered stream delta callbacks (display + TTS)."""
|
||||
# If a tool iteration set the break flag, prepend a single paragraph
|
||||
# break before the first real text delta. This prevents the original
|
||||
# problem (text concatenation across tool boundaries) without stacking
|
||||
# blank lines when multiple tool iterations run back-to-back.
|
||||
if getattr(self, "_stream_needs_break", False) and text and text.strip():
|
||||
self._stream_needs_break = False
|
||||
text = "\n\n" + text
|
||||
for cb in (self.stream_delta_callback, self._stream_callback):
|
||||
if cb is not None:
|
||||
try:
|
||||
@@ -3761,7 +3819,7 @@ class AIAgent:
|
||||
if fb_api_mode == "anthropic_messages":
|
||||
# Build native Anthropic client instead of using OpenAI client
|
||||
from agent.anthropic_adapter import build_anthropic_client, resolve_anthropic_token, _is_oauth_token
|
||||
effective_key = fb_client.api_key or resolve_anthropic_token() or ""
|
||||
effective_key = (fb_client.api_key or resolve_anthropic_token() or "") if fb_provider == "anthropic" else (fb_client.api_key or "")
|
||||
self._anthropic_api_key = effective_key
|
||||
self._anthropic_base_url = getattr(fb_client, "base_url", None)
|
||||
self._anthropic_client = build_anthropic_client(effective_key, self._anthropic_base_url)
|
||||
@@ -3940,6 +3998,13 @@ class AIAgent:
|
||||
)
|
||||
return transformed
|
||||
|
||||
def _anthropic_preserve_dots(self) -> bool:
|
||||
"""True when using Alibaba/DashScope anthropic-compatible endpoint (model names keep dots, e.g. qwen3.5-plus)."""
|
||||
if (getattr(self, "provider", "") or "").lower() == "alibaba":
|
||||
return True
|
||||
base = (getattr(self, "base_url", "") or "").lower()
|
||||
return "dashscope" in base or "aliyuncs" in base
|
||||
|
||||
def _build_api_kwargs(self, api_messages: list) -> dict:
|
||||
"""Build the keyword arguments dict for the active API mode."""
|
||||
if self.api_mode == "anthropic_messages":
|
||||
@@ -3952,6 +4017,7 @@ class AIAgent:
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_config=self.reasoning_config,
|
||||
is_oauth=getattr(self, "_is_anthropic_oauth", False),
|
||||
preserve_dots=self._anthropic_preserve_dots(),
|
||||
)
|
||||
|
||||
if self.api_mode == "codex_responses":
|
||||
@@ -4413,6 +4479,7 @@ class AIAgent:
|
||||
model=self.model, messages=api_messages,
|
||||
tools=[memory_tool_def], max_tokens=5120,
|
||||
reasoning_config=None,
|
||||
preserve_dots=self._anthropic_preserve_dots(),
|
||||
)
|
||||
response = self._anthropic_messages_create(ant_kwargs)
|
||||
elif not _aux_available:
|
||||
@@ -5221,7 +5288,8 @@ class AIAgent:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs as _bak, normalize_anthropic_response as _nar
|
||||
_ant_kw = _bak(model=self.model, messages=api_messages, tools=None,
|
||||
max_tokens=self.max_tokens, reasoning_config=self.reasoning_config,
|
||||
is_oauth=getattr(self, '_is_anthropic_oauth', False))
|
||||
is_oauth=getattr(self, '_is_anthropic_oauth', False),
|
||||
preserve_dots=self._anthropic_preserve_dots())
|
||||
summary_response = self._anthropic_messages_create(_ant_kw)
|
||||
_msg, _ = _nar(summary_response, strip_tool_prefix=getattr(self, '_is_anthropic_oauth', False))
|
||||
final_response = (_msg.content or "").strip()
|
||||
@@ -5252,7 +5320,8 @@ class AIAgent:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs as _bak2, normalize_anthropic_response as _nar2
|
||||
_ant_kw2 = _bak2(model=self.model, messages=api_messages, tools=None,
|
||||
is_oauth=getattr(self, '_is_anthropic_oauth', False),
|
||||
max_tokens=self.max_tokens, reasoning_config=self.reasoning_config)
|
||||
max_tokens=self.max_tokens, reasoning_config=self.reasoning_config,
|
||||
preserve_dots=self._anthropic_preserve_dots())
|
||||
retry_response = self._anthropic_messages_create(_ant_kw2)
|
||||
_retry_msg, _ = _nar2(retry_response, strip_tool_prefix=getattr(self, '_is_anthropic_oauth', False))
|
||||
final_response = (_retry_msg.content or "").strip()
|
||||
@@ -5608,7 +5677,7 @@ class AIAgent:
|
||||
# inject cache_control breakpoints (system + last 3 messages) to reduce
|
||||
# input token costs by ~75% on multi-turn conversations.
|
||||
if self._use_prompt_caching:
|
||||
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl)
|
||||
api_messages = apply_anthropic_cache_control(api_messages, cache_ttl=self._cache_ttl, native_anthropic=(self.api_mode == 'anthropic_messages'))
|
||||
|
||||
# Safety net: strip orphaned tool results / add stubs for missing
|
||||
# results before sending to the API. Runs unconditionally — not
|
||||
@@ -6713,6 +6782,14 @@ class AIAgent:
|
||||
_msg_count_before_tools = len(messages)
|
||||
self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count)
|
||||
|
||||
# Signal that a paragraph break is needed before the next
|
||||
# streamed text. We don't emit it immediately because
|
||||
# multiple consecutive tool iterations would stack up
|
||||
# redundant blank lines. Instead, _fire_stream_delta()
|
||||
# will prepend a single "\n\n" the next time real text
|
||||
# arrives.
|
||||
self._stream_needs_break = True
|
||||
|
||||
# Refund the iteration if the ONLY tool(s) called were
|
||||
# execute_code (programmatic tool calling). These are
|
||||
# cheap RPC-style calls that shouldn't eat the budget.
|
||||
|
||||
@@ -122,6 +122,44 @@ web_extract(urls=["https://arxiv.org/pdf/2402.03300"])
|
||||
web_search(query="arxiv GRPO reinforcement learning 2026")
|
||||
```
|
||||
|
||||
## Split, Merge & Search
|
||||
|
||||
pymupdf handles these natively — use `execute_code` or inline Python:
|
||||
|
||||
```python
|
||||
# Split: extract pages 1-5 to a new PDF
|
||||
import pymupdf
|
||||
doc = pymupdf.open("report.pdf")
|
||||
new = pymupdf.open()
|
||||
for i in range(5):
|
||||
new.insert_pdf(doc, from_page=i, to_page=i)
|
||||
new.save("pages_1-5.pdf")
|
||||
```
|
||||
|
||||
```python
|
||||
# Merge multiple PDFs
|
||||
import pymupdf
|
||||
result = pymupdf.open()
|
||||
for path in ["a.pdf", "b.pdf", "c.pdf"]:
|
||||
result.insert_pdf(pymupdf.open(path))
|
||||
result.save("merged.pdf")
|
||||
```
|
||||
|
||||
```python
|
||||
# Search for text across all pages
|
||||
import pymupdf
|
||||
doc = pymupdf.open("report.pdf")
|
||||
for i, page in enumerate(doc):
|
||||
results = page.search_for("revenue")
|
||||
if results:
|
||||
print(f"Page {i+1}: {len(results)} match(es)")
|
||||
print(page.get_text("text"))
|
||||
```
|
||||
|
||||
No extra dependencies needed — pymupdf covers split, merge, search, and text extraction in one package.
|
||||
|
||||
---
|
||||
|
||||
## Notes
|
||||
|
||||
- `web_extract` is always first choice for URLs
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -23,6 +24,7 @@ from acp.schema import (
|
||||
)
|
||||
from acp_adapter.server import HermesACPAgent, HERMES_VERSION
|
||||
from acp_adapter.session import SessionManager
|
||||
from hermes_state import SessionDB
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -389,3 +391,46 @@ class TestSlashCommands:
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert resp.stop_reason == "end_turn"
|
||||
|
||||
def test_model_switch_uses_requested_provider(self, tmp_path, monkeypatch):
|
||||
"""`/model provider:model` should rebuild the ACP agent on that provider."""
|
||||
runtime_calls = []
|
||||
|
||||
def fake_resolve_runtime_provider(requested=None, **kwargs):
|
||||
runtime_calls.append(requested)
|
||||
provider = requested or "openrouter"
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
|
||||
"base_url": f"https://{provider}.example/v1",
|
||||
"api_key": f"{provider}-key",
|
||||
"command": None,
|
||||
"args": [],
|
||||
}
|
||||
|
||||
def fake_agent(**kwargs):
|
||||
return SimpleNamespace(
|
||||
model=kwargs.get("model"),
|
||||
provider=kwargs.get("provider"),
|
||||
base_url=kwargs.get("base_url"),
|
||||
api_mode=kwargs.get("api_mode"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
|
||||
"model": {"provider": "openrouter", "default": "openrouter/gpt-5"}
|
||||
})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
fake_resolve_runtime_provider,
|
||||
)
|
||||
manager = SessionManager(db=SessionDB(tmp_path / "state.db"))
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||
acp_agent = HermesACPAgent(session_manager=manager)
|
||||
state = manager.create_session(cwd="/tmp")
|
||||
result = acp_agent._cmd_model("anthropic:claude-sonnet-4-6", state)
|
||||
|
||||
assert "Provider: anthropic" in result
|
||||
assert state.agent.provider == "anthropic"
|
||||
assert state.agent.base_url == "https://anthropic.example/v1"
|
||||
assert runtime_calls[-1] == "anthropic"
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Tests for acp_adapter.session — SessionManager and SessionState."""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from acp_adapter.session import SessionManager, SessionState
|
||||
from hermes_state import SessionDB
|
||||
@@ -281,3 +282,50 @@ class TestPersistence:
|
||||
assert len(restored.history) == 2
|
||||
assert restored.history[0].get("tool_calls") is not None
|
||||
assert restored.history[1].get("tool_call_id") == "tc_1"
|
||||
|
||||
def test_restore_preserves_persisted_provider_snapshot(self, tmp_path, monkeypatch):
|
||||
"""Restored ACP sessions should keep their original runtime provider."""
|
||||
runtime_choice = {"provider": "anthropic"}
|
||||
|
||||
def fake_resolve_runtime_provider(requested=None, **kwargs):
|
||||
provider = requested or runtime_choice["provider"]
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions",
|
||||
"base_url": f"https://{provider}.example/v1",
|
||||
"api_key": f"{provider}-key",
|
||||
"command": None,
|
||||
"args": [],
|
||||
}
|
||||
|
||||
def fake_agent(**kwargs):
|
||||
return SimpleNamespace(
|
||||
model=kwargs.get("model"),
|
||||
provider=kwargs.get("provider"),
|
||||
base_url=kwargs.get("base_url"),
|
||||
api_mode=kwargs.get("api_mode"),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: {
|
||||
"model": {"provider": runtime_choice["provider"], "default": "test-model"}
|
||||
})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
fake_resolve_runtime_provider,
|
||||
)
|
||||
db = SessionDB(tmp_path / "state.db")
|
||||
|
||||
with patch("run_agent.AIAgent", side_effect=fake_agent):
|
||||
manager = SessionManager(db=db)
|
||||
state = manager.create_session(cwd="/work")
|
||||
manager.save_session(state.session_id)
|
||||
|
||||
with manager._lock:
|
||||
del manager._sessions[state.session_id]
|
||||
|
||||
runtime_choice["provider"] = "openrouter"
|
||||
restored = manager.get_session(state.session_id)
|
||||
|
||||
assert restored is not None
|
||||
assert restored.agent.provider == "anthropic"
|
||||
assert restored.agent.base_url == "https://anthropic.example/v1"
|
||||
|
||||
@@ -112,6 +112,339 @@ class TestReadCodexAccessToken:
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_expired_jwt_returns_none(self, tmp_path, monkeypatch):
|
||||
"""Expired JWT tokens should be skipped so auto chain continues."""
|
||||
import base64
|
||||
import time as _time
|
||||
|
||||
# Build a JWT with exp in the past
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode()
|
||||
payload_data = json.dumps({"exp": int(_time.time()) - 3600}).encode()
|
||||
payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode()
|
||||
expired_jwt = f"{header}.{payload}.fakesig"
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {"access_token": expired_jwt, "refresh_token": "r"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
result = _read_codex_access_token()
|
||||
assert result is None, "Expired JWT should return None"
|
||||
|
||||
def test_valid_jwt_returns_token(self, tmp_path, monkeypatch):
|
||||
"""Non-expired JWT tokens should be returned."""
|
||||
import base64
|
||||
import time as _time
|
||||
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode()
|
||||
payload_data = json.dumps({"exp": int(_time.time()) + 3600}).encode()
|
||||
payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode()
|
||||
valid_jwt = f"{header}.{payload}.fakesig"
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {"access_token": valid_jwt, "refresh_token": "r"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
result = _read_codex_access_token()
|
||||
assert result == valid_jwt
|
||||
|
||||
def test_non_jwt_token_passes_through(self, tmp_path, monkeypatch):
|
||||
"""Non-JWT tokens (no dots) should be returned as-is."""
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {"access_token": "plain-token-no-jwt", "refresh_token": "r"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
result = _read_codex_access_token()
|
||||
assert result == "plain-token-no-jwt"
|
||||
|
||||
|
||||
class TestAnthropicOAuthFlag:
|
||||
"""Test that OAuth tokens get is_oauth=True in auxiliary Anthropic client."""
|
||||
|
||||
def test_oauth_token_sets_flag(self, monkeypatch):
|
||||
"""OAuth tokens (sk-ant-oat01-*) should create client with is_oauth=True."""
|
||||
monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-test-token")
|
||||
with patch("agent.anthropic_adapter.build_anthropic_client") as mock_build:
|
||||
mock_build.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _try_anthropic, AnthropicAuxiliaryClient
|
||||
client, model = _try_anthropic()
|
||||
assert client is not None
|
||||
assert isinstance(client, AnthropicAuxiliaryClient)
|
||||
# The adapter inside should have is_oauth=True
|
||||
adapter = client.chat.completions
|
||||
assert adapter._is_oauth is True
|
||||
|
||||
def test_api_key_no_oauth_flag(self, monkeypatch):
|
||||
"""Regular API keys (sk-ant-api-*) should create client with is_oauth=False."""
|
||||
with patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-testkey1234"), \
|
||||
patch("agent.anthropic_adapter.build_anthropic_client") as mock_build:
|
||||
mock_build.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _try_anthropic, AnthropicAuxiliaryClient
|
||||
client, model = _try_anthropic()
|
||||
assert client is not None
|
||||
assert isinstance(client, AnthropicAuxiliaryClient)
|
||||
adapter = client.chat.completions
|
||||
assert adapter._is_oauth is False
|
||||
|
||||
|
||||
class TestExpiredCodexFallback:
|
||||
"""Test that expired Codex tokens don't block the auto chain."""
|
||||
|
||||
def test_expired_codex_falls_through_to_next(self, tmp_path, monkeypatch):
|
||||
"""When Codex token is expired, auto chain should skip it and try next provider."""
|
||||
import base64
|
||||
import time as _time
|
||||
|
||||
# Expired Codex JWT
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode()
|
||||
payload_data = json.dumps({"exp": int(_time.time()) - 3600}).encode()
|
||||
payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode()
|
||||
expired_jwt = f"{header}.{payload}.fakesig"
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {"access_token": expired_jwt, "refresh_token": "r"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
# Set up Anthropic as fallback
|
||||
monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-test-fallback")
|
||||
with patch("agent.anthropic_adapter.build_anthropic_client") as mock_build:
|
||||
mock_build.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _resolve_auto, AnthropicAuxiliaryClient
|
||||
client, model = _resolve_auto()
|
||||
# Should NOT be Codex, should be Anthropic (or another available provider)
|
||||
assert not isinstance(client, type(None)), "Should find a provider after expired Codex"
|
||||
|
||||
|
||||
def test_expired_codex_openrouter_wins(self, tmp_path, monkeypatch):
|
||||
"""With expired Codex + OpenRouter key, OpenRouter should win (1st in chain)."""
|
||||
import base64
|
||||
import time as _time
|
||||
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode()
|
||||
payload_data = json.dumps({"exp": int(_time.time()) - 3600}).encode()
|
||||
payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode()
|
||||
expired_jwt = f"{header}.{payload}.fakesig"
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {"access_token": expired_jwt, "refresh_token": "r"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-test-key")
|
||||
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
client, model = _resolve_auto()
|
||||
assert client is not None
|
||||
# OpenRouter is 1st in chain, should win
|
||||
mock_openai.assert_called()
|
||||
|
||||
def test_expired_codex_custom_endpoint_wins(self, tmp_path, monkeypatch):
|
||||
"""With expired Codex + custom endpoint (Ollama), custom should win (3rd in chain)."""
|
||||
import base64
|
||||
import time as _time
|
||||
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode()
|
||||
payload_data = json.dumps({"exp": int(_time.time()) - 3600}).encode()
|
||||
payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode()
|
||||
expired_jwt = f"{header}.{payload}.fakesig"
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {"access_token": expired_jwt, "refresh_token": "r"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
# Simulate Ollama or custom endpoint
|
||||
with patch("agent.auxiliary_client._resolve_custom_runtime",
|
||||
return_value=("http://localhost:11434/v1", "sk-dummy")):
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _resolve_auto
|
||||
client, model = _resolve_auto()
|
||||
assert client is not None
|
||||
|
||||
|
||||
def test_hermes_oauth_file_sets_oauth_flag(self, monkeypatch):
|
||||
"""Hermes OAuth credentials should get is_oauth=True (token is not sk-ant-api-*)."""
|
||||
# Mock resolve_anthropic_token to return an OAuth-style token
|
||||
# (simulates what read_hermes_oauth_credentials would return)
|
||||
with patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="hermes-oauth-jwt-token"), \
|
||||
patch("agent.anthropic_adapter.build_anthropic_client") as mock_build:
|
||||
mock_build.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _try_anthropic, AnthropicAuxiliaryClient
|
||||
client, model = _try_anthropic()
|
||||
assert client is not None, "Should resolve token"
|
||||
adapter = client.chat.completions
|
||||
assert adapter._is_oauth is True, "Non-sk-ant-api token should set is_oauth=True"
|
||||
|
||||
def test_jwt_missing_exp_passes_through(self, tmp_path, monkeypatch):
|
||||
"""JWT with valid JSON but no exp claim should pass through."""
|
||||
import base64
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256","typ":"JWT"}').rstrip(b"=").decode()
|
||||
payload_data = json.dumps({"sub": "user123"}).encode() # no exp
|
||||
payload = base64.urlsafe_b64encode(payload_data).rstrip(b"=").decode()
|
||||
no_exp_jwt = f"{header}.{payload}.fakesig"
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {"access_token": no_exp_jwt, "refresh_token": "r"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
result = _read_codex_access_token()
|
||||
assert result == no_exp_jwt, "JWT without exp should pass through"
|
||||
|
||||
def test_jwt_invalid_json_payload_passes_through(self, tmp_path, monkeypatch):
|
||||
"""JWT with valid base64 but invalid JSON payload should pass through."""
|
||||
import base64
|
||||
header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').rstrip(b"=").decode()
|
||||
payload = base64.urlsafe_b64encode(b"not-json-content").rstrip(b"=").decode()
|
||||
bad_jwt = f"{header}.{payload}.fakesig"
|
||||
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir(parents=True, exist_ok=True)
|
||||
(hermes_home / "auth.json").write_text(json.dumps({
|
||||
"version": 1,
|
||||
"providers": {
|
||||
"openai-codex": {
|
||||
"tokens": {"access_token": bad_jwt, "refresh_token": "r"},
|
||||
},
|
||||
},
|
||||
}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
result = _read_codex_access_token()
|
||||
assert result == bad_jwt, "JWT with invalid JSON payload should pass through"
|
||||
|
||||
def test_claude_code_oauth_env_sets_flag(self, monkeypatch):
|
||||
"""CLAUDE_CODE_OAUTH_TOKEN env var should get is_oauth=True."""
|
||||
monkeypatch.setenv("CLAUDE_CODE_OAUTH_TOKEN", "cc-oauth-token-test")
|
||||
monkeypatch.delenv("ANTHROPIC_TOKEN", raising=False)
|
||||
with patch("agent.anthropic_adapter.build_anthropic_client") as mock_build:
|
||||
mock_build.return_value = MagicMock()
|
||||
from agent.auxiliary_client import _try_anthropic, AnthropicAuxiliaryClient
|
||||
client, model = _try_anthropic()
|
||||
assert client is not None
|
||||
adapter = client.chat.completions
|
||||
assert adapter._is_oauth is True
|
||||
|
||||
|
||||
class TestExplicitProviderRouting:
|
||||
"""Test explicit provider selection bypasses auto chain correctly."""
|
||||
|
||||
def test_explicit_anthropic_oauth(self, monkeypatch):
|
||||
"""provider='anthropic' + OAuth token should work with is_oauth=True."""
|
||||
monkeypatch.setenv("ANTHROPIC_TOKEN", "sk-ant-oat01-explicit-test")
|
||||
with patch("agent.anthropic_adapter.build_anthropic_client") as mock_build:
|
||||
mock_build.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("anthropic")
|
||||
assert client is not None
|
||||
# Verify OAuth flag propagated
|
||||
adapter = client.chat.completions
|
||||
assert adapter._is_oauth is True
|
||||
|
||||
def test_explicit_anthropic_api_key(self, monkeypatch):
|
||||
"""provider='anthropic' + regular API key should work with is_oauth=False."""
|
||||
with patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api-regular-key"), \
|
||||
patch("agent.anthropic_adapter.build_anthropic_client") as mock_build:
|
||||
mock_build.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("anthropic")
|
||||
assert client is not None
|
||||
adapter = client.chat.completions
|
||||
assert adapter._is_oauth is False
|
||||
|
||||
def test_explicit_openrouter(self, monkeypatch):
|
||||
"""provider='openrouter' should use OPENROUTER_API_KEY."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-explicit")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("openrouter")
|
||||
assert client is not None
|
||||
|
||||
def test_explicit_kimi(self, monkeypatch):
|
||||
"""provider='kimi-coding' should use KIMI_API_KEY."""
|
||||
monkeypatch.setenv("KIMI_API_KEY", "kimi-test-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("kimi-coding")
|
||||
assert client is not None
|
||||
|
||||
def test_explicit_minimax(self, monkeypatch):
|
||||
"""provider='minimax' should use MINIMAX_API_KEY."""
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "mm-test-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("minimax")
|
||||
assert client is not None
|
||||
|
||||
def test_explicit_deepseek(self, monkeypatch):
|
||||
"""provider='deepseek' should use DEEPSEEK_API_KEY."""
|
||||
monkeypatch.setenv("DEEPSEEK_API_KEY", "ds-test-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("deepseek")
|
||||
assert client is not None
|
||||
|
||||
def test_explicit_zai(self, monkeypatch):
|
||||
"""provider='zai' should use GLM_API_KEY."""
|
||||
monkeypatch.setenv("GLM_API_KEY", "zai-test-key")
|
||||
with patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("zai")
|
||||
assert client is not None
|
||||
|
||||
def test_explicit_unknown_returns_none(self, monkeypatch):
|
||||
"""Unknown provider should return None."""
|
||||
client, model = resolve_provider_client("nonexistent-provider")
|
||||
assert client is None
|
||||
|
||||
|
||||
class TestGetTextAuxiliaryClient:
|
||||
"""Test the full resolution chain for get_text_auxiliary_client."""
|
||||
|
||||
|
||||
@@ -13,11 +13,18 @@ MARKER = {"type": "ephemeral"}
|
||||
|
||||
|
||||
class TestApplyCacheMarker:
|
||||
def test_tool_message_gets_top_level_marker(self):
|
||||
def test_tool_message_gets_top_level_marker_on_native_anthropic(self):
|
||||
"""Native Anthropic path: cache_control injected top-level (adapter moves it inside tool_result)."""
|
||||
msg = {"role": "tool", "content": "result"}
|
||||
_apply_cache_marker(msg, MARKER)
|
||||
_apply_cache_marker(msg, MARKER, native_anthropic=True)
|
||||
assert msg["cache_control"] == MARKER
|
||||
|
||||
def test_tool_message_skips_marker_on_openrouter(self):
|
||||
"""OpenRouter path: top-level cache_control on role:tool is invalid and causes silent hang."""
|
||||
msg = {"role": "tool", "content": "result"}
|
||||
_apply_cache_marker(msg, MARKER, native_anthropic=False)
|
||||
assert "cache_control" not in msg
|
||||
|
||||
def test_none_content_gets_top_level_marker(self):
|
||||
msg = {"role": "assistant", "content": None}
|
||||
_apply_cache_marker(msg, MARKER)
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
"""Tests for agent.redact -- secret masking in logs and output."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.redact import redact_sensitive_text, RedactingFormatter
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _ensure_redaction_enabled(monkeypatch):
|
||||
"""Ensure HERMES_REDACT_SECRETS is not disabled by prior test imports."""
|
||||
monkeypatch.delenv("HERMES_REDACT_SECRETS", raising=False)
|
||||
|
||||
|
||||
class TestKnownPrefixes:
|
||||
def test_openai_sk_key(self):
|
||||
text = "Using key sk-proj-abc123def456ghi789jkl012"
|
||||
@@ -124,6 +131,13 @@ class TestPassthrough:
|
||||
def test_none_returns_none(self):
|
||||
assert redact_sensitive_text(None) is None
|
||||
|
||||
def test_non_string_input_int_coerced(self):
|
||||
assert redact_sensitive_text(12345) == "12345"
|
||||
|
||||
def test_non_string_input_dict_coerced_and_redacted(self):
|
||||
result = redact_sensitive_text({"token": "sk-proj-abc123def456ghi789jkl012"})
|
||||
assert "abc123def456" not in result
|
||||
|
||||
def test_normal_text_unchanged(self):
|
||||
text = "Hello world, this is a normal log message with no secrets."
|
||||
assert redact_sensitive_text(text) == text
|
||||
|
||||
+30
-6
@@ -313,6 +313,24 @@ class TestMarkJobRun:
|
||||
# Job should be removed after hitting repeat limit
|
||||
assert get_job(job["id"]) is None
|
||||
|
||||
def test_repeat_negative_one_is_infinite(self, tmp_cron_dir):
|
||||
# LLMs often pass repeat=-1 to mean "infinite/forever".
|
||||
# The job must NOT be deleted after runs when repeat <= 0.
|
||||
job = create_job(prompt="Forever", schedule="every 1h", repeat=-1)
|
||||
# -1 should be normalised to None (infinite) at create time
|
||||
assert job["repeat"]["times"] is None
|
||||
# Running it multiple times should never delete it
|
||||
for _ in range(3):
|
||||
mark_job_run(job["id"], success=True)
|
||||
assert get_job(job["id"]) is not None, "job was deleted after run despite infinite repeat"
|
||||
|
||||
def test_repeat_zero_is_infinite(self, tmp_cron_dir):
|
||||
# repeat=0 should also be treated as None (infinite), not "run zero times".
|
||||
job = create_job(prompt="ZeroRepeat", schedule="every 1h", repeat=0)
|
||||
assert job["repeat"]["times"] is None
|
||||
mark_job_run(job["id"], success=True)
|
||||
assert get_job(job["id"]) is not None
|
||||
|
||||
def test_error_status(self, tmp_cron_dir):
|
||||
job = create_job(prompt="Fail", schedule="every 1h")
|
||||
mark_job_run(job["id"], success=False, error="timeout")
|
||||
@@ -323,11 +341,14 @@ class TestMarkJobRun:
|
||||
|
||||
class TestGetDueJobs:
|
||||
def test_past_due_within_window_returned(self, tmp_cron_dir):
|
||||
"""Jobs less than 2 minutes late are still considered due (not stale)."""
|
||||
"""Jobs within the dynamic grace window are still considered due (not stale).
|
||||
|
||||
For an hourly job, grace = 30 min (half the period, clamped to [120s, 2h]).
|
||||
"""
|
||||
job = create_job(prompt="Due now", schedule="every 1h")
|
||||
# Force next_run_at to just 1 minute ago (within the 2-min window)
|
||||
# Force next_run_at to 10 minutes ago (within the 30-min grace for hourly)
|
||||
jobs = load_jobs()
|
||||
jobs[0]["next_run_at"] = (datetime.now() - timedelta(seconds=60)).isoformat()
|
||||
jobs[0]["next_run_at"] = (datetime.now() - timedelta(minutes=10)).isoformat()
|
||||
save_jobs(jobs)
|
||||
|
||||
due = get_due_jobs()
|
||||
@@ -335,11 +356,14 @@ class TestGetDueJobs:
|
||||
assert due[0]["id"] == job["id"]
|
||||
|
||||
def test_stale_past_due_skipped(self, tmp_cron_dir):
|
||||
"""Recurring jobs more than 2 minutes late are fast-forwarded, not fired."""
|
||||
"""Recurring jobs past their dynamic grace window are fast-forwarded, not fired.
|
||||
|
||||
For an hourly job, grace = 30 min. Setting 35 min late exceeds the window.
|
||||
"""
|
||||
job = create_job(prompt="Stale", schedule="every 1h")
|
||||
# Force next_run_at to 5 minutes ago (beyond the 2-min window)
|
||||
# Force next_run_at to 35 minutes ago (beyond the 30-min grace for hourly)
|
||||
jobs = load_jobs()
|
||||
jobs[0]["next_run_at"] = (datetime.now() - timedelta(minutes=5)).isoformat()
|
||||
jobs[0]["next_run_at"] = (datetime.now() - timedelta(minutes=35)).isoformat()
|
||||
save_jobs(jobs)
|
||||
|
||||
due = get_due_jobs()
|
||||
|
||||
@@ -62,6 +62,28 @@ class TestResolveDeliveryTarget:
|
||||
"thread_id": "17585",
|
||||
}
|
||||
|
||||
def test_explicit_telegram_topic_target_with_thread_id(self):
|
||||
"""deliver: 'telegram:chat_id:thread_id' parses correctly."""
|
||||
job = {
|
||||
"deliver": "telegram:-1003724596514:17",
|
||||
}
|
||||
assert _resolve_delivery_target(job) == {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-1003724596514",
|
||||
"thread_id": "17",
|
||||
}
|
||||
|
||||
def test_explicit_telegram_chat_id_without_thread_id(self):
|
||||
"""deliver: 'telegram:chat_id' sets thread_id to None."""
|
||||
job = {
|
||||
"deliver": "telegram:-1003724596514",
|
||||
}
|
||||
assert _resolve_delivery_target(job) == {
|
||||
"platform": "telegram",
|
||||
"chat_id": "-1003724596514",
|
||||
"thread_id": None,
|
||||
}
|
||||
|
||||
def test_bare_platform_uses_matching_origin_chat(self):
|
||||
job = {
|
||||
"deliver": "telegram",
|
||||
@@ -234,6 +256,47 @@ class TestRunJobSessionPersistence:
|
||||
assert kwargs["session_id"].startswith("cron_test-job_")
|
||||
fake_db.close.assert_called_once()
|
||||
|
||||
def test_run_job_empty_response_returns_empty_not_placeholder(self, tmp_path):
|
||||
"""Empty final_response should stay empty for delivery logic (issue #2234).
|
||||
|
||||
The placeholder '(No response generated)' should only appear in the
|
||||
output log, not in the returned final_response that's used for delivery.
|
||||
"""
|
||||
job = {
|
||||
"id": "silent-job",
|
||||
"name": "silent test",
|
||||
"prompt": "do work via tools only",
|
||||
}
|
||||
fake_db = MagicMock()
|
||||
|
||||
with patch("cron.scheduler._hermes_home", tmp_path), \
|
||||
patch("cron.scheduler._resolve_origin", return_value=None), \
|
||||
patch("dotenv.load_dotenv"), \
|
||||
patch("hermes_state.SessionDB", return_value=fake_db), \
|
||||
patch(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
return_value={
|
||||
"api_key": "test-key",
|
||||
"base_url": "https://example.invalid/v1",
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
},
|
||||
), \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
mock_agent = MagicMock()
|
||||
# Agent did work via tools but returned no text
|
||||
mock_agent.run_conversation.return_value = {"final_response": ""}
|
||||
mock_agent_cls.return_value = mock_agent
|
||||
|
||||
success, output, final_response, error = run_job(job)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
# final_response should be empty for delivery logic to skip
|
||||
assert final_response == ""
|
||||
# But the output log should show the placeholder
|
||||
assert "(No response generated)" in output
|
||||
|
||||
def test_run_job_sets_auto_delivery_env_from_dotenv_home_channel(self, tmp_path, monkeypatch):
|
||||
job = {
|
||||
"id": "test-job",
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
"""Integration tests for gateway AIAgent caching.
|
||||
|
||||
Verifies that the agent cache correctly:
|
||||
- Reuses agents across messages (same config → same instance)
|
||||
- Rebuilds agents when config changes (model, provider, toolsets)
|
||||
- Updates reasoning_config in-place without rebuilding
|
||||
- Evicts on session reset
|
||||
- Evicts on fallback activation
|
||||
- Preserves frozen system prompt across turns
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Create a minimal GatewayRunner with just the cache infrastructure."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._agent_cache = {}
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
return runner
|
||||
|
||||
|
||||
class TestAgentConfigSignature:
|
||||
"""Config signature produces stable, distinct keys."""
|
||||
|
||||
def test_same_config_same_signature(self):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter", "api_mode": "chat_completions"}
|
||||
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
assert sig1 == sig2
|
||||
|
||||
def test_model_change_different_signature(self):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter"}
|
||||
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
sig2 = GatewayRunner._agent_config_signature("claude-opus-4.6", runtime, ["hermes-telegram"], "")
|
||||
assert sig1 != sig2
|
||||
|
||||
def test_provider_change_different_signature(self):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
rt1 = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1", "provider": "openrouter"}
|
||||
rt2 = {"api_key": "sk-test12345678", "base_url": "https://api.anthropic.com", "provider": "anthropic"}
|
||||
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", rt1, ["hermes-telegram"], "")
|
||||
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", rt2, ["hermes-telegram"], "")
|
||||
assert sig1 != sig2
|
||||
|
||||
def test_toolset_change_different_signature(self):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1", "provider": "openrouter"}
|
||||
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-discord"], "")
|
||||
assert sig1 != sig2
|
||||
|
||||
def test_reasoning_not_in_signature(self):
|
||||
"""Reasoning config is set per-message, not part of the signature."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runtime = {"api_key": "sk-test12345678", "base_url": "https://openrouter.ai/api/v1", "provider": "openrouter"}
|
||||
# Same config — signature should be identical regardless of what
|
||||
# reasoning_config the caller might have (it's not passed in)
|
||||
sig1 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
sig2 = GatewayRunner._agent_config_signature("claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
assert sig1 == sig2
|
||||
|
||||
|
||||
class TestAgentCacheLifecycle:
|
||||
"""End-to-end cache behavior with real AIAgent construction."""
|
||||
|
||||
def test_cache_hit_returns_same_agent(self):
|
||||
"""Second message with same config reuses the cached agent instance."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
runner = _make_runner()
|
||||
session_key = "telegram:12345"
|
||||
runtime = {"api_key": "test", "base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter", "api_mode": "chat_completions"}
|
||||
sig = runner._agent_config_signature("anthropic/claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
|
||||
# First message — create and cache
|
||||
agent1 = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True, platform="telegram",
|
||||
)
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache[session_key] = (agent1, sig)
|
||||
|
||||
# Second message — cache hit
|
||||
with runner._agent_cache_lock:
|
||||
cached = runner._agent_cache.get(session_key)
|
||||
assert cached is not None
|
||||
assert cached[1] == sig
|
||||
assert cached[0] is agent1 # same instance
|
||||
|
||||
def test_cache_miss_on_model_change(self):
|
||||
"""Model change produces different signature → cache miss."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
runner = _make_runner()
|
||||
session_key = "telegram:12345"
|
||||
runtime = {"api_key": "test", "base_url": "https://openrouter.ai/api/v1",
|
||||
"provider": "openrouter", "api_mode": "chat_completions"}
|
||||
|
||||
old_sig = runner._agent_config_signature("anthropic/claude-sonnet-4", runtime, ["hermes-telegram"], "")
|
||||
agent1 = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True, platform="telegram",
|
||||
)
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache[session_key] = (agent1, old_sig)
|
||||
|
||||
# New model → different signature
|
||||
new_sig = runner._agent_config_signature("anthropic/claude-opus-4.6", runtime, ["hermes-telegram"], "")
|
||||
assert new_sig != old_sig
|
||||
|
||||
with runner._agent_cache_lock:
|
||||
cached = runner._agent_cache.get(session_key)
|
||||
assert cached[1] != new_sig # signature mismatch → would create new agent
|
||||
|
||||
def test_evict_on_session_reset(self):
|
||||
"""_evict_cached_agent removes the entry."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
runner = _make_runner()
|
||||
session_key = "telegram:12345"
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache[session_key] = (agent, "sig123")
|
||||
|
||||
runner._evict_cached_agent(session_key)
|
||||
|
||||
with runner._agent_cache_lock:
|
||||
assert session_key not in runner._agent_cache
|
||||
|
||||
def test_evict_does_not_affect_other_sessions(self):
|
||||
"""Evicting one session leaves other sessions cached."""
|
||||
runner = _make_runner()
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["session-A"] = ("agent-A", "sig-A")
|
||||
runner._agent_cache["session-B"] = ("agent-B", "sig-B")
|
||||
|
||||
runner._evict_cached_agent("session-A")
|
||||
|
||||
with runner._agent_cache_lock:
|
||||
assert "session-A" not in runner._agent_cache
|
||||
assert "session-B" in runner._agent_cache
|
||||
|
||||
def test_reasoning_config_updates_in_place(self):
|
||||
"""Reasoning config can be set on a cached agent without eviction."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
)
|
||||
|
||||
# Simulate per-message reasoning update
|
||||
agent.reasoning_config = {"enabled": True, "effort": "high"}
|
||||
assert agent.reasoning_config["effort"] == "high"
|
||||
|
||||
# System prompt should not be affected by reasoning change
|
||||
prompt1 = agent._build_system_prompt()
|
||||
agent._cached_system_prompt = prompt1 # simulate run_conversation caching
|
||||
agent.reasoning_config = {"enabled": True, "effort": "low"}
|
||||
prompt2 = agent._cached_system_prompt
|
||||
assert prompt1 is prompt2 # same object — not invalidated by reasoning change
|
||||
|
||||
def test_system_prompt_frozen_across_cache_reuse(self):
|
||||
"""The cached agent's system prompt stays identical across turns."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True, platform="telegram",
|
||||
)
|
||||
|
||||
# Build system prompt (simulates first run_conversation)
|
||||
prompt1 = agent._build_system_prompt()
|
||||
agent._cached_system_prompt = prompt1
|
||||
|
||||
# Simulate second turn — prompt should be frozen
|
||||
prompt2 = agent._cached_system_prompt
|
||||
assert prompt1 is prompt2 # same object, not rebuilt
|
||||
|
||||
def test_callbacks_update_without_cache_eviction(self):
|
||||
"""Per-message callbacks can be set on cached agent."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
model="anthropic/claude-sonnet-4", api_key="test",
|
||||
base_url="https://openrouter.ai/api/v1", provider="openrouter",
|
||||
max_iterations=5, quiet_mode=True, skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
# Set callbacks like the gateway does per-message
|
||||
cb1 = lambda *a: None
|
||||
cb2 = lambda *a: None
|
||||
agent.tool_progress_callback = cb1
|
||||
agent.step_callback = cb2
|
||||
agent.stream_delta_callback = None
|
||||
agent.status_callback = None
|
||||
|
||||
assert agent.tool_progress_callback is cb1
|
||||
assert agent.step_callback is cb2
|
||||
|
||||
# Update for next message
|
||||
cb3 = lambda *a: None
|
||||
agent.tool_progress_callback = cb3
|
||||
assert agent.tool_progress_callback is cb3
|
||||
@@ -119,22 +119,33 @@ class TestAdapterInit:
|
||||
def test_custom_config_from_extra(self):
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "0.0.0.0", "port": 9999, "key": "sk-test"},
|
||||
extra={
|
||||
"host": "0.0.0.0",
|
||||
"port": 9999,
|
||||
"key": "sk-test",
|
||||
"cors_origins": ["http://localhost:3000"],
|
||||
},
|
||||
)
|
||||
adapter = APIServerAdapter(config)
|
||||
assert adapter._host == "0.0.0.0"
|
||||
assert adapter._port == 9999
|
||||
assert adapter._api_key == "sk-test"
|
||||
assert adapter._cors_origins == ("http://localhost:3000",)
|
||||
|
||||
def test_config_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("API_SERVER_HOST", "10.0.0.1")
|
||||
monkeypatch.setenv("API_SERVER_PORT", "7777")
|
||||
monkeypatch.setenv("API_SERVER_KEY", "sk-env")
|
||||
monkeypatch.setenv("API_SERVER_CORS_ORIGINS", "http://localhost:3000, http://127.0.0.1:3000")
|
||||
config = PlatformConfig(enabled=True)
|
||||
adapter = APIServerAdapter(config)
|
||||
assert adapter._host == "10.0.0.1"
|
||||
assert adapter._port == 7777
|
||||
assert adapter._api_key == "sk-env"
|
||||
assert adapter._cors_origins == (
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -190,11 +201,13 @@ class TestAuth:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_adapter(api_key: str = "") -> APIServerAdapter:
|
||||
def _make_adapter(api_key: str = "", cors_origins=None) -> APIServerAdapter:
|
||||
"""Create an adapter with optional API key."""
|
||||
extra = {}
|
||||
if api_key:
|
||||
extra["key"] = api_key
|
||||
if cors_origins is not None:
|
||||
extra["cors_origins"] = cors_origins
|
||||
config = PlatformConfig(enabled=True, extra=extra)
|
||||
return APIServerAdapter(config)
|
||||
|
||||
@@ -202,6 +215,7 @@ def _make_adapter(api_key: str = "") -> APIServerAdapter:
|
||||
def _create_app(adapter: APIServerAdapter) -> web.Application:
|
||||
"""Create the aiohttp app from the adapter (without starting the full server)."""
|
||||
app = web.Application(middlewares=[cors_middleware])
|
||||
app["api_server_adapter"] = adapter
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_get("/v1/models", adapter._handle_models)
|
||||
app.router.add_post("/v1/chat/completions", adapter._handle_chat_completions)
|
||||
@@ -788,6 +802,19 @@ class TestConfigIntegration:
|
||||
assert config.platforms[Platform.API_SERVER].extra.get("port") == 9999
|
||||
assert config.platforms[Platform.API_SERVER].extra.get("host") == "0.0.0.0"
|
||||
|
||||
def test_env_override_cors_origins(self, monkeypatch):
|
||||
monkeypatch.setenv("API_SERVER_ENABLED", "true")
|
||||
monkeypatch.setenv(
|
||||
"API_SERVER_CORS_ORIGINS",
|
||||
"http://localhost:3000, http://127.0.0.1:3000",
|
||||
)
|
||||
from gateway.config import load_gateway_config
|
||||
config = load_gateway_config()
|
||||
assert config.platforms[Platform.API_SERVER].extra.get("cors_origins") == [
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
]
|
||||
|
||||
def test_api_server_in_connected_platforms(self):
|
||||
config = GatewayConfig()
|
||||
config.platforms[Platform.API_SERVER] = PlatformConfig(enabled=True)
|
||||
@@ -1156,26 +1183,91 @@ class TestTruncation:
|
||||
|
||||
|
||||
class TestCORS:
|
||||
def test_origin_allowed_for_non_browser_client(self, adapter):
|
||||
assert adapter._origin_allowed("") is True
|
||||
|
||||
def test_origin_rejected_by_default(self, adapter):
|
||||
assert adapter._origin_allowed("http://evil.example") is False
|
||||
|
||||
def test_origin_allowed_for_allowlist_match(self):
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
assert adapter._origin_allowed("http://localhost:3000") is True
|
||||
|
||||
def test_cors_headers_for_origin_disabled_by_default(self, adapter):
|
||||
assert adapter._cors_headers_for_origin("http://localhost:3000") is None
|
||||
|
||||
def test_cors_headers_for_origin_matches_allowlist(self):
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
headers = adapter._cors_headers_for_origin("http://localhost:3000")
|
||||
assert headers is not None
|
||||
assert headers["Access-Control-Allow-Origin"] == "http://localhost:3000"
|
||||
assert "POST" in headers["Access-Control-Allow-Methods"]
|
||||
|
||||
def test_cors_headers_for_origin_rejects_unknown_origin(self):
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
assert adapter._cors_headers_for_origin("http://evil.example") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_headers_on_get(self, adapter):
|
||||
"""CORS headers present on normal responses."""
|
||||
async def test_cors_headers_not_present_by_default(self, adapter):
|
||||
"""CORS is disabled unless explicitly configured."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/health")
|
||||
assert resp.status == 200
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "*"
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_browser_origin_rejected_by_default(self, adapter):
|
||||
"""Browser-originated requests are rejected unless explicitly allowed."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/health", headers={"Origin": "http://evil.example"})
|
||||
assert resp.status == 403
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_options_preflight_rejected_by_default(self, adapter):
|
||||
"""Browser preflight is rejected unless CORS is explicitly configured."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.options(
|
||||
"/v1/chat/completions",
|
||||
headers={
|
||||
"Origin": "http://evil.example",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
},
|
||||
)
|
||||
assert resp.status == 403
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_headers_present_for_allowed_origin(self):
|
||||
"""Allowed origins receive explicit CORS headers."""
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/health", headers={"Origin": "http://localhost:3000"})
|
||||
assert resp.status == 200
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
|
||||
assert "POST" in resp.headers.get("Access-Control-Allow-Methods", "")
|
||||
assert "DELETE" in resp.headers.get("Access-Control-Allow-Methods", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_options_preflight(self, adapter):
|
||||
"""OPTIONS preflight request returns CORS headers."""
|
||||
async def test_cors_options_preflight_allowed_for_configured_origin(self):
|
||||
"""Configured origins can complete browser preflight."""
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# OPTIONS to a known path — aiohttp will route through middleware
|
||||
resp = await cli.options("/health")
|
||||
resp = await cli.options(
|
||||
"/v1/chat/completions",
|
||||
headers={
|
||||
"Origin": "http://localhost:3000",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Access-Control-Request-Headers": "Authorization, Content-Type",
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "*"
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
|
||||
assert "Authorization" in resp.headers.get("Access-Control-Allow-Headers", "")
|
||||
|
||||
|
||||
@@ -1203,7 +1295,7 @@ class TestConversationParameter:
|
||||
data = await resp.json()
|
||||
assert data["status"] == "completed"
|
||||
# Conversation mapping should be set
|
||||
assert "my-chat" in adapter._conversations
|
||||
assert adapter._response_store.get_conversation("my-chat") is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_chains_automatically(self, adapter):
|
||||
@@ -1277,7 +1369,7 @@ class TestConversationParameter:
|
||||
await cli.post("/v1/responses", json={"input": "conv-b msg", "conversation": "conv-b"})
|
||||
|
||||
# They should have different response IDs in the mapping
|
||||
assert adapter._conversations["conv-a"] != adapter._conversations["conv-b"]
|
||||
assert adapter._response_store.get_conversation("conv-a") != adapter._response_store.get_conversation("conv-b")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conversation_store_false_no_mapping(self, adapter):
|
||||
@@ -1296,4 +1388,4 @@ class TestConversationParameter:
|
||||
})
|
||||
assert resp.status == 200
|
||||
# Conversation mapping should NOT be set since store=false
|
||||
assert "ephemeral-chat" not in adapter._conversations
|
||||
assert adapter._response_store.get_conversation("ephemeral-chat") is None
|
||||
|
||||
@@ -0,0 +1,597 @@
|
||||
"""
|
||||
Tests for the Cron Jobs API endpoints on the API server adapter.
|
||||
|
||||
Covers:
|
||||
- CRUD operations for cron jobs (list, create, get, update, delete)
|
||||
- Pause / resume / run (trigger) actions
|
||||
- Input validation (missing name, name too long, prompt too long, invalid repeat)
|
||||
- Job ID validation (invalid hex)
|
||||
- Auth enforcement (401 when API_SERVER_KEY is set)
|
||||
- Cron module unavailability (501 when _CRON_AVAILABLE is False)
|
||||
"""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.api_server import APIServerAdapter, cors_middleware
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SAMPLE_JOB = {
|
||||
"id": "aabbccddeeff",
|
||||
"name": "test-job",
|
||||
"schedule": "*/5 * * * *",
|
||||
"prompt": "do something",
|
||||
"deliver": "local",
|
||||
"enabled": True,
|
||||
}
|
||||
|
||||
VALID_JOB_ID = "aabbccddeeff"
|
||||
|
||||
|
||||
def _make_adapter(api_key: str = "") -> APIServerAdapter:
|
||||
"""Create an adapter with optional API key."""
|
||||
extra = {}
|
||||
if api_key:
|
||||
extra["key"] = api_key
|
||||
config = PlatformConfig(enabled=True, extra=extra)
|
||||
return APIServerAdapter(config)
|
||||
|
||||
|
||||
def _create_app(adapter: APIServerAdapter) -> web.Application:
|
||||
"""Create the aiohttp app with jobs routes registered."""
|
||||
app = web.Application(middlewares=[cors_middleware])
|
||||
app["api_server_adapter"] = adapter
|
||||
# Register only job routes (plus health for sanity)
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_get("/api/jobs", adapter._handle_list_jobs)
|
||||
app.router.add_post("/api/jobs", adapter._handle_create_job)
|
||||
app.router.add_get("/api/jobs/{job_id}", adapter._handle_get_job)
|
||||
app.router.add_patch("/api/jobs/{job_id}", adapter._handle_update_job)
|
||||
app.router.add_delete("/api/jobs/{job_id}", adapter._handle_delete_job)
|
||||
app.router.add_post("/api/jobs/{job_id}/pause", adapter._handle_pause_job)
|
||||
app.router.add_post("/api/jobs/{job_id}/resume", adapter._handle_resume_job)
|
||||
app.router.add_post("/api/jobs/{job_id}/run", adapter._handle_run_job)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter():
|
||||
return _make_adapter()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_adapter():
|
||||
return _make_adapter(api_key="sk-secret")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. test_list_jobs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestListJobs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_jobs(self, adapter):
|
||||
"""GET /api/jobs returns job list."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", return_value=[SAMPLE_JOB]
|
||||
):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert "jobs" in data
|
||||
assert data["jobs"] == [SAMPLE_JOB]
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
# 2. test_list_jobs_include_disabled
|
||||
# -------------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_jobs_include_disabled(self, adapter):
|
||||
"""GET /api/jobs?include_disabled=true passes the flag."""
|
||||
app = _create_app(adapter)
|
||||
mock_list = MagicMock(return_value=[SAMPLE_JOB])
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", mock_list
|
||||
):
|
||||
resp = await cli.get("/api/jobs?include_disabled=true")
|
||||
assert resp.status == 200
|
||||
mock_list.assert_called_once_with(include_disabled=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_jobs_default_excludes_disabled(self, adapter):
|
||||
"""GET /api/jobs without flag passes include_disabled=False."""
|
||||
app = _create_app(adapter)
|
||||
mock_list = MagicMock(return_value=[])
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", mock_list
|
||||
):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 200
|
||||
mock_list.assert_called_once_with(include_disabled=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3-7. test_create_job and validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCreateJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job(self, adapter):
|
||||
"""POST /api/jobs with valid body returns created job."""
|
||||
app = _create_app(adapter)
|
||||
mock_create = MagicMock(return_value=SAMPLE_JOB)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_create", mock_create
|
||||
):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
"schedule": "*/5 * * * *",
|
||||
"prompt": "do something",
|
||||
})
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == SAMPLE_JOB
|
||||
mock_create.assert_called_once()
|
||||
call_kwargs = mock_create.call_args[1]
|
||||
assert call_kwargs["name"] == "test-job"
|
||||
assert call_kwargs["schedule"] == "*/5 * * * *"
|
||||
assert call_kwargs["prompt"] == "do something"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job_missing_name(self, adapter):
|
||||
"""POST /api/jobs without name returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"schedule": "*/5 * * * *",
|
||||
"prompt": "do something",
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "name" in data["error"].lower() or "Name" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job_name_too_long(self, adapter):
|
||||
"""POST /api/jobs with name > 200 chars returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "x" * 201,
|
||||
"schedule": "*/5 * * * *",
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "200" in data["error"] or "Name" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job_prompt_too_long(self, adapter):
|
||||
"""POST /api/jobs with prompt > 5000 chars returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
"schedule": "*/5 * * * *",
|
||||
"prompt": "x" * 5001,
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "5000" in data["error"] or "Prompt" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job_invalid_repeat(self, adapter):
|
||||
"""POST /api/jobs with repeat=0 returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
"schedule": "*/5 * * * *",
|
||||
"repeat": 0,
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "repeat" in data["error"].lower() or "Repeat" in data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_job_missing_schedule(self, adapter):
|
||||
"""POST /api/jobs without schedule returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test-job",
|
||||
})
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "schedule" in data["error"].lower() or "Schedule" in data["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8-10. test_get_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestGetJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_job(self, adapter):
|
||||
"""GET /api/jobs/{id} returns job."""
|
||||
app = _create_app(adapter)
|
||||
mock_get = MagicMock(return_value=SAMPLE_JOB)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_get", mock_get
|
||||
):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == SAMPLE_JOB
|
||||
mock_get.assert_called_once_with(VALID_JOB_ID)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_job_not_found(self, adapter):
|
||||
"""GET /api/jobs/{id} returns 404 when job doesn't exist."""
|
||||
app = _create_app(adapter)
|
||||
mock_get = MagicMock(return_value=None)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_get", mock_get
|
||||
):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_job_invalid_id(self, adapter):
|
||||
"""GET /api/jobs/{id} with non-hex id returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.get("/api/jobs/not-a-valid-hex!")
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "Invalid" in data["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 11-12. test_update_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUpdateJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job(self, adapter):
|
||||
"""PATCH /api/jobs/{id} updates with whitelisted fields."""
|
||||
app = _create_app(adapter)
|
||||
updated_job = {**SAMPLE_JOB, "name": "updated-name"}
|
||||
mock_update = MagicMock(return_value=updated_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_update", mock_update
|
||||
):
|
||||
resp = await cli.patch(
|
||||
f"/api/jobs/{VALID_JOB_ID}",
|
||||
json={"name": "updated-name", "schedule": "0 * * * *"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == updated_job
|
||||
mock_update.assert_called_once()
|
||||
call_args = mock_update.call_args
|
||||
assert call_args[0][0] == VALID_JOB_ID
|
||||
sanitized = call_args[0][1]
|
||||
assert "name" in sanitized
|
||||
assert "schedule" in sanitized
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_rejects_unknown_fields(self, adapter):
|
||||
"""PATCH /api/jobs/{id} — only allowed fields pass through."""
|
||||
app = _create_app(adapter)
|
||||
updated_job = {**SAMPLE_JOB, "name": "new-name"}
|
||||
mock_update = MagicMock(return_value=updated_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_update", mock_update
|
||||
):
|
||||
resp = await cli.patch(
|
||||
f"/api/jobs/{VALID_JOB_ID}",
|
||||
json={
|
||||
"name": "new-name",
|
||||
"evil_field": "malicious",
|
||||
"__proto__": "hack",
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
call_args = mock_update.call_args
|
||||
sanitized = call_args[0][1]
|
||||
assert "name" in sanitized
|
||||
assert "evil_field" not in sanitized
|
||||
assert "__proto__" not in sanitized
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_job_no_valid_fields(self, adapter):
|
||||
"""PATCH /api/jobs/{id} with only unknown fields returns 400."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.patch(
|
||||
f"/api/jobs/{VALID_JOB_ID}",
|
||||
json={"evil_field": "malicious"},
|
||||
)
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert "No valid fields" in data["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 13. test_delete_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDeleteJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_job(self, adapter):
|
||||
"""DELETE /api/jobs/{id} returns ok."""
|
||||
app = _create_app(adapter)
|
||||
mock_remove = MagicMock(return_value=True)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_remove", mock_remove
|
||||
):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["ok"] is True
|
||||
mock_remove.assert_called_once_with(VALID_JOB_ID)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_job_not_found(self, adapter):
|
||||
"""DELETE /api/jobs/{id} returns 404 when job doesn't exist."""
|
||||
app = _create_app(adapter)
|
||||
mock_remove = MagicMock(return_value=False)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_remove", mock_remove
|
||||
):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 14. test_pause_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPauseJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_job(self, adapter):
|
||||
"""POST /api/jobs/{id}/pause returns updated job."""
|
||||
app = _create_app(adapter)
|
||||
paused_job = {**SAMPLE_JOB, "enabled": False}
|
||||
mock_pause = MagicMock(return_value=paused_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_pause", mock_pause
|
||||
):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == paused_job
|
||||
assert data["job"]["enabled"] is False
|
||||
mock_pause.assert_called_once_with(VALID_JOB_ID)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 15. test_resume_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResumeJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_job(self, adapter):
|
||||
"""POST /api/jobs/{id}/resume returns updated job."""
|
||||
app = _create_app(adapter)
|
||||
resumed_job = {**SAMPLE_JOB, "enabled": True}
|
||||
mock_resume = MagicMock(return_value=resumed_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_resume", mock_resume
|
||||
):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/resume")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == resumed_job
|
||||
assert data["job"]["enabled"] is True
|
||||
mock_resume.assert_called_once_with(VALID_JOB_ID)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 16. test_run_job
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRunJob:
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_job(self, adapter):
|
||||
"""POST /api/jobs/{id}/run returns triggered job."""
|
||||
app = _create_app(adapter)
|
||||
triggered_job = {**SAMPLE_JOB, "last_run": "2025-01-01T00:00:00Z"}
|
||||
mock_trigger = MagicMock(return_value=triggered_job)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_trigger", mock_trigger
|
||||
):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/run")
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["job"] == triggered_job
|
||||
mock_trigger.assert_called_once_with(VALID_JOB_ID)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 17. test_auth_required
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAuthRequired:
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_list_jobs(self, auth_adapter):
|
||||
"""GET /api/jobs without API key returns 401 when key is set."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_create_job(self, auth_adapter):
|
||||
"""POST /api/jobs without API key returns 401 when key is set."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test", "schedule": "* * * * *",
|
||||
})
|
||||
assert resp.status == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_get_job(self, auth_adapter):
|
||||
"""GET /api/jobs/{id} without API key returns 401 when key is set."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_required_delete_job(self, auth_adapter):
|
||||
"""DELETE /api/jobs/{id} without API key returns 401."""
|
||||
app = _create_app(auth_adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", True):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_passes_with_valid_key(self, auth_adapter):
|
||||
"""GET /api/jobs with correct API key succeeds."""
|
||||
app = _create_app(auth_adapter)
|
||||
mock_list = MagicMock(return_value=[])
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(
|
||||
APIServerAdapter, "_CRON_AVAILABLE", True
|
||||
), patch.object(
|
||||
APIServerAdapter, "_cron_list", mock_list
|
||||
):
|
||||
resp = await cli.get(
|
||||
"/api/jobs",
|
||||
headers={"Authorization": "Bearer sk-secret"},
|
||||
)
|
||||
assert resp.status == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 18. test_cron_unavailable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCronUnavailable:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_list(self, adapter):
|
||||
"""GET /api/jobs returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.get("/api/jobs")
|
||||
assert resp.status == 501
|
||||
data = await resp.json()
|
||||
assert "not available" in data["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_create(self, adapter):
|
||||
"""POST /api/jobs returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.post("/api/jobs", json={
|
||||
"name": "test", "schedule": "* * * * *",
|
||||
})
|
||||
assert resp.status == 501
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_get(self, adapter):
|
||||
"""GET /api/jobs/{id} returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.get(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 501
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_delete(self, adapter):
|
||||
"""DELETE /api/jobs/{id} returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.delete(f"/api/jobs/{VALID_JOB_ID}")
|
||||
assert resp.status == 501
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_pause(self, adapter):
|
||||
"""POST /api/jobs/{id}/pause returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/pause")
|
||||
assert resp.status == 501
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_resume(self, adapter):
|
||||
"""POST /api/jobs/{id}/resume returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/resume")
|
||||
assert resp.status == 501
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cron_unavailable_run(self, adapter):
|
||||
"""POST /api/jobs/{id}/run returns 501 when _CRON_AVAILABLE is False."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
with patch.object(APIServerAdapter, "_CRON_AVAILABLE", False):
|
||||
resp = await cli.post(f"/api/jobs/{VALID_JOB_ID}/run")
|
||||
assert resp.status == 501
|
||||
@@ -0,0 +1,347 @@
|
||||
"""Tests for Discord incoming document/file attachment handling.
|
||||
|
||||
Covers the document branch in DiscordAdapter._handle_message() —
|
||||
the `else` clause of the attachment content-type loop that was added
|
||||
to download, cache, and optionally inject text from non-image/audio files.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.base import MessageType
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord mock setup (copied from test_discord_free_response.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, danger=3, green=1, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
import gateway.platforms.discord as discord_platform # noqa: E402
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fake channel / thread types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class FakeDMChannel:
|
||||
def __init__(self, channel_id: int = 1):
|
||||
self.id = channel_id
|
||||
self.name = "dm"
|
||||
|
||||
|
||||
class FakeThread:
|
||||
def __init__(self, channel_id: int = 10):
|
||||
self.id = channel_id
|
||||
self.name = "thread"
|
||||
self.parent = None
|
||||
self.parent_id = None
|
||||
self.guild = SimpleNamespace(name="TestServer")
|
||||
self.topic = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _redirect_cache(tmp_path, monkeypatch):
|
||||
"""Point document cache to tmp_path so tests never write to ~/.hermes."""
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.base.DOCUMENT_CACHE_DIR", tmp_path / "doc_cache"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(monkeypatch):
|
||||
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
|
||||
monkeypatch.setattr(discord_platform.discord, "Thread", FakeThread, raising=False)
|
||||
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
a = DiscordAdapter(config)
|
||||
a._client = SimpleNamespace(user=SimpleNamespace(id=999))
|
||||
a.handle_message = AsyncMock()
|
||||
return a
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_attachment(
|
||||
*,
|
||||
filename: str,
|
||||
content_type: str,
|
||||
size: int = 1024,
|
||||
url: str = "https://cdn.discordapp.com/attachments/fake/file",
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
filename=filename,
|
||||
content_type=content_type,
|
||||
size=size,
|
||||
url=url,
|
||||
)
|
||||
|
||||
|
||||
def make_message(attachments: list, content: str = "") -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
id=123,
|
||||
content=content,
|
||||
attachments=attachments,
|
||||
mentions=[],
|
||||
reference=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=FakeDMChannel(),
|
||||
author=SimpleNamespace(id=42, display_name="Tester", name="Tester"),
|
||||
)
|
||||
|
||||
|
||||
def _mock_aiohttp_download(raw_bytes: bytes):
|
||||
"""Return a patch context manager that makes aiohttp return raw_bytes."""
|
||||
resp = AsyncMock()
|
||||
resp.status = 200
|
||||
resp.read = AsyncMock(return_value=raw_bytes)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = AsyncMock()
|
||||
session.get = MagicMock(return_value=resp)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
return patch("aiohttp.ClientSession", return_value=session)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIncomingDocumentHandling:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pdf_document_cached(self, adapter):
|
||||
"""A PDF attachment should be downloaded, cached, typed as DOCUMENT."""
|
||||
pdf_bytes = b"%PDF-1.4 fake content"
|
||||
|
||||
with _mock_aiohttp_download(pdf_bytes):
|
||||
msg = make_message([make_attachment(filename="report.pdf", content_type="application/pdf")])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.DOCUMENT
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert event.media_types == ["application/pdf"]
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_txt_content_injected(self, adapter):
|
||||
""".txt file under 100KB should have its content injected into event.text."""
|
||||
file_content = b"Hello from a text file"
|
||||
|
||||
with _mock_aiohttp_download(file_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="notes.txt", content_type="text/plain")],
|
||||
content="summarize this",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of notes.txt]:" in event.text
|
||||
assert "Hello from a text file" in event.text
|
||||
assert "summarize this" in event.text
|
||||
# injection prepended before caption
|
||||
assert event.text.index("[Content of") < event.text.index("summarize this")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_md_content_injected(self, adapter):
|
||||
""".md file under 100KB should have its content injected."""
|
||||
file_content = b"# Title\nSome markdown content"
|
||||
|
||||
with _mock_aiohttp_download(file_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="readme.md", content_type="text/markdown")],
|
||||
content="",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of readme.md]:" in event.text
|
||||
assert "# Title" in event.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oversized_document_skipped(self, adapter):
|
||||
"""A document over 20MB should be skipped — media_urls stays empty."""
|
||||
msg = make_message([
|
||||
make_attachment(
|
||||
filename="huge.pdf",
|
||||
content_type="application/pdf",
|
||||
size=25 * 1024 * 1024,
|
||||
)
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
# handler must still be called
|
||||
adapter.handle_message.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_type_skipped(self, adapter):
|
||||
"""An unsupported file type (.zip) should be skipped silently."""
|
||||
msg = make_message([
|
||||
make_attachment(filename="archive.zip", content_type="application/zip")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
assert event.message_type == MessageType.TEXT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_error_handled(self, adapter):
|
||||
"""If the HTTP download raises, the handler should not crash."""
|
||||
resp = AsyncMock()
|
||||
resp.__aenter__ = AsyncMock(side_effect=RuntimeError("connection reset"))
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = AsyncMock()
|
||||
session.get = MagicMock(return_value=resp)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=session):
|
||||
msg = make_message([
|
||||
make_attachment(filename="report.pdf", content_type="application/pdf")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
# Must still deliver an event
|
||||
adapter.handle_message.assert_called_once()
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_txt_cached_not_injected(self, adapter):
|
||||
""".txt over 100KB should be cached but NOT injected into event.text."""
|
||||
large_content = b"x" * (200 * 1024)
|
||||
|
||||
with _mock_aiohttp_download(large_content):
|
||||
msg = make_message(
|
||||
attachments=[make_attachment(filename="big.txt", content_type="text/plain", size=len(large_content))],
|
||||
content="",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert len(event.media_urls) == 1
|
||||
assert os.path.exists(event.media_urls[0])
|
||||
assert "[Content of" not in (event.text or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_text_files_both_injected(self, adapter):
|
||||
"""Two text file attachments should both be injected into event.text in order."""
|
||||
content1 = b"First file content"
|
||||
content2 = b"Second file content"
|
||||
|
||||
call_count = 0
|
||||
responses = [content1, content2]
|
||||
|
||||
def make_session(_responses):
|
||||
idx = 0
|
||||
|
||||
class FakeSession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_):
|
||||
pass
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
nonlocal idx
|
||||
data = _responses[idx % len(_responses)]
|
||||
idx += 1
|
||||
|
||||
resp = AsyncMock()
|
||||
resp.status = 200
|
||||
resp.read = AsyncMock(return_value=data)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return resp
|
||||
|
||||
return FakeSession()
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=make_session([content1, content2])):
|
||||
msg = make_message(
|
||||
attachments=[
|
||||
make_attachment(filename="file1.txt", content_type="text/plain"),
|
||||
make_attachment(filename="file2.txt", content_type="text/plain"),
|
||||
],
|
||||
content="",
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert "[Content of file1.txt]:" in event.text
|
||||
assert "First file content" in event.text
|
||||
assert "[Content of file2.txt]:" in event.text
|
||||
assert "Second file content" in event.text
|
||||
assert event.text.index("file1") < event.text.index("file2")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_attachment_unaffected(self, adapter):
|
||||
"""Image attachments should still go through the image path, not the document path."""
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/cached_image.png",
|
||||
):
|
||||
msg = make_message([
|
||||
make_attachment(filename="photo.png", content_type="image/png")
|
||||
])
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.message_type == MessageType.PHOTO
|
||||
assert event.media_urls == ["/tmp/cached_image.png"]
|
||||
assert event.media_types == ["image/png"]
|
||||
@@ -241,6 +241,42 @@ async def test_dispatch_thread_session_builds_thread_event(adapter):
|
||||
assert "TestGuild" in event.source.chat_name
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _build_slash_event — preserve thread context for native slash commands
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_slash_event_preserves_thread_context(adapter):
|
||||
interaction = SimpleNamespace(
|
||||
channel=_FakeThreadChannel(channel_id=555, name="Planning"),
|
||||
channel_id=555,
|
||||
user=SimpleNamespace(display_name="Jezza", id=42),
|
||||
)
|
||||
|
||||
event = adapter._build_slash_event(interaction, "/status")
|
||||
|
||||
assert event.text == "/status"
|
||||
assert event.source.chat_id == "555"
|
||||
assert event.source.chat_type == "thread"
|
||||
assert event.source.thread_id == "555"
|
||||
assert "TestGuild" in event.source.chat_name
|
||||
|
||||
|
||||
def test_build_slash_event_uses_group_context_for_channels(adapter):
|
||||
interaction = SimpleNamespace(
|
||||
channel=_FakeTextChannel(channel_id=123, name="general"),
|
||||
channel_id=123,
|
||||
user=SimpleNamespace(display_name="Jezza", id=42),
|
||||
)
|
||||
|
||||
event = adapter._build_slash_event(interaction, "/status")
|
||||
|
||||
assert event.source.chat_id == "123"
|
||||
assert event.source.chat_type == "group"
|
||||
assert event.source.thread_id is None
|
||||
assert "TestGuild / #general" == event.source.chat_name
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Auto-thread: _auto_create_thread
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
"""Tests for Discord system message filtering (thread renames, pins, etc.)."""
|
||||
|
||||
import pytest
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
discord = pytest.importorskip("discord")
|
||||
|
||||
|
||||
def _make_author(*, bot: bool = False, is_self: bool = False):
|
||||
"""Create a mock Discord author."""
|
||||
author = MagicMock()
|
||||
author.bot = bot
|
||||
author.id = 99999 if is_self else 12345
|
||||
author.name = "TestBot" if bot else "TestUser"
|
||||
author.display_name = author.name
|
||||
return author
|
||||
|
||||
|
||||
def _make_message(*, author=None, content="hello", msg_type=None):
|
||||
"""Create a mock Discord message with a specific type."""
|
||||
msg = MagicMock()
|
||||
msg.author = author or _make_author()
|
||||
msg.content = content
|
||||
msg.attachments = []
|
||||
msg.mentions = []
|
||||
msg.type = msg_type if msg_type is not None else discord.MessageType.default
|
||||
msg.channel = MagicMock()
|
||||
msg.channel.id = 222
|
||||
msg.channel.name = "test-channel"
|
||||
msg.channel.guild = MagicMock()
|
||||
msg.channel.guild.name = "TestServer"
|
||||
return msg
|
||||
|
||||
|
||||
class TestDiscordSystemMessageFilter(unittest.TestCase):
|
||||
"""Test that Discord system messages (thread renames, pins, etc.) are ignored."""
|
||||
|
||||
def _run_filter(self, message, client_user=None):
|
||||
"""Simulate the on_message filter logic and return whether message was accepted.
|
||||
|
||||
Replicates the guard added to discord.py:
|
||||
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
return # ignored
|
||||
"""
|
||||
# Own messages always ignored
|
||||
if message.author == client_user:
|
||||
return False
|
||||
|
||||
# System message filter (the fix being tested)
|
||||
if message.type not in (discord.MessageType.default, discord.MessageType.reply):
|
||||
return False
|
||||
|
||||
return True # message accepted
|
||||
|
||||
def test_default_messages_accepted(self):
|
||||
"""Regular user messages (type=default) should be accepted."""
|
||||
msg = _make_message(msg_type=discord.MessageType.default)
|
||||
self.assertTrue(self._run_filter(msg))
|
||||
|
||||
def test_reply_messages_accepted(self):
|
||||
"""Reply messages (type=reply) should be accepted — users reply to bot messages."""
|
||||
msg = _make_message(msg_type=discord.MessageType.reply)
|
||||
self.assertTrue(self._run_filter(msg))
|
||||
|
||||
def test_thread_rename_ignored(self):
|
||||
"""Thread rename system messages should be ignored."""
|
||||
msg = _make_message(msg_type=discord.MessageType.channel_name_change)
|
||||
self.assertFalse(self._run_filter(msg))
|
||||
|
||||
def test_pins_add_ignored(self):
|
||||
"""Pin notifications should be ignored."""
|
||||
msg = _make_message(msg_type=discord.MessageType.pins_add)
|
||||
self.assertFalse(self._run_filter(msg))
|
||||
|
||||
def test_new_member_ignored(self):
|
||||
"""New member join messages should be ignored."""
|
||||
msg = _make_message(msg_type=discord.MessageType.new_member)
|
||||
self.assertFalse(self._run_filter(msg))
|
||||
|
||||
def test_premium_guild_subscription_ignored(self):
|
||||
"""Boost messages should be ignored."""
|
||||
msg = _make_message(msg_type=discord.MessageType.premium_guild_subscription)
|
||||
self.assertFalse(self._run_filter(msg))
|
||||
|
||||
def test_recipient_add_ignored(self):
|
||||
"""Group DM recipient add messages should be ignored."""
|
||||
msg = _make_message(msg_type=discord.MessageType.recipient_add)
|
||||
self.assertFalse(self._run_filter(msg))
|
||||
|
||||
def test_own_default_messages_still_ignored(self):
|
||||
"""Bot's own messages should still be ignored even if type is default."""
|
||||
bot_user = _make_author(is_self=True)
|
||||
msg = _make_message(author=bot_user, msg_type=discord.MessageType.default)
|
||||
self.assertFalse(self._run_filter(msg, client_user=bot_user))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Tests for memory flush stale-overwrite prevention (#2670).
|
||||
|
||||
Verifies that:
|
||||
1. Cron sessions are skipped (no flush for headless cron runs)
|
||||
2. Current memory state is injected into the flush prompt so the
|
||||
flush agent can see what's already saved and avoid overwrites
|
||||
3. The flush still works normally when memory files don't exist
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._honcho_managers = {}
|
||||
runner._honcho_configs = {}
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner.adapters = {}
|
||||
runner.hooks = MagicMock()
|
||||
runner.session_store = MagicMock()
|
||||
return runner
|
||||
|
||||
|
||||
_TRANSCRIPT_4_MSGS = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi there"},
|
||||
{"role": "user", "content": "remember my name is Alice"},
|
||||
{"role": "assistant", "content": "Got it, Alice!"},
|
||||
]
|
||||
|
||||
|
||||
class TestCronSessionBypass:
|
||||
"""Cron sessions should never trigger a memory flush."""
|
||||
|
||||
def test_cron_session_skipped(self):
|
||||
runner = _make_runner()
|
||||
runner._flush_memories_for_session("cron_job123_20260323_120000")
|
||||
# session_store.load_transcript should never be called
|
||||
runner.session_store.load_transcript.assert_not_called()
|
||||
|
||||
def test_cron_session_with_honcho_key_skipped(self):
|
||||
runner = _make_runner()
|
||||
runner._flush_memories_for_session("cron_daily_20260323", "some-honcho-key")
|
||||
runner.session_store.load_transcript.assert_not_called()
|
||||
|
||||
def test_non_cron_session_proceeds(self):
|
||||
"""Non-cron sessions should still attempt the flush."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner._flush_memories_for_session("session_abc123")
|
||||
runner.session_store.load_transcript.assert_called_once_with("session_abc123")
|
||||
|
||||
|
||||
class TestMemoryInjection:
|
||||
"""The flush prompt should include current memory state from disk."""
|
||||
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path):
|
||||
"""When memory files exist, their content appears in the flush prompt."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode")
|
||||
(memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST")
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Intercept `from tools.memory_tool import MEMORY_DIR` inside the function
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_123")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
call_kwargs = tmp_agent.run_conversation.call_args.kwargs
|
||||
flush_prompt = call_kwargs.get("user_message", "")
|
||||
|
||||
# Verify both memory sections appear in the prompt
|
||||
assert "Agent knows Python" in flush_prompt
|
||||
assert "User prefers dark mode" in flush_prompt
|
||||
assert "Name: Alice" in flush_prompt
|
||||
assert "Timezone: PST" in flush_prompt
|
||||
# Verify the stale-overwrite warning is present
|
||||
assert "Do NOT overwrite or remove entries" in flush_prompt
|
||||
assert "current live state of memory" in flush_prompt
|
||||
|
||||
def test_flush_works_without_memory_files(self, tmp_path):
|
||||
"""When no memory files exist, flush still runs without the guard."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
empty_dir = tmp_path / "no_memories"
|
||||
empty_dir.mkdir()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_456")
|
||||
|
||||
# Should still run, just without the memory guard section
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "Do NOT overwrite or remove entries" not in flush_prompt
|
||||
assert "Review the conversation above" in flush_prompt
|
||||
|
||||
def test_empty_memory_files_no_injection(self, tmp_path):
|
||||
"""Empty memory files should not trigger the guard section."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("")
|
||||
(memory_dir / "USER.md").write_text(" \n ") # whitespace only
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_789")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
# No memory content → no guard section
|
||||
assert "current live state of memory" not in flush_prompt
|
||||
|
||||
|
||||
class TestFlushPromptStructure:
|
||||
"""Verify the flush prompt retains its core instructions."""
|
||||
|
||||
def test_core_instructions_present(self):
|
||||
"""The flush prompt should still contain the original guidance."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Make the import fail gracefully so we test without memory files
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_struct")
|
||||
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "automatically reset" in flush_prompt
|
||||
assert "Save any important facts" in flush_prompt
|
||||
assert "consider saving it as a skill" in flush_prompt
|
||||
assert "Do NOT respond to the user" in flush_prompt
|
||||
@@ -279,7 +279,7 @@ class TestMattermostWebSocketParsing:
|
||||
"id": "post_abc",
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": "Hello from Matrix!",
|
||||
"message": "@bot_user_id Hello from Matrix!",
|
||||
}
|
||||
event = {
|
||||
"event": "posted",
|
||||
@@ -293,7 +293,7 @@ class TestMattermostWebSocketParsing:
|
||||
await self.adapter._handle_ws_event(event)
|
||||
assert self.adapter.handle_message.called
|
||||
msg_event = self.adapter.handle_message.call_args[0][0]
|
||||
assert msg_event.text == "Hello from Matrix!"
|
||||
assert msg_event.text == "@bot_user_id Hello from Matrix!"
|
||||
assert msg_event.message_id == "post_abc"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -378,7 +378,7 @@ class TestMattermostWebSocketParsing:
|
||||
"id": "post_reply",
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": "Thread reply",
|
||||
"message": "@bot_user_id Thread reply",
|
||||
"root_id": "root_post_123",
|
||||
}
|
||||
event = {
|
||||
@@ -487,7 +487,7 @@ class TestMattermostDedup:
|
||||
"id": "post_dup",
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": "Hello!",
|
||||
"message": "@bot_user_id Hello!",
|
||||
}
|
||||
event = {
|
||||
"event": "posted",
|
||||
@@ -514,7 +514,7 @@ class TestMattermostDedup:
|
||||
"id": pid,
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": f"Message {i}",
|
||||
"message": f"@bot_user_id Message {i}",
|
||||
}
|
||||
event = {
|
||||
"event": "posted",
|
||||
@@ -593,7 +593,7 @@ class TestMattermostMediaTypes:
|
||||
"id": "post_media",
|
||||
"user_id": "user_123",
|
||||
"channel_id": "chan_456",
|
||||
"message": "file attached",
|
||||
"message": "@bot_user_id file attached",
|
||||
"file_ids": file_ids,
|
||||
}
|
||||
return {
|
||||
|
||||
@@ -0,0 +1,401 @@
|
||||
"""Tests for the gateway platform reconnection watcher."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter, MessageEvent, SendResult
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
class StubAdapter(BasePlatformAdapter):
|
||||
"""Adapter whose connect() result can be controlled."""
|
||||
|
||||
def __init__(self, *, succeed=True, fatal_error=None, fatal_retryable=True):
|
||||
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
|
||||
self._succeed = succeed
|
||||
self._fatal_error = fatal_error
|
||||
self._fatal_retryable = fatal_retryable
|
||||
|
||||
async def connect(self):
|
||||
if self._fatal_error:
|
||||
self._set_fatal_error("test_error", self._fatal_error, retryable=self._fatal_retryable)
|
||||
return False
|
||||
return self._succeed
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
return SendResult(success=True, message_id="1")
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None):
|
||||
return None
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
def _make_runner():
|
||||
"""Create a minimal GatewayRunner via object.__new__ to skip __init__."""
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="test")}
|
||||
)
|
||||
runner._running = True
|
||||
runner._shutdown_event = asyncio.Event()
|
||||
runner._exit_reason = None
|
||||
runner._exit_with_failure = False
|
||||
runner._exit_cleanly = False
|
||||
runner._failed_platforms = {}
|
||||
runner.adapters = {}
|
||||
runner.delivery_router = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._honcho_managers = {}
|
||||
runner._honcho_configs = {}
|
||||
runner._shutdown_all_gateway_honcho = lambda: None
|
||||
return runner
|
||||
|
||||
|
||||
# --- Startup queueing ---
|
||||
|
||||
class TestStartupFailureQueuing:
|
||||
"""Verify that failed platforms are queued during startup."""
|
||||
|
||||
def test_failed_platform_queued_on_connect_failure(self):
|
||||
"""When adapter.connect() returns False without fatal error, queue for retry."""
|
||||
runner = _make_runner()
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() + 30,
|
||||
}
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
assert runner._failed_platforms[Platform.TELEGRAM]["attempts"] == 1
|
||||
|
||||
def test_failed_platform_not_queued_for_nonretryable(self):
|
||||
"""Non-retryable errors should not be in the retry queue."""
|
||||
runner = _make_runner()
|
||||
# Simulate: adapter had a non-retryable error, wasn't queued
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
|
||||
|
||||
# --- Reconnect watcher ---
|
||||
|
||||
class TestPlatformReconnectWatcher:
|
||||
"""Test the _platform_reconnect_watcher background task."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_succeeds_on_retry(self):
|
||||
"""Watcher should reconnect a failed platform when connect() succeeds."""
|
||||
runner = _make_runner()
|
||||
runner._sync_voice_mode_state_to_adapter = MagicMock()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() - 1, # Already past retry time
|
||||
}
|
||||
|
||||
succeed_adapter = StubAdapter(succeed=True)
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter", return_value=succeed_adapter):
|
||||
with patch("gateway.run.build_channel_directory", create=True):
|
||||
# Run one iteration of the watcher then stop
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
# Patch the sleep to exit after first check
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
assert Platform.TELEGRAM in runner.adapters
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_nonretryable_removed_from_queue(self):
|
||||
"""Non-retryable errors should remove the platform from the retry queue."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() - 1,
|
||||
}
|
||||
|
||||
fail_adapter = StubAdapter(
|
||||
succeed=False, fatal_error="bad token", fatal_retryable=False
|
||||
)
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter", return_value=fail_adapter):
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
assert Platform.TELEGRAM not in runner.adapters
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_retryable_stays_in_queue(self):
|
||||
"""Retryable failures should remain in the queue with incremented attempts."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() - 1,
|
||||
}
|
||||
|
||||
fail_adapter = StubAdapter(
|
||||
succeed=False, fatal_error="DNS failure", fatal_retryable=True
|
||||
)
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter", return_value=fail_adapter):
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
assert runner._failed_platforms[Platform.TELEGRAM]["attempts"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_gives_up_after_max_attempts(self):
|
||||
"""After max attempts, platform should be removed from retry queue."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 20, # At max
|
||||
"next_retry": time.monotonic() - 1,
|
||||
}
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter") as mock_create:
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
mock_create.assert_not_called() # Should give up without trying
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconnect_skips_when_not_time_yet(self):
|
||||
"""Watcher should skip platforms whose next_retry is in the future."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() + 9999, # Far in the future
|
||||
}
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter") as mock_create:
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
mock_create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_failed_platforms_watcher_idles(self):
|
||||
"""When no platforms are failed, watcher should just idle."""
|
||||
runner = _make_runner()
|
||||
# No failed platforms
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter") as mock_create:
|
||||
async def run_briefly():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 2:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_briefly()
|
||||
|
||||
mock_create.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_create_returns_none(self):
|
||||
"""If _create_adapter returns None, remove from queue (missing deps)."""
|
||||
runner = _make_runner()
|
||||
|
||||
platform_config = PlatformConfig(enabled=True, token="test")
|
||||
runner._failed_platforms[Platform.TELEGRAM] = {
|
||||
"config": platform_config,
|
||||
"attempts": 1,
|
||||
"next_retry": time.monotonic() - 1,
|
||||
}
|
||||
|
||||
real_sleep = asyncio.sleep
|
||||
|
||||
with patch.object(runner, "_create_adapter", return_value=None):
|
||||
async def run_one_iteration():
|
||||
runner._running = True
|
||||
call_count = 0
|
||||
|
||||
async def fake_sleep(n):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count > 1:
|
||||
runner._running = False
|
||||
await real_sleep(0)
|
||||
|
||||
with patch("asyncio.sleep", side_effect=fake_sleep):
|
||||
await runner._platform_reconnect_watcher()
|
||||
|
||||
await run_one_iteration()
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
|
||||
|
||||
# --- Runtime disconnection queueing ---
|
||||
|
||||
class TestRuntimeDisconnectQueuing:
|
||||
"""Test that _handle_adapter_fatal_error queues retryable disconnections."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_runtime_error_queued_for_reconnect(self):
|
||||
"""Retryable runtime errors should add the platform to _failed_platforms."""
|
||||
runner = _make_runner()
|
||||
|
||||
adapter = StubAdapter(succeed=True)
|
||||
adapter._set_fatal_error("network_error", "DNS failure", retryable=True)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
assert runner._failed_platforms[Platform.TELEGRAM]["attempts"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonretryable_runtime_error_not_queued(self):
|
||||
"""Non-retryable runtime errors should not be queued for reconnection."""
|
||||
runner = _make_runner()
|
||||
|
||||
adapter = StubAdapter(succeed=True)
|
||||
adapter._set_fatal_error("auth_error", "bad token", retryable=False)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
# Need to prevent stop() from running fully
|
||||
runner.stop = AsyncMock()
|
||||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
assert Platform.TELEGRAM not in runner._failed_platforms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_error_prevents_shutdown_when_queued(self):
|
||||
"""Gateway should not shut down if failed platforms are queued for reconnection."""
|
||||
runner = _make_runner()
|
||||
runner.stop = AsyncMock()
|
||||
|
||||
adapter = StubAdapter(succeed=True)
|
||||
adapter._set_fatal_error("network_error", "DNS failure", retryable=True)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
# stop() should NOT have been called since we have platforms queued
|
||||
runner.stop.assert_not_called()
|
||||
assert Platform.TELEGRAM in runner._failed_platforms
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonretryable_error_triggers_shutdown(self):
|
||||
"""Gateway should shut down when no adapters remain and nothing is queued."""
|
||||
runner = _make_runner()
|
||||
runner.stop = AsyncMock()
|
||||
|
||||
adapter = StubAdapter(succeed=True)
|
||||
adapter._set_fatal_error("auth_error", "bad token", retryable=False)
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
runner.stop.assert_called_once()
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Tests for /queue message consumption after normal agent completion.
|
||||
|
||||
Verifies that messages queued via /queue (which store in
|
||||
adapter._pending_messages WITHOUT triggering an interrupt) are consumed
|
||||
after the agent finishes its current task — not silently dropped.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.platforms.base import (
|
||||
BasePlatformAdapter,
|
||||
MessageEvent,
|
||||
MessageType,
|
||||
PlatformConfig,
|
||||
Platform,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal adapter for testing pending message storage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
from gateway.platforms.base import SendResult
|
||||
return SendResult(success=True, message_id="msg-1")
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id, "type": "dm"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestQueueMessageStorage:
|
||||
"""Verify /queue stores messages correctly in adapter._pending_messages."""
|
||||
|
||||
def test_queue_stores_message_in_pending(self):
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
event = MessageEvent(
|
||||
text="do this next",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||||
message_id="q1",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
assert session_key in adapter._pending_messages
|
||||
assert adapter._pending_messages[session_key].text == "do this next"
|
||||
|
||||
def test_get_pending_message_consumes_and_clears(self):
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
event = MessageEvent(
|
||||
text="queued prompt",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||||
message_id="q2",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
retrieved = adapter.get_pending_message(session_key)
|
||||
assert retrieved is not None
|
||||
assert retrieved.text == "queued prompt"
|
||||
# Should be consumed (cleared)
|
||||
assert adapter.get_pending_message(session_key) is None
|
||||
|
||||
def test_queue_does_not_set_interrupt_event(self):
|
||||
"""The whole point of /queue — no interrupt signal."""
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
# Simulate an active session (agent running)
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Store a queued message (what /queue does)
|
||||
event = MessageEvent(
|
||||
text="queued",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id="q3",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
# The interrupt event should NOT be set
|
||||
assert not adapter._active_sessions[session_key].is_set()
|
||||
assert not adapter.has_pending_interrupt(session_key)
|
||||
|
||||
def test_regular_message_sets_interrupt_event(self):
|
||||
"""Contrast: regular messages DO trigger interrupt."""
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
|
||||
# Simulate regular message arrival (what handle_message does)
|
||||
event = MessageEvent(
|
||||
text="new message",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id="m1",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
adapter._active_sessions[session_key].set() # this is what handle_message does
|
||||
|
||||
assert adapter.has_pending_interrupt(session_key)
|
||||
|
||||
|
||||
class TestQueueConsumptionAfterCompletion:
|
||||
"""Verify that pending messages are consumed after normal completion."""
|
||||
|
||||
def test_pending_message_available_after_normal_completion(self):
|
||||
"""After agent finishes without interrupt, pending message should
|
||||
still be retrievable from adapter._pending_messages."""
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
# Simulate: agent starts, /queue stores a message, agent finishes
|
||||
adapter._active_sessions[session_key] = asyncio.Event()
|
||||
event = MessageEvent(
|
||||
text="process this after",
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id="q4",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
# Agent finishes (no interrupt)
|
||||
del adapter._active_sessions[session_key]
|
||||
|
||||
# The queued message should still be retrievable
|
||||
retrieved = adapter.get_pending_message(session_key)
|
||||
assert retrieved is not None
|
||||
assert retrieved.text == "process this after"
|
||||
|
||||
def test_multiple_queues_last_one_wins(self):
|
||||
"""If user /queue's multiple times, last message overwrites."""
|
||||
adapter = _StubAdapter()
|
||||
session_key = "telegram:user:123"
|
||||
|
||||
for text in ["first", "second", "third"]:
|
||||
event = MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=MagicMock(),
|
||||
message_id=f"q-{text}",
|
||||
)
|
||||
adapter._pending_messages[session_key] = event
|
||||
|
||||
retrieved = adapter.get_pending_message(session_key)
|
||||
assert retrieved.text == "third"
|
||||
@@ -56,7 +56,7 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
|
||||
|
||||
class FakeAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs["tool_progress_callback"]
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
@@ -27,6 +29,23 @@ class _FatalAdapter(BasePlatformAdapter):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
class _RuntimeRetryableAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="token"), Platform.WHATSAPP)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_requests_clean_exit_for_nonretryable_startup_conflict(monkeypatch, tmp_path):
|
||||
config = GatewayConfig(
|
||||
@@ -44,3 +63,33 @@ async def test_runner_requests_clean_exit_for_nonretryable_startup_conflict(monk
|
||||
assert ok is True
|
||||
assert runner.should_exit_cleanly is True
|
||||
assert "already using this Telegram bot token" in runner.exit_reason
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_queues_retryable_runtime_fatal_for_reconnection(monkeypatch, tmp_path):
|
||||
"""Retryable runtime fatal errors queue the platform for reconnection
|
||||
instead of shutting down the gateway."""
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.WHATSAPP: PlatformConfig(enabled=True, token="token")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
adapter = _RuntimeRetryableAdapter()
|
||||
adapter._set_fatal_error(
|
||||
"whatsapp_bridge_exited",
|
||||
"WhatsApp bridge process exited unexpectedly (code 1).",
|
||||
retryable=True,
|
||||
)
|
||||
|
||||
runner.adapters = {Platform.WHATSAPP: adapter}
|
||||
runner.delivery_router.adapters = runner.adapters
|
||||
runner.stop = AsyncMock()
|
||||
|
||||
await runner._handle_adapter_fatal_error(adapter)
|
||||
|
||||
# Should NOT shut down — platform is queued for reconnection
|
||||
runner.stop.assert_not_awaited()
|
||||
assert Platform.WHATSAPP in runner._failed_platforms
|
||||
assert runner._failed_platforms[Platform.WHATSAPP]["attempts"] == 0
|
||||
|
||||
@@ -147,6 +147,26 @@ class TestTelegramSendImageFile:
|
||||
call_kwargs = adapter._bot.send_photo.call_args.kwargs
|
||||
assert len(call_kwargs["caption"]) == 1024
|
||||
|
||||
def test_thread_id_forwarded(self, adapter, tmp_path):
|
||||
"""metadata thread_id is forwarded as message_thread_id (required for Telegram forum groups)."""
|
||||
img = tmp_path / "shot.png"
|
||||
img.write_bytes(b"\x89PNG" + b"\x00" * 50)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 43
|
||||
adapter._bot.send_photo = AsyncMock(return_value=mock_msg)
|
||||
|
||||
_run(
|
||||
adapter.send_image_file(
|
||||
chat_id="12345",
|
||||
image_path=str(img),
|
||||
metadata={"thread_id": "789"},
|
||||
)
|
||||
)
|
||||
|
||||
call_kwargs = adapter._bot.send_photo.call_args.kwargs
|
||||
assert call_kwargs["message_thread_id"] == 789
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord send_image_file tests
|
||||
|
||||
@@ -212,6 +212,61 @@ class TestSessionHygieneWarnThreshold:
|
||||
assert post_compress_tokens < warn_threshold
|
||||
|
||||
|
||||
class TestEstimatedTokenThreshold:
|
||||
"""Verify that hygiene thresholds are always below the model's context
|
||||
limit — for both actual and estimated token counts.
|
||||
|
||||
Regression: a previous 1.4x multiplier on rough estimates pushed the
|
||||
threshold to 85% * 1.4 = 119% of context, which exceeded the model's
|
||||
limit and prevented hygiene from ever firing for ~200K models (GLM-5).
|
||||
The fix removed the multiplier entirely — the 85% threshold already
|
||||
provides ample headroom over the agent's 50% compressor.
|
||||
"""
|
||||
|
||||
def test_threshold_below_context_for_200k_model(self):
|
||||
"""Hygiene threshold must always be below model context."""
|
||||
context_length = 200_000
|
||||
threshold = int(context_length * 0.85)
|
||||
assert threshold < context_length
|
||||
|
||||
def test_threshold_below_context_for_128k_model(self):
|
||||
context_length = 128_000
|
||||
threshold = int(context_length * 0.85)
|
||||
assert threshold < context_length
|
||||
|
||||
def test_no_multiplier_means_same_threshold_for_estimated_and_actual(self):
|
||||
"""Without the 1.4x, estimated and actual token paths use the same threshold."""
|
||||
context_length = 200_000
|
||||
threshold_pct = 0.85
|
||||
threshold = int(context_length * threshold_pct)
|
||||
# Both paths should use 170K — no inflation
|
||||
assert threshold == 170_000
|
||||
|
||||
def test_warn_threshold_below_context(self):
|
||||
"""Warn threshold (95%) must be below context length."""
|
||||
for ctx in (128_000, 200_000, 1_000_000):
|
||||
warn = int(ctx * 0.95)
|
||||
assert warn < ctx
|
||||
|
||||
def test_overestimate_fires_early_but_safely(self):
|
||||
"""If rough estimate is 50% inflated, hygiene fires at ~57% actual usage.
|
||||
|
||||
That's between the agent's 50% threshold and the model's limit —
|
||||
safe and harmless.
|
||||
"""
|
||||
context_length = 200_000
|
||||
threshold = int(context_length * 0.85) # 170K
|
||||
# If actual tokens = 113K, rough estimate = 113K * 1.5 = 170K
|
||||
# Hygiene fires when estimate hits 170K, actual is ~113K = 57% of ctx
|
||||
actual_when_fires = threshold / 1.5
|
||||
assert actual_when_fires > context_length * 0.50, (
|
||||
"Early fire should still be above agent's 50% threshold"
|
||||
)
|
||||
assert actual_when_fires < context_length, (
|
||||
"Early fire must be well below model limit"
|
||||
)
|
||||
|
||||
|
||||
class TestTokenEstimation:
|
||||
"""Verify rough token estimation works as expected for hygiene checks."""
|
||||
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
"""Tests for session auto-reset notifications.
|
||||
|
||||
Verifies that:
|
||||
- _should_reset() returns a reason string ("idle" or "daily") instead of bool
|
||||
- SessionEntry captures auto_reset_reason
|
||||
- SessionResetPolicy.notify controls whether notifications are sent
|
||||
- notify_exclude_platforms skips notifications for excluded platforms
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import (
|
||||
GatewayConfig,
|
||||
Platform,
|
||||
PlatformConfig,
|
||||
SessionResetPolicy,
|
||||
)
|
||||
from gateway.session import SessionEntry, SessionSource, SessionStore
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_source(platform=Platform.TELEGRAM, chat_id="123", user_id="u1"):
|
||||
return SessionSource(
|
||||
platform=platform,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
def _make_store(policy=None, tmp_path=None):
|
||||
config = GatewayConfig()
|
||||
if policy:
|
||||
config.default_reset_policy = policy
|
||||
store = SessionStore(sessions_dir=tmp_path or "/tmp/test-sessions", config=config)
|
||||
return store
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _should_reset returns reason string
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestShouldResetReason:
|
||||
def test_returns_none_when_not_expired(self, tmp_path):
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="both", idle_minutes=60, at_hour=4),
|
||||
tmp_path,
|
||||
)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(), # just updated
|
||||
)
|
||||
source = _make_source()
|
||||
assert store._should_reset(entry, source) is None
|
||||
|
||||
def test_returns_idle_when_idle_expired(self, tmp_path):
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="idle", idle_minutes=30),
|
||||
tmp_path,
|
||||
)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now() - timedelta(hours=2),
|
||||
updated_at=datetime.now() - timedelta(hours=1), # 60min ago > 30min threshold
|
||||
)
|
||||
source = _make_source()
|
||||
assert store._should_reset(entry, source) == "idle"
|
||||
|
||||
def test_returns_daily_when_daily_boundary_crossed(self, tmp_path):
|
||||
now = datetime.now()
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="daily", at_hour=now.hour),
|
||||
tmp_path,
|
||||
)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=now - timedelta(days=2),
|
||||
updated_at=now - timedelta(days=1), # last active yesterday
|
||||
)
|
||||
source = _make_source()
|
||||
assert store._should_reset(entry, source) == "daily"
|
||||
|
||||
def test_returns_none_when_mode_is_none(self, tmp_path):
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="none"),
|
||||
tmp_path,
|
||||
)
|
||||
entry = SessionEntry(
|
||||
session_key="test",
|
||||
session_id="s1",
|
||||
created_at=datetime.now() - timedelta(days=30),
|
||||
updated_at=datetime.now() - timedelta(days=30),
|
||||
)
|
||||
source = _make_source()
|
||||
assert store._should_reset(entry, source) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionEntry captures reason
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionEntryReason:
|
||||
def test_auto_reset_reason_stored(self, tmp_path):
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="idle", idle_minutes=1),
|
||||
tmp_path,
|
||||
)
|
||||
source = _make_source()
|
||||
|
||||
# Create initial session
|
||||
entry1 = store.get_or_create_session(source)
|
||||
assert not entry1.was_auto_reset
|
||||
|
||||
# Age it past the idle threshold
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=5)
|
||||
store._save()
|
||||
|
||||
# Next call should create a new session with reason
|
||||
entry2 = store.get_or_create_session(source)
|
||||
assert entry2.was_auto_reset is True
|
||||
assert entry2.auto_reset_reason == "idle"
|
||||
assert entry2.session_id != entry1.session_id
|
||||
|
||||
def test_reset_had_activity_false_when_no_tokens(self, tmp_path):
|
||||
"""Expired session with no tokens → reset_had_activity=False."""
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="idle", idle_minutes=1),
|
||||
tmp_path,
|
||||
)
|
||||
source = _make_source()
|
||||
|
||||
entry1 = store.get_or_create_session(source)
|
||||
# No tokens used — session was idle with no conversation
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=5)
|
||||
store._save()
|
||||
|
||||
entry2 = store.get_or_create_session(source)
|
||||
assert entry2.was_auto_reset is True
|
||||
assert entry2.reset_had_activity is False
|
||||
|
||||
def test_reset_had_activity_true_when_tokens_used(self, tmp_path):
|
||||
"""Expired session with tokens → reset_had_activity=True."""
|
||||
store = _make_store(
|
||||
SessionResetPolicy(mode="idle", idle_minutes=1),
|
||||
tmp_path,
|
||||
)
|
||||
source = _make_source()
|
||||
|
||||
entry1 = store.get_or_create_session(source)
|
||||
# Simulate some conversation happened
|
||||
entry1.total_tokens = 5000
|
||||
entry1.updated_at = datetime.now() - timedelta(minutes=5)
|
||||
store._save()
|
||||
|
||||
entry2 = store.get_or_create_session(source)
|
||||
assert entry2.was_auto_reset is True
|
||||
assert entry2.reset_had_activity is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionResetPolicy notify config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResetPolicyNotify:
|
||||
def test_notify_defaults_true(self):
|
||||
policy = SessionResetPolicy()
|
||||
assert policy.notify is True
|
||||
|
||||
def test_notify_exclude_defaults(self):
|
||||
policy = SessionResetPolicy()
|
||||
assert "api_server" in policy.notify_exclude_platforms
|
||||
assert "webhook" in policy.notify_exclude_platforms
|
||||
|
||||
def test_from_dict_with_notify_false(self):
|
||||
policy = SessionResetPolicy.from_dict({"notify": False})
|
||||
assert policy.notify is False
|
||||
|
||||
def test_from_dict_with_custom_excludes(self):
|
||||
policy = SessionResetPolicy.from_dict({
|
||||
"notify_exclude_platforms": ["api_server", "webhook", "homeassistant"],
|
||||
})
|
||||
assert "homeassistant" in policy.notify_exclude_platforms
|
||||
|
||||
def test_from_dict_preserves_defaults_on_missing_keys(self):
|
||||
policy = SessionResetPolicy.from_dict({})
|
||||
assert policy.notify is True
|
||||
assert "api_server" in policy.notify_exclude_platforms
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
original = SessionResetPolicy(
|
||||
mode="idle",
|
||||
notify=False,
|
||||
notify_exclude_platforms=("api_server",),
|
||||
)
|
||||
restored = SessionResetPolicy.from_dict(original.to_dict())
|
||||
assert restored.notify == original.notify
|
||||
assert restored.notify_exclude_platforms == original.notify_exclude_platforms
|
||||
assert restored.mode == original.mode
|
||||
@@ -229,6 +229,10 @@ class TestSignalSessionSource:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalPhoneRedaction:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _ensure_redaction_enabled(self, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_REDACT_SECRETS", raising=False)
|
||||
|
||||
def test_us_number(self):
|
||||
from agent.redact import redact_sensitive_text
|
||||
result = redact_sensitive_text("Call +15551234567 now")
|
||||
|
||||
@@ -557,6 +557,25 @@ class TestSendDocument:
|
||||
call_kwargs = connected_adapter._bot.send_document.call_args[1]
|
||||
assert call_kwargs["reply_to_message_id"] == 50
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_document_thread_id(self, connected_adapter, tmp_path):
|
||||
"""metadata thread_id is forwarded as message_thread_id (required for Telegram forum groups)."""
|
||||
test_file = tmp_path / "report.pdf"
|
||||
test_file.write_bytes(b"%PDF-1.4 data")
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 103
|
||||
connected_adapter._bot.send_document = AsyncMock(return_value=mock_msg)
|
||||
|
||||
await connected_adapter.send_document(
|
||||
chat_id="12345",
|
||||
file_path=str(test_file),
|
||||
metadata={"thread_id": "789"},
|
||||
)
|
||||
|
||||
call_kwargs = connected_adapter._bot.send_document.call_args[1]
|
||||
assert call_kwargs["message_thread_id"] == 789
|
||||
|
||||
|
||||
class TestTelegramPhotoBatching:
|
||||
@pytest.mark.asyncio
|
||||
@@ -654,3 +673,22 @@ class TestSendVideo:
|
||||
|
||||
assert result.success is False
|
||||
assert "Not connected" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_video_thread_id(self, connected_adapter, tmp_path):
|
||||
"""metadata thread_id is forwarded as message_thread_id (required for Telegram forum groups)."""
|
||||
test_file = tmp_path / "clip.mp4"
|
||||
test_file.write_bytes(b"\x00\x00\x00\x1c" + b"ftyp" + b"\x00" * 100)
|
||||
|
||||
mock_msg = MagicMock()
|
||||
mock_msg.message_id = 201
|
||||
connected_adapter._bot.send_video = AsyncMock(return_value=mock_msg)
|
||||
|
||||
await connected_adapter.send_video(
|
||||
chat_id="12345",
|
||||
video_path=str(test_file),
|
||||
metadata={"thread_id": "789"},
|
||||
)
|
||||
|
||||
call_kwargs = connected_adapter._bot.send_video.call_args[1]
|
||||
assert call_kwargs["message_thread_id"] == 789
|
||||
|
||||
@@ -53,6 +53,15 @@ def _make_adapter():
|
||||
adapter._bridge_process = None
|
||||
adapter._reply_prefix = None
|
||||
adapter._running = False
|
||||
adapter._message_handler = None
|
||||
adapter._fatal_error_code = None
|
||||
adapter._fatal_error_message = None
|
||||
adapter._fatal_error_retryable = True
|
||||
adapter._fatal_error_handler = None
|
||||
adapter._active_sessions = {}
|
||||
adapter._pending_messages = {}
|
||||
adapter._background_tasks = set()
|
||||
adapter._auto_tts_disabled_chats = set()
|
||||
adapter._message_queue = asyncio.Queue()
|
||||
return adapter
|
||||
|
||||
@@ -200,6 +209,54 @@ class TestFileHandleClosedOnError:
|
||||
mock_fh.close.assert_called_once()
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
|
||||
class TestBridgeRuntimeFailure:
|
||||
"""Verify runtime bridge death is surfaced as a fatal adapter error."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_marks_retryable_fatal_when_managed_bridge_exits(self):
|
||||
adapter = _make_adapter()
|
||||
fatal_handler = AsyncMock()
|
||||
adapter.set_fatal_error_handler(fatal_handler)
|
||||
adapter._running = True
|
||||
mock_fh = MagicMock()
|
||||
adapter._bridge_log_fh = mock_fh
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = 7
|
||||
adapter._bridge_process = mock_proc
|
||||
|
||||
result = await adapter.send("chat-123", "hello")
|
||||
|
||||
assert result.success is False
|
||||
assert "exited unexpectedly" in result.error
|
||||
assert adapter.fatal_error_code == "whatsapp_bridge_exited"
|
||||
assert adapter.fatal_error_retryable is True
|
||||
fatal_handler.assert_awaited_once()
|
||||
mock_fh.close.assert_called_once()
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_poll_messages_marks_retryable_fatal_when_managed_bridge_exits(self):
|
||||
adapter = _make_adapter()
|
||||
fatal_handler = AsyncMock()
|
||||
adapter.set_fatal_error_handler(fatal_handler)
|
||||
adapter._running = True
|
||||
mock_fh = MagicMock()
|
||||
adapter._bridge_log_fh = mock_fh
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.poll.return_value = 23
|
||||
adapter._bridge_process = mock_proc
|
||||
|
||||
await adapter._poll_messages()
|
||||
|
||||
assert adapter.fatal_error_code == "whatsapp_bridge_exited"
|
||||
assert adapter.fatal_error_retryable is True
|
||||
fatal_handler.assert_awaited_once()
|
||||
mock_fh.close.assert_called_once()
|
||||
assert adapter._bridge_log_fh is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_closed_when_http_not_ready(self):
|
||||
"""Health endpoint never returns 200 within 15 attempts."""
|
||||
|
||||
@@ -290,21 +290,17 @@ class TestEnsureUserSystemdEnv:
|
||||
monkeypatch.delenv("DBUS_SESSION_BUS_ADDRESS", raising=False)
|
||||
monkeypatch.setattr(os, "getuid", lambda: 42)
|
||||
|
||||
# Patch Path so /run/user/42 resolves to our tmp dir (which exists)
|
||||
from pathlib import Path as RealPath
|
||||
|
||||
class FakePath(type(RealPath())):
|
||||
def __new__(cls, *args):
|
||||
p = str(args[0]) if args else ""
|
||||
if p == "/run/user/42":
|
||||
return RealPath.__new__(cls, str(tmp_path))
|
||||
return RealPath.__new__(cls, *args)
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "Path", FakePath)
|
||||
# Patch Path.exists so /run/user/42 appears to exist.
|
||||
# Using a FakePath subclass breaks on Python 3.12+ where
|
||||
# PosixPath.__new__ ignores the redirected path argument.
|
||||
_orig_exists = gateway_cli.Path.exists
|
||||
monkeypatch.setattr(
|
||||
gateway_cli.Path, "exists",
|
||||
lambda self: True if str(self) == "/run/user/42" else _orig_exists(self),
|
||||
)
|
||||
|
||||
gateway_cli._ensure_user_systemd_env()
|
||||
|
||||
# Function sets the canonical string, not the fake path
|
||||
assert os.environ.get("XDG_RUNTIME_DIR") == "/run/user/42"
|
||||
|
||||
def test_sets_dbus_address_when_bus_socket_exists(self, tmp_path, monkeypatch):
|
||||
|
||||
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
Tests for hermes_cli.mcp_config — ``hermes mcp`` subcommands.
|
||||
|
||||
These tests mock the MCP server connection layer so they run without
|
||||
any actual MCP servers or API keys.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _isolate_config(tmp_path, monkeypatch):
|
||||
"""Redirect all config I/O to a temp directory."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.get_hermes_home", lambda: tmp_path
|
||||
)
|
||||
config_path = tmp_path / "config.yaml"
|
||||
env_path = tmp_path / ".env"
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.get_config_path", lambda: config_path
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.get_env_path", lambda: env_path
|
||||
)
|
||||
return tmp_path
|
||||
|
||||
|
||||
def _make_args(**kwargs):
|
||||
"""Build a minimal argparse.Namespace."""
|
||||
defaults = {
|
||||
"name": "test-server",
|
||||
"url": None,
|
||||
"command": None,
|
||||
"args": None,
|
||||
"auth": None,
|
||||
"mcp_action": None,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return argparse.Namespace(**defaults)
|
||||
|
||||
|
||||
def _seed_config(tmp_path: Path, mcp_servers: dict):
|
||||
"""Write a config.yaml with the given mcp_servers."""
|
||||
import yaml
|
||||
|
||||
config = {"mcp_servers": mcp_servers, "_config_version": 9}
|
||||
config_path = tmp_path / "config.yaml"
|
||||
with open(config_path, "w") as f:
|
||||
yaml.safe_dump(config, f)
|
||||
|
||||
|
||||
class FakeTool:
|
||||
"""Mimics an MCP tool object returned by the SDK."""
|
||||
|
||||
def __init__(self, name: str, description: str = ""):
|
||||
self.name = name
|
||||
self.description = description
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMcpList:
|
||||
def test_list_empty_config(self, tmp_path, capsys):
|
||||
from hermes_cli.mcp_config import cmd_mcp_list
|
||||
|
||||
cmd_mcp_list()
|
||||
out = capsys.readouterr().out
|
||||
assert "No MCP servers configured" in out
|
||||
|
||||
def test_list_with_servers(self, tmp_path, capsys):
|
||||
_seed_config(tmp_path, {
|
||||
"ink": {
|
||||
"url": "https://mcp.ml.ink/mcp",
|
||||
"enabled": True,
|
||||
"tools": {"include": ["create_service", "get_service"]},
|
||||
},
|
||||
"github": {
|
||||
"command": "npx",
|
||||
"args": ["@mcp/github"],
|
||||
"enabled": False,
|
||||
},
|
||||
})
|
||||
from hermes_cli.mcp_config import cmd_mcp_list
|
||||
|
||||
cmd_mcp_list()
|
||||
out = capsys.readouterr().out
|
||||
assert "ink" in out
|
||||
assert "github" in out
|
||||
assert "2 selected" in out # ink has 2 in include
|
||||
assert "disabled" in out # github is disabled
|
||||
|
||||
def test_list_enabled_default_true(self, tmp_path, capsys):
|
||||
"""Server without explicit enabled key defaults to enabled."""
|
||||
_seed_config(tmp_path, {
|
||||
"myserver": {"url": "https://example.com/mcp"},
|
||||
})
|
||||
from hermes_cli.mcp_config import cmd_mcp_list
|
||||
|
||||
cmd_mcp_list()
|
||||
out = capsys.readouterr().out
|
||||
assert "myserver" in out
|
||||
assert "enabled" in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_remove
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMcpRemove:
|
||||
def test_remove_existing_server(self, tmp_path, capsys, monkeypatch):
|
||||
_seed_config(tmp_path, {
|
||||
"myserver": {"url": "https://example.com/mcp"},
|
||||
})
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
from hermes_cli.mcp_config import cmd_mcp_remove
|
||||
|
||||
cmd_mcp_remove(_make_args(name="myserver"))
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Removed" in out
|
||||
|
||||
# Verify config updated
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
assert "myserver" not in config.get("mcp_servers", {})
|
||||
|
||||
def test_remove_nonexistent(self, tmp_path, capsys):
|
||||
_seed_config(tmp_path, {})
|
||||
from hermes_cli.mcp_config import cmd_mcp_remove
|
||||
|
||||
cmd_mcp_remove(_make_args(name="ghost"))
|
||||
out = capsys.readouterr().out
|
||||
assert "not found" in out
|
||||
|
||||
def test_remove_cleans_oauth_tokens(self, tmp_path, capsys, monkeypatch):
|
||||
_seed_config(tmp_path, {
|
||||
"oauth-srv": {"url": "https://example.com/mcp", "auth": "oauth"},
|
||||
})
|
||||
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||
# Also patch get_hermes_home in the mcp_config module namespace
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config.get_hermes_home", lambda: tmp_path
|
||||
)
|
||||
|
||||
# Create a fake token file
|
||||
token_dir = tmp_path / "mcp-tokens"
|
||||
token_dir.mkdir()
|
||||
token_file = token_dir / "oauth-srv.json"
|
||||
token_file.write_text("{}")
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_remove
|
||||
|
||||
cmd_mcp_remove(_make_args(name="oauth-srv"))
|
||||
assert not token_file.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_add
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMcpAdd:
|
||||
def test_add_no_transport(self, capsys):
|
||||
"""Must specify --url or --command."""
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(name="bad"))
|
||||
out = capsys.readouterr().out
|
||||
assert "Must specify" in out
|
||||
|
||||
def test_add_http_server_all_tools(self, tmp_path, capsys, monkeypatch):
|
||||
"""Add an HTTP server, accept all tools."""
|
||||
fake_tools = [
|
||||
FakeTool("create_service", "Deploy from repo"),
|
||||
FakeTool("list_services", "List all services"),
|
||||
]
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
return [(t.name, t.description) for t in fake_tools]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
# No auth, accept all tools
|
||||
inputs = iter(["n", ""]) # no auth needed, enable all
|
||||
monkeypatch.setattr("builtins.input", lambda _: next(inputs))
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(name="ink", url="https://mcp.ml.ink/mcp"))
|
||||
out = capsys.readouterr().out
|
||||
assert "Saved" in out
|
||||
assert "2/2 tools" in out
|
||||
|
||||
# Verify config written
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
assert "ink" in config.get("mcp_servers", {})
|
||||
assert config["mcp_servers"]["ink"]["url"] == "https://mcp.ml.ink/mcp"
|
||||
|
||||
def test_add_stdio_server(self, tmp_path, capsys, monkeypatch):
|
||||
"""Add a stdio server."""
|
||||
fake_tools = [FakeTool("search", "Search repos")]
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
return [(t.name, t.description) for t in fake_tools]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
inputs = iter([""]) # accept all tools
|
||||
monkeypatch.setattr("builtins.input", lambda _: next(inputs))
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(
|
||||
name="github",
|
||||
command="npx",
|
||||
args=["@mcp/github"],
|
||||
))
|
||||
out = capsys.readouterr().out
|
||||
assert "Saved" in out
|
||||
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
srv = config["mcp_servers"]["github"]
|
||||
assert srv["command"] == "npx"
|
||||
assert srv["args"] == ["@mcp/github"]
|
||||
|
||||
def test_add_connection_failure_save_disabled(
|
||||
self, tmp_path, capsys, monkeypatch
|
||||
):
|
||||
"""Failed connection → option to save as disabled."""
|
||||
|
||||
def mock_probe_fail(name, config, **kw):
|
||||
raise ConnectionError("Connection refused")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe_fail
|
||||
)
|
||||
inputs = iter(["n", "y"]) # no auth, yes save disabled
|
||||
monkeypatch.setattr("builtins.input", lambda _: next(inputs))
|
||||
|
||||
from hermes_cli.mcp_config import cmd_mcp_add
|
||||
|
||||
cmd_mcp_add(_make_args(name="broken", url="https://bad.host/mcp"))
|
||||
out = capsys.readouterr().out
|
||||
assert "disabled" in out
|
||||
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
config = load_config()
|
||||
assert config["mcp_servers"]["broken"]["enabled"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMcpTest:
|
||||
def test_test_not_found(self, tmp_path, capsys):
|
||||
_seed_config(tmp_path, {})
|
||||
from hermes_cli.mcp_config import cmd_mcp_test
|
||||
|
||||
cmd_mcp_test(_make_args(name="ghost"))
|
||||
out = capsys.readouterr().out
|
||||
assert "not found" in out
|
||||
|
||||
def test_test_success(self, tmp_path, capsys, monkeypatch):
|
||||
_seed_config(tmp_path, {
|
||||
"ink": {"url": "https://mcp.ml.ink/mcp"},
|
||||
})
|
||||
|
||||
def mock_probe(name, config, **kw):
|
||||
return [("create_service", "Deploy"), ("list_services", "List all")]
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._probe_single_server", mock_probe
|
||||
)
|
||||
from hermes_cli.mcp_config import cmd_mcp_test
|
||||
|
||||
cmd_mcp_test(_make_args(name="ink"))
|
||||
out = capsys.readouterr().out
|
||||
assert "Connected" in out
|
||||
assert "Tools discovered: 2" in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: env var interpolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEnvVarInterpolation:
|
||||
def test_interpolate_simple(self, monkeypatch):
|
||||
monkeypatch.setenv("MY_KEY", "secret123")
|
||||
from tools.mcp_tool import _interpolate_env_vars
|
||||
|
||||
result = _interpolate_env_vars("Bearer ${MY_KEY}")
|
||||
assert result == "Bearer secret123"
|
||||
|
||||
def test_interpolate_missing_var(self, monkeypatch):
|
||||
monkeypatch.delenv("MISSING_VAR", raising=False)
|
||||
from tools.mcp_tool import _interpolate_env_vars
|
||||
|
||||
result = _interpolate_env_vars("Bearer ${MISSING_VAR}")
|
||||
assert result == "Bearer ${MISSING_VAR}"
|
||||
|
||||
def test_interpolate_nested_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("API_KEY", "abc")
|
||||
from tools.mcp_tool import _interpolate_env_vars
|
||||
|
||||
result = _interpolate_env_vars({
|
||||
"url": "https://example.com",
|
||||
"headers": {"Authorization": "Bearer ${API_KEY}"},
|
||||
})
|
||||
assert result["headers"]["Authorization"] == "Bearer abc"
|
||||
assert result["url"] == "https://example.com"
|
||||
|
||||
def test_interpolate_list(self, monkeypatch):
|
||||
monkeypatch.setenv("ARG1", "hello")
|
||||
from tools.mcp_tool import _interpolate_env_vars
|
||||
|
||||
result = _interpolate_env_vars(["${ARG1}", "static"])
|
||||
assert result == ["hello", "static"]
|
||||
|
||||
def test_interpolate_non_string(self):
|
||||
from tools.mcp_tool import _interpolate_env_vars
|
||||
|
||||
assert _interpolate_env_vars(42) == 42
|
||||
assert _interpolate_env_vars(True) is True
|
||||
assert _interpolate_env_vars(None) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: config helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConfigHelpers:
|
||||
def test_save_and_load_mcp_server(self, tmp_path):
|
||||
from hermes_cli.mcp_config import _save_mcp_server, _get_mcp_servers
|
||||
|
||||
_save_mcp_server("mysvr", {"url": "https://example.com/mcp"})
|
||||
servers = _get_mcp_servers()
|
||||
assert "mysvr" in servers
|
||||
assert servers["mysvr"]["url"] == "https://example.com/mcp"
|
||||
|
||||
def test_remove_mcp_server(self, tmp_path):
|
||||
from hermes_cli.mcp_config import (
|
||||
_save_mcp_server,
|
||||
_remove_mcp_server,
|
||||
_get_mcp_servers,
|
||||
)
|
||||
|
||||
_save_mcp_server("s1", {"command": "test"})
|
||||
_save_mcp_server("s2", {"command": "test2"})
|
||||
result = _remove_mcp_server("s1")
|
||||
assert result is True
|
||||
assert "s1" not in _get_mcp_servers()
|
||||
assert "s2" in _get_mcp_servers()
|
||||
|
||||
def test_remove_nonexistent(self, tmp_path):
|
||||
from hermes_cli.mcp_config import _remove_mcp_server
|
||||
|
||||
assert _remove_mcp_server("ghost") is False
|
||||
|
||||
def test_env_key_for_server(self):
|
||||
from hermes_cli.mcp_config import _env_key_for_server
|
||||
|
||||
assert _env_key_for_server("ink") == "MCP_INK_API_KEY"
|
||||
assert _env_key_for_server("my-server") == "MCP_MY_SERVER_API_KEY"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestDispatcher:
|
||||
def test_no_action_shows_list(self, tmp_path, capsys):
|
||||
from hermes_cli.mcp_config import mcp_command
|
||||
|
||||
_seed_config(tmp_path, {})
|
||||
mcp_command(_make_args(mcp_action=None))
|
||||
out = capsys.readouterr().out
|
||||
assert "Commands:" in out or "No MCP servers" in out
|
||||
@@ -100,3 +100,107 @@ def test_save_platform_tools_handles_invalid_existing_config():
|
||||
|
||||
saved_toolsets = config["platform_toolsets"]["cli"]
|
||||
assert "web" in saved_toolsets
|
||||
|
||||
|
||||
def test_save_platform_tools_does_not_preserve_platform_default_toolsets():
|
||||
"""Platform default toolsets (hermes-cli, hermes-telegram, etc.) must NOT
|
||||
be preserved across saves.
|
||||
|
||||
These "super" toolsets resolve to ALL tools, so if they survive in the
|
||||
config, they silently override any tools the user unchecked. Previously,
|
||||
the preserve filter only excluded configurable toolset keys (web, browser,
|
||||
terminal, etc.) and treated platform defaults as unknown custom entries
|
||||
(like MCP server names), causing them to be kept unconditionally.
|
||||
|
||||
Regression test: user unchecks image_gen and homeassistant via
|
||||
``hermes tools``, but hermes-cli stays in the config and re-enables
|
||||
everything on the next read.
|
||||
"""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"cli": [
|
||||
"browser", "clarify", "code_execution", "cronjob",
|
||||
"delegation", "file", "hermes-cli", # <-- the culprit
|
||||
"memory", "session_search", "skills", "terminal",
|
||||
"todo", "tts", "vision", "web",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# User unchecks image_gen, homeassistant, moa — keeps the rest
|
||||
new_selection = {
|
||||
"browser", "clarify", "code_execution", "cronjob",
|
||||
"delegation", "file", "memory", "session_search",
|
||||
"skills", "terminal", "todo", "tts", "vision", "web",
|
||||
}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "cli", new_selection)
|
||||
|
||||
saved = config["platform_toolsets"]["cli"]
|
||||
|
||||
# hermes-cli must NOT survive — it's a platform default, not an MCP server
|
||||
assert "hermes-cli" not in saved
|
||||
|
||||
# The individual toolset keys the user selected must be present
|
||||
assert "web" in saved
|
||||
assert "terminal" in saved
|
||||
assert "browser" in saved
|
||||
|
||||
# Tools the user unchecked must NOT be present
|
||||
assert "image_gen" not in saved
|
||||
assert "homeassistant" not in saved
|
||||
assert "moa" not in saved
|
||||
|
||||
|
||||
def test_save_platform_tools_does_not_preserve_hermes_telegram():
|
||||
"""Same bug for Telegram — hermes-telegram must not be preserved."""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"telegram": [
|
||||
"browser", "file", "hermes-telegram", "terminal", "web",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
new_selection = {"browser", "file", "terminal", "web"}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "telegram", new_selection)
|
||||
|
||||
saved = config["platform_toolsets"]["telegram"]
|
||||
assert "hermes-telegram" not in saved
|
||||
assert "web" in saved
|
||||
|
||||
|
||||
def test_save_platform_tools_still_preserves_mcp_with_platform_default_present():
|
||||
"""MCP server names must still be preserved even when platform defaults
|
||||
are being stripped out."""
|
||||
config = {
|
||||
"platform_toolsets": {
|
||||
"cli": [
|
||||
"web", "terminal", "hermes-cli", "my-mcp-server", "github-tools",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
new_selection = {"web", "browser"}
|
||||
|
||||
with patch("hermes_cli.tools_config.save_config"):
|
||||
_save_platform_tools(config, "cli", new_selection)
|
||||
|
||||
saved = config["platform_toolsets"]["cli"]
|
||||
|
||||
# MCP servers preserved
|
||||
assert "my-mcp-server" in saved
|
||||
assert "github-tools" in saved
|
||||
|
||||
# Platform default stripped
|
||||
assert "hermes-cli" not in saved
|
||||
|
||||
# User selections present
|
||||
assert "web" in saved
|
||||
assert "browser" in saved
|
||||
|
||||
# Deselected configurable toolset removed
|
||||
assert "terminal" not in saved
|
||||
|
||||
@@ -68,6 +68,8 @@ def test_restore_stashed_changes_prompts_before_applying(monkeypatch, tmp_path,
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="applied\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["diff", "--name-only"]:
|
||||
return SimpleNamespace(stdout="", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "list"]:
|
||||
return SimpleNamespace(stdout="stash@{1} abc123\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "drop"]:
|
||||
@@ -81,8 +83,9 @@ def test_restore_stashed_changes_prompts_before_applying(monkeypatch, tmp_path,
|
||||
|
||||
assert restored is True
|
||||
assert calls[0][0] == ["git", "stash", "apply", "abc123"]
|
||||
assert calls[1][0] == ["git", "stash", "list", "--format=%gd %H"]
|
||||
assert calls[2][0] == ["git", "stash", "drop", "stash@{1}"]
|
||||
assert calls[1][0] == ["git", "diff", "--name-only", "--diff-filter=U"]
|
||||
assert calls[2][0] == ["git", "stash", "list", "--format=%gd %H"]
|
||||
assert calls[3][0] == ["git", "stash", "drop", "stash@{1}"]
|
||||
out = capsys.readouterr().out
|
||||
assert "Restore local changes now? [Y/n]" in out
|
||||
assert "restored on top of the updated codebase" in out
|
||||
@@ -117,6 +120,8 @@ def test_restore_stashed_changes_applies_without_prompt_when_disabled(monkeypatc
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="applied\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["diff", "--name-only"]:
|
||||
return SimpleNamespace(stdout="", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "list"]:
|
||||
return SimpleNamespace(stdout="stash@{0} abc123\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "drop"]:
|
||||
@@ -129,8 +134,9 @@ def test_restore_stashed_changes_applies_without_prompt_when_disabled(monkeypatc
|
||||
|
||||
assert restored is True
|
||||
assert calls[0][0] == ["git", "stash", "apply", "abc123"]
|
||||
assert calls[1][0] == ["git", "stash", "list", "--format=%gd %H"]
|
||||
assert calls[2][0] == ["git", "stash", "drop", "stash@{0}"]
|
||||
assert calls[1][0] == ["git", "diff", "--name-only", "--diff-filter=U"]
|
||||
assert calls[2][0] == ["git", "stash", "list", "--format=%gd %H"]
|
||||
assert calls[3][0] == ["git", "stash", "drop", "stash@{0}"]
|
||||
assert "Restore local changes now?" not in capsys.readouterr().out
|
||||
|
||||
|
||||
@@ -152,6 +158,8 @@ def test_restore_stashed_changes_keeps_going_when_stash_entry_cannot_be_resolved
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="applied\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["diff", "--name-only"]:
|
||||
return SimpleNamespace(stdout="", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "list"]:
|
||||
return SimpleNamespace(stdout="stash@{0} def456\n", stderr="", returncode=0)
|
||||
raise AssertionError(f"unexpected command: {cmd}")
|
||||
@@ -161,10 +169,9 @@ def test_restore_stashed_changes_keeps_going_when_stash_entry_cannot_be_resolved
|
||||
restored = hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=False)
|
||||
|
||||
assert restored is True
|
||||
assert calls == [
|
||||
(["git", "stash", "apply", "abc123"], {"cwd": tmp_path, "capture_output": True, "text": True}),
|
||||
(["git", "stash", "list", "--format=%gd %H"], {"cwd": tmp_path, "capture_output": True, "text": True, "check": True}),
|
||||
]
|
||||
assert calls[0] == (["git", "stash", "apply", "abc123"], {"cwd": tmp_path, "capture_output": True, "text": True})
|
||||
assert calls[1] == (["git", "diff", "--name-only", "--diff-filter=U"], {"cwd": tmp_path, "capture_output": True, "text": True})
|
||||
assert calls[2] == (["git", "stash", "list", "--format=%gd %H"], {"cwd": tmp_path, "capture_output": True, "text": True, "check": True})
|
||||
out = capsys.readouterr().out
|
||||
assert "couldn't find the stash entry to drop" in out
|
||||
assert "stash was left in place" in out
|
||||
@@ -181,6 +188,8 @@ def test_restore_stashed_changes_keeps_going_when_drop_fails(monkeypatch, tmp_pa
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="applied\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["diff", "--name-only"]:
|
||||
return SimpleNamespace(stdout="", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "list"]:
|
||||
return SimpleNamespace(stdout="stash@{0} abc123\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["stash", "drop"]:
|
||||
@@ -192,7 +201,7 @@ def test_restore_stashed_changes_keeps_going_when_drop_fails(monkeypatch, tmp_pa
|
||||
restored = hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=False)
|
||||
|
||||
assert restored is True
|
||||
assert calls[2][0] == ["git", "stash", "drop", "stash@{0}"]
|
||||
assert calls[3][0] == ["git", "stash", "drop", "stash@{0}"]
|
||||
out = capsys.readouterr().out
|
||||
assert "couldn't drop the saved stash entry" in out
|
||||
assert "drop failed" in out
|
||||
@@ -201,13 +210,18 @@ def test_restore_stashed_changes_keeps_going_when_drop_fails(monkeypatch, tmp_pa
|
||||
assert "git stash drop stash@{0}" in out
|
||||
|
||||
|
||||
def test_restore_stashed_changes_exits_cleanly_when_apply_fails(monkeypatch, tmp_path, capsys):
|
||||
def test_restore_stashed_changes_prompts_before_reset_on_conflict(monkeypatch, tmp_path, capsys):
|
||||
"""When conflicts occur interactively, user is prompted before reset."""
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="conflict output\n", stderr="conflict stderr\n", returncode=1)
|
||||
if cmd[1:3] == ["diff", "--name-only"]:
|
||||
return SimpleNamespace(stdout="hermes_cli/main.py\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["reset", "--hard"]:
|
||||
return SimpleNamespace(stdout="", stderr="", returncode=0)
|
||||
raise AssertionError(f"unexpected command: {cmd}")
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
@@ -217,9 +231,64 @@ def test_restore_stashed_changes_exits_cleanly_when_apply_fails(monkeypatch, tmp
|
||||
hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=True)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Your changes are still preserved in git stash." in out
|
||||
assert "git stash apply abc123" in out
|
||||
assert calls == [(["git", "stash", "apply", "abc123"], {"cwd": tmp_path, "capture_output": True, "text": True})]
|
||||
assert "Conflicted files:" in out
|
||||
assert "hermes_cli/main.py" in out
|
||||
assert "stashed changes are preserved" in out
|
||||
assert "Reset working tree to clean state" in out
|
||||
assert "Working tree reset to clean state" in out
|
||||
reset_calls = [c for c, _ in calls if c[1:3] == ["reset", "--hard"]]
|
||||
assert len(reset_calls) == 1
|
||||
|
||||
|
||||
def test_restore_stashed_changes_user_declines_reset(monkeypatch, tmp_path, capsys):
|
||||
"""When user declines reset, working tree is left as-is."""
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="", stderr="conflict\n", returncode=1)
|
||||
if cmd[1:3] == ["diff", "--name-only"]:
|
||||
return SimpleNamespace(stdout="cli.py\n", stderr="", returncode=0)
|
||||
raise AssertionError(f"unexpected command: {cmd}")
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
# First input: "y" to restore, second input: "n" to decline reset
|
||||
inputs = iter(["y", "n"])
|
||||
monkeypatch.setattr("builtins.input", lambda: next(inputs))
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=True)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "left as-is" in out
|
||||
reset_calls = [c for c, _ in calls if c[1:3] == ["reset", "--hard"]]
|
||||
assert len(reset_calls) == 0
|
||||
|
||||
|
||||
def test_restore_stashed_changes_auto_resets_non_interactive(monkeypatch, tmp_path, capsys):
|
||||
"""Non-interactive mode auto-resets without prompting."""
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
calls.append((cmd, kwargs))
|
||||
if cmd[1:3] == ["stash", "apply"]:
|
||||
return SimpleNamespace(stdout="applied\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["diff", "--name-only"]:
|
||||
return SimpleNamespace(stdout="cli.py\n", stderr="", returncode=0)
|
||||
if cmd[1:3] == ["reset", "--hard"]:
|
||||
return SimpleNamespace(stdout="", stderr="", returncode=0)
|
||||
raise AssertionError(f"unexpected command: {cmd}")
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=False)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Working tree reset to clean state" in out
|
||||
reset_calls = [c for c, _ in calls if c[1:3] == ["reset", "--hard"]]
|
||||
assert len(reset_calls) == 1
|
||||
|
||||
|
||||
def test_stash_local_changes_if_needed_raises_when_stash_ref_missing(monkeypatch, tmp_path):
|
||||
|
||||
@@ -11,6 +11,7 @@ from honcho_integration.client import (
|
||||
HonchoClientConfig,
|
||||
get_honcho_client,
|
||||
reset_honcho_client,
|
||||
resolve_config_path,
|
||||
GLOBAL_CONFIG_PATH,
|
||||
HOST,
|
||||
)
|
||||
@@ -25,7 +26,7 @@ class TestHonchoClientConfigDefaults:
|
||||
assert config.environment == "production"
|
||||
assert config.enabled is False
|
||||
assert config.save_messages is True
|
||||
assert config.session_strategy == "per-session"
|
||||
assert config.session_strategy == "per-directory"
|
||||
assert config.recall_mode == "hybrid"
|
||||
assert config.session_peer_prefix is False
|
||||
assert config.linked_hosts == []
|
||||
@@ -157,7 +158,7 @@ class TestFromGlobalConfig:
|
||||
config_file = tmp_path / "config.json"
|
||||
config_file.write_text(json.dumps({"apiKey": "key"}))
|
||||
config = HonchoClientConfig.from_global_config(config_path=config_file)
|
||||
assert config.session_strategy == "per-session"
|
||||
assert config.session_strategy == "per-directory"
|
||||
|
||||
def test_context_tokens_host_block_wins(self, tmp_path):
|
||||
"""Host block contextTokens should override root."""
|
||||
@@ -330,6 +331,47 @@ class TestGetLinkedWorkspaces:
|
||||
assert "cursor" in workspaces
|
||||
|
||||
|
||||
class TestResolveConfigPath:
|
||||
def test_prefers_hermes_home_when_exists(self, tmp_path):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
local_cfg = hermes_home / "honcho.json"
|
||||
local_cfg.write_text('{"apiKey": "local"}')
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
result = resolve_config_path()
|
||||
assert result == local_cfg
|
||||
|
||||
def test_falls_back_to_global_when_no_local(self, tmp_path):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
# No honcho.json in HERMES_HOME
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
result = resolve_config_path()
|
||||
assert result == GLOBAL_CONFIG_PATH
|
||||
|
||||
def test_falls_back_to_global_without_hermes_home_env(self):
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("HERMES_HOME", None)
|
||||
result = resolve_config_path()
|
||||
assert result == GLOBAL_CONFIG_PATH
|
||||
|
||||
def test_from_global_config_uses_local_path(self, tmp_path):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
local_cfg = hermes_home / "honcho.json"
|
||||
local_cfg.write_text(json.dumps({
|
||||
"apiKey": "local-key",
|
||||
"workspace": "local-ws",
|
||||
}))
|
||||
|
||||
with patch.dict(os.environ, {"HERMES_HOME": str(hermes_home)}):
|
||||
config = HonchoClientConfig.from_global_config()
|
||||
assert config.api_key == "local-key"
|
||||
assert config.workspace_id == "local-ws"
|
||||
|
||||
|
||||
class TestResetHonchoClient:
|
||||
def test_reset_clears_singleton(self):
|
||||
import honcho_integration.client as mod
|
||||
|
||||
@@ -450,6 +450,12 @@ class TestNormalizeModelName:
|
||||
assert normalize_model_name("claude-opus-4-6") == "claude-opus-4-6"
|
||||
assert normalize_model_name("claude-opus-4-5-20251101") == "claude-opus-4-5-20251101"
|
||||
|
||||
def test_preserve_dots_for_alibaba_dashscope(self):
|
||||
"""Alibaba/DashScope use dots in model names (e.g. qwen3.5-plus). Fixes #1739."""
|
||||
assert normalize_model_name("qwen3.5-plus", preserve_dots=True) == "qwen3.5-plus"
|
||||
assert normalize_model_name("anthropic/qwen3.5-plus", preserve_dots=True) == "qwen3.5-plus"
|
||||
assert normalize_model_name("qwen3.5-flash", preserve_dots=True) == "qwen3.5-flash"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool conversion
|
||||
@@ -712,7 +718,7 @@ class TestConvertMessages:
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "result"},
|
||||
])
|
||||
], native_anthropic=True)
|
||||
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
user_msg = [m for m in result if m["role"] == "user"][0]
|
||||
|
||||
@@ -106,7 +106,8 @@ def test_main_raises_for_unknown_preloaded_skill(monkeypatch):
|
||||
cli_mod.main(skills="missing-skill", list_tools=True)
|
||||
|
||||
|
||||
def test_show_banner_prints_preloaded_skills_once_before_banner():
|
||||
def test_show_banner_does_not_print_skills():
|
||||
"""show_banner() no longer prints the activated skills line — it moved to run()."""
|
||||
cli_obj = _make_real_cli(compact=False)
|
||||
cli_obj.preloaded_skills = ["hermes-agent-dev", "github-auth"]
|
||||
cli_obj.console = MagicMock()
|
||||
@@ -115,7 +116,6 @@ def test_show_banner_prints_preloaded_skills_once_before_banner():
|
||||
"shutil.get_terminal_size", return_value=os.terminal_size((120, 40))
|
||||
):
|
||||
cli_obj.show_banner()
|
||||
cli_obj.show_banner()
|
||||
|
||||
print_calls = [
|
||||
call.args[0]
|
||||
@@ -123,8 +123,5 @@ def test_show_banner_prints_preloaded_skills_once_before_banner():
|
||||
if call.args and isinstance(call.args[0], str)
|
||||
]
|
||||
startup_lines = [line for line in print_calls if "Activated skills:" in line]
|
||||
|
||||
assert len(startup_lines) == 1
|
||||
assert "Activated skills:" in startup_lines[0]
|
||||
assert "hermes-agent-dev, github-auth" in startup_lines[0]
|
||||
assert mock_banner.call_count == 2
|
||||
assert len(startup_lines) == 0
|
||||
assert mock_banner.call_count == 1
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Tests for ${ENV_VAR} substitution in config.yaml values."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from hermes_cli.config import _expand_env_vars, load_config
|
||||
from unittest.mock import patch as mock_patch
|
||||
|
||||
|
||||
class TestExpandEnvVars:
|
||||
def test_simple_substitution(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("MY_KEY", "secret123")
|
||||
assert _expand_env_vars("${MY_KEY}") == "secret123"
|
||||
|
||||
def test_missing_var_kept_verbatim(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.delenv("UNDEFINED_VAR_XYZ", raising=False)
|
||||
assert _expand_env_vars("${UNDEFINED_VAR_XYZ}") == "${UNDEFINED_VAR_XYZ}"
|
||||
|
||||
def test_no_placeholder_unchanged(self):
|
||||
assert _expand_env_vars("plain-value") == "plain-value"
|
||||
|
||||
def test_dict_recursive(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("TOKEN", "tok-abc")
|
||||
result = _expand_env_vars({"key": "${TOKEN}", "other": "literal"})
|
||||
assert result == {"key": "tok-abc", "other": "literal"}
|
||||
|
||||
def test_nested_dict(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("API_KEY", "sk-xyz")
|
||||
result = _expand_env_vars({"model": {"api_key": "${API_KEY}"}})
|
||||
assert result["model"]["api_key"] == "sk-xyz"
|
||||
|
||||
def test_list_items(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("VAL", "hello")
|
||||
result = _expand_env_vars(["${VAL}", "literal", 42])
|
||||
assert result == ["hello", "literal", 42]
|
||||
|
||||
def test_non_string_values_untouched(self):
|
||||
assert _expand_env_vars(42) == 42
|
||||
assert _expand_env_vars(3.14) == 3.14
|
||||
assert _expand_env_vars(True) is True
|
||||
assert _expand_env_vars(None) is None
|
||||
|
||||
def test_multiple_placeholders_in_one_string(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("HOST", "localhost")
|
||||
mp.setenv("PORT", "5432")
|
||||
assert _expand_env_vars("${HOST}:${PORT}") == "localhost:5432"
|
||||
|
||||
def test_dict_keys_not_expanded(self):
|
||||
with pytest.MonkeyPatch().context() as mp:
|
||||
mp.setenv("KEY", "value")
|
||||
result = _expand_env_vars({"${KEY}": "no-expand-key"})
|
||||
assert "${KEY}" in result
|
||||
|
||||
|
||||
class TestLoadConfigExpansion:
|
||||
def test_load_config_expands_env_vars(self, tmp_path, monkeypatch):
|
||||
config_yaml = (
|
||||
"model:\n"
|
||||
" api_key: ${GOOGLE_API_KEY}\n"
|
||||
"platforms:\n"
|
||||
" telegram:\n"
|
||||
" token: ${TELEGRAM_BOT_TOKEN}\n"
|
||||
"plain: no-substitution\n"
|
||||
)
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(config_yaml)
|
||||
|
||||
monkeypatch.setenv("GOOGLE_API_KEY", "gsk-test-key")
|
||||
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "1234567:ABC-token")
|
||||
monkeypatch.setattr("hermes_cli.config.get_config_path", lambda: config_file)
|
||||
|
||||
config = load_config()
|
||||
|
||||
assert config["model"]["api_key"] == "gsk-test-key"
|
||||
assert config["platforms"]["telegram"]["token"] == "1234567:ABC-token"
|
||||
assert config["plain"] == "no-substitution"
|
||||
|
||||
def test_load_config_unresolved_kept_verbatim(self, tmp_path, monkeypatch):
|
||||
config_yaml = "model:\n api_key: ${NOT_SET_XYZ_123}\n"
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(config_yaml)
|
||||
|
||||
monkeypatch.delenv("NOT_SET_XYZ_123", raising=False)
|
||||
monkeypatch.setattr("hermes_cli.config.get_config_path", lambda: config_file)
|
||||
|
||||
config = load_config()
|
||||
|
||||
assert config["model"]["api_key"] == "${NOT_SET_XYZ_123}"
|
||||
|
||||
|
||||
class TestLoadCliConfigExpansion:
|
||||
"""Verify that load_cli_config() also expands ${VAR} references."""
|
||||
|
||||
def test_cli_config_expands_auxiliary_api_key(self, tmp_path, monkeypatch):
|
||||
config_yaml = (
|
||||
"auxiliary:\n"
|
||||
" vision:\n"
|
||||
" api_key: ${TEST_VISION_KEY_XYZ}\n"
|
||||
)
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(config_yaml)
|
||||
|
||||
monkeypatch.setenv("TEST_VISION_KEY_XYZ", "vis-key-123")
|
||||
# Patch the hermes home so load_cli_config finds our test config
|
||||
monkeypatch.setattr("cli._hermes_home", tmp_path)
|
||||
|
||||
from cli import load_cli_config
|
||||
config = load_cli_config()
|
||||
|
||||
assert config["auxiliary"]["vision"]["api_key"] == "vis-key-123"
|
||||
|
||||
def test_cli_config_unresolved_kept_verbatim(self, tmp_path, monkeypatch):
|
||||
config_yaml = (
|
||||
"auxiliary:\n"
|
||||
" vision:\n"
|
||||
" api_key: ${UNSET_CLI_VAR_ABC}\n"
|
||||
)
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text(config_yaml)
|
||||
|
||||
monkeypatch.delenv("UNSET_CLI_VAR_ABC", raising=False)
|
||||
monkeypatch.setattr("cli._hermes_home", tmp_path)
|
||||
|
||||
from cli import load_cli_config
|
||||
config = load_cli_config()
|
||||
|
||||
assert config["auxiliary"]["vision"]["api_key"] == "${UNSET_CLI_VAR_ABC}"
|
||||
@@ -0,0 +1,268 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _git(cwd: Path, *args: str) -> str:
|
||||
result = subprocess.run(
|
||||
["git", *args],
|
||||
cwd=cwd,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_repo(tmp_path: Path) -> Path:
|
||||
repo = tmp_path / "repo"
|
||||
repo.mkdir()
|
||||
_git(repo, "init")
|
||||
_git(repo, "config", "user.name", "Hermes Tests")
|
||||
_git(repo, "config", "user.email", "tests@example.com")
|
||||
|
||||
(repo / "src").mkdir()
|
||||
(repo / "src" / "main.py").write_text(
|
||||
"def alpha():\n"
|
||||
" return 'a'\n\n"
|
||||
"def beta():\n"
|
||||
" return 'b'\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(repo / "src" / "helper.py").write_text("VALUE = 1\n", encoding="utf-8")
|
||||
(repo / "README.md").write_text("# Demo\n", encoding="utf-8")
|
||||
(repo / "blob.bin").write_bytes(b"\x00\x01\x02binary")
|
||||
|
||||
_git(repo, "add", ".")
|
||||
_git(repo, "commit", "-m", "initial")
|
||||
|
||||
(repo / "src" / "main.py").write_text(
|
||||
"def alpha():\n"
|
||||
" return 'changed'\n\n"
|
||||
"def beta():\n"
|
||||
" return 'b'\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(repo / "src" / "helper.py").write_text("VALUE = 2\n", encoding="utf-8")
|
||||
_git(repo, "add", "src/helper.py")
|
||||
return repo
|
||||
|
||||
|
||||
def test_parse_typed_references_ignores_emails_and_handles():
|
||||
from agent.context_references import parse_context_references
|
||||
|
||||
message = (
|
||||
"email me at user@example.com and ping @teammate "
|
||||
"but include @file:src/main.py:1-2 plus @diff and @git:2 "
|
||||
"and @url:https://example.com/docs"
|
||||
)
|
||||
|
||||
refs = parse_context_references(message)
|
||||
|
||||
assert [ref.kind for ref in refs] == ["file", "diff", "git", "url"]
|
||||
assert refs[0].target == "src/main.py"
|
||||
assert refs[0].line_start == 1
|
||||
assert refs[0].line_end == 2
|
||||
assert refs[2].target == "2"
|
||||
|
||||
|
||||
def test_parse_references_strips_trailing_punctuation():
|
||||
from agent.context_references import parse_context_references
|
||||
|
||||
refs = parse_context_references(
|
||||
"review @file:README.md, then see (@url:https://example.com/docs)."
|
||||
)
|
||||
|
||||
assert [ref.kind for ref in refs] == ["file", "url"]
|
||||
assert refs[0].target == "README.md"
|
||||
assert refs[1].target == "https://example.com/docs"
|
||||
|
||||
|
||||
def test_expand_file_range_and_folder_listing(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
result = preprocess_context_references(
|
||||
"Review @file:src/main.py:1-2 and @folder:src/",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "Review and" in result.message
|
||||
assert "Review @file:src/main.py:1-2" not in result.message
|
||||
assert "--- Attached Context ---" in result.message
|
||||
assert "def alpha():" in result.message
|
||||
assert "return 'changed'" in result.message
|
||||
assert "def beta():" not in result.message
|
||||
assert "src/" in result.message
|
||||
assert "main.py" in result.message
|
||||
assert "helper.py" in result.message
|
||||
assert result.injected_tokens > 0
|
||||
assert not result.warnings
|
||||
|
||||
|
||||
def test_expand_git_diff_staged_and_log(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
result = preprocess_context_references(
|
||||
"Inspect @diff and @staged and @git:1",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "git diff" in result.message
|
||||
assert "git diff --staged" in result.message
|
||||
assert "git log -1 -p" in result.message
|
||||
assert "initial" in result.message
|
||||
assert "return 'changed'" in result.message
|
||||
assert "VALUE = 2" in result.message
|
||||
|
||||
|
||||
def test_binary_and_missing_files_become_warnings(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
result = preprocess_context_references(
|
||||
"Check @file:blob.bin and @file:nope.txt",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert len(result.warnings) == 2
|
||||
assert "binary" in result.message.lower()
|
||||
assert "not found" in result.message.lower()
|
||||
|
||||
|
||||
def test_soft_budget_warns_and_hard_budget_refuses(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
soft = preprocess_context_references(
|
||||
"Check @file:src/main.py",
|
||||
cwd=sample_repo,
|
||||
context_length=100,
|
||||
)
|
||||
assert soft.expanded
|
||||
assert any("25%" in warning for warning in soft.warnings)
|
||||
|
||||
hard = preprocess_context_references(
|
||||
"Check @file:src/main.py and @file:README.md",
|
||||
cwd=sample_repo,
|
||||
context_length=20,
|
||||
)
|
||||
assert not hard.expanded
|
||||
assert hard.blocked
|
||||
assert "@file:src/main.py" in hard.message
|
||||
assert any("50%" in warning for warning in hard.warnings)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_url_expansion_uses_fetcher(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references_async
|
||||
|
||||
async def fake_fetch(url: str) -> str:
|
||||
assert url == "https://example.com/spec"
|
||||
return "# Spec\n\nImportant details."
|
||||
|
||||
result = await preprocess_context_references_async(
|
||||
"Use @url:https://example.com/spec",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
url_fetcher=fake_fetch,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "Important details." in result.message
|
||||
assert result.injected_tokens > 0
|
||||
|
||||
|
||||
def test_sync_url_expansion_uses_async_fetcher(sample_repo: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
async def fake_fetch(url: str) -> str:
|
||||
await asyncio.sleep(0)
|
||||
return f"Content for {url}"
|
||||
|
||||
result = preprocess_context_references(
|
||||
"Use @url:https://example.com/spec",
|
||||
cwd=sample_repo,
|
||||
context_length=100_000,
|
||||
url_fetcher=fake_fetch,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "Content for https://example.com/spec" in result.message
|
||||
|
||||
|
||||
def test_restricts_paths_to_allowed_root(tmp_path: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
(workspace / "notes.txt").write_text("inside\n", encoding="utf-8")
|
||||
secret = tmp_path / "secret.txt"
|
||||
secret.write_text("outside\n", encoding="utf-8")
|
||||
|
||||
result = preprocess_context_references(
|
||||
"read @file:../secret.txt and @file:notes.txt",
|
||||
cwd=workspace,
|
||||
context_length=100_000,
|
||||
allowed_root=workspace,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "```\noutside\n```" not in result.message
|
||||
assert "inside" in result.message
|
||||
assert any("outside the allowed workspace" in warning for warning in result.warnings)
|
||||
|
||||
|
||||
def test_defaults_allowed_root_to_cwd(tmp_path: Path):
|
||||
from agent.context_references import preprocess_context_references
|
||||
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
secret = tmp_path / "secret.txt"
|
||||
secret.write_text("outside\n", encoding="utf-8")
|
||||
|
||||
result = preprocess_context_references(
|
||||
f"read @file:{secret}",
|
||||
cwd=workspace,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "```\noutside\n```" not in result.message
|
||||
assert any("outside the allowed workspace" in warning for warning in result.warnings)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_sensitive_home_and_hermes_paths(tmp_path: Path, monkeypatch):
|
||||
from agent.context_references import preprocess_context_references_async
|
||||
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
|
||||
hermes_env = tmp_path / ".hermes" / ".env"
|
||||
hermes_env.parent.mkdir(parents=True)
|
||||
hermes_env.write_text("API_KEY=super-secret\n", encoding="utf-8")
|
||||
|
||||
ssh_key = tmp_path / ".ssh" / "id_rsa"
|
||||
ssh_key.parent.mkdir(parents=True)
|
||||
ssh_key.write_text("PRIVATE-KEY\n", encoding="utf-8")
|
||||
|
||||
result = await preprocess_context_references_async(
|
||||
"read @file:.hermes/.env and @file:.ssh/id_rsa",
|
||||
cwd=tmp_path,
|
||||
allowed_root=tmp_path,
|
||||
context_length=100_000,
|
||||
)
|
||||
|
||||
assert result.expanded
|
||||
assert "API_KEY=super-secret" not in result.message
|
||||
assert "PRIVATE-KEY" not in result.message
|
||||
assert any("sensitive credential" in warning for warning in result.warnings)
|
||||
@@ -30,10 +30,22 @@ class _FakeAnthropicClient:
|
||||
pass
|
||||
|
||||
|
||||
class _FakeOpenAIClient:
|
||||
"""Fake OpenAI client returned by mocked resolve_provider_client."""
|
||||
api_key = "fake-codex-key"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
_default_headers = None
|
||||
|
||||
|
||||
def _make_agent(monkeypatch, api_mode, provider, response_fn):
|
||||
_patch_bootstrap(monkeypatch)
|
||||
if api_mode == "anthropic_messages":
|
||||
monkeypatch.setattr("agent.anthropic_adapter.build_anthropic_client", lambda k, b=None: _FakeAnthropicClient())
|
||||
if provider == "openai-codex":
|
||||
monkeypatch.setattr(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
lambda *a, **kw: (_FakeOpenAIClient(), "test-model"),
|
||||
)
|
||||
|
||||
class _A(run_agent.AIAgent):
|
||||
def __init__(self, *a, **kw):
|
||||
|
||||
+21
-2
@@ -281,7 +281,7 @@ class TestPluginToolVisibility:
|
||||
"""Plugin-registered tools appear in get_tool_definitions()."""
|
||||
|
||||
def test_plugin_tools_in_definitions(self, tmp_path, monkeypatch):
|
||||
"""Tools from plugins bypass the toolset filter."""
|
||||
"""Plugin tools are included when their toolset is in enabled_toolsets."""
|
||||
import hermes_cli.plugins as plugins_mod
|
||||
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
@@ -304,10 +304,22 @@ class TestPluginToolVisibility:
|
||||
monkeypatch.setattr(plugins_mod, "_plugin_manager", mgr)
|
||||
|
||||
from model_tools import get_tool_definitions
|
||||
tools = get_tool_definitions(enabled_toolsets=["terminal"], quiet_mode=True)
|
||||
|
||||
# Plugin tools are included when their toolset is explicitly enabled
|
||||
tools = get_tool_definitions(enabled_toolsets=["terminal", "plugin_vis_plugin"], quiet_mode=True)
|
||||
tool_names = [t["function"]["name"] for t in tools]
|
||||
assert "vis_tool" in tool_names
|
||||
|
||||
# Plugin tools are excluded when only other toolsets are enabled
|
||||
tools2 = get_tool_definitions(enabled_toolsets=["terminal"], quiet_mode=True)
|
||||
tool_names2 = [t["function"]["name"] for t in tools2]
|
||||
assert "vis_tool" not in tool_names2
|
||||
|
||||
# Plugin tools are included when no toolset filter is active (all enabled)
|
||||
tools3 = get_tool_definitions(quiet_mode=True)
|
||||
tool_names3 = [t["function"]["name"] for t in tools3]
|
||||
assert "vis_tool" in tool_names3
|
||||
|
||||
|
||||
# ── TestPluginManagerList ──────────────────────────────────────────────────
|
||||
|
||||
@@ -352,3 +364,10 @@ class TestPluginManagerList:
|
||||
assert "enabled" in p
|
||||
assert "tools" in p
|
||||
assert "hooks" in p
|
||||
|
||||
|
||||
|
||||
# NOTE: TestPluginCommands removed – register_command() was never implemented
|
||||
# in PluginContext (hermes_cli/plugins.py). The tests referenced _plugin_commands,
|
||||
# commands_registered, get_plugin_command_handler, and GATEWAY_KNOWN_COMMANDS
|
||||
# integration — all of which are unimplemented features.
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
"""Tests for hermes_cli.plugins_cmd — the ``hermes plugins`` CLI subcommand."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from hermes_cli.plugins_cmd import (
|
||||
_copy_example_files,
|
||||
_read_manifest,
|
||||
_repo_name_from_url,
|
||||
_resolve_git_url,
|
||||
_sanitize_plugin_name,
|
||||
plugins_command,
|
||||
)
|
||||
|
||||
|
||||
# ── _sanitize_plugin_name ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSanitizePluginName:
|
||||
"""Reject path-traversal attempts while accepting valid names."""
|
||||
|
||||
def test_valid_simple_name(self, tmp_path):
|
||||
target = _sanitize_plugin_name("my-plugin", tmp_path)
|
||||
assert target == (tmp_path / "my-plugin").resolve()
|
||||
|
||||
def test_valid_name_with_hyphen_and_digits(self, tmp_path):
|
||||
target = _sanitize_plugin_name("plugin-v2", tmp_path)
|
||||
assert target.name == "plugin-v2"
|
||||
|
||||
def test_rejects_dot_dot(self, tmp_path):
|
||||
with pytest.raises(ValueError, match="must not contain"):
|
||||
_sanitize_plugin_name("../../etc/passwd", tmp_path)
|
||||
|
||||
def test_rejects_single_dot_dot(self, tmp_path):
|
||||
with pytest.raises(ValueError, match="must not contain"):
|
||||
_sanitize_plugin_name("..", tmp_path)
|
||||
|
||||
def test_rejects_forward_slash(self, tmp_path):
|
||||
with pytest.raises(ValueError, match="must not contain"):
|
||||
_sanitize_plugin_name("foo/bar", tmp_path)
|
||||
|
||||
def test_rejects_backslash(self, tmp_path):
|
||||
with pytest.raises(ValueError, match="must not contain"):
|
||||
_sanitize_plugin_name("foo\\bar", tmp_path)
|
||||
|
||||
def test_rejects_absolute_path(self, tmp_path):
|
||||
with pytest.raises(ValueError, match="must not contain"):
|
||||
_sanitize_plugin_name("/etc/passwd", tmp_path)
|
||||
|
||||
def test_rejects_empty_name(self, tmp_path):
|
||||
with pytest.raises(ValueError, match="must not be empty"):
|
||||
_sanitize_plugin_name("", tmp_path)
|
||||
|
||||
|
||||
# ── _resolve_git_url ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveGitUrl:
|
||||
"""Shorthand and full-URL resolution."""
|
||||
|
||||
def test_owner_repo_shorthand(self):
|
||||
url = _resolve_git_url("owner/repo")
|
||||
assert url == "https://github.com/owner/repo.git"
|
||||
|
||||
def test_https_url_passthrough(self):
|
||||
url = _resolve_git_url("https://github.com/x/y.git")
|
||||
assert url == "https://github.com/x/y.git"
|
||||
|
||||
def test_ssh_url_passthrough(self):
|
||||
url = _resolve_git_url("git@github.com:x/y.git")
|
||||
assert url == "git@github.com:x/y.git"
|
||||
|
||||
def test_http_url_passthrough(self):
|
||||
url = _resolve_git_url("http://example.com/repo.git")
|
||||
assert url == "http://example.com/repo.git"
|
||||
|
||||
def test_file_url_passthrough(self):
|
||||
url = _resolve_git_url("file:///tmp/repo")
|
||||
assert url == "file:///tmp/repo"
|
||||
|
||||
def test_invalid_single_word_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid plugin identifier"):
|
||||
_resolve_git_url("justoneword")
|
||||
|
||||
def test_invalid_three_parts_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid plugin identifier"):
|
||||
_resolve_git_url("a/b/c")
|
||||
|
||||
|
||||
# ── _repo_name_from_url ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRepoNameFromUrl:
|
||||
"""Extract plugin directory name from Git URLs."""
|
||||
|
||||
def test_https_with_dot_git(self):
|
||||
assert (
|
||||
_repo_name_from_url("https://github.com/owner/my-plugin.git") == "my-plugin"
|
||||
)
|
||||
|
||||
def test_https_without_dot_git(self):
|
||||
assert _repo_name_from_url("https://github.com/owner/my-plugin") == "my-plugin"
|
||||
|
||||
def test_trailing_slash(self):
|
||||
assert _repo_name_from_url("https://github.com/owner/repo/") == "repo"
|
||||
|
||||
def test_ssh_style(self):
|
||||
assert _repo_name_from_url("git@github.com:owner/repo.git") == "repo"
|
||||
|
||||
def test_ssh_protocol(self):
|
||||
assert _repo_name_from_url("ssh://git@github.com/owner/repo.git") == "repo"
|
||||
|
||||
|
||||
# ── plugins_command dispatch ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginsCommandDispatch:
|
||||
"""Verify alias routing in plugins_command()."""
|
||||
|
||||
def _make_args(self, action, **extras):
|
||||
args = MagicMock()
|
||||
args.plugins_action = action
|
||||
for k, v in extras.items():
|
||||
setattr(args, k, v)
|
||||
return args
|
||||
|
||||
@patch("hermes_cli.plugins_cmd.cmd_remove")
|
||||
def test_rm_alias(self, mock_remove):
|
||||
args = self._make_args("rm", name="some-plugin")
|
||||
plugins_command(args)
|
||||
mock_remove.assert_called_once_with("some-plugin")
|
||||
|
||||
@patch("hermes_cli.plugins_cmd.cmd_remove")
|
||||
def test_uninstall_alias(self, mock_remove):
|
||||
args = self._make_args("uninstall", name="some-plugin")
|
||||
plugins_command(args)
|
||||
mock_remove.assert_called_once_with("some-plugin")
|
||||
|
||||
@patch("hermes_cli.plugins_cmd.cmd_list")
|
||||
def test_ls_alias(self, mock_list):
|
||||
args = self._make_args("ls")
|
||||
plugins_command(args)
|
||||
mock_list.assert_called_once()
|
||||
|
||||
@patch("hermes_cli.plugins_cmd.cmd_list")
|
||||
def test_none_falls_through_to_list(self, mock_list):
|
||||
args = self._make_args(None)
|
||||
plugins_command(args)
|
||||
mock_list.assert_called_once()
|
||||
|
||||
@patch("hermes_cli.plugins_cmd.cmd_install")
|
||||
def test_install_dispatches(self, mock_install):
|
||||
args = self._make_args("install", identifier="owner/repo", force=False)
|
||||
plugins_command(args)
|
||||
mock_install.assert_called_once_with("owner/repo", force=False)
|
||||
|
||||
@patch("hermes_cli.plugins_cmd.cmd_update")
|
||||
def test_update_dispatches(self, mock_update):
|
||||
args = self._make_args("update", name="foo")
|
||||
plugins_command(args)
|
||||
mock_update.assert_called_once_with("foo")
|
||||
|
||||
@patch("hermes_cli.plugins_cmd.cmd_remove")
|
||||
def test_remove_dispatches(self, mock_remove):
|
||||
args = self._make_args("remove", name="bar")
|
||||
plugins_command(args)
|
||||
mock_remove.assert_called_once_with("bar")
|
||||
|
||||
|
||||
# ── _read_manifest ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestReadManifest:
|
||||
"""Manifest reading edge cases."""
|
||||
|
||||
def test_valid_yaml(self, tmp_path):
|
||||
manifest = {"name": "cool-plugin", "version": "1.0.0"}
|
||||
(tmp_path / "plugin.yaml").write_text(yaml.dump(manifest))
|
||||
result = _read_manifest(tmp_path)
|
||||
assert result["name"] == "cool-plugin"
|
||||
assert result["version"] == "1.0.0"
|
||||
|
||||
def test_missing_file_returns_empty(self, tmp_path):
|
||||
result = _read_manifest(tmp_path)
|
||||
assert result == {}
|
||||
|
||||
def test_invalid_yaml_returns_empty_and_logs(self, tmp_path, caplog):
|
||||
(tmp_path / "plugin.yaml").write_text(": : : bad yaml [[[")
|
||||
with caplog.at_level(logging.WARNING, logger="hermes_cli.plugins_cmd"):
|
||||
result = _read_manifest(tmp_path)
|
||||
assert result == {}
|
||||
assert any("Failed to read plugin.yaml" in r.message for r in caplog.records)
|
||||
|
||||
def test_empty_file_returns_empty(self, tmp_path):
|
||||
(tmp_path / "plugin.yaml").write_text("")
|
||||
result = _read_manifest(tmp_path)
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ── cmd_install tests ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCmdInstall:
|
||||
"""Test the install command."""
|
||||
|
||||
def test_install_requires_identifier(self):
|
||||
from hermes_cli.plugins_cmd import cmd_install
|
||||
import argparse
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
cmd_install("")
|
||||
|
||||
@patch("hermes_cli.plugins_cmd._resolve_git_url")
|
||||
def test_install_validates_identifier(self, mock_resolve):
|
||||
from hermes_cli.plugins_cmd import cmd_install
|
||||
|
||||
mock_resolve.side_effect = ValueError("Invalid identifier")
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
cmd_install("invalid")
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
# ── cmd_update tests ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCmdUpdate:
|
||||
"""Test the update command."""
|
||||
|
||||
@patch("hermes_cli.plugins_cmd._sanitize_plugin_name")
|
||||
@patch("hermes_cli.plugins_cmd._plugins_dir")
|
||||
@patch("hermes_cli.plugins_cmd.subprocess.run")
|
||||
def test_update_git_pull_success(self, mock_run, mock_plugins_dir, mock_sanitize):
|
||||
from hermes_cli.plugins_cmd import cmd_update
|
||||
|
||||
mock_plugins_dir_val = MagicMock()
|
||||
mock_plugins_dir.return_value = mock_plugins_dir_val
|
||||
mock_target = MagicMock()
|
||||
mock_target.exists.return_value = True
|
||||
mock_target.__truediv__ = lambda self, x: MagicMock(
|
||||
exists=MagicMock(return_value=True)
|
||||
)
|
||||
mock_sanitize.return_value = mock_target
|
||||
|
||||
mock_run.return_value = MagicMock(returncode=0, stdout="Updated", stderr="")
|
||||
|
||||
cmd_update("test-plugin")
|
||||
|
||||
mock_run.assert_called_once()
|
||||
|
||||
@patch("hermes_cli.plugins_cmd._sanitize_plugin_name")
|
||||
@patch("hermes_cli.plugins_cmd._plugins_dir")
|
||||
def test_update_plugin_not_found(self, mock_plugins_dir, mock_sanitize):
|
||||
from hermes_cli.plugins_cmd import cmd_update
|
||||
|
||||
mock_plugins_dir_val = MagicMock()
|
||||
mock_plugins_dir_val.iterdir.return_value = []
|
||||
mock_plugins_dir.return_value = mock_plugins_dir_val
|
||||
mock_target = MagicMock()
|
||||
mock_target.exists.return_value = False
|
||||
mock_sanitize.return_value = mock_target
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
cmd_update("nonexistent-plugin")
|
||||
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
# ── cmd_remove tests ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCmdRemove:
|
||||
"""Test the remove command."""
|
||||
|
||||
@patch("hermes_cli.plugins_cmd._sanitize_plugin_name")
|
||||
@patch("hermes_cli.plugins_cmd._plugins_dir")
|
||||
@patch("hermes_cli.plugins_cmd.shutil.rmtree")
|
||||
def test_remove_deletes_plugin(self, mock_rmtree, mock_plugins_dir, mock_sanitize):
|
||||
from hermes_cli.plugins_cmd import cmd_remove
|
||||
|
||||
mock_plugins_dir.return_value = MagicMock()
|
||||
mock_target = MagicMock()
|
||||
mock_target.exists.return_value = True
|
||||
mock_sanitize.return_value = mock_target
|
||||
|
||||
cmd_remove("test-plugin")
|
||||
|
||||
mock_rmtree.assert_called_once_with(mock_target)
|
||||
|
||||
@patch("hermes_cli.plugins_cmd._sanitize_plugin_name")
|
||||
@patch("hermes_cli.plugins_cmd._plugins_dir")
|
||||
def test_remove_plugin_not_found(self, mock_plugins_dir, mock_sanitize):
|
||||
from hermes_cli.plugins_cmd import cmd_remove
|
||||
|
||||
mock_plugins_dir_val = MagicMock()
|
||||
mock_plugins_dir_val.iterdir.return_value = []
|
||||
mock_plugins_dir.return_value = mock_plugins_dir_val
|
||||
mock_target = MagicMock()
|
||||
mock_target.exists.return_value = False
|
||||
mock_sanitize.return_value = mock_target
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
cmd_remove("nonexistent-plugin")
|
||||
|
||||
assert exc_info.value.code == 1
|
||||
|
||||
|
||||
# ── cmd_list tests ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCmdList:
|
||||
"""Test the list command."""
|
||||
|
||||
@patch("hermes_cli.plugins_cmd._plugins_dir")
|
||||
def test_list_empty_plugins_dir(self, mock_plugins_dir):
|
||||
from hermes_cli.plugins_cmd import cmd_list
|
||||
|
||||
mock_plugins_dir_val = MagicMock()
|
||||
mock_plugins_dir_val.iterdir.return_value = []
|
||||
mock_plugins_dir.return_value = mock_plugins_dir_val
|
||||
|
||||
cmd_list()
|
||||
|
||||
@patch("hermes_cli.plugins_cmd._plugins_dir")
|
||||
@patch("hermes_cli.plugins_cmd._read_manifest")
|
||||
def test_list_with_plugins(self, mock_read_manifest, mock_plugins_dir):
|
||||
from hermes_cli.plugins_cmd import cmd_list
|
||||
|
||||
mock_plugins_dir_val = MagicMock()
|
||||
mock_plugin_dir = MagicMock()
|
||||
mock_plugin_dir.name = "test-plugin"
|
||||
mock_plugin_dir.is_dir.return_value = True
|
||||
mock_plugin_dir.__truediv__ = lambda self, x: MagicMock(
|
||||
exists=MagicMock(return_value=False)
|
||||
)
|
||||
mock_plugins_dir_val.iterdir.return_value = [mock_plugin_dir]
|
||||
mock_plugins_dir.return_value = mock_plugins_dir_val
|
||||
mock_read_manifest.return_value = {"name": "test-plugin", "version": "1.0.0"}
|
||||
|
||||
cmd_list()
|
||||
|
||||
|
||||
# ── _copy_example_files tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCopyExampleFiles:
|
||||
"""Test example file copying."""
|
||||
|
||||
def test_copies_example_files(self, tmp_path):
|
||||
from hermes_cli.plugins_cmd import _copy_example_files
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
console = MagicMock()
|
||||
|
||||
# Create example file
|
||||
example_file = tmp_path / "config.yaml.example"
|
||||
example_file.write_text("key: value")
|
||||
|
||||
_copy_example_files(tmp_path, console)
|
||||
|
||||
# Should have created the file
|
||||
assert (tmp_path / "config.yaml").exists()
|
||||
console.print.assert_called()
|
||||
|
||||
def test_skips_existing_files(self, tmp_path):
|
||||
from hermes_cli.plugins_cmd import _copy_example_files
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
console = MagicMock()
|
||||
|
||||
# Create both example and real file
|
||||
example_file = tmp_path / "config.yaml.example"
|
||||
example_file.write_text("key: value")
|
||||
real_file = tmp_path / "config.yaml"
|
||||
real_file.write_text("existing: true")
|
||||
|
||||
_copy_example_files(tmp_path, console)
|
||||
|
||||
# Should NOT have overwritten
|
||||
assert real_file.read_text() == "existing: true"
|
||||
|
||||
def test_handles_copy_error_gracefully(self, tmp_path):
|
||||
from hermes_cli.plugins_cmd import _copy_example_files
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
console = MagicMock()
|
||||
|
||||
# Create example file
|
||||
example_file = tmp_path / "config.yaml.example"
|
||||
example_file.write_text("key: value")
|
||||
|
||||
# Mock shutil.copy2 to raise an error
|
||||
with patch(
|
||||
"hermes_cli.plugins_cmd.shutil.copy2",
|
||||
side_effect=OSError("Permission denied"),
|
||||
):
|
||||
# Should not raise, just warn
|
||||
_copy_example_files(tmp_path, console)
|
||||
|
||||
# Should have printed a warning
|
||||
assert any("Warning" in str(c) for c in console.print.call_args_list)
|
||||
@@ -2413,6 +2413,7 @@ class TestAnthropicCredentialRefresh:
|
||||
agent._anthropic_client = old_client
|
||||
agent._anthropic_api_key = "sk-ant-oat01-stale-token"
|
||||
agent._anthropic_base_url = "https://api.anthropic.com"
|
||||
agent.provider = "anthropic"
|
||||
|
||||
with (
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-oat01-fresh-token"),
|
||||
@@ -2908,6 +2909,7 @@ class TestOAuthFlagAfterCredentialRefresh:
|
||||
def test_oauth_flag_updates_api_key_to_oauth(self, agent):
|
||||
"""Refreshing from API key to OAuth token must set flag to True."""
|
||||
agent.api_mode = "anthropic_messages"
|
||||
agent.provider = "anthropic"
|
||||
agent._anthropic_api_key = "sk-ant-api-old"
|
||||
agent._anthropic_client = MagicMock()
|
||||
agent._is_anthropic_oauth = False
|
||||
@@ -2926,6 +2928,7 @@ class TestOAuthFlagAfterCredentialRefresh:
|
||||
def test_oauth_flag_updates_oauth_to_api_key(self, agent):
|
||||
"""Refreshing from OAuth to API key must set flag to False."""
|
||||
agent.api_mode = "anthropic_messages"
|
||||
agent.provider = "anthropic"
|
||||
agent._anthropic_api_key = "sk-ant-setup-old"
|
||||
agent._anthropic_client = MagicMock()
|
||||
agent._is_anthropic_oauth = True
|
||||
|
||||
@@ -534,6 +534,34 @@ def test_minimax_explicit_api_mode_respected(monkeypatch):
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
|
||||
|
||||
def test_alibaba_default_anthropic_endpoint_uses_anthropic_messages(monkeypatch):
|
||||
"""Alibaba with default /apps/anthropic URL should use anthropic_messages mode."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "alibaba")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "test-dashscope-key")
|
||||
monkeypatch.delenv("DASHSCOPE_BASE_URL", raising=False)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="alibaba")
|
||||
|
||||
assert resolved["provider"] == "alibaba"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["base_url"] == "https://dashscope-intl.aliyuncs.com/apps/anthropic"
|
||||
|
||||
|
||||
def test_alibaba_openai_compatible_v1_endpoint_stays_chat_completions(monkeypatch):
|
||||
"""Alibaba with /v1 coding endpoint should use chat_completions mode."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "alibaba")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "test-dashscope-key")
|
||||
monkeypatch.setenv("DASHSCOPE_BASE_URL", "https://coding-intl.dashscope.aliyuncs.com/v1")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="alibaba")
|
||||
|
||||
assert resolved["provider"] == "alibaba"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["base_url"] == "https://coding-intl.dashscope.aliyuncs.com/v1"
|
||||
|
||||
|
||||
def test_named_custom_provider_anthropic_api_mode(monkeypatch):
|
||||
"""Custom providers should accept api_mode: anthropic_messages."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "my-anthropic-proxy")
|
||||
|
||||
@@ -209,3 +209,66 @@ class TestDeepSeekV3Parser:
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
|
||||
|
||||
# ─── Mistral parser tests ───────────────────────────────────────────────
|
||||
|
||||
class TestMistralParser:
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
return get_parser("mistral")
|
||||
|
||||
def test_no_tool_call(self, parser):
|
||||
text = "Hello, how can I help you?"
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert content == text
|
||||
assert tool_calls is None
|
||||
|
||||
def test_pre_v11_single_tool_call(self, parser):
|
||||
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"key": "val"}}]'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "func"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["key"] == "val"
|
||||
|
||||
def test_pre_v11_nested_json(self, parser):
|
||||
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"nested": {"deep": true}}}]'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "func"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["nested"]["deep"] is True
|
||||
|
||||
def test_v11_single_tool_call(self, parser):
|
||||
text = '[TOOL_CALLS]get_weather{"city": "London"}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "get_weather"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["city"] == "London"
|
||||
|
||||
def test_v11_multiple_tool_calls(self, parser):
|
||||
text = '[TOOL_CALLS]func1{"a": 1}[TOOL_CALLS]func2{"b": 2}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 2
|
||||
names = [tc.function.name for tc in tool_calls]
|
||||
assert "func1" in names
|
||||
assert "func2" in names
|
||||
|
||||
def test_preceding_text_preserved(self, parser):
|
||||
text = 'Hello[TOOL_CALLS]func{"a": 1}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert content == "Hello"
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "func"
|
||||
|
||||
def test_malformed_json_fallback(self, parser):
|
||||
text = "[TOOL_CALLS] not valid json"
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is None
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
"""Comprehensive tests for ANSI escape sequence stripping (ECMA-48).
|
||||
|
||||
The strip_ansi function in tools/ansi_strip.py is the source-level fix for
|
||||
ANSI codes leaking into the model's context via terminal/execute_code output.
|
||||
It must strip ALL terminal escape sequences while preserving legitimate text.
|
||||
"""
|
||||
|
||||
from tools.ansi_strip import strip_ansi
|
||||
|
||||
|
||||
class TestStripAnsiBasicSGR:
|
||||
"""Select Graphic Rendition — the most common ANSI sequences."""
|
||||
|
||||
def test_reset(self):
|
||||
assert strip_ansi("\x1b[0m") == ""
|
||||
|
||||
def test_color(self):
|
||||
assert strip_ansi("\x1b[31;1m") == ""
|
||||
|
||||
def test_truecolor_semicolon(self):
|
||||
assert strip_ansi("\x1b[38;2;255;0;0m") == ""
|
||||
|
||||
def test_truecolor_colon_separated(self):
|
||||
"""Modern terminals use colon-separated SGR params."""
|
||||
assert strip_ansi("\x1b[38:2:255:0:0m") == ""
|
||||
assert strip_ansi("\x1b[48:2:0:255:0m") == ""
|
||||
|
||||
|
||||
class TestStripAnsiCSIPrivateMode:
|
||||
"""CSI sequences with ? prefix (DEC private modes)."""
|
||||
|
||||
def test_cursor_show_hide(self):
|
||||
assert strip_ansi("\x1b[?25h") == ""
|
||||
assert strip_ansi("\x1b[?25l") == ""
|
||||
|
||||
def test_alt_screen(self):
|
||||
assert strip_ansi("\x1b[?1049h") == ""
|
||||
assert strip_ansi("\x1b[?1049l") == ""
|
||||
|
||||
def test_bracketed_paste(self):
|
||||
assert strip_ansi("\x1b[?2004h") == ""
|
||||
|
||||
|
||||
class TestStripAnsiCSIIntermediate:
|
||||
"""CSI sequences with intermediate bytes (space, etc.)."""
|
||||
|
||||
def test_cursor_shape(self):
|
||||
assert strip_ansi("\x1b[0 q") == ""
|
||||
assert strip_ansi("\x1b[2 q") == ""
|
||||
assert strip_ansi("\x1b[6 q") == ""
|
||||
|
||||
|
||||
class TestStripAnsiOSC:
|
||||
"""Operating System Command sequences."""
|
||||
|
||||
def test_bel_terminator(self):
|
||||
assert strip_ansi("\x1b]0;title\x07") == ""
|
||||
|
||||
def test_st_terminator(self):
|
||||
assert strip_ansi("\x1b]0;title\x1b\\") == ""
|
||||
|
||||
def test_hyperlink_preserves_text(self):
|
||||
assert strip_ansi(
|
||||
"\x1b]8;;https://example.com\x1b\\click\x1b]8;;\x1b\\"
|
||||
) == "click"
|
||||
|
||||
|
||||
class TestStripAnsiDECPrivate:
|
||||
"""DEC private / Fp escape sequences."""
|
||||
|
||||
def test_save_restore_cursor(self):
|
||||
assert strip_ansi("\x1b7") == ""
|
||||
assert strip_ansi("\x1b8") == ""
|
||||
|
||||
def test_keypad_modes(self):
|
||||
assert strip_ansi("\x1b=") == ""
|
||||
assert strip_ansi("\x1b>") == ""
|
||||
|
||||
|
||||
class TestStripAnsiFe:
|
||||
"""Fe (C1 as 7-bit) escape sequences."""
|
||||
|
||||
def test_reverse_index(self):
|
||||
assert strip_ansi("\x1bM") == ""
|
||||
|
||||
def test_reset_terminal(self):
|
||||
assert strip_ansi("\x1bc") == ""
|
||||
|
||||
def test_index_and_newline(self):
|
||||
assert strip_ansi("\x1bD") == ""
|
||||
assert strip_ansi("\x1bE") == ""
|
||||
|
||||
|
||||
class TestStripAnsiNF:
|
||||
"""nF (character set selection) sequences."""
|
||||
|
||||
def test_charset_selection(self):
|
||||
assert strip_ansi("\x1b(A") == ""
|
||||
assert strip_ansi("\x1b(B") == ""
|
||||
assert strip_ansi("\x1b(0") == ""
|
||||
|
||||
|
||||
class TestStripAnsiDCS:
|
||||
"""Device Control String sequences."""
|
||||
|
||||
def test_dcs(self):
|
||||
assert strip_ansi("\x1bP+q\x1b\\") == ""
|
||||
|
||||
|
||||
class TestStripAnsi8BitC1:
|
||||
"""8-bit C1 control characters."""
|
||||
|
||||
def test_8bit_csi(self):
|
||||
assert strip_ansi("\x9b31m") == ""
|
||||
assert strip_ansi("\x9b38;2;255;0;0m") == ""
|
||||
|
||||
def test_8bit_standalone(self):
|
||||
assert strip_ansi("\x9c") == ""
|
||||
assert strip_ansi("\x9d") == ""
|
||||
assert strip_ansi("\x90") == ""
|
||||
|
||||
|
||||
class TestStripAnsiRealWorld:
|
||||
"""Real-world contamination scenarios from bug reports."""
|
||||
|
||||
def test_colored_shebang(self):
|
||||
"""The original reported bug: shebang corrupted by color codes."""
|
||||
assert strip_ansi(
|
||||
"\x1b[32m#!/usr/bin/env python3\x1b[0m\nprint('hello')"
|
||||
) == "#!/usr/bin/env python3\nprint('hello')"
|
||||
|
||||
def test_stacked_sgr(self):
|
||||
assert strip_ansi(
|
||||
"\x1b[1m\x1b[31m\x1b[42mhello\x1b[0m"
|
||||
) == "hello"
|
||||
|
||||
def test_ansi_mid_code(self):
|
||||
assert strip_ansi(
|
||||
"def foo(\x1b[33m):\x1b[0m\n return 42"
|
||||
) == "def foo():\n return 42"
|
||||
|
||||
|
||||
class TestStripAnsiPassthrough:
|
||||
"""Clean content must pass through unmodified."""
|
||||
|
||||
def test_plain_text(self):
|
||||
assert strip_ansi("normal text") == "normal text"
|
||||
|
||||
def test_empty(self):
|
||||
assert strip_ansi("") == ""
|
||||
|
||||
def test_none(self):
|
||||
assert strip_ansi(None) is None
|
||||
|
||||
def test_whitespace_preserved(self):
|
||||
assert strip_ansi("line1\nline2\ttab") == "line1\nline2\ttab"
|
||||
|
||||
def test_unicode_safe(self):
|
||||
assert strip_ansi("emoji 🎉 and ñ café") == "emoji 🎉 and ñ café"
|
||||
|
||||
def test_backslash_in_code(self):
|
||||
code = "path = 'C:\\\\Users\\\\test'"
|
||||
assert strip_ansi(code) == code
|
||||
|
||||
def test_square_brackets_in_code(self):
|
||||
"""Array indexing must not be confused with CSI."""
|
||||
code = "arr[0] = arr[31]"
|
||||
assert strip_ansi(code) == code
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import patch as mock_patch
|
||||
|
||||
import tools.approval as approval_module
|
||||
from tools.approval import (
|
||||
_get_approval_mode,
|
||||
approve_session,
|
||||
clear_session,
|
||||
detect_dangerous_command,
|
||||
@@ -16,6 +17,16 @@ from tools.approval import (
|
||||
)
|
||||
|
||||
|
||||
class TestApprovalModeParsing:
|
||||
def test_unquoted_yaml_off_boolean_false_maps_to_off(self):
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"mode": False}}):
|
||||
assert _get_approval_mode() == "off"
|
||||
|
||||
def test_string_off_still_maps_to_off(self):
|
||||
with mock_patch("hermes_cli.config.load_config", return_value={"approvals": {"mode": "off"}}):
|
||||
assert _get_approval_mode() == "off"
|
||||
|
||||
|
||||
class TestDetectDangerousRm:
|
||||
def test_rm_rf_detected(self):
|
||||
is_dangerous, key, desc = detect_dangerous_command("rm -rf /home/user")
|
||||
@@ -464,3 +475,40 @@ class TestForkBombDetection:
|
||||
dangerous, key, desc = detect_dangerous_command("echo hello:world")
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
class TestGatewayProtection:
|
||||
"""Prevent agents from starting the gateway outside systemd management."""
|
||||
|
||||
def test_gateway_run_with_disown_detected(self):
|
||||
cmd = "kill 1605 && cd ~/.hermes/hermes-agent && source venv/bin/activate && python -m hermes_cli.main gateway run --replace &disown; echo done"
|
||||
dangerous, key, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
assert "systemctl" in desc
|
||||
|
||||
def test_gateway_run_with_ampersand_detected(self):
|
||||
cmd = "python -m hermes_cli.main gateway run --replace &"
|
||||
dangerous, key, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_gateway_run_with_nohup_detected(self):
|
||||
cmd = "nohup python -m hermes_cli.main gateway run --replace"
|
||||
dangerous, key, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_gateway_run_with_setsid_detected(self):
|
||||
cmd = "hermes_cli.main gateway run --replace &disown"
|
||||
dangerous, key, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is True
|
||||
|
||||
def test_gateway_run_foreground_not_flagged(self):
|
||||
"""Normal foreground gateway run (as in systemd ExecStart) is fine."""
|
||||
cmd = "python -m hermes_cli.main gateway run --replace"
|
||||
dangerous, key, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is False
|
||||
|
||||
def test_systemctl_restart_not_flagged(self):
|
||||
"""Using systemctl to manage the gateway is the correct approach."""
|
||||
cmd = "systemctl --user restart hermes-gateway"
|
||||
dangerous, key, desc = detect_dangerous_command(cmd)
|
||||
assert dangerous is False
|
||||
|
||||
|
||||
@@ -309,3 +309,6 @@ class TestSearchHints:
|
||||
raw = search_tool(pattern="foo", offset=50, limit=50)
|
||||
assert "[Hint:" in raw
|
||||
assert "offset=100" in raw
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
"""Tests for tools/mcp_oauth.py — thin OAuth adapter over MCP SDK."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.mcp_oauth import (
|
||||
HermesTokenStorage,
|
||||
build_oauth_auth,
|
||||
remove_oauth_tokens,
|
||||
_find_free_port,
|
||||
_can_open_browser,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HermesTokenStorage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHermesTokenStorage:
|
||||
def test_roundtrip_tokens(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("test-server")
|
||||
|
||||
import asyncio
|
||||
|
||||
# Initially empty
|
||||
assert asyncio.run(storage.get_tokens()) is None
|
||||
|
||||
# Save and retrieve
|
||||
mock_token = MagicMock()
|
||||
mock_token.model_dump.return_value = {
|
||||
"access_token": "abc123",
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "ref456",
|
||||
}
|
||||
asyncio.run(storage.set_tokens(mock_token))
|
||||
|
||||
# File exists with correct permissions
|
||||
token_path = tmp_path / "mcp-tokens" / "test-server.json"
|
||||
assert token_path.exists()
|
||||
data = json.loads(token_path.read_text())
|
||||
assert data["access_token"] == "abc123"
|
||||
|
||||
def test_roundtrip_client_info(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("test-server")
|
||||
import asyncio
|
||||
|
||||
assert asyncio.run(storage.get_client_info()) is None
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.model_dump.return_value = {
|
||||
"client_id": "hermes-123",
|
||||
"client_secret": "secret",
|
||||
}
|
||||
asyncio.run(storage.set_client_info(mock_client))
|
||||
|
||||
client_path = tmp_path / "mcp-tokens" / "test-server.client.json"
|
||||
assert client_path.exists()
|
||||
|
||||
def test_remove_cleans_up(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("test-server")
|
||||
|
||||
# Create files
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "test-server.json").write_text("{}")
|
||||
(d / "test-server.client.json").write_text("{}")
|
||||
|
||||
storage.remove()
|
||||
assert not (d / "test-server.json").exists()
|
||||
assert not (d / "test-server.client.json").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_oauth_auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildOAuthAuth:
|
||||
def test_returns_oauth_provider(self):
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
auth = build_oauth_auth("test", "https://example.com/mcp")
|
||||
assert isinstance(auth, OAuthClientProvider)
|
||||
|
||||
def test_returns_none_without_sdk(self, monkeypatch):
|
||||
import tools.mcp_oauth as mod
|
||||
orig_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
|
||||
|
||||
def _block_import(name, *args, **kwargs):
|
||||
if "mcp.client.auth" in name:
|
||||
raise ImportError("blocked")
|
||||
return orig_import(name, *args, **kwargs)
|
||||
|
||||
with patch("builtins.__import__", side_effect=_block_import):
|
||||
result = build_oauth_auth("test", "https://example.com")
|
||||
# May or may not be None depending on import caching, but shouldn't crash
|
||||
assert result is None or result is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUtilities:
|
||||
def test_find_free_port_returns_int(self):
|
||||
port = _find_free_port()
|
||||
assert isinstance(port, int)
|
||||
assert 1024 <= port <= 65535
|
||||
|
||||
def test_can_open_browser_false_in_ssh(self, monkeypatch):
|
||||
monkeypatch.setenv("SSH_CLIENT", "1.2.3.4 1234 22")
|
||||
assert _can_open_browser() is False
|
||||
|
||||
def test_can_open_browser_false_without_display(self, monkeypatch):
|
||||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.delenv("DISPLAY", raising=False)
|
||||
# Mock os.name and uname for non-macOS, non-Windows
|
||||
monkeypatch.setattr(os, "name", "posix")
|
||||
monkeypatch.setattr(os, "uname", lambda: type("", (), {"sysname": "Linux"})())
|
||||
assert _can_open_browser() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# remove_oauth_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPathTraversal:
|
||||
"""Verify server_name is sanitized to prevent path traversal."""
|
||||
|
||||
def test_path_traversal_blocked(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("../../.ssh/config")
|
||||
path = storage._tokens_path()
|
||||
# Should stay within mcp-tokens directory
|
||||
assert "mcp-tokens" in str(path)
|
||||
assert ".ssh" not in str(path.resolve())
|
||||
|
||||
def test_dots_and_slashes_sanitized(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("../../../etc/passwd")
|
||||
path = storage._tokens_path()
|
||||
resolved = path.resolve()
|
||||
assert resolved.is_relative_to((tmp_path / "mcp-tokens").resolve())
|
||||
|
||||
def test_normal_name_unchanged(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("my-mcp-server")
|
||||
assert "my-mcp-server.json" in str(storage._tokens_path())
|
||||
|
||||
def test_special_chars_sanitized(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("server@host:8080/path")
|
||||
path = storage._tokens_path()
|
||||
assert "@" not in path.name
|
||||
assert ":" not in path.name
|
||||
assert "/" not in path.stem
|
||||
|
||||
|
||||
class TestCallbackHandlerIsolation:
|
||||
"""Verify concurrent OAuth flows don't share state."""
|
||||
|
||||
def test_independent_result_dicts(self):
|
||||
from tools.mcp_oauth import _make_callback_handler
|
||||
_, result_a = _make_callback_handler()
|
||||
_, result_b = _make_callback_handler()
|
||||
|
||||
result_a["auth_code"] = "code_A"
|
||||
result_b["auth_code"] = "code_B"
|
||||
|
||||
assert result_a["auth_code"] == "code_A"
|
||||
assert result_b["auth_code"] == "code_B"
|
||||
|
||||
def test_handler_writes_to_own_result(self):
|
||||
from tools.mcp_oauth import _make_callback_handler
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
HandlerClass, result = _make_callback_handler()
|
||||
assert result["auth_code"] is None
|
||||
|
||||
# Simulate a GET request
|
||||
handler = HandlerClass.__new__(HandlerClass)
|
||||
handler.path = "/callback?code=test123&state=mystate"
|
||||
handler.wfile = BytesIO()
|
||||
handler.send_response = MagicMock()
|
||||
handler.send_header = MagicMock()
|
||||
handler.end_headers = MagicMock()
|
||||
handler.do_GET()
|
||||
|
||||
assert result["auth_code"] == "test123"
|
||||
assert result["state"] == "mystate"
|
||||
|
||||
|
||||
class TestOAuthPortSharing:
|
||||
"""Verify build_oauth_auth and _wait_for_callback use the same port."""
|
||||
|
||||
def test_port_stored_globally(self):
|
||||
import tools.mcp_oauth as mod
|
||||
# Reset
|
||||
mod._oauth_port = None
|
||||
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
build_oauth_auth("test-port", "https://example.com/mcp")
|
||||
assert mod._oauth_port is not None
|
||||
assert isinstance(mod._oauth_port, int)
|
||||
assert 1024 <= mod._oauth_port <= 65535
|
||||
|
||||
|
||||
class TestRemoveOAuthTokens:
|
||||
def test_removes_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir()
|
||||
(d / "myserver.json").write_text("{}")
|
||||
(d / "myserver.client.json").write_text("{}")
|
||||
|
||||
remove_oauth_tokens("myserver")
|
||||
|
||||
assert not (d / "myserver.json").exists()
|
||||
assert not (d / "myserver.client.json").exists()
|
||||
|
||||
def test_no_error_when_files_missing(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
remove_oauth_tokens("nonexistent") # should not raise
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user