Compare commits
48 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| f7d57c0108 | |||
| d77783d198 | |||
| 4af69097f2 | |||
| 59471b79e5 | |||
| 0e459f2b7b | |||
| 3befb9389f | |||
| 3baafea380 | |||
| e26393ffc2 | |||
| e19252afc4 | |||
| d684d7ee7e | |||
| 7d26feb9a3 | |||
| 875a72e4c8 | |||
| 20a5e589c6 | |||
| 7156f8d866 | |||
| 8de91ce9d2 | |||
| 8385f54e98 | |||
| 105caa001b | |||
| d46db0a1b4 | |||
| 5f4b93c20f | |||
| 5d2fc6d928 | |||
| 3377017eb4 | |||
| a1213d06bd | |||
| 1631895d5a | |||
| 4f467700d4 | |||
| ff6a86cb52 | |||
| 86960cdbb0 | |||
| 8b0afa0e57 | |||
| ab21fbfd89 | |||
| bdc72ec355 | |||
| c8a5e36be8 | |||
| 1368caf66f | |||
| 30ea423ce8 | |||
| 19b0ddce40 | |||
| 383db35925 | |||
| 55ac056920 | |||
| 085c1c6875 | |||
| a18e5b95ad | |||
| 3696c74bfb | |||
| bbcff8dcd0 | |||
| 77c5bc9da9 | |||
| 65e24c942e | |||
| 22d1bda185 | |||
| ab271ebe10 | |||
| e1befe5077 | |||
| fff237e111 | |||
| 598c25d43e | |||
| 5c03f2e7cc | |||
| 8d7a98d2ff |
@@ -81,6 +81,14 @@
|
||||
# HF_TOKEN=
|
||||
# OPENCODE_GO_BASE_URL=https://opencode.ai/zen/go/v1 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (Qwen OAuth)
|
||||
# =============================================================================
|
||||
# Qwen OAuth reuses your local Qwen CLI login (qwen auth qwen-oauth).
|
||||
# No API key needed — credentials come from ~/.qwen/oauth_creds.json.
|
||||
# Optional base URL override:
|
||||
# HERMES_QWEN_BASE_URL=https://portal.qwen.ai/v1
|
||||
|
||||
# =============================================================================
|
||||
# TOOL API KEYS
|
||||
# =============================================================================
|
||||
|
||||
@@ -0,0 +1,346 @@
|
||||
# Hermes Agent v0.8.0 (v2026.4.8)
|
||||
|
||||
**Release Date:** April 8, 2026
|
||||
|
||||
> The intelligence release — background task auto-notifications, free MiMo v2 Pro on Nous Portal, live model switching across all platforms, self-optimized GPT/Codex guidance, native Google AI Studio, smart inactivity timeouts, approval buttons, MCP OAuth 2.1, and 209 merged PRs with 82 resolved issues.
|
||||
|
||||
---
|
||||
|
||||
## ✨ Highlights
|
||||
|
||||
- **Background Process Auto-Notifications (`notify_on_complete`)** — Background tasks can now automatically notify the agent when they finish. Start a long-running process (AI model training, test suites, deployments, builds) and the agent gets notified on completion — no polling needed. The agent can keep working on other things and pick up results when they land. ([#5779](https://github.com/NousResearch/hermes-agent/pull/5779))
|
||||
|
||||
- **Free Xiaomi MiMo v2 Pro on Nous Portal** — Nous Portal now supports the free-tier Xiaomi MiMo v2 Pro model for auxiliary tasks (compression, vision, summarization), with free-tier model gating and pricing display in model selection. ([#6018](https://github.com/NousResearch/hermes-agent/pull/6018), [#5880](https://github.com/NousResearch/hermes-agent/pull/5880))
|
||||
|
||||
- **Live Model Switching (`/model` Command)** — Switch models and providers mid-session from CLI, Telegram, Discord, Slack, or any gateway platform. Aggregator-aware resolution keeps you on OpenRouter/Nous when possible, with automatic cross-provider fallback when needed. Interactive model pickers on Telegram and Discord with inline buttons. ([#5181](https://github.com/NousResearch/hermes-agent/pull/5181), [#5742](https://github.com/NousResearch/hermes-agent/pull/5742))
|
||||
|
||||
- **Self-Optimized GPT/Codex Tool-Use Guidance** — The agent diagnosed and patched 5 failure modes in GPT and Codex tool calling through automated behavioral benchmarking, dramatically improving reliability on OpenAI models. Includes execution discipline guidance and thinking-only prefill continuation for structured reasoning. ([#6120](https://github.com/NousResearch/hermes-agent/pull/6120), [#5414](https://github.com/NousResearch/hermes-agent/pull/5414), [#5931](https://github.com/NousResearch/hermes-agent/pull/5931))
|
||||
|
||||
- **Google AI Studio (Gemini) Native Provider** — Direct access to Gemini models through Google's AI Studio API. Includes automatic models.dev registry integration for real-time context length detection across any provider. ([#5577](https://github.com/NousResearch/hermes-agent/pull/5577))
|
||||
|
||||
- **Inactivity-Based Agent Timeouts** — Gateway and cron timeouts now track actual tool activity instead of wall-clock time. Long-running tasks that are actively working will never be killed — only truly idle agents time out. ([#5389](https://github.com/NousResearch/hermes-agent/pull/5389), [#5440](https://github.com/NousResearch/hermes-agent/pull/5440))
|
||||
|
||||
- **Approval Buttons on Slack & Telegram** — Dangerous command approval via native platform buttons instead of typing `/approve`. Slack gets thread context preservation; Telegram gets emoji reactions for approval status. ([#5890](https://github.com/NousResearch/hermes-agent/pull/5890), [#5975](https://github.com/NousResearch/hermes-agent/pull/5975))
|
||||
|
||||
- **MCP OAuth 2.1 PKCE + OSV Malware Scanning** — Full standards-compliant OAuth for MCP server authentication, plus automatic malware scanning of MCP extension packages via the OSV vulnerability database. ([#5420](https://github.com/NousResearch/hermes-agent/pull/5420), [#5305](https://github.com/NousResearch/hermes-agent/pull/5305))
|
||||
|
||||
- **Centralized Logging & Config Validation** — Structured logging to `~/.hermes/logs/` (agent.log + errors.log) with the `hermes logs` command for tailing and filtering. Config structure validation catches malformed YAML at startup before it causes cryptic failures. ([#5430](https://github.com/NousResearch/hermes-agent/pull/5430), [#5426](https://github.com/NousResearch/hermes-agent/pull/5426))
|
||||
|
||||
- **Plugin System Expansion** — Plugins can now register CLI subcommands, receive request-scoped API hooks with correlation IDs, prompt for required env vars during install, and hook into session lifecycle events (finalize/reset). ([#5295](https://github.com/NousResearch/hermes-agent/pull/5295), [#5427](https://github.com/NousResearch/hermes-agent/pull/5427), [#5470](https://github.com/NousResearch/hermes-agent/pull/5470), [#6129](https://github.com/NousResearch/hermes-agent/pull/6129))
|
||||
|
||||
- **Matrix Tier 1 & Platform Hardening** — Matrix gets reactions, read receipts, rich formatting, and room management. Discord adds channel controls and ignored channels. Signal gets full MEDIA: tag delivery. Mattermost gets file attachments. Comprehensive reliability fixes across all platforms. ([#5275](https://github.com/NousResearch/hermes-agent/pull/5275), [#5975](https://github.com/NousResearch/hermes-agent/pull/5975), [#5602](https://github.com/NousResearch/hermes-agent/pull/5602))
|
||||
|
||||
- **Security Hardening Pass** — Consolidated SSRF protections, timing attack mitigations, tar traversal prevention, credential leakage guards, cron path traversal hardening, and cross-session isolation. Terminal workdir sanitization across all backends. ([#5944](https://github.com/NousResearch/hermes-agent/pull/5944), [#5613](https://github.com/NousResearch/hermes-agent/pull/5613), [#5629](https://github.com/NousResearch/hermes-agent/pull/5629))
|
||||
|
||||
---
|
||||
|
||||
## 🏗️ Core Agent & Architecture
|
||||
|
||||
### Provider & Model Support
|
||||
- **Native Google AI Studio (Gemini) provider** with models.dev integration for automatic context length detection ([#5577](https://github.com/NousResearch/hermes-agent/pull/5577))
|
||||
- **`/model` command — full provider+model system overhaul** — live switching across CLI and all gateway platforms with aggregator-aware resolution ([#5181](https://github.com/NousResearch/hermes-agent/pull/5181))
|
||||
- **Interactive model picker for Telegram and Discord** — inline button-based model selection ([#5742](https://github.com/NousResearch/hermes-agent/pull/5742))
|
||||
- **Nous Portal free-tier model gating** with pricing display in model selection ([#5880](https://github.com/NousResearch/hermes-agent/pull/5880))
|
||||
- **Model pricing display** for OpenRouter and Nous Portal providers ([#5416](https://github.com/NousResearch/hermes-agent/pull/5416))
|
||||
- **xAI (Grok) prompt caching** via `x-grok-conv-id` header ([#5604](https://github.com/NousResearch/hermes-agent/pull/5604))
|
||||
- **Grok added to tool-use enforcement models** for direct xAI usage ([#5595](https://github.com/NousResearch/hermes-agent/pull/5595))
|
||||
- **MiniMax TTS provider** (speech-2.8) ([#4963](https://github.com/NousResearch/hermes-agent/pull/4963))
|
||||
- **Non-agentic model warning** — warns users when loading Hermes LLM models not designed for tool use ([#5378](https://github.com/NousResearch/hermes-agent/pull/5378))
|
||||
- **Ollama Cloud auth, /model switch persistence**, and alias tab completion ([#5269](https://github.com/NousResearch/hermes-agent/pull/5269))
|
||||
- **Preserve dots in OpenCode Go model names** (minimax-m2.7, glm-4.5, kimi-k2.5) ([#5597](https://github.com/NousResearch/hermes-agent/pull/5597))
|
||||
- **MiniMax models 404 fix** — strip /v1 from Anthropic base URL for OpenCode Go ([#4918](https://github.com/NousResearch/hermes-agent/pull/4918))
|
||||
- **Provider credential reset windows** honored in pooled failover ([#5188](https://github.com/NousResearch/hermes-agent/pull/5188))
|
||||
- **OAuth token sync** between credential pool and credentials file ([#4981](https://github.com/NousResearch/hermes-agent/pull/4981))
|
||||
- **Stale OAuth credentials** no longer block OpenRouter users on auto-detect ([#5746](https://github.com/NousResearch/hermes-agent/pull/5746))
|
||||
- **Codex OAuth credential pool disconnect** + expired token import fix ([#5681](https://github.com/NousResearch/hermes-agent/pull/5681))
|
||||
- **Codex pool entry sync** from `~/.codex/auth.json` on exhaustion — @GratefulDave ([#5610](https://github.com/NousResearch/hermes-agent/pull/5610))
|
||||
- **Auxiliary client payment fallback** — retry with next provider on 402 ([#5599](https://github.com/NousResearch/hermes-agent/pull/5599))
|
||||
- **Auxiliary client resolves named custom providers** and 'main' alias ([#5978](https://github.com/NousResearch/hermes-agent/pull/5978))
|
||||
- **Use mimo-v2-pro** for non-vision auxiliary tasks on Nous free tier ([#6018](https://github.com/NousResearch/hermes-agent/pull/6018))
|
||||
- **Vision auto-detection** tries main provider first ([#6041](https://github.com/NousResearch/hermes-agent/pull/6041))
|
||||
- **Provider re-ordering and Quick Install** — @austinpickett ([#4664](https://github.com/NousResearch/hermes-agent/pull/4664))
|
||||
- **Nous OAuth access_token** no longer used as inference API key — @SHL0MS ([#5564](https://github.com/NousResearch/hermes-agent/pull/5564))
|
||||
- **HERMES_PORTAL_BASE_URL env var** respected during Nous login — @benbarclay ([#5745](https://github.com/NousResearch/hermes-agent/pull/5745))
|
||||
- **Env var overrides** for Nous portal/inference URLs ([#5419](https://github.com/NousResearch/hermes-agent/pull/5419))
|
||||
- **Z.AI endpoint auto-detect** via probe and cache ([#5763](https://github.com/NousResearch/hermes-agent/pull/5763))
|
||||
- **MiniMax context lengths, model catalog, thinking guard, aux model, and config base_url** corrections ([#6082](https://github.com/NousResearch/hermes-agent/pull/6082))
|
||||
- **Community provider/model resolution fixes** — salvaged 4 community PRs + MiniMax aux URL ([#5983](https://github.com/NousResearch/hermes-agent/pull/5983))
|
||||
|
||||
### Agent Loop & Conversation
|
||||
- **Self-optimized GPT/Codex tool-use guidance** via automated behavioral benchmarking — agent self-diagnosed and patched 5 failure modes ([#6120](https://github.com/NousResearch/hermes-agent/pull/6120))
|
||||
- **GPT/Codex execution discipline guidance** in system prompts ([#5414](https://github.com/NousResearch/hermes-agent/pull/5414))
|
||||
- **Thinking-only prefill continuation** for structured reasoning responses ([#5931](https://github.com/NousResearch/hermes-agent/pull/5931))
|
||||
- **Accept reasoning-only responses** without retries — set content to "(empty)" instead of infinite retry ([#5278](https://github.com/NousResearch/hermes-agent/pull/5278))
|
||||
- **Jittered retry backoff** — exponential backoff with jitter for API retries ([#6048](https://github.com/NousResearch/hermes-agent/pull/6048))
|
||||
- **Smart thinking block signature management** — preserve and manage Anthropic thinking signatures across turns ([#6112](https://github.com/NousResearch/hermes-agent/pull/6112))
|
||||
- **Coerce tool call arguments** to match JSON Schema types — fixes models that send strings instead of numbers/booleans ([#5265](https://github.com/NousResearch/hermes-agent/pull/5265))
|
||||
- **Save oversized tool results to file** instead of destructive truncation ([#5210](https://github.com/NousResearch/hermes-agent/pull/5210))
|
||||
- **Sandbox-aware tool result persistence** ([#6085](https://github.com/NousResearch/hermes-agent/pull/6085))
|
||||
- **Streaming fallback** improved after edit failures ([#6110](https://github.com/NousResearch/hermes-agent/pull/6110))
|
||||
- **Codex empty-output gaps** covered in fallback + normalizer + auxiliary client ([#5724](https://github.com/NousResearch/hermes-agent/pull/5724), [#5730](https://github.com/NousResearch/hermes-agent/pull/5730), [#5734](https://github.com/NousResearch/hermes-agent/pull/5734))
|
||||
- **Codex stream output backfill** from output_item.done events ([#5689](https://github.com/NousResearch/hermes-agent/pull/5689))
|
||||
- **Stream consumer creates new message** after tool boundaries ([#5739](https://github.com/NousResearch/hermes-agent/pull/5739))
|
||||
- **Codex validation aligned** with normalization for empty stream output ([#5940](https://github.com/NousResearch/hermes-agent/pull/5940))
|
||||
- **Bridge tool-calls** in copilot-acp adapter ([#5460](https://github.com/NousResearch/hermes-agent/pull/5460))
|
||||
- **Filter transcript-only roles** from chat-completions payload ([#4880](https://github.com/NousResearch/hermes-agent/pull/4880))
|
||||
- **Context compaction failures fixed** on temperature-restricted models — @MadKangYu ([#5608](https://github.com/NousResearch/hermes-agent/pull/5608))
|
||||
- **Sanitize tool_calls for all strict APIs** (Fireworks, Mistral, etc.) — @lumethegreat ([#5183](https://github.com/NousResearch/hermes-agent/pull/5183))
|
||||
|
||||
### Memory & Sessions
|
||||
- **Supermemory memory provider** — new memory plugin with multi-container, search_mode, identity template, and env var override ([#5737](https://github.com/NousResearch/hermes-agent/pull/5737), [#5933](https://github.com/NousResearch/hermes-agent/pull/5933))
|
||||
- **Shared thread sessions** by default — multi-user thread support across gateway platforms ([#5391](https://github.com/NousResearch/hermes-agent/pull/5391))
|
||||
- **Subagent sessions linked to parent** and hidden from session list ([#5309](https://github.com/NousResearch/hermes-agent/pull/5309))
|
||||
- **Profile-scoped memory isolation** and clone support ([#4845](https://github.com/NousResearch/hermes-agent/pull/4845))
|
||||
- **Thread gateway user_id to memory plugins** for per-user scoping ([#5895](https://github.com/NousResearch/hermes-agent/pull/5895))
|
||||
- **Honcho plugin drift overhaul** + plugin CLI registration system ([#5295](https://github.com/NousResearch/hermes-agent/pull/5295))
|
||||
- **Honcho holographic prompt and trust score** rendering preserved ([#4872](https://github.com/NousResearch/hermes-agent/pull/4872))
|
||||
- **Honcho doctor fix** — use recall_mode instead of memory_mode — @techguysimon ([#5645](https://github.com/NousResearch/hermes-agent/pull/5645))
|
||||
- **RetainDB** — API routes, write queue, dialectic, agent model, file tools fixes ([#5461](https://github.com/NousResearch/hermes-agent/pull/5461))
|
||||
- **Hindsight memory plugin overhaul** + memory setup wizard fixes ([#5094](https://github.com/NousResearch/hermes-agent/pull/5094))
|
||||
- **mem0 API v2 compat**, prefetch context fencing, secret redaction ([#5423](https://github.com/NousResearch/hermes-agent/pull/5423))
|
||||
- **mem0 env vars merged** with mem0.json instead of either/or ([#4939](https://github.com/NousResearch/hermes-agent/pull/4939))
|
||||
- **Clean user message** used for all memory provider operations ([#4940](https://github.com/NousResearch/hermes-agent/pull/4940))
|
||||
- **Silent memory flush failure** on /new and /resume fixed — @ryanautomated ([#5640](https://github.com/NousResearch/hermes-agent/pull/5640))
|
||||
- **OpenViking atexit safety net** for session commit ([#5664](https://github.com/NousResearch/hermes-agent/pull/5664))
|
||||
- **OpenViking tenant-scoping headers** for multi-tenant servers ([#4936](https://github.com/NousResearch/hermes-agent/pull/4936))
|
||||
- **ByteRover brv query** runs synchronously before LLM call ([#4831](https://github.com/NousResearch/hermes-agent/pull/4831))
|
||||
|
||||
---
|
||||
|
||||
## 📱 Messaging Platforms (Gateway)
|
||||
|
||||
### Gateway Core
|
||||
- **Inactivity-based agent timeout** — replaces wall-clock timeout with smart activity tracking; long-running active tasks never killed ([#5389](https://github.com/NousResearch/hermes-agent/pull/5389))
|
||||
- **Approval buttons for Slack & Telegram** + Slack thread context preservation ([#5890](https://github.com/NousResearch/hermes-agent/pull/5890))
|
||||
- **Live-stream /update output** + forward interactive prompts to user ([#5180](https://github.com/NousResearch/hermes-agent/pull/5180))
|
||||
- **Infinite timeout support** + periodic notifications + actionable error messages ([#4959](https://github.com/NousResearch/hermes-agent/pull/4959))
|
||||
- **Duplicate message prevention** — gateway dedup + partial stream guard ([#4878](https://github.com/NousResearch/hermes-agent/pull/4878))
|
||||
- **Webhook delivery_info persistence** + full session id in /status ([#5942](https://github.com/NousResearch/hermes-agent/pull/5942))
|
||||
- **Tool preview truncation** respects tool_preview_length in all/new progress modes ([#5937](https://github.com/NousResearch/hermes-agent/pull/5937))
|
||||
- **Short preview truncation** restored for all/new tool progress modes ([#4935](https://github.com/NousResearch/hermes-agent/pull/4935))
|
||||
- **Update-pending state** written atomically to prevent corruption ([#4923](https://github.com/NousResearch/hermes-agent/pull/4923))
|
||||
- **Approval session key isolated** per turn ([#4884](https://github.com/NousResearch/hermes-agent/pull/4884))
|
||||
- **Active-session guard bypass** for /approve, /deny, /stop, /new ([#4926](https://github.com/NousResearch/hermes-agent/pull/4926), [#5765](https://github.com/NousResearch/hermes-agent/pull/5765))
|
||||
- **Typing indicator paused** during approval waits ([#5893](https://github.com/NousResearch/hermes-agent/pull/5893))
|
||||
- **Caption check** uses exact line-by-line match instead of substring (all platforms) ([#5939](https://github.com/NousResearch/hermes-agent/pull/5939))
|
||||
- **MEDIA: tags stripped** from streamed gateway messages ([#5152](https://github.com/NousResearch/hermes-agent/pull/5152))
|
||||
- **MEDIA: tags extracted** from cron delivery before sending ([#5598](https://github.com/NousResearch/hermes-agent/pull/5598))
|
||||
- **Profile-aware service units** + voice transcription cleanup ([#5972](https://github.com/NousResearch/hermes-agent/pull/5972))
|
||||
- **Thread-safe PairingStore** with atomic writes — @CharlieKerfoot ([#5656](https://github.com/NousResearch/hermes-agent/pull/5656))
|
||||
- **Sanitize media URLs** in base platform logs — @WAXLYY ([#5631](https://github.com/NousResearch/hermes-agent/pull/5631))
|
||||
- **Reduce Telegram fallback IP activation log noise** — @MadKangYu ([#5615](https://github.com/NousResearch/hermes-agent/pull/5615))
|
||||
- **Cron static method wrappers** to prevent self-binding ([#5299](https://github.com/NousResearch/hermes-agent/pull/5299))
|
||||
- **Stale 'hermes login' replaced** with 'hermes auth' + credential removal re-seeding fix ([#5670](https://github.com/NousResearch/hermes-agent/pull/5670))
|
||||
|
||||
### Telegram
|
||||
- **Group topics skill binding** for supergroup forum topics ([#4886](https://github.com/NousResearch/hermes-agent/pull/4886))
|
||||
- **Emoji reactions** for approval status and notifications ([#5975](https://github.com/NousResearch/hermes-agent/pull/5975))
|
||||
- **Duplicate message delivery prevented** on send timeout ([#5153](https://github.com/NousResearch/hermes-agent/pull/5153))
|
||||
- **Command names sanitized** to strip invalid characters ([#5596](https://github.com/NousResearch/hermes-agent/pull/5596))
|
||||
- **Per-platform disabled skills** respected in Telegram menu and gateway dispatch ([#4799](https://github.com/NousResearch/hermes-agent/pull/4799))
|
||||
- **/approve and /deny** routed through running-agent guard ([#4798](https://github.com/NousResearch/hermes-agent/pull/4798))
|
||||
|
||||
### Discord
|
||||
- **Channel controls** — ignored_channels and no_thread_channels config options ([#5975](https://github.com/NousResearch/hermes-agent/pull/5975))
|
||||
- **Skills registered as native slash commands** via shared gateway logic ([#5603](https://github.com/NousResearch/hermes-agent/pull/5603))
|
||||
- **/approve, /deny, /queue, /background, /btw** registered as native slash commands ([#4800](https://github.com/NousResearch/hermes-agent/pull/4800), [#5477](https://github.com/NousResearch/hermes-agent/pull/5477))
|
||||
- **Unnecessary members intent** removed on startup + token lock leak fix ([#5302](https://github.com/NousResearch/hermes-agent/pull/5302))
|
||||
|
||||
### Slack
|
||||
- **Thread engagement** — auto-respond in bot-started and mentioned threads ([#5897](https://github.com/NousResearch/hermes-agent/pull/5897))
|
||||
- **mrkdwn in edit_message** + thread replies without @mentions ([#5733](https://github.com/NousResearch/hermes-agent/pull/5733))
|
||||
|
||||
### Matrix
|
||||
- **Tier 1 feature parity** — reactions, read receipts, rich formatting, room management ([#5275](https://github.com/NousResearch/hermes-agent/pull/5275))
|
||||
- **MATRIX_REQUIRE_MENTION and MATRIX_AUTO_THREAD** support ([#5106](https://github.com/NousResearch/hermes-agent/pull/5106))
|
||||
- **Comprehensive reliability** — encrypted media, auth recovery, cron E2EE, Synapse compat ([#5271](https://github.com/NousResearch/hermes-agent/pull/5271))
|
||||
- **CJK input, E2EE, and reconnect** fixes ([#5665](https://github.com/NousResearch/hermes-agent/pull/5665))
|
||||
|
||||
### Signal
|
||||
- **Full MEDIA: tag delivery** — send_image_file, send_voice, and send_video implemented ([#5602](https://github.com/NousResearch/hermes-agent/pull/5602))
|
||||
|
||||
### Mattermost
|
||||
- **File attachments** — set message type to DOCUMENT when post has file attachments — @nericervin ([#5609](https://github.com/NousResearch/hermes-agent/pull/5609))
|
||||
|
||||
### Feishu
|
||||
- **Interactive card approval buttons** ([#6043](https://github.com/NousResearch/hermes-agent/pull/6043))
|
||||
- **Reconnect and ACL** fixes ([#5665](https://github.com/NousResearch/hermes-agent/pull/5665))
|
||||
|
||||
### Webhooks
|
||||
- **`{__raw__}` template token** and thread_id passthrough for forum topics ([#5662](https://github.com/NousResearch/hermes-agent/pull/5662))
|
||||
|
||||
---
|
||||
|
||||
## 🖥️ CLI & User Experience
|
||||
|
||||
### Interactive CLI
|
||||
- **Defer response content** until reasoning block completes ([#5773](https://github.com/NousResearch/hermes-agent/pull/5773))
|
||||
- **Ghost status-bar lines cleared** on terminal resize ([#4960](https://github.com/NousResearch/hermes-agent/pull/4960))
|
||||
- **Normalise \r\n and \r line endings** in pasted text ([#4849](https://github.com/NousResearch/hermes-agent/pull/4849))
|
||||
- **ChatConsole errors, curses scroll, skin-aware banner, git state** banner fixes ([#5974](https://github.com/NousResearch/hermes-agent/pull/5974))
|
||||
- **Native Windows image paste** support ([#5917](https://github.com/NousResearch/hermes-agent/pull/5917))
|
||||
- **--yolo and other flags** no longer silently dropped when placed before 'chat' subcommand ([#5145](https://github.com/NousResearch/hermes-agent/pull/5145))
|
||||
|
||||
### Setup & Configuration
|
||||
- **Config structure validation** — detect malformed YAML at startup with actionable error messages ([#5426](https://github.com/NousResearch/hermes-agent/pull/5426))
|
||||
- **Centralized logging** to `~/.hermes/logs/` — agent.log (INFO+), errors.log (WARNING+) with `hermes logs` command ([#5430](https://github.com/NousResearch/hermes-agent/pull/5430))
|
||||
- **Docs links added** to setup wizard sections ([#5283](https://github.com/NousResearch/hermes-agent/pull/5283))
|
||||
- **Doctor diagnostics** — sync provider checks, config migration, WAL and mem0 diagnostics ([#5077](https://github.com/NousResearch/hermes-agent/pull/5077))
|
||||
- **Timeout debug logging** and user-facing diagnostics improved ([#5370](https://github.com/NousResearch/hermes-agent/pull/5370))
|
||||
- **Reasoning effort unified** to config.yaml only ([#6118](https://github.com/NousResearch/hermes-agent/pull/6118))
|
||||
- **Permanent command allowlist** loaded on startup ([#5076](https://github.com/NousResearch/hermes-agent/pull/5076))
|
||||
- **`hermes auth remove`** now clears env-seeded credentials permanently ([#5285](https://github.com/NousResearch/hermes-agent/pull/5285))
|
||||
- **Bundled skills synced to all profiles** during update ([#5795](https://github.com/NousResearch/hermes-agent/pull/5795))
|
||||
- **`hermes update` no longer kills** freshly-restarted gateway service ([#5448](https://github.com/NousResearch/hermes-agent/pull/5448))
|
||||
- **Subprocess.run() timeouts** added to all gateway CLI commands ([#5424](https://github.com/NousResearch/hermes-agent/pull/5424))
|
||||
- **Actionable error message** when Codex refresh token is reused — @tymrtn ([#5612](https://github.com/NousResearch/hermes-agent/pull/5612))
|
||||
- **Google-workspace skill scripts** can now run directly — @xinbenlv ([#5624](https://github.com/NousResearch/hermes-agent/pull/5624))
|
||||
|
||||
### Cron System
|
||||
- **Inactivity-based cron timeout** — replaces wall-clock; active tasks run indefinitely ([#5440](https://github.com/NousResearch/hermes-agent/pull/5440))
|
||||
- **Pre-run script injection** for data collection and change detection ([#5082](https://github.com/NousResearch/hermes-agent/pull/5082))
|
||||
- **Delivery failure tracking** in job status ([#6042](https://github.com/NousResearch/hermes-agent/pull/6042))
|
||||
- **Delivery guidance** in cron prompts — stops send_message thrashing ([#5444](https://github.com/NousResearch/hermes-agent/pull/5444))
|
||||
- **MEDIA files delivered** as native platform attachments ([#5921](https://github.com/NousResearch/hermes-agent/pull/5921))
|
||||
- **[SILENT] suppression** works anywhere in response — @auspic7 ([#5654](https://github.com/NousResearch/hermes-agent/pull/5654))
|
||||
- **Cron path traversal** hardening ([#5147](https://github.com/NousResearch/hermes-agent/pull/5147))
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Tool System
|
||||
|
||||
### Terminal & Execution
|
||||
- **Execute_code on remote backends** — code execution now works on Docker, SSH, Modal, and other remote terminal backends ([#5088](https://github.com/NousResearch/hermes-agent/pull/5088))
|
||||
- **Exit code context** for common CLI tools in terminal results — helps agent understand what went wrong ([#5144](https://github.com/NousResearch/hermes-agent/pull/5144))
|
||||
- **Progressive subdirectory hint discovery** — agent learns project structure as it navigates ([#5291](https://github.com/NousResearch/hermes-agent/pull/5291))
|
||||
- **notify_on_complete for background processes** — get notified when long-running tasks finish ([#5779](https://github.com/NousResearch/hermes-agent/pull/5779))
|
||||
- **Docker env config** — explicit container environment variables via docker_env config ([#4738](https://github.com/NousResearch/hermes-agent/pull/4738))
|
||||
- **Approval metadata included** in terminal tool results ([#5141](https://github.com/NousResearch/hermes-agent/pull/5141))
|
||||
- **Workdir parameter sanitized** in terminal tool across all backends ([#5629](https://github.com/NousResearch/hermes-agent/pull/5629))
|
||||
- **Detached process crash recovery** state corrected ([#6101](https://github.com/NousResearch/hermes-agent/pull/6101))
|
||||
- **Agent-browser paths with spaces** preserved — @Vasanthdev2004 ([#6077](https://github.com/NousResearch/hermes-agent/pull/6077))
|
||||
- **Portable base64 encoding** for image reading on macOS — @CharlieKerfoot ([#5657](https://github.com/NousResearch/hermes-agent/pull/5657))
|
||||
|
||||
### Browser
|
||||
- **Switch managed browser provider** from Browserbase to Browser Use — @benbarclay ([#5750](https://github.com/NousResearch/hermes-agent/pull/5750))
|
||||
- **Firecrawl cloud browser** provider — @alt-glitch ([#5628](https://github.com/NousResearch/hermes-agent/pull/5628))
|
||||
- **JS evaluation** via browser_console expression parameter ([#5303](https://github.com/NousResearch/hermes-agent/pull/5303))
|
||||
- **Windows browser** fixes ([#5665](https://github.com/NousResearch/hermes-agent/pull/5665))
|
||||
|
||||
### MCP
|
||||
- **MCP OAuth 2.1 PKCE** — full standards-compliant OAuth client support ([#5420](https://github.com/NousResearch/hermes-agent/pull/5420))
|
||||
- **OSV malware check** for MCP extension packages ([#5305](https://github.com/NousResearch/hermes-agent/pull/5305))
|
||||
- **Prefer structuredContent over text** + no_mcp sentinel ([#5979](https://github.com/NousResearch/hermes-agent/pull/5979))
|
||||
- **Unknown toolsets warning suppressed** for MCP server names ([#5279](https://github.com/NousResearch/hermes-agent/pull/5279))
|
||||
|
||||
### Web & Files
|
||||
- **.zip document support** + auto-mount cache dirs into remote backends ([#4846](https://github.com/NousResearch/hermes-agent/pull/4846))
|
||||
- **Redact query secrets** in send_message errors — @WAXLYY ([#5650](https://github.com/NousResearch/hermes-agent/pull/5650))
|
||||
|
||||
### Delegation
|
||||
- **Credential pool sharing** + workspace path hints for subagents ([#5748](https://github.com/NousResearch/hermes-agent/pull/5748))
|
||||
|
||||
### ACP (VS Code / Zed / JetBrains)
|
||||
- **Aggregate ACP improvements** — auth compat, protocol fixes, command ads, delegation, SSE events ([#5292](https://github.com/NousResearch/hermes-agent/pull/5292))
|
||||
|
||||
---
|
||||
|
||||
## 🧩 Skills Ecosystem
|
||||
|
||||
### Skills System
|
||||
- **Skill config interface** — skills can declare required config.yaml settings, prompted during setup, injected at load time ([#5635](https://github.com/NousResearch/hermes-agent/pull/5635))
|
||||
- **Plugin CLI registration system** — plugins register their own CLI subcommands without touching main.py ([#5295](https://github.com/NousResearch/hermes-agent/pull/5295))
|
||||
- **Request-scoped API hooks** with tool call correlation IDs for plugins ([#5427](https://github.com/NousResearch/hermes-agent/pull/5427))
|
||||
- **Session lifecycle hooks** — on_session_finalize and on_session_reset for CLI + gateway ([#6129](https://github.com/NousResearch/hermes-agent/pull/6129))
|
||||
- **Prompt for required env vars** during plugin install — @kshitijk4poor ([#5470](https://github.com/NousResearch/hermes-agent/pull/5470))
|
||||
- **Plugin name validation** — reject names that resolve to plugins root ([#5368](https://github.com/NousResearch/hermes-agent/pull/5368))
|
||||
- **pre_llm_call plugin context** moved to user message to preserve prompt cache ([#5146](https://github.com/NousResearch/hermes-agent/pull/5146))
|
||||
|
||||
### New & Updated Skills
|
||||
- **popular-web-designs** — 54 production website design systems ([#5194](https://github.com/NousResearch/hermes-agent/pull/5194))
|
||||
- **p5js creative coding** — @SHL0MS ([#5600](https://github.com/NousResearch/hermes-agent/pull/5600))
|
||||
- **manim-video** — mathematical and technical animations — @SHL0MS ([#4930](https://github.com/NousResearch/hermes-agent/pull/4930))
|
||||
- **llm-wiki** — Karpathy's LLM Wiki skill ([#5635](https://github.com/NousResearch/hermes-agent/pull/5635))
|
||||
- **gitnexus-explorer** — codebase indexing and knowledge serving ([#5208](https://github.com/NousResearch/hermes-agent/pull/5208))
|
||||
- **research-paper-writing** — AI-Scientist & GPT-Researcher patterns — @SHL0MS ([#5421](https://github.com/NousResearch/hermes-agent/pull/5421))
|
||||
- **blogwatcher** updated to JulienTant's fork ([#5759](https://github.com/NousResearch/hermes-agent/pull/5759))
|
||||
- **claude-code skill** comprehensive rewrite v2.0 + v2.2 ([#5155](https://github.com/NousResearch/hermes-agent/pull/5155), [#5158](https://github.com/NousResearch/hermes-agent/pull/5158))
|
||||
- **Code verification skills** consolidated into one ([#4854](https://github.com/NousResearch/hermes-agent/pull/4854))
|
||||
- **Manim CE reference docs** expanded — geometry, animations, LaTeX — @leotrs ([#5791](https://github.com/NousResearch/hermes-agent/pull/5791))
|
||||
- **Manim-video references** — design thinking, updaters, paper explainer, decorations, production quality — @SHL0MS ([#5588](https://github.com/NousResearch/hermes-agent/pull/5588), [#5408](https://github.com/NousResearch/hermes-agent/pull/5408))
|
||||
|
||||
---
|
||||
|
||||
## 🔒 Security & Reliability
|
||||
|
||||
### Security Hardening
|
||||
- **Consolidated security** — SSRF protections, timing attack mitigations, tar traversal prevention, credential leakage guards ([#5944](https://github.com/NousResearch/hermes-agent/pull/5944))
|
||||
- **Cross-session isolation** + cron path traversal hardening ([#5613](https://github.com/NousResearch/hermes-agent/pull/5613))
|
||||
- **Workdir parameter sanitized** in terminal tool across all backends ([#5629](https://github.com/NousResearch/hermes-agent/pull/5629))
|
||||
- **Approval 'once' session escalation** prevented + cron delivery platform validation ([#5280](https://github.com/NousResearch/hermes-agent/pull/5280))
|
||||
- **Profile-scoped Google Workspace OAuth tokens** protected ([#4910](https://github.com/NousResearch/hermes-agent/pull/4910))
|
||||
|
||||
### Reliability
|
||||
- **Aggressive worktree and branch cleanup** to prevent accumulation ([#6134](https://github.com/NousResearch/hermes-agent/pull/6134))
|
||||
- **O(n²) catastrophic backtracking** in redact regex fixed — 100x improvement on large outputs ([#4962](https://github.com/NousResearch/hermes-agent/pull/4962))
|
||||
- **Runtime stability fixes** across core, web, delegate, and browser tools ([#4843](https://github.com/NousResearch/hermes-agent/pull/4843))
|
||||
- **API server streaming fix** + conversation history support ([#5977](https://github.com/NousResearch/hermes-agent/pull/5977))
|
||||
- **OpenViking API endpoint paths** and response parsing corrected ([#5078](https://github.com/NousResearch/hermes-agent/pull/5078))
|
||||
|
||||
---
|
||||
|
||||
## 🐛 Notable Bug Fixes
|
||||
|
||||
- **9 community bugfixes salvaged** — gateway, cron, deps, macOS launchd in one batch ([#5288](https://github.com/NousResearch/hermes-agent/pull/5288))
|
||||
- **Batch core bug fixes** — model config, session reset, alias fallback, launchctl, delegation, atomic writes ([#5630](https://github.com/NousResearch/hermes-agent/pull/5630))
|
||||
- **Batch gateway/platform fixes** — matrix E2EE, CJK input, Windows browser, Feishu reconnect + ACL ([#5665](https://github.com/NousResearch/hermes-agent/pull/5665))
|
||||
- **Stale test skips removed**, regex backtracking, file search bug, and test flakiness ([#4969](https://github.com/NousResearch/hermes-agent/pull/4969))
|
||||
- **Nix flake** — read version, regen uv.lock, add hermes_logging — @alt-glitch ([#5651](https://github.com/NousResearch/hermes-agent/pull/5651))
|
||||
- **Lowercase variable redaction** regression tests ([#5185](https://github.com/NousResearch/hermes-agent/pull/5185))
|
||||
|
||||
---
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
- **57 failing CI tests repaired** across 14 files ([#5823](https://github.com/NousResearch/hermes-agent/pull/5823))
|
||||
- **Test suite re-architecture** + CI failure fixes — @alt-glitch ([#5946](https://github.com/NousResearch/hermes-agent/pull/5946))
|
||||
- **Codebase-wide lint cleanup** — unused imports, dead code, and inefficient patterns ([#5821](https://github.com/NousResearch/hermes-agent/pull/5821))
|
||||
- **browser_close tool removed** — auto-cleanup handles it ([#5792](https://github.com/NousResearch/hermes-agent/pull/5792))
|
||||
|
||||
---
|
||||
|
||||
## 📚 Documentation
|
||||
|
||||
- **Comprehensive documentation audit** — fix stale info, expand thin pages, add depth ([#5393](https://github.com/NousResearch/hermes-agent/pull/5393))
|
||||
- **40+ discrepancies fixed** between documentation and codebase ([#5818](https://github.com/NousResearch/hermes-agent/pull/5818))
|
||||
- **13 features documented** from last week's PRs ([#5815](https://github.com/NousResearch/hermes-agent/pull/5815))
|
||||
- **Guides section overhaul** — fix existing + add 3 new tutorials ([#5735](https://github.com/NousResearch/hermes-agent/pull/5735))
|
||||
- **Salvaged 4 docs PRs** — docker setup, post-update validation, local LLM guide, signal-cli install ([#5727](https://github.com/NousResearch/hermes-agent/pull/5727))
|
||||
- **Discord configuration reference** ([#5386](https://github.com/NousResearch/hermes-agent/pull/5386))
|
||||
- **Community FAQ entries** for common workflows and troubleshooting ([#4797](https://github.com/NousResearch/hermes-agent/pull/4797))
|
||||
- **WSL2 networking guide** for local model servers ([#5616](https://github.com/NousResearch/hermes-agent/pull/5616))
|
||||
- **Honcho CLI reference** + plugin CLI registration docs ([#5308](https://github.com/NousResearch/hermes-agent/pull/5308))
|
||||
- **Obsidian Headless setup** for servers in llm-wiki ([#5660](https://github.com/NousResearch/hermes-agent/pull/5660))
|
||||
- **Hermes Mod visual skin editor** added to skins page ([#6095](https://github.com/NousResearch/hermes-agent/pull/6095))
|
||||
|
||||
---
|
||||
|
||||
## 👥 Contributors
|
||||
|
||||
### Core
|
||||
- **@teknium1** — 179 PRs
|
||||
|
||||
### Top Community Contributors
|
||||
- **@SHL0MS** (7 PRs) — p5js creative coding skill, manim-video skill + 5 reference expansions, research-paper-writing, Nous OAuth fix, manim font fix
|
||||
- **@alt-glitch** (3 PRs) — Firecrawl cloud browser provider, test re-architecture + CI fixes, Nix flake fixes
|
||||
- **@benbarclay** (2 PRs) — Browser Use managed provider switch, Nous portal base URL fix
|
||||
- **@CharlieKerfoot** (2 PRs) — macOS portable base64 encoding, thread-safe PairingStore
|
||||
- **@WAXLYY** (2 PRs) — send_message secret redaction, gateway media URL sanitization
|
||||
- **@MadKangYu** (2 PRs) — Telegram log noise reduction, context compaction fix for temperature-restricted models
|
||||
|
||||
### All Contributors
|
||||
@alt-glitch, @austinpickett, @auspic7, @benbarclay, @CharlieKerfoot, @GratefulDave, @kshitijk4poor, @leotrs, @lumethegreat, @MadKangYu, @nericervin, @ryanautomated, @SHL0MS, @techguysimon, @tymrtn, @Vasanthdev2004, @WAXLYY, @xinbenlv
|
||||
|
||||
---
|
||||
|
||||
**Full Changelog**: [v2026.4.3...v2026.4.8](https://github.com/NousResearch/hermes-agent/compare/v2026.4.3...v2026.4.8)
|
||||
+117
-12
@@ -163,6 +163,17 @@ def _is_oauth_token(key: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _normalize_base_url_text(base_url) -> str:
|
||||
"""Normalize SDK/base transport URL values to a plain string for inspection.
|
||||
|
||||
Some client objects expose ``base_url`` as an ``httpx.URL`` instead of a raw
|
||||
string. Provider/auth detection should accept either shape.
|
||||
"""
|
||||
if not base_url:
|
||||
return ""
|
||||
return str(base_url).strip()
|
||||
|
||||
|
||||
def _is_third_party_anthropic_endpoint(base_url: str | None) -> bool:
|
||||
"""Return True for non-Anthropic endpoints using the Anthropic Messages API.
|
||||
|
||||
@@ -170,9 +181,10 @@ def _is_third_party_anthropic_endpoint(base_url: str | None) -> bool:
|
||||
with their own API keys via x-api-key, not Anthropic OAuth tokens. OAuth
|
||||
detection should be skipped for these endpoints.
|
||||
"""
|
||||
if not base_url:
|
||||
normalized = _normalize_base_url_text(base_url)
|
||||
if not normalized:
|
||||
return False # No base_url = direct Anthropic API
|
||||
normalized = base_url.rstrip("/").lower()
|
||||
normalized = normalized.rstrip("/").lower()
|
||||
if "anthropic.com" in normalized:
|
||||
return False # Direct Anthropic API — OAuth applies
|
||||
return True # Any other endpoint is a third-party proxy
|
||||
@@ -182,12 +194,13 @@ def _requires_bearer_auth(base_url: str | None) -> bool:
|
||||
"""Return True for Anthropic-compatible providers that require Bearer auth.
|
||||
|
||||
Some third-party /anthropic endpoints implement Anthropic's Messages API but
|
||||
require Authorization: Bearer instead of Anthropic's native x-api-key header.
|
||||
require Authorization: Bearer *** of Anthropic's native x-api-key header.
|
||||
MiniMax's global and China Anthropic-compatible endpoints follow this pattern.
|
||||
"""
|
||||
if not base_url:
|
||||
normalized = _normalize_base_url_text(base_url)
|
||||
if not normalized:
|
||||
return False
|
||||
normalized = base_url.rstrip("/").lower()
|
||||
normalized = normalized.rstrip("/").lower()
|
||||
return normalized.startswith(("https://api.minimax.io/anthropic", "https://api.minimaxi.com/anthropic"))
|
||||
|
||||
|
||||
@@ -203,13 +216,14 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
)
|
||||
from httpx import Timeout
|
||||
|
||||
normalized_base_url = _normalize_base_url_text(base_url)
|
||||
kwargs = {
|
||||
"timeout": Timeout(timeout=900.0, connect=10.0),
|
||||
}
|
||||
if base_url:
|
||||
kwargs["base_url"] = base_url
|
||||
if normalized_base_url:
|
||||
kwargs["base_url"] = normalized_base_url
|
||||
|
||||
if _requires_bearer_auth(base_url):
|
||||
if _requires_bearer_auth(normalized_base_url):
|
||||
# Some Anthropic-compatible providers (e.g. MiniMax) expect the API key in
|
||||
# Authorization: Bearer even for regular API keys. Route those endpoints
|
||||
# through auth_token so the SDK sends Bearer auth instead of x-api-key.
|
||||
@@ -942,12 +956,18 @@ def _convert_content_to_anthropic(content: Any) -> Any:
|
||||
|
||||
def convert_messages_to_anthropic(
|
||||
messages: List[Dict],
|
||||
base_url: str | None = None,
|
||||
) -> Tuple[Optional[Any], List[Dict]]:
|
||||
"""Convert OpenAI-format messages to Anthropic format.
|
||||
|
||||
Returns (system_prompt, anthropic_messages).
|
||||
System messages are extracted since Anthropic takes them as a separate param.
|
||||
system_prompt is a string or list of content blocks (when cache_control present).
|
||||
|
||||
When *base_url* is provided and points to a third-party Anthropic-compatible
|
||||
endpoint, all thinking block signatures are stripped. Signatures are
|
||||
Anthropic-proprietary — third-party endpoints cannot validate them and will
|
||||
reject them with HTTP 400 "Invalid signature in thinking block".
|
||||
"""
|
||||
system = None
|
||||
result = []
|
||||
@@ -1102,7 +1122,15 @@ def convert_messages_to_anthropic(
|
||||
curr_content = [{"type": "text", "text": curr_content}]
|
||||
fixed[-1]["content"] = prev_content + curr_content
|
||||
else:
|
||||
# Consecutive assistant messages — merge text content
|
||||
# Consecutive assistant messages — merge text content.
|
||||
# Drop thinking blocks from the *second* message: their
|
||||
# signature was computed against a different turn boundary
|
||||
# and becomes invalid once merged.
|
||||
if isinstance(m["content"], list):
|
||||
m["content"] = [
|
||||
b for b in m["content"]
|
||||
if not (isinstance(b, dict) and b.get("type") in ("thinking", "redacted_thinking"))
|
||||
]
|
||||
prev_blocks = fixed[-1]["content"]
|
||||
curr_blocks = m["content"]
|
||||
if isinstance(prev_blocks, list) and isinstance(curr_blocks, list):
|
||||
@@ -1120,6 +1148,79 @@ def convert_messages_to_anthropic(
|
||||
fixed.append(m)
|
||||
result = fixed
|
||||
|
||||
# ── Thinking block signature management ──────────────────────────
|
||||
# Anthropic signs thinking blocks against the full turn content.
|
||||
# Any upstream mutation (context compression, session truncation,
|
||||
# orphan stripping, message merging) invalidates the signature,
|
||||
# causing HTTP 400 "Invalid signature in thinking block".
|
||||
#
|
||||
# Signatures are Anthropic-proprietary. Third-party endpoints
|
||||
# (MiniMax, Azure AI Foundry, self-hosted proxies) cannot validate
|
||||
# them and will reject them outright. When targeting a third-party
|
||||
# endpoint, strip ALL thinking/redacted_thinking blocks from every
|
||||
# assistant message — the third-party will generate its own
|
||||
# thinking blocks if it supports extended thinking.
|
||||
#
|
||||
# For direct Anthropic (strategy following clawdbot/OpenClaw):
|
||||
# 1. Strip thinking/redacted_thinking from all assistant messages
|
||||
# EXCEPT the last one — preserves reasoning continuity on the
|
||||
# current tool-use chain while avoiding stale signature errors.
|
||||
# 2. Downgrade unsigned thinking blocks (no signature) to text —
|
||||
# Anthropic can't validate them and will reject them.
|
||||
# 3. Strip cache_control from thinking/redacted_thinking blocks —
|
||||
# cache markers can interfere with signature validation.
|
||||
_THINKING_TYPES = frozenset(("thinking", "redacted_thinking"))
|
||||
_is_third_party = _is_third_party_anthropic_endpoint(base_url)
|
||||
|
||||
last_assistant_idx = None
|
||||
for i in range(len(result) - 1, -1, -1):
|
||||
if result[i].get("role") == "assistant":
|
||||
last_assistant_idx = i
|
||||
break
|
||||
|
||||
for idx, m in enumerate(result):
|
||||
if m.get("role") != "assistant" or not isinstance(m.get("content"), list):
|
||||
continue
|
||||
|
||||
if _is_third_party or idx != last_assistant_idx:
|
||||
# Third-party endpoint: strip ALL thinking blocks from every
|
||||
# assistant message — signatures are Anthropic-proprietary.
|
||||
# Direct Anthropic: strip from non-latest assistant messages only.
|
||||
stripped = [
|
||||
b for b in m["content"]
|
||||
if not (isinstance(b, dict) and b.get("type") in _THINKING_TYPES)
|
||||
]
|
||||
m["content"] = stripped or [{"type": "text", "text": "(thinking elided)"}]
|
||||
else:
|
||||
# Latest assistant on direct Anthropic: keep signed thinking
|
||||
# blocks for reasoning continuity; downgrade unsigned ones to
|
||||
# plain text.
|
||||
new_content = []
|
||||
for b in m["content"]:
|
||||
if not isinstance(b, dict) or b.get("type") not in _THINKING_TYPES:
|
||||
new_content.append(b)
|
||||
continue
|
||||
if b.get("type") == "redacted_thinking":
|
||||
# Redacted blocks use 'data' for the signature payload
|
||||
if b.get("data"):
|
||||
new_content.append(b)
|
||||
# else: drop — no data means it can't be validated
|
||||
elif b.get("signature"):
|
||||
# Signed thinking block — keep it
|
||||
new_content.append(b)
|
||||
else:
|
||||
# Unsigned thinking — downgrade to text so it's not lost
|
||||
thinking_text = b.get("thinking", "")
|
||||
if thinking_text:
|
||||
new_content.append({"type": "text", "text": thinking_text})
|
||||
m["content"] = new_content or [{"type": "text", "text": "(empty)"}]
|
||||
|
||||
# Strip cache_control from any remaining thinking/redacted_thinking
|
||||
# blocks — cache markers interfere with signature validation.
|
||||
for b in m["content"]:
|
||||
if isinstance(b, dict) and b.get("type") in _THINKING_TYPES:
|
||||
b.pop("cache_control", None)
|
||||
|
||||
return system, result
|
||||
|
||||
|
||||
@@ -1133,6 +1234,7 @@ def build_anthropic_kwargs(
|
||||
is_oauth: bool = False,
|
||||
preserve_dots: bool = False,
|
||||
context_length: Optional[int] = None,
|
||||
base_url: str | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs for anthropic.messages.create().
|
||||
|
||||
@@ -1146,8 +1248,11 @@ def build_anthropic_kwargs(
|
||||
|
||||
When *preserve_dots* is True, model name dots are not converted to hyphens
|
||||
(for Alibaba/DashScope anthropic-compatible endpoints: qwen3.5-plus).
|
||||
|
||||
When *base_url* points to a third-party Anthropic-compatible endpoint,
|
||||
thinking block signatures are stripped (they are Anthropic-proprietary).
|
||||
"""
|
||||
system, anthropic_messages = convert_messages_to_anthropic(messages)
|
||||
system, anthropic_messages = convert_messages_to_anthropic(messages, base_url=base_url)
|
||||
anthropic_tools = convert_tools_to_anthropic(tools) if tools else []
|
||||
|
||||
model = normalize_model_name(model, preserve_dots=preserve_dots)
|
||||
@@ -1224,9 +1329,9 @@ def build_anthropic_kwargs(
|
||||
# Map reasoning_config to Anthropic's thinking parameter.
|
||||
# Claude 4.6 models use adaptive thinking + output_config.effort.
|
||||
# Older models use manual thinking with budget_tokens.
|
||||
# Haiku models do NOT support extended thinking at all — skip entirely.
|
||||
# Haiku and MiniMax models do NOT support extended thinking — skip entirely.
|
||||
if reasoning_config and isinstance(reasoning_config, dict):
|
||||
if reasoning_config.get("enabled") is not False and "haiku" not in model.lower():
|
||||
if reasoning_config.get("enabled") is not False and "haiku" not in model.lower() and "minimax" not in model.lower():
|
||||
effort = str(reasoning_config.get("effort", "medium")).lower()
|
||||
budget = THINKING_BUDGET.get(effort, 8000)
|
||||
if _supports_adaptive_thinking(model):
|
||||
|
||||
+120
-51
@@ -59,13 +59,48 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PROVIDER_ALIASES = {
|
||||
"google": "gemini",
|
||||
"google-gemini": "gemini",
|
||||
"google-ai-studio": "gemini",
|
||||
"glm": "zai",
|
||||
"z-ai": "zai",
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
"kimi": "kimi-coding",
|
||||
"moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn",
|
||||
"minimax_cn": "minimax-cn",
|
||||
"claude": "anthropic",
|
||||
"claude-code": "anthropic",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_aux_provider(provider: Optional[str], *, for_vision: bool = False) -> str:
|
||||
normalized = (provider or "auto").strip().lower()
|
||||
if normalized.startswith("custom:"):
|
||||
suffix = normalized.split(":", 1)[1].strip()
|
||||
if not suffix:
|
||||
return "custom"
|
||||
normalized = suffix if not for_vision else "custom"
|
||||
if normalized == "codex":
|
||||
return "openai-codex"
|
||||
if normalized == "main":
|
||||
# Resolve to the user's actual main provider so named custom providers
|
||||
# and non-aggregator providers (DeepSeek, Alibaba, etc.) work correctly.
|
||||
main_prov = _read_main_provider()
|
||||
if main_prov and main_prov not in ("auto", "main", ""):
|
||||
return main_prov
|
||||
return "custom"
|
||||
return _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"gemini": "gemini-3-flash-preview",
|
||||
"zai": "glm-4.5-flash",
|
||||
"kimi-coding": "kimi-k2-turbo-preview",
|
||||
"minimax": "MiniMax-M2.7-highspeed",
|
||||
"minimax-cn": "MiniMax-M2.7-highspeed",
|
||||
"minimax": "MiniMax-M2.7",
|
||||
"minimax-cn": "MiniMax-M2.7",
|
||||
"anthropic": "claude-haiku-4-5-20251001",
|
||||
"ai-gateway": "google/gemini-3-flash",
|
||||
"opencode-zen": "gemini-3-flash",
|
||||
@@ -92,6 +127,7 @@ auxiliary_is_nous: bool = False
|
||||
_OPENROUTER_MODEL = "google/gemini-3-flash-preview"
|
||||
_NOUS_MODEL = "google/gemini-3-flash-preview"
|
||||
_NOUS_FREE_TIER_VISION_MODEL = "xiaomi/mimo-v2-omni"
|
||||
_NOUS_FREE_TIER_AUX_MODEL = "xiaomi/mimo-v2-pro"
|
||||
_NOUS_DEFAULT_BASE_URL = "https://inference-api.nousresearch.com/v1"
|
||||
_ANTHROPIC_DEFAULT_BASE_URL = "https://api.anthropic.com"
|
||||
_AUTH_JSON_PATH = get_hermes_home() / "auth.json"
|
||||
@@ -105,6 +141,23 @@ _CODEX_AUX_MODEL = "gpt-5.2-codex"
|
||||
_CODEX_AUX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
def _to_openai_base_url(base_url: str) -> str:
|
||||
"""Normalize an Anthropic-style base URL to OpenAI-compatible format.
|
||||
|
||||
Some providers (MiniMax, MiniMax-CN) expose an ``/anthropic`` endpoint for
|
||||
the Anthropic Messages API and a separate ``/v1`` endpoint for OpenAI chat
|
||||
completions. The auxiliary client uses the OpenAI SDK, so it must hit the
|
||||
``/v1`` surface. Passing the raw ``inference_base_url`` causes requests to
|
||||
land on ``/anthropic/chat/completions`` — a 404.
|
||||
"""
|
||||
url = str(base_url or "").strip().rstrip("/")
|
||||
if url.endswith("/anthropic"):
|
||||
rewritten = url[: -len("/anthropic")] + "/v1"
|
||||
logger.debug("Auxiliary client: rewrote base URL %s → %s", url, rewritten)
|
||||
return rewritten
|
||||
return url
|
||||
|
||||
|
||||
def _select_pool_entry(provider: str) -> Tuple[bool, Optional[Any]]:
|
||||
"""Return (pool_exists_for_provider, selected_entry)."""
|
||||
try:
|
||||
@@ -634,7 +687,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
if not api_key:
|
||||
continue
|
||||
|
||||
base_url = _pool_runtime_base_url(entry, pconfig.inference_base_url) or pconfig.inference_base_url
|
||||
base_url = _to_openai_base_url(
|
||||
_pool_runtime_base_url(entry, pconfig.inference_base_url) or pconfig.inference_base_url
|
||||
)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model)
|
||||
extra = {}
|
||||
@@ -651,7 +706,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
if not api_key:
|
||||
continue
|
||||
|
||||
base_url = str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
base_url = _to_openai_base_url(
|
||||
str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||
extra = {}
|
||||
@@ -713,7 +770,7 @@ def _try_openrouter() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
|
||||
|
||||
|
||||
def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
def _try_nous(vision: bool = False) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
nous = _read_nous_auth()
|
||||
if not nous:
|
||||
return None, None
|
||||
@@ -725,12 +782,13 @@ def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
else:
|
||||
model = _NOUS_MODEL
|
||||
# Free-tier users can't use paid auxiliary models — use the free
|
||||
# multimodal model instead so vision/browser-vision still works.
|
||||
# models instead: mimo-v2-omni for vision, mimo-v2-pro for text tasks.
|
||||
try:
|
||||
from hermes_cli.models import check_nous_free_tier
|
||||
if check_nous_free_tier():
|
||||
model = _NOUS_FREE_TIER_VISION_MODEL
|
||||
logger.debug("Free-tier Nous account — using %s for auxiliary/vision", model)
|
||||
model = _NOUS_FREE_TIER_VISION_MODEL if vision else _NOUS_FREE_TIER_AUX_MODEL
|
||||
logger.debug("Free-tier Nous account — using %s for auxiliary/%s",
|
||||
model, "vision" if vision else "text")
|
||||
except Exception:
|
||||
pass
|
||||
return (
|
||||
@@ -1138,17 +1196,7 @@ def resolve_provider_client(
|
||||
(client, resolved_model) or (None, None) if auth is unavailable.
|
||||
"""
|
||||
# Normalise aliases
|
||||
provider = (provider or "auto").strip().lower()
|
||||
if provider == "codex":
|
||||
provider = "openai-codex"
|
||||
if provider == "main":
|
||||
# Resolve to the user's actual main provider so named custom providers
|
||||
# and non-aggregator providers (DeepSeek, Alibaba, etc.) work correctly.
|
||||
main_prov = _read_main_provider()
|
||||
if main_prov and main_prov not in ("auto", "main", ""):
|
||||
provider = main_prov
|
||||
else:
|
||||
provider = "custom"
|
||||
provider = _normalize_aux_provider(provider)
|
||||
|
||||
# ── Auto: try all providers in priority order ────────────────────
|
||||
if provider == "auto":
|
||||
@@ -1298,7 +1346,9 @@ def resolve_provider_client(
|
||||
provider, ", ".join(tried_sources))
|
||||
return None, None
|
||||
|
||||
base_url = str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
base_url = _to_openai_base_url(
|
||||
str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
)
|
||||
|
||||
default_model = _API_KEY_PROVIDER_AUX_MODELS.get(provider, "")
|
||||
final_model = model or default_model
|
||||
@@ -1375,24 +1425,11 @@ def get_async_text_auxiliary_client(task: str = ""):
|
||||
_VISION_AUTO_PROVIDER_ORDER = (
|
||||
"openrouter",
|
||||
"nous",
|
||||
"openai-codex",
|
||||
"anthropic",
|
||||
"custom",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_vision_provider(provider: Optional[str]) -> str:
|
||||
provider = (provider or "auto").strip().lower()
|
||||
if provider == "codex":
|
||||
return "openai-codex"
|
||||
if provider == "main":
|
||||
# Resolve to actual main provider — named custom providers and
|
||||
# non-aggregator providers need to pass through as their real name.
|
||||
main_prov = _read_main_provider()
|
||||
if main_prov and main_prov not in ("auto", "main", ""):
|
||||
return main_prov
|
||||
return "custom"
|
||||
return provider
|
||||
return _normalize_aux_provider(provider, for_vision=True)
|
||||
|
||||
|
||||
def _resolve_strict_vision_backend(provider: str) -> Tuple[Optional[Any], Optional[str]]:
|
||||
@@ -1400,7 +1437,7 @@ def _resolve_strict_vision_backend(provider: str) -> Tuple[Optional[Any], Option
|
||||
if provider == "openrouter":
|
||||
return _try_openrouter()
|
||||
if provider == "nous":
|
||||
return _try_nous()
|
||||
return _try_nous(vision=True)
|
||||
if provider == "openai-codex":
|
||||
return _try_codex()
|
||||
if provider == "anthropic":
|
||||
@@ -1433,17 +1470,26 @@ def _preferred_main_vision_provider() -> Optional[str]:
|
||||
def get_available_vision_backends() -> List[str]:
|
||||
"""Return the currently available vision backends in auto-selection order.
|
||||
|
||||
This is the single source of truth for setup, tool gating, and runtime
|
||||
auto-routing of vision tasks. The selected main provider is preferred when
|
||||
it is also a known-good vision backend; otherwise Hermes falls back through
|
||||
the standard conservative order.
|
||||
Order: active provider → OpenRouter → Nous → stop. This is the single
|
||||
source of truth for setup, tool gating, and runtime auto-routing of
|
||||
vision tasks.
|
||||
"""
|
||||
ordered = list(_VISION_AUTO_PROVIDER_ORDER)
|
||||
preferred = _preferred_main_vision_provider()
|
||||
if preferred in ordered:
|
||||
ordered.remove(preferred)
|
||||
ordered.insert(0, preferred)
|
||||
return [provider for provider in ordered if _strict_vision_backend_available(provider)]
|
||||
available: List[str] = []
|
||||
# 1. Active provider — if the user configured a provider, try it first.
|
||||
main_provider = _read_main_provider()
|
||||
if main_provider and main_provider not in ("auto", ""):
|
||||
if main_provider in _VISION_AUTO_PROVIDER_ORDER:
|
||||
if _strict_vision_backend_available(main_provider):
|
||||
available.append(main_provider)
|
||||
else:
|
||||
client, _ = resolve_provider_client(main_provider, _read_main_model())
|
||||
if client is not None:
|
||||
available.append(main_provider)
|
||||
# 2. OpenRouter, 3. Nous — skip if already covered by main provider.
|
||||
for p in _VISION_AUTO_PROVIDER_ORDER:
|
||||
if p not in available and _strict_vision_backend_available(p):
|
||||
available.append(p)
|
||||
return available
|
||||
|
||||
|
||||
def resolve_vision_provider_client(
|
||||
@@ -1488,16 +1534,39 @@ def resolve_vision_provider_client(
|
||||
return "custom", client, final_model
|
||||
|
||||
if requested == "auto":
|
||||
ordered = list(_VISION_AUTO_PROVIDER_ORDER)
|
||||
preferred = _preferred_main_vision_provider()
|
||||
if preferred in ordered:
|
||||
ordered.remove(preferred)
|
||||
ordered.insert(0, preferred)
|
||||
# Vision auto-detection order:
|
||||
# 1. Active provider + model (user's main chat config)
|
||||
# 2. OpenRouter (known vision-capable default model)
|
||||
# 3. Nous Portal (known vision-capable default model)
|
||||
# 4. Stop
|
||||
main_provider = _read_main_provider()
|
||||
main_model = _read_main_model()
|
||||
if main_provider and main_provider not in ("auto", ""):
|
||||
if main_provider in _VISION_AUTO_PROVIDER_ORDER:
|
||||
# Known strict backend — use its defaults.
|
||||
sync_client, default_model = _resolve_strict_vision_backend(main_provider)
|
||||
if sync_client is not None:
|
||||
return _finalize(main_provider, sync_client, default_model)
|
||||
else:
|
||||
# Exotic provider (DeepSeek, Alibaba, named custom, etc.)
|
||||
rpc_client, rpc_model = resolve_provider_client(
|
||||
main_provider, main_model)
|
||||
if rpc_client is not None:
|
||||
logger.info(
|
||||
"Vision auto-detect: using active provider %s (%s)",
|
||||
main_provider, rpc_model or main_model,
|
||||
)
|
||||
return _finalize(
|
||||
main_provider, rpc_client, rpc_model or main_model)
|
||||
|
||||
for candidate in ordered:
|
||||
# Fall back through aggregators.
|
||||
for candidate in _VISION_AUTO_PROVIDER_ORDER:
|
||||
if candidate == main_provider:
|
||||
continue # already tried above
|
||||
sync_client, default_model = _resolve_strict_vision_backend(candidate)
|
||||
if sync_client is not None:
|
||||
return _finalize(candidate, sync_client, default_model)
|
||||
|
||||
logger.debug("Auxiliary vision client: none available")
|
||||
return None, None, None
|
||||
|
||||
|
||||
+66
-3
@@ -26,12 +26,14 @@ _PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
||||
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
||||
"gemini", "zai", "kimi-coding", "minimax", "minimax-cn", "anthropic", "deepseek",
|
||||
"opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba",
|
||||
"qwen-oauth",
|
||||
"custom", "local",
|
||||
# Common aliases
|
||||
"google", "google-gemini", "google-ai-studio",
|
||||
"glm", "z-ai", "z.ai", "zhipu", "github", "github-copilot",
|
||||
"github-models", "kimi", "moonshot", "claude", "deep-seek",
|
||||
"opencode", "zen", "go", "vercel", "kilo", "dashscope", "aliyun", "qwen",
|
||||
"qwen-portal",
|
||||
})
|
||||
|
||||
|
||||
@@ -113,8 +115,15 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"llama": 131072,
|
||||
# Qwen
|
||||
"qwen": 131072,
|
||||
# MiniMax
|
||||
"minimax": 204800,
|
||||
# MiniMax (lowercase — lookup lowercases model names at line 973)
|
||||
"minimax-m1-256k": 1000000,
|
||||
"minimax-m1-128k": 1000000,
|
||||
"minimax-m1-80k": 1000000,
|
||||
"minimax-m1-40k": 1000000,
|
||||
"minimax-m1": 1000000,
|
||||
"minimax-m2.5": 1048576,
|
||||
"minimax-m2.7": 1048576,
|
||||
"minimax": 1048576,
|
||||
# GLM
|
||||
"glm": 202752,
|
||||
# Kimi
|
||||
@@ -127,7 +136,7 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"deepseek-ai/DeepSeek-V3.2": 65536,
|
||||
"moonshotai/Kimi-K2.5": 262144,
|
||||
"moonshotai/Kimi-K2-Thinking": 262144,
|
||||
"MiniMaxAI/MiniMax-M2.5": 204800,
|
||||
"MiniMaxAI/MiniMax-M2.5": 1048576,
|
||||
"XiaomiMiMo/MiMo-V2-Flash": 32768,
|
||||
"mimo-v2-pro": 1048576,
|
||||
"mimo-v2-omni": 1048576,
|
||||
@@ -180,6 +189,7 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
||||
"api.minimax": "minimax",
|
||||
"dashscope.aliyuncs.com": "alibaba",
|
||||
"dashscope-intl.aliyuncs.com": "alibaba",
|
||||
"portal.qwen.ai": "qwen-oauth",
|
||||
"openrouter.ai": "openrouter",
|
||||
"generativelanguage.googleapis.com": "gemini",
|
||||
"inference-api.nousresearch.com": "nous",
|
||||
@@ -611,6 +621,59 @@ def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]:
|
||||
"""Query an Ollama server for the model's context length.
|
||||
|
||||
Returns the model's maximum context from GGUF metadata via ``/api/show``,
|
||||
or the explicit ``num_ctx`` from the Modelfile if set. Returns None if
|
||||
the server is unreachable or not Ollama.
|
||||
|
||||
This is the value that should be passed as ``num_ctx`` in Ollama chat
|
||||
requests to override the default 2048.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
bare_model = _strip_provider_prefix(model)
|
||||
server_url = base_url.rstrip("/")
|
||||
if server_url.endswith("/v1"):
|
||||
server_url = server_url[:-3]
|
||||
|
||||
try:
|
||||
server_type = detect_local_server_type(base_url)
|
||||
except Exception:
|
||||
return None
|
||||
if server_type != "ollama":
|
||||
return None
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=3.0) as client:
|
||||
resp = client.post(f"{server_url}/api/show", json={"name": bare_model})
|
||||
if resp.status_code != 200:
|
||||
return None
|
||||
data = resp.json()
|
||||
|
||||
# Prefer explicit num_ctx from Modelfile parameters (user override)
|
||||
params = data.get("parameters", "")
|
||||
if "num_ctx" in params:
|
||||
for line in params.split("\n"):
|
||||
if "num_ctx" in line:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
return int(parts[-1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Fall back to GGUF model_info context_length (training max)
|
||||
model_info = data.get("model_info", {})
|
||||
for key, value in model_info.items():
|
||||
if "context_length" in key and isinstance(value, (int, float)):
|
||||
return int(value)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
|
||||
"""Query a local server for the model's context length."""
|
||||
import httpx
|
||||
|
||||
@@ -153,6 +153,7 @@ PROVIDER_TO_MODELS_DEV: Dict[str, str] = {
|
||||
"minimax-cn": "minimax-cn",
|
||||
"deepseek": "deepseek",
|
||||
"alibaba": "alibaba",
|
||||
"qwen-oauth": "alibaba",
|
||||
"copilot": "github-copilot",
|
||||
"ai-gateway": "vercel",
|
||||
"opencode-zen": "opencode",
|
||||
|
||||
@@ -204,6 +204,30 @@ OPENAI_MODEL_EXECUTION_GUIDANCE = (
|
||||
"the result.\n"
|
||||
"</tool_persistence>\n"
|
||||
"\n"
|
||||
"<mandatory_tool_use>\n"
|
||||
"NEVER answer these from memory or mental computation — ALWAYS use a tool:\n"
|
||||
"- Arithmetic, math, calculations → use terminal or execute_code\n"
|
||||
"- Hashes, encodings, checksums → use terminal (e.g. sha256sum, base64)\n"
|
||||
"- Current time, date, timezone → use terminal (e.g. date)\n"
|
||||
"- System state: OS, CPU, memory, disk, ports, processes → use terminal\n"
|
||||
"- File contents, sizes, line counts → use read_file, search_files, or terminal\n"
|
||||
"- Git history, branches, diffs → use terminal\n"
|
||||
"- Current facts (weather, news, versions) → use web_search\n"
|
||||
"Your memory and user profile describe the USER, not the system you are "
|
||||
"running on. The execution environment may differ from what the user profile "
|
||||
"says about their personal setup.\n"
|
||||
"</mandatory_tool_use>\n"
|
||||
"\n"
|
||||
"<act_dont_ask>\n"
|
||||
"When a question has an obvious default interpretation, act on it immediately "
|
||||
"instead of asking for clarification. Examples:\n"
|
||||
"- 'Is port 443 open?' → check THIS machine (don't ask 'open where?')\n"
|
||||
"- 'What OS am I running?' → check the live system (don't use user profile)\n"
|
||||
"- 'What time is it?' → run `date` (don't guess)\n"
|
||||
"Only ask for clarification when the ambiguity genuinely changes what tool "
|
||||
"you would call.\n"
|
||||
"</act_dont_ask>\n"
|
||||
"\n"
|
||||
"<prerequisite_checks>\n"
|
||||
"- Before taking an action, check whether prerequisite discovery, lookup, or "
|
||||
"context-gathering steps are needed.\n"
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Retry utilities — jittered backoff for decorrelated retries.
|
||||
|
||||
Replaces fixed exponential backoff with jittered delays to prevent
|
||||
thundering-herd retry spikes when multiple sessions hit the same
|
||||
rate-limited provider concurrently.
|
||||
"""
|
||||
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
|
||||
# Monotonic counter for jitter seed uniqueness within the same process.
|
||||
# Protected by a lock to avoid race conditions in concurrent retry paths
|
||||
# (e.g. multiple gateway sessions retrying simultaneously).
|
||||
_jitter_counter = 0
|
||||
_jitter_lock = threading.Lock()
|
||||
|
||||
|
||||
def jittered_backoff(
|
||||
attempt: int,
|
||||
*,
|
||||
base_delay: float = 5.0,
|
||||
max_delay: float = 120.0,
|
||||
jitter_ratio: float = 0.5,
|
||||
) -> float:
|
||||
"""Compute a jittered exponential backoff delay.
|
||||
|
||||
Args:
|
||||
attempt: 1-based retry attempt number.
|
||||
base_delay: Base delay in seconds for attempt 1.
|
||||
max_delay: Maximum delay cap in seconds.
|
||||
jitter_ratio: Fraction of computed delay to use as random jitter
|
||||
range. 0.5 means jitter is uniform in [0, 0.5 * delay].
|
||||
|
||||
Returns:
|
||||
Delay in seconds: min(base * 2^(attempt-1), max_delay) + jitter.
|
||||
|
||||
The jitter decorrelates concurrent retries so multiple sessions
|
||||
hitting the same provider don't all retry at the same instant.
|
||||
"""
|
||||
global _jitter_counter
|
||||
with _jitter_lock:
|
||||
_jitter_counter += 1
|
||||
tick = _jitter_counter
|
||||
|
||||
exponent = max(0, attempt - 1)
|
||||
if exponent >= 63 or base_delay <= 0:
|
||||
delay = max_delay
|
||||
else:
|
||||
delay = min(base_delay * (2 ** exponent), max_delay)
|
||||
|
||||
# Seed from time + counter for decorrelation even with coarse clocks.
|
||||
seed = (time.time_ns() ^ (tick * 0x9E3779B9)) & 0xFFFFFFFF
|
||||
rng = random.Random(seed)
|
||||
jitter = rng.uniform(0, jitter_ratio * delay)
|
||||
|
||||
return delay + jitter
|
||||
@@ -644,10 +644,14 @@ platform_toolsets:
|
||||
# Voice Transcription (Speech-to-Text)
|
||||
# =============================================================================
|
||||
# Automatically transcribe voice messages on messaging platforms.
|
||||
# Requires OPENAI_API_KEY in .env (uses OpenAI Whisper API directly).
|
||||
# Providers: local (free, faster-whisper) | groq (free tier) | openai (Whisper API) | mistral (Voxtral Transcribe)
|
||||
# Set the corresponding API key in .env: GROQ_API_KEY, OPENAI_API_KEY, or MISTRAL_API_KEY.
|
||||
stt:
|
||||
enabled: true
|
||||
# provider: "local" # auto-detected if omitted
|
||||
model: "whisper-1" # whisper-1 (cheapest) | gpt-4o-mini-transcribe | gpt-4o-transcribe
|
||||
# mistral:
|
||||
# model: "voxtral-mini-latest" # voxtral-mini-latest | voxtral-mini-2602
|
||||
|
||||
# =============================================================================
|
||||
# Response Pacing (Messaging Platforms)
|
||||
|
||||
@@ -612,6 +612,11 @@ def _run_cleanup():
|
||||
pass
|
||||
# Shut down memory provider (on_session_end + shutdown_all) at actual
|
||||
# session boundary — NOT per-turn inside run_conversation().
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_invoke_hook("on_session_finalize", session_id=_active_agent_ref.session_id if _active_agent_ref else None, platform="cli")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if _active_agent_ref and hasattr(_active_agent_ref, 'shutdown_memory_provider'):
|
||||
_active_agent_ref.shutdown_memory_provider(
|
||||
@@ -755,7 +760,10 @@ def _setup_worktree(repo_root: str = None) -> Optional[Dict[str, str]]:
|
||||
def _cleanup_worktree(info: Dict[str, str] = None) -> None:
|
||||
"""Remove a worktree and its branch on exit.
|
||||
|
||||
If the worktree has uncommitted changes, warn and keep it.
|
||||
Preserves the worktree only if it has unpushed commits (real work
|
||||
that hasn't been pushed to any remote). Uncommitted changes alone
|
||||
(untracked files, test artifacts) are not enough to keep it — agent
|
||||
work lives in commits/PRs, not the working tree.
|
||||
"""
|
||||
global _active_worktree
|
||||
info = info or _active_worktree
|
||||
@@ -771,23 +779,27 @@ def _cleanup_worktree(info: Dict[str, str] = None) -> None:
|
||||
if not Path(wt_path).exists():
|
||||
return
|
||||
|
||||
# Check for uncommitted changes
|
||||
# Check for unpushed commits — commits reachable from HEAD but not
|
||||
# from any remote branch. These represent real work the agent did
|
||||
# but didn't push.
|
||||
has_unpushed = False
|
||||
try:
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
result = subprocess.run(
|
||||
["git", "log", "--oneline", "HEAD", "--not", "--remotes"],
|
||||
capture_output=True, text=True, timeout=10, cwd=wt_path,
|
||||
)
|
||||
has_changes = bool(status.stdout.strip())
|
||||
has_unpushed = bool(result.stdout.strip())
|
||||
except Exception:
|
||||
has_changes = True # Assume dirty on error — don't delete
|
||||
has_unpushed = True # Assume unpushed on error — don't delete
|
||||
|
||||
if has_changes:
|
||||
print(f"\n\033[33m⚠ Worktree has uncommitted changes, keeping: {wt_path}\033[0m")
|
||||
print(f" To clean up manually: git worktree remove {wt_path}")
|
||||
if has_unpushed:
|
||||
print(f"\n\033[33m⚠ Worktree has unpushed commits, keeping: {wt_path}\033[0m")
|
||||
print(f" To clean up manually: git worktree remove --force {wt_path}")
|
||||
_active_worktree = None
|
||||
return
|
||||
|
||||
# Remove worktree
|
||||
# Remove worktree (even if working tree is dirty — uncommitted
|
||||
# changes without unpushed commits are just artifacts)
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "worktree", "remove", wt_path, "--force"],
|
||||
@@ -796,7 +808,7 @@ def _cleanup_worktree(info: Dict[str, str] = None) -> None:
|
||||
except Exception as e:
|
||||
logger.debug("Failed to remove worktree: %s", e)
|
||||
|
||||
# Delete the branch (only if it was never pushed / has no upstream)
|
||||
# Delete the branch
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "branch", "-D", branch],
|
||||
@@ -810,19 +822,27 @@ def _cleanup_worktree(info: Dict[str, str] = None) -> None:
|
||||
|
||||
|
||||
def _prune_stale_worktrees(repo_root: str, max_age_hours: int = 24) -> None:
|
||||
"""Remove worktrees older than max_age_hours that have no uncommitted changes.
|
||||
"""Remove stale worktrees and orphaned branches on startup.
|
||||
|
||||
Runs silently on startup to clean up after crashed/killed sessions.
|
||||
Age-based tiers:
|
||||
- Under max_age_hours (24h): skip — session may still be active.
|
||||
- 24h–72h: remove if no unpushed commits.
|
||||
- Over 72h: force remove regardless (nothing should sit this long).
|
||||
|
||||
Also prunes orphaned ``hermes/*`` and ``pr-*`` local branches that
|
||||
have no corresponding worktree.
|
||||
"""
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
worktrees_dir = Path(repo_root) / ".worktrees"
|
||||
if not worktrees_dir.exists():
|
||||
_prune_orphaned_branches(repo_root)
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
cutoff = now - (max_age_hours * 3600)
|
||||
soft_cutoff = now - (max_age_hours * 3600) # 24h default
|
||||
hard_cutoff = now - (max_age_hours * 3 * 3600) # 72h default
|
||||
|
||||
for entry in worktrees_dir.iterdir():
|
||||
if not entry.is_dir() or not entry.name.startswith("hermes-"):
|
||||
@@ -831,21 +851,24 @@ def _prune_stale_worktrees(repo_root: str, max_age_hours: int = 24) -> None:
|
||||
# Check age
|
||||
try:
|
||||
mtime = entry.stat().st_mtime
|
||||
if mtime > cutoff:
|
||||
if mtime > soft_cutoff:
|
||||
continue # Too recent — skip
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Check for uncommitted changes
|
||||
try:
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
capture_output=True, text=True, timeout=5, cwd=str(entry),
|
||||
)
|
||||
if status.stdout.strip():
|
||||
continue # Has changes — skip
|
||||
except Exception:
|
||||
continue # Can't check — skip
|
||||
force = mtime <= hard_cutoff # Over 72h — force remove
|
||||
|
||||
if not force:
|
||||
# 24h–72h tier: only remove if no unpushed commits
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "log", "--oneline", "HEAD", "--not", "--remotes"],
|
||||
capture_output=True, text=True, timeout=5, cwd=str(entry),
|
||||
)
|
||||
if result.stdout.strip():
|
||||
continue # Has unpushed commits — skip
|
||||
except Exception:
|
||||
continue # Can't check — skip
|
||||
|
||||
# Safe to remove
|
||||
try:
|
||||
@@ -864,10 +887,81 @@ def _prune_stale_worktrees(repo_root: str, max_age_hours: int = 24) -> None:
|
||||
["git", "branch", "-D", branch],
|
||||
capture_output=True, text=True, timeout=10, cwd=repo_root,
|
||||
)
|
||||
logger.debug("Pruned stale worktree: %s", entry.name)
|
||||
logger.debug("Pruned stale worktree: %s (force=%s)", entry.name, force)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to prune worktree %s: %s", entry.name, e)
|
||||
|
||||
_prune_orphaned_branches(repo_root)
|
||||
|
||||
|
||||
def _prune_orphaned_branches(repo_root: str) -> None:
|
||||
"""Delete local ``hermes/hermes-*`` and ``pr-*`` branches with no worktree.
|
||||
|
||||
These are auto-generated by ``hermes -w`` sessions and PR review
|
||||
workflows respectively. Once their worktree is gone they serve no
|
||||
purpose and just accumulate.
|
||||
"""
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--format=%(refname:short)"],
|
||||
capture_output=True, text=True, timeout=10, cwd=repo_root,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return
|
||||
all_branches = [b.strip() for b in result.stdout.strip().split("\n") if b.strip()]
|
||||
except Exception:
|
||||
return
|
||||
|
||||
# Collect branches that are actively checked out in a worktree
|
||||
active_branches: set = set()
|
||||
try:
|
||||
wt_result = subprocess.run(
|
||||
["git", "worktree", "list", "--porcelain"],
|
||||
capture_output=True, text=True, timeout=10, cwd=repo_root,
|
||||
)
|
||||
for line in wt_result.stdout.split("\n"):
|
||||
if line.startswith("branch refs/heads/"):
|
||||
active_branches.add(line.split("branch refs/heads/", 1)[-1].strip())
|
||||
except Exception:
|
||||
return # Can't determine active branches — bail
|
||||
|
||||
# Also protect the currently checked-out branch and main
|
||||
try:
|
||||
head_result = subprocess.run(
|
||||
["git", "branch", "--show-current"],
|
||||
capture_output=True, text=True, timeout=5, cwd=repo_root,
|
||||
)
|
||||
current = head_result.stdout.strip()
|
||||
if current:
|
||||
active_branches.add(current)
|
||||
except Exception:
|
||||
pass
|
||||
active_branches.add("main")
|
||||
|
||||
orphaned = [
|
||||
b for b in all_branches
|
||||
if b not in active_branches
|
||||
and (b.startswith("hermes/hermes-") or b.startswith("pr-"))
|
||||
]
|
||||
|
||||
if not orphaned:
|
||||
return
|
||||
|
||||
# Delete in batches
|
||||
for i in range(0, len(orphaned), 50):
|
||||
batch = orphaned[i:i + 50]
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "branch", "-D"] + batch,
|
||||
capture_output=True, text=True, timeout=30, cwd=repo_root,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to prune orphaned branches: %s", e)
|
||||
|
||||
logger.debug("Pruned %d orphaned branches", len(orphaned))
|
||||
|
||||
# ============================================================================
|
||||
# ASCII Art & Branding
|
||||
# ============================================================================
|
||||
@@ -3314,6 +3408,22 @@ class HermesCLI:
|
||||
flush_tool_summary()
|
||||
print()
|
||||
|
||||
def _notify_session_boundary(self, event_type: str) -> None:
|
||||
"""Fire a session-boundary plugin hook (on_session_finalize or on_session_reset).
|
||||
|
||||
Non-blocking — errors are caught and logged. Safe to call from any
|
||||
lifecycle point (shutdown, /new, /reset).
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_invoke_hook(
|
||||
event_type,
|
||||
session_id=self.agent.session_id if self.agent else None,
|
||||
platform=getattr(self, "platform", None) or "cli",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def new_session(self, silent=False):
|
||||
"""Start a fresh session with a new session ID and cleared agent state."""
|
||||
if self.agent and self.conversation_history:
|
||||
@@ -3321,6 +3431,10 @@ class HermesCLI:
|
||||
self.agent.flush_memories(self.conversation_history)
|
||||
except (Exception, KeyboardInterrupt):
|
||||
pass
|
||||
self._notify_session_boundary("on_session_finalize")
|
||||
elif self.agent:
|
||||
# First session or empty history — still finalize the old session
|
||||
self._notify_session_boundary("on_session_finalize")
|
||||
|
||||
old_session_id = self.session_id
|
||||
if self._session_db and old_session_id:
|
||||
@@ -3365,6 +3479,7 @@ class HermesCLI:
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._notify_session_boundary("on_session_reset")
|
||||
|
||||
if not silent:
|
||||
print("(^_^)v New session started!")
|
||||
@@ -4553,13 +4668,13 @@ class HermesCLI:
|
||||
if output:
|
||||
self.console.print(_rich_text_from_ansi(output))
|
||||
else:
|
||||
ChatConsole().print("[dim]Command returned no output[/]")
|
||||
self.console.print("[dim]Command returned no output[/]")
|
||||
except subprocess.TimeoutExpired:
|
||||
ChatConsole().print("[bold red]Quick command timed out (30s)[/]")
|
||||
self.console.print("[bold red]Quick command timed out (30s)[/]")
|
||||
except Exception as e:
|
||||
ChatConsole().print(f"[bold red]Quick command error: {e}[/]")
|
||||
self.console.print(f"[bold red]Quick command error: {e}[/]")
|
||||
else:
|
||||
ChatConsole().print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]")
|
||||
self.console.print(f"[bold red]Quick command '{base_cmd}' has no command defined[/]")
|
||||
elif qcmd.get("type") == "alias":
|
||||
target = qcmd.get("target", "").strip()
|
||||
if target:
|
||||
@@ -4568,9 +4683,9 @@ class HermesCLI:
|
||||
aliased_command = f"{target} {user_args}".strip()
|
||||
return self.process_command(aliased_command)
|
||||
else:
|
||||
ChatConsole().print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]")
|
||||
self.console.print(f"[bold red]Quick command '{base_cmd}' has no target defined[/]")
|
||||
else:
|
||||
ChatConsole().print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]")
|
||||
self.console.print(f"[bold red]Quick command '{base_cmd}' has unsupported type (supported: 'exec', 'alias')[/]")
|
||||
# Check for plugin-registered slash commands
|
||||
elif base_cmd.lstrip("/") in _get_plugin_cmd_handler_names():
|
||||
from hermes_cli.plugins import get_plugin_command_handler
|
||||
|
||||
+7
-1
@@ -574,12 +574,16 @@ def remove_job(job_id: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
def mark_job_run(job_id: str, success: bool, error: Optional[str] = None,
|
||||
delivery_error: Optional[str] = None):
|
||||
"""
|
||||
Mark a job as having been run.
|
||||
|
||||
Updates last_run_at, last_status, increments completed count,
|
||||
computes next_run_at, and auto-deletes if repeat limit reached.
|
||||
|
||||
``delivery_error`` is tracked separately from the agent error — a job
|
||||
can succeed (agent produced output) but fail delivery (platform down).
|
||||
"""
|
||||
jobs = load_jobs()
|
||||
for i, job in enumerate(jobs):
|
||||
@@ -588,6 +592,8 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
job["last_run_at"] = now
|
||||
job["last_status"] = "ok" if success else "error"
|
||||
job["last_error"] = error if not success else None
|
||||
# Track delivery failures separately — cleared on successful delivery
|
||||
job["last_delivery_error"] = delivery_error
|
||||
|
||||
# Increment completed count
|
||||
if job.get("repeat"):
|
||||
|
||||
+32
-25
@@ -196,7 +196,7 @@ def _send_media_via_adapter(adapter, chat_id: str, media_files: list, metadata:
|
||||
logger.warning("Job '%s': failed to send media %s: %s", job.get("id", "?"), media_path, e)
|
||||
|
||||
|
||||
def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> None:
|
||||
def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> Optional[str]:
|
||||
"""
|
||||
Deliver job output to the configured target (origin chat, specific platform, etc.).
|
||||
|
||||
@@ -204,16 +204,16 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> None:
|
||||
use the live adapter first — this supports E2EE rooms (e.g. Matrix) where
|
||||
the standalone HTTP path cannot encrypt. Falls back to standalone send if
|
||||
the adapter path fails or is unavailable.
|
||||
|
||||
Returns None on success, or an error string on failure.
|
||||
"""
|
||||
target = _resolve_delivery_target(job)
|
||||
if not target:
|
||||
if job.get("deliver", "local") != "local":
|
||||
logger.warning(
|
||||
"Job '%s' deliver=%s but no concrete delivery target could be resolved",
|
||||
job["id"],
|
||||
job.get("deliver", "local"),
|
||||
)
|
||||
return
|
||||
msg = f"no delivery target resolved for deliver={job.get('deliver', 'local')}"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
return None # local-only jobs don't deliver — not a failure
|
||||
|
||||
platform_name = target["platform"]
|
||||
chat_id = target["chat_id"]
|
||||
@@ -239,19 +239,22 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> None:
|
||||
}
|
||||
platform = platform_map.get(platform_name.lower())
|
||||
if not platform:
|
||||
logger.warning("Job '%s': unknown platform '%s' for delivery", job["id"], platform_name)
|
||||
return
|
||||
msg = f"unknown platform '{platform_name}'"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
try:
|
||||
config = load_gateway_config()
|
||||
except Exception as e:
|
||||
logger.error("Job '%s': failed to load gateway config for delivery: %s", job["id"], e)
|
||||
return
|
||||
msg = f"failed to load gateway config: {e}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
pconfig = config.platforms.get(platform)
|
||||
if not pconfig or not pconfig.enabled:
|
||||
logger.warning("Job '%s': platform '%s' not configured/enabled", job["id"], platform_name)
|
||||
return
|
||||
msg = f"platform '{platform_name}' not configured/enabled"
|
||||
logger.warning("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
# Optionally wrap the content with a header/footer so the user knows this
|
||||
# is a cron delivery. Wrapping is on by default; set cron.wrap_response: false
|
||||
@@ -307,7 +310,7 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> None:
|
||||
|
||||
if adapter_ok:
|
||||
logger.info("Job '%s': delivered to %s:%s via live adapter", job["id"], platform_name, chat_id)
|
||||
return
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Job '%s': live adapter delivery to %s:%s failed (%s), falling back to standalone",
|
||||
@@ -329,13 +332,17 @@ def _deliver_result(job: dict, content: str, adapters=None, loop=None) -> None:
|
||||
future = pool.submit(asyncio.run, _send_to_platform(platform, pconfig, chat_id, cleaned_delivery_content, thread_id=thread_id, media_files=media_files))
|
||||
result = future.result(timeout=30)
|
||||
except Exception as e:
|
||||
logger.error("Job '%s': delivery to %s:%s failed: %s", job["id"], platform_name, chat_id, e)
|
||||
return
|
||||
msg = f"delivery to {platform_name}:{chat_id} failed: {e}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
if result and result.get("error"):
|
||||
logger.error("Job '%s': delivery error: %s", job["id"], result["error"])
|
||||
else:
|
||||
logger.info("Job '%s': delivered to %s:%s", job["id"], platform_name, chat_id)
|
||||
msg = f"delivery error: {result['error']}"
|
||||
logger.error("Job '%s': %s", job["id"], msg)
|
||||
return msg
|
||||
|
||||
logger.info("Job '%s': delivered to %s:%s", job["id"], platform_name, chat_id)
|
||||
return None
|
||||
|
||||
|
||||
_SCRIPT_TIMEOUT = 120 # seconds
|
||||
@@ -578,11 +585,9 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
except Exception as e:
|
||||
logger.warning("Job '%s': failed to load config.yaml, using defaults: %s", job_id, e)
|
||||
|
||||
# Reasoning config from env or config.yaml
|
||||
# Reasoning config from config.yaml
|
||||
from hermes_constants import parse_reasoning_effort
|
||||
effort = os.getenv("HERMES_REASONING_EFFORT", "")
|
||||
if not effort:
|
||||
effort = str(_cfg.get("agent", {}).get("reasoning_effort", "")).strip()
|
||||
effort = str(_cfg.get("agent", {}).get("reasoning_effort", "")).strip()
|
||||
reasoning_config = parse_reasoning_effort(effort)
|
||||
|
||||
# Prefill messages from env or config.yaml
|
||||
@@ -868,13 +873,15 @@ def tick(verbose: bool = True, adapters=None, loop=None) -> int:
|
||||
logger.info("Job '%s': agent returned %s — skipping delivery", job["id"], SILENT_MARKER)
|
||||
should_deliver = False
|
||||
|
||||
delivery_error = None
|
||||
if should_deliver:
|
||||
try:
|
||||
_deliver_result(job, deliver_content, adapters=adapters, loop=loop)
|
||||
delivery_error = _deliver_result(job, deliver_content, adapters=adapters, loop=loop)
|
||||
except Exception as de:
|
||||
delivery_error = str(de)
|
||||
logger.error("Delivery failed for job %s: %s", job["id"], de)
|
||||
|
||||
mark_job_run(job["id"], success, error)
|
||||
mark_job_run(job["id"], success, error, delivery_error=delivery_error)
|
||||
executed += 1
|
||||
|
||||
except Exception as e:
|
||||
|
||||
+58
-14
@@ -21,6 +21,8 @@ from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from model_tools import handle_function_call
|
||||
from tools.terminal_tool import get_active_env
|
||||
from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget
|
||||
|
||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
||||
# (e.g., the Modal/Docker/Daytona terminal backends). Running them in a separate
|
||||
@@ -136,8 +138,10 @@ class HermesAgentLoop:
|
||||
max_turns: int = 30,
|
||||
task_id: Optional[str] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
budget_config: Optional["BudgetConfig"] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the agent loop.
|
||||
@@ -150,19 +154,26 @@ class HermesAgentLoop:
|
||||
max_turns: Maximum number of LLM calls before stopping
|
||||
task_id: Unique ID for terminal/browser session isolation
|
||||
temperature: Sampling temperature for generation
|
||||
top_p: Nucleus sampling top_p (None = omit, use provider default)
|
||||
max_tokens: Max tokens per generation (None for server default)
|
||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
||||
Used for OpenRouter provider preferences, transforms, etc.
|
||||
e.g. {"provider": {"ignore": ["DeepInfra"]}}
|
||||
budget_config: Tool result persistence budget. Controls per-tool
|
||||
thresholds, per-turn aggregate budget, and preview size.
|
||||
If None, uses DEFAULT_BUDGET (current hardcoded values).
|
||||
"""
|
||||
from tools.budget_config import DEFAULT_BUDGET
|
||||
self.server = server
|
||||
self.tool_schemas = tool_schemas
|
||||
self.valid_tool_names = valid_tool_names
|
||||
self.max_turns = max_turns
|
||||
self.task_id = task_id or str(uuid.uuid4())
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.max_tokens = max_tokens
|
||||
self.extra_body = extra_body
|
||||
self.budget_config = budget_config or DEFAULT_BUDGET
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
||||
"""
|
||||
@@ -203,6 +214,9 @@ class HermesAgentLoop:
|
||||
"temperature": self.temperature,
|
||||
}
|
||||
|
||||
if self.top_p is not None:
|
||||
chat_kwargs["top_p"] = self.top_p
|
||||
|
||||
# Only pass tools if we have them
|
||||
if self.tool_schemas:
|
||||
chat_kwargs["tools"] = self.tool_schemas
|
||||
@@ -217,20 +231,35 @@ class HermesAgentLoop:
|
||||
chat_kwargs["extra_body"] = self.extra_body
|
||||
|
||||
# Make the API call -- standard OpenAI spec
|
||||
# Retry on timeout/connection errors (provider queuing, rate limits)
|
||||
api_start = _time.monotonic()
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
response = None
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = await self.server.chat_completion(**chat_kwargs)
|
||||
break
|
||||
except Exception as e:
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
is_retryable = "timeout" in type(e).__name__.lower() or "connection" in type(e).__name__.lower()
|
||||
if is_retryable and attempt < max_retries - 1:
|
||||
wait = 2 ** attempt
|
||||
logger.warning(
|
||||
"[%s] API call timed out on turn %d attempt %d (%.1fs), retrying in %ds: %s",
|
||||
self.task_id[:8], turn + 1, attempt + 1, api_elapsed, wait, e,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
api_start = _time.monotonic()
|
||||
continue
|
||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
||||
return AgentResult(
|
||||
messages=messages,
|
||||
managed_state=self._get_managed_state(),
|
||||
turns_used=turn + 1,
|
||||
finished_naturally=False,
|
||||
reasoning_per_turn=reasoning_per_turn,
|
||||
tool_errors=tool_errors,
|
||||
)
|
||||
|
||||
api_elapsed = _time.monotonic() - api_start
|
||||
|
||||
@@ -446,8 +475,15 @@ class HermesAgentLoop:
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Add tool response to conversation
|
||||
tc_id = tc.get("id", "") if isinstance(tc, dict) else tc.id
|
||||
tool_result = maybe_persist_tool_result(
|
||||
content=tool_result,
|
||||
tool_name=tool_name,
|
||||
tool_use_id=tc_id,
|
||||
env=get_active_env(self.task_id),
|
||||
config=self.budget_config,
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
@@ -456,6 +492,14 @@ class HermesAgentLoop:
|
||||
}
|
||||
)
|
||||
|
||||
num_tcs = len(assistant_msg.tool_calls)
|
||||
if num_tcs > 0:
|
||||
enforce_turn_budget(
|
||||
messages[-num_tcs:],
|
||||
env=get_active_env(self.task_id),
|
||||
config=self.budget_config,
|
||||
)
|
||||
|
||||
turn_elapsed = _time.monotonic() - turn_start
|
||||
logger.info(
|
||||
"[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs",
|
||||
|
||||
@@ -1048,6 +1048,7 @@ class AgenticOPDEnv(HermesAgentBaseEnv):
|
||||
temperature=0.0,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
|
||||
@@ -15,15 +15,15 @@
|
||||
|
||||
env:
|
||||
enabled_toolsets: ["terminal", "file"]
|
||||
max_agent_turns: 60
|
||||
max_agent_turns: 100
|
||||
max_token_length: 32000
|
||||
agent_temperature: 0.8
|
||||
agent_temperature: 1.0
|
||||
terminal_backend: "modal"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2"
|
||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
||||
dataset_name: "NousResearch/terminal-bench-2-verified-flattened"
|
||||
test_timeout: 600
|
||||
task_timeout: 1800 # 30 min wall-clock per task, auto-FAIL if exceeded
|
||||
task_timeout: 900 # 15 min wall-clock per task, auto-FAIL if exceeded
|
||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
||||
use_wandb: true
|
||||
wandb_name: "terminal-bench-2"
|
||||
@@ -33,10 +33,15 @@ env:
|
||||
# Modal's blocking calls (App.lookup, etc.) deadlock when too many sandboxes
|
||||
# are created simultaneously inside thread pool workers via asyncio.run().
|
||||
max_concurrent_tasks: 8
|
||||
extra_body:
|
||||
provider:
|
||||
order: ["DeepInfra"]
|
||||
allow_fallbacks: false
|
||||
|
||||
openai:
|
||||
base_url: "https://openrouter.ai/api/v1"
|
||||
model_name: "anthropic/claude-opus-4.6"
|
||||
model_name: "nvidia/nemotron-3-super-120b-a12b"
|
||||
server_type: "openai"
|
||||
health_check: false
|
||||
timeout: 300 # 5 min per API call (default 1200s causes 20min stalls)
|
||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
||||
|
||||
@@ -32,8 +32,8 @@ export PYTHONUNBUFFERED=1
|
||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
||||
export LOGLEVEL=INFO
|
||||
|
||||
python terminalbench2_env.py evaluate \
|
||||
--config default.yaml \
|
||||
uv run python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
||||
--config environments/benchmarks/terminalbench_2/default.yaml \
|
||||
"$@" \
|
||||
2>&1 | tee "$LOG_FILE"
|
||||
|
||||
|
||||
@@ -52,18 +52,18 @@ _repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import EvalHandlingEnum
|
||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||||
from pydantic import Field
|
||||
|
||||
from agent.prompt_builder import DEFAULT_AGENT_IDENTITY
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.terminal_tool import (
|
||||
register_task_env_overrides,
|
||||
clear_task_env_overrides,
|
||||
cleanup_vm,
|
||||
clear_task_env_overrides,
|
||||
register_task_env_overrides,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -73,6 +73,7 @@ logger = logging.getLogger(__name__)
|
||||
# Configuration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TerminalBench2EvalConfig(HermesAgentEnvConfig):
|
||||
"""
|
||||
Configuration for the Terminal-Bench 2.0 evaluation environment.
|
||||
@@ -138,11 +139,27 @@ class TerminalBench2EvalConfig(HermesAgentEnvConfig):
|
||||
|
||||
# Tasks that cannot run properly on Modal and are excluded from scoring.
|
||||
MODAL_INCOMPATIBLE_TASKS = {
|
||||
"qemu-startup", # Needs KVM/hardware virtualization
|
||||
"qemu-alpine-ssh", # Needs KVM/hardware virtualization
|
||||
"crack-7z-hash", # Password brute-force -- too slow for cloud sandbox timeouts
|
||||
"qemu-startup", # Needs KVM/hardware virtualization
|
||||
"qemu-alpine-ssh", # Needs KVM/hardware virtualization
|
||||
"crack-7z-hash", # Password brute-force -- too slow for cloud sandbox timeouts
|
||||
}
|
||||
|
||||
# Injected as a user message when the model responds with plain text instead of
|
||||
# calling a tool or including a <task_status> tag.
|
||||
_FORMAT_NUDGE_MESSAGE = (
|
||||
"You wrote a plain text response instead of using your tools. "
|
||||
"Plain text responses do not affect the environment — nothing was executed or saved.\n\n"
|
||||
"You MUST use your tools (terminal, read_file, write_file) to actually complete the task. "
|
||||
"Do not describe what you would do — execute it now by making tool calls.\n\n"
|
||||
"If you have already completed all required work using tools in previous turns, "
|
||||
"respond with exactly: <task_status>DONE</task_status>\n"
|
||||
"If you have exhausted all approaches and cannot make further progress, "
|
||||
"respond with exactly: <task_status>UNFINISHED</task_status>"
|
||||
)
|
||||
|
||||
# Maximum number of format nudges before giving up and moving on to scoring.
|
||||
_MAX_FORMAT_NUDGES = 3
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tar extraction helper
|
||||
@@ -203,7 +220,6 @@ def _safe_extract_tar(tar: tarfile.TarFile, target_dir: Path) -> None:
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
"""Extract a base64-encoded tar.gz archive into target_dir."""
|
||||
if not b64_data:
|
||||
@@ -218,6 +234,7 @@ def _extract_base64_tar(b64_data: str, target_dir: Path):
|
||||
# Main Environment
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
Terminal-Bench 2.0 evaluation environment (eval-only, no training).
|
||||
@@ -262,23 +279,18 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
enabled_toolsets=["terminal", "file"],
|
||||
disabled_toolsets=None,
|
||||
distribution=None,
|
||||
|
||||
# Agent settings -- TB2 tasks are complex, need many turns
|
||||
max_agent_turns=60,
|
||||
max_token_length=16000,
|
||||
agent_temperature=0.6,
|
||||
system_prompt=None,
|
||||
|
||||
system_prompt=DEFAULT_AGENT_IDENTITY,
|
||||
# Modal backend for per-task cloud-isolated sandboxes
|
||||
terminal_backend="modal",
|
||||
terminal_timeout=300, # 5 min per command (builds, pip install, etc.)
|
||||
|
||||
terminal_timeout=300, # 5 min per command (builds, pip install, etc.)
|
||||
# Test execution timeout (TB2 test scripts can install deps like pytest)
|
||||
test_timeout=180,
|
||||
|
||||
# 89 tasks run in parallel, each needs a thread for tool calls
|
||||
tool_pool_size=128,
|
||||
|
||||
# --- Eval-only Atropos settings ---
|
||||
# These settings make the env work as an eval-only environment:
|
||||
# - STOP_TRAIN: pauses training during eval (standard for eval envs)
|
||||
@@ -288,7 +300,6 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
group_size=1,
|
||||
steps_per_eval=1,
|
||||
total_steps=1,
|
||||
|
||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
||||
use_wandb=True,
|
||||
wandb_name="terminal-bench-2",
|
||||
@@ -336,7 +347,11 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
|
||||
# Skip tasks incompatible with the current backend (e.g., QEMU on Modal)
|
||||
# plus any user-specified skip_tasks
|
||||
skip = set(MODAL_INCOMPATIBLE_TASKS) if self.config.terminal_backend == "modal" else set()
|
||||
skip = (
|
||||
set(MODAL_INCOMPATIBLE_TASKS)
|
||||
if self.config.terminal_backend == "modal"
|
||||
else set()
|
||||
)
|
||||
if self.config.skip_tasks:
|
||||
skip |= {name.strip() for name in self.config.skip_tasks.split(",")}
|
||||
if skip:
|
||||
@@ -344,7 +359,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
tasks = [t for t in tasks if t["task_name"] not in skip]
|
||||
skipped = before - len(tasks)
|
||||
if skipped > 0:
|
||||
print(f" Skipped {skipped} incompatible tasks: {sorted(skip & {t['task_name'] for t in ds})}")
|
||||
print(
|
||||
f" Skipped {skipped} incompatible tasks: {sorted(skip & {t['task_name'] for t in ds})}"
|
||||
)
|
||||
|
||||
self.all_eval_items = tasks
|
||||
self.iter = 0
|
||||
@@ -354,6 +371,16 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
for i, task in enumerate(self.all_eval_items):
|
||||
self.category_index[task.get("category", "unknown")].append(i)
|
||||
|
||||
# Pre-compute which tasks need Modal's add_python (avoids re-decoding
|
||||
# multi-MB environment_tar blobs during per-task rollouts).
|
||||
self._needs_add_python: Dict[str, bool] = {
|
||||
task["task_name"]: self._image_needs_add_python(task)
|
||||
for task in self.all_eval_items
|
||||
}
|
||||
add_py_count = sum(self._needs_add_python.values())
|
||||
if add_py_count:
|
||||
print(f" {add_py_count} tasks need add_python (non-python base image)")
|
||||
|
||||
# Reward tracking for wandb logging
|
||||
self.eval_metrics: List[Tuple[str, float]] = []
|
||||
|
||||
@@ -361,15 +388,30 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# immediately on completion so data is preserved even on Ctrl+C.
|
||||
# Timestamped filename so each run produces a unique file.
|
||||
import datetime
|
||||
|
||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
||||
model_name = self.server.servers[0].config.model_name
|
||||
model_slug = model_name.replace("/", "_").replace(":", "_")
|
||||
self._streaming_path = os.path.join(
|
||||
log_dir, f"samples_{run_ts}_{model_slug}.jsonl"
|
||||
)
|
||||
self._streaming_file = open(self._streaming_path, "w")
|
||||
self._streaming_lock = __import__("threading").Lock()
|
||||
self._run_meta = {
|
||||
"model_name": model_name,
|
||||
"temperature": self.config.agent_temperature,
|
||||
"top_p": self.config.agent_top_p,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
"task_timeout": self.config.task_timeout,
|
||||
"terminal_backend": self.config.terminal_backend,
|
||||
}
|
||||
print(f" Streaming results to: {self._streaming_path}")
|
||||
|
||||
print(f"TB2 ready: {len(self.all_eval_items)} tasks across {len(self.category_index)} categories")
|
||||
print(
|
||||
f"TB2 ready: {len(self.all_eval_items)} tasks across {len(self.category_index)} categories"
|
||||
)
|
||||
for cat, indices in sorted(self.category_index.items()):
|
||||
print(f" {cat}: {len(indices)} tasks")
|
||||
|
||||
@@ -378,7 +420,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
||||
return
|
||||
with self._streaming_lock:
|
||||
self._streaming_file.write(json.dumps(result, ensure_ascii=False, default=str) + "\n")
|
||||
self._streaming_file.write(
|
||||
json.dumps(result, ensure_ascii=False, default=str) + "\n"
|
||||
)
|
||||
self._streaming_file.flush()
|
||||
|
||||
# =========================================================================
|
||||
@@ -414,6 +458,36 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# Docker image resolution
|
||||
# =========================================================================
|
||||
|
||||
@staticmethod
|
||||
def _image_needs_add_python(item: Dict[str, Any]) -> bool:
|
||||
"""Check if the task's base image lacks `python` on PATH.
|
||||
|
||||
Parses the Dockerfile FROM line in environment_tar. Returns True
|
||||
for non-python base images (ubuntu, debian, etc.) that need
|
||||
Modal's add_python parameter.
|
||||
"""
|
||||
environment_tar = item.get("environment_tar", "")
|
||||
if not environment_tar:
|
||||
return False
|
||||
try:
|
||||
raw = base64.b64decode(environment_tar)
|
||||
buf = io.BytesIO(raw)
|
||||
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
|
||||
for member in tar:
|
||||
if not member.isfile() or "Dockerfile" not in member.name:
|
||||
continue
|
||||
f = tar.extractfile(member)
|
||||
if not f:
|
||||
continue
|
||||
for line in f.read().decode("utf-8", errors="ignore").splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped.upper().startswith("FROM "):
|
||||
base = stripped.split()[1].lower()
|
||||
return not base.startswith("python:")
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _resolve_task_image(
|
||||
self, item: Dict[str, Any], task_name: str
|
||||
) -> Tuple[str, Optional[Path]]:
|
||||
@@ -446,7 +520,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
if dockerfile_path.exists():
|
||||
logger.info(
|
||||
"Task %s: building from Dockerfile (force_build=%s, docker_image=%s)",
|
||||
task_name, self.config.force_build, bool(docker_image),
|
||||
task_name,
|
||||
self.config.force_build,
|
||||
bool(docker_image),
|
||||
)
|
||||
return str(dockerfile_path), task_dir
|
||||
|
||||
@@ -454,12 +530,80 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
if docker_image:
|
||||
logger.warning(
|
||||
"Task %s: force_build=True but no environment_tar, "
|
||||
"falling back to docker_image %s", task_name, docker_image,
|
||||
"falling back to docker_image %s",
|
||||
task_name,
|
||||
docker_image,
|
||||
)
|
||||
return docker_image, None
|
||||
|
||||
return "", None
|
||||
|
||||
# =========================================================================
|
||||
# Agent loop with format nudging
|
||||
# =========================================================================
|
||||
|
||||
async def _run_with_nudges(
|
||||
self,
|
||||
server,
|
||||
tools: List[Dict[str, Any]],
|
||||
valid_names: set,
|
||||
messages: List[Dict[str, Any]],
|
||||
task_id: str,
|
||||
task_name: str,
|
||||
) -> Tuple["AgentResult", int]:
|
||||
"""Run the agent loop, nudging if the model returns plain text without task_status tag."""
|
||||
total_turns_used = 0
|
||||
nudge_count = 0
|
||||
result = None
|
||||
|
||||
while total_turns_used < self.config.max_agent_turns:
|
||||
remaining = self.config.max_agent_turns - total_turns_used
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=remaining,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
total_turns_used += result.turns_used
|
||||
|
||||
if not result.finished_naturally:
|
||||
break
|
||||
|
||||
last_content = next(
|
||||
(
|
||||
m.get("content", "") or ""
|
||||
for m in reversed(messages)
|
||||
if m.get("role") == "assistant"
|
||||
),
|
||||
"",
|
||||
)
|
||||
if "<task_status>" in last_content:
|
||||
break
|
||||
|
||||
if nudge_count >= _MAX_FORMAT_NUDGES:
|
||||
logger.warning(
|
||||
"Task %s: model ignored %d format nudges; stopping.",
|
||||
task_name,
|
||||
nudge_count,
|
||||
)
|
||||
break
|
||||
nudge_count += 1
|
||||
logger.info(
|
||||
"Task %s: nudging model (nudge %d/%d) — no tool calls and no task_status",
|
||||
task_name,
|
||||
nudge_count,
|
||||
_MAX_FORMAT_NUDGES,
|
||||
)
|
||||
messages.append({"role": "user", "content": _FORMAT_NUDGE_MESSAGE})
|
||||
|
||||
return result, total_turns_used
|
||||
|
||||
# =========================================================================
|
||||
# Per-task evaluation -- agent loop + test verification
|
||||
# =========================================================================
|
||||
@@ -488,6 +632,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
task_dir = None # Set if we extract a Dockerfile (needs cleanup)
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
tqdm.write(f" [START] {task_name} (task_id={task_id[:8]})")
|
||||
task_start = time.time()
|
||||
|
||||
@@ -495,24 +640,32 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# --- 1. Resolve Docker image ---
|
||||
modal_image, task_dir = self._resolve_task_image(eval_item, task_name)
|
||||
if not modal_image:
|
||||
logger.error("Task %s: no docker_image or environment_tar, skipping", task_name)
|
||||
logger.error(
|
||||
"Task %s: no docker_image or environment_tar, skipping", task_name
|
||||
)
|
||||
return {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
"passed": False,
|
||||
"reward": 0.0,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"error": "no_image",
|
||||
}
|
||||
|
||||
# --- 2. Register per-task image override ---
|
||||
# Set both modal_image and docker_image so the task image is used
|
||||
# regardless of which backend is configured.
|
||||
register_task_env_overrides(task_id, {
|
||||
overrides = {
|
||||
"modal_image": modal_image,
|
||||
"docker_image": modal_image,
|
||||
"cwd": "/app",
|
||||
})
|
||||
}
|
||||
if self._needs_add_python.get(task_name, False):
|
||||
overrides["add_python"] = "3.12"
|
||||
register_task_env_overrides(task_id, overrides)
|
||||
logger.info(
|
||||
"Task %s: registered image override for task_id %s",
|
||||
task_name, task_id[:8],
|
||||
task_name,
|
||||
task_id[:8],
|
||||
)
|
||||
|
||||
# --- 3. Resolve tools and build messages ---
|
||||
@@ -520,51 +673,48 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append(
|
||||
{"role": "system", "content": self.config.system_prompt}
|
||||
)
|
||||
messages.append({"role": "user", "content": self.format_prompt(eval_item)})
|
||||
|
||||
# --- 4. Run agent loop ---
|
||||
# Use ManagedServer (Phase 2) for vLLM/SGLang backends to get
|
||||
# token-level tracking via /generate. Falls back to direct
|
||||
# ServerManager (Phase 1) for OpenAI endpoints.
|
||||
# --- 4. Run agent loop with format enforcement ---
|
||||
# The model must either call a tool or end with <task_status>DONE/UNFINISHED</task_status>.
|
||||
# If it returns plain text without the tag, inject a nudge user message and
|
||||
# continue with the remaining turn budget (up to _MAX_FORMAT_NUDGES times).
|
||||
if self._use_managed_server():
|
||||
async with self.server.managed_server(
|
||||
tokenizer=self.tokenizer,
|
||||
preserve_think_blocks=bool(self.config.thinking_mode),
|
||||
) as managed:
|
||||
agent = HermesAgentLoop(
|
||||
result, total_turns_used = await self._run_with_nudges(
|
||||
server=managed,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
tools=tools,
|
||||
valid_names=valid_names,
|
||||
messages=messages,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
task_name=task_name,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
agent = HermesAgentLoop(
|
||||
result, total_turns_used = await self._run_with_nudges(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
tools=tools,
|
||||
valid_names=valid_names,
|
||||
messages=messages,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
task_name=task_name,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# --- 5. Verify -- run test suite in the agent's sandbox ---
|
||||
# Skip verification if the agent produced no meaningful output
|
||||
only_system_and_user = all(
|
||||
msg.get("role") in ("system", "user") for msg in result.messages
|
||||
msg.get("role") in ("system", "user") for msg in messages
|
||||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
if total_turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Task %s: agent produced no output (turns=%d). Reward=0.",
|
||||
task_name, result.turns_used,
|
||||
task_name,
|
||||
total_turns_used,
|
||||
)
|
||||
reward = 0.0
|
||||
else:
|
||||
@@ -576,7 +726,10 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
loop = asyncio.get_event_loop()
|
||||
reward = await loop.run_in_executor(
|
||||
None, # default thread pool
|
||||
self._run_tests, eval_item, ctx, task_name,
|
||||
self._run_tests,
|
||||
eval_item,
|
||||
ctx,
|
||||
task_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Task %s: test verification failed: %s", task_name, e)
|
||||
@@ -587,20 +740,26 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
passed = reward == 1.0
|
||||
status = "PASS" if passed else "FAIL"
|
||||
elapsed = time.time() - task_start
|
||||
tqdm.write(f" [{status}] {task_name} (turns={result.turns_used}, {elapsed:.0f}s)")
|
||||
tqdm.write(
|
||||
f" [{status}] {task_name} (turns={total_turns_used}, {elapsed:.0f}s)"
|
||||
)
|
||||
logger.info(
|
||||
"Task %s: reward=%.1f, turns=%d, finished=%s",
|
||||
task_name, reward, result.turns_used, result.finished_naturally,
|
||||
task_name,
|
||||
reward,
|
||||
total_turns_used,
|
||||
result.finished_naturally,
|
||||
)
|
||||
|
||||
out = {
|
||||
**self._run_meta,
|
||||
"passed": passed,
|
||||
"reward": reward,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"turns_used": result.turns_used,
|
||||
"turns_used": total_turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"messages": result.messages,
|
||||
"messages": messages,
|
||||
}
|
||||
self._save_result(out)
|
||||
return out
|
||||
@@ -610,8 +769,11 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
logger.error("Task %s: rollout failed: %s", task_name, e, exc_info=True)
|
||||
tqdm.write(f" [ERROR] {task_name}: {e} ({elapsed:.0f}s)")
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
**self._run_meta,
|
||||
"passed": False,
|
||||
"reward": 0.0,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"error": str(e),
|
||||
}
|
||||
self._save_result(out)
|
||||
@@ -684,7 +846,8 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# Execute the test suite
|
||||
logger.info(
|
||||
"Task %s: running test suite (timeout=%ds)",
|
||||
task_name, self.config.test_timeout,
|
||||
task_name,
|
||||
self.config.test_timeout,
|
||||
)
|
||||
test_result = ctx.terminal(
|
||||
"bash /tests/test.sh",
|
||||
@@ -717,7 +880,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
logger.warning(
|
||||
"Task %s: reward.txt content unexpected (%r), "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, content, exit_code,
|
||||
task_name,
|
||||
content,
|
||||
exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
else:
|
||||
@@ -725,14 +890,17 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
logger.warning(
|
||||
"Task %s: reward.txt not found after download, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, exit_code,
|
||||
task_name,
|
||||
exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Task %s: failed to download verifier dir: %s, "
|
||||
"falling back to exit_code=%d",
|
||||
task_name, e, exit_code,
|
||||
task_name,
|
||||
e,
|
||||
exit_code,
|
||||
)
|
||||
reward = 1.0 if exit_code == 0 else 0.0
|
||||
finally:
|
||||
@@ -743,7 +911,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
output_preview = output[-500:] if output else "(no output)"
|
||||
logger.info(
|
||||
"Task %s: FAIL (exit_code=%d)\n%s",
|
||||
task_name, exit_code, output_preview,
|
||||
task_name,
|
||||
exit_code,
|
||||
output_preview,
|
||||
)
|
||||
|
||||
return reward
|
||||
@@ -768,12 +938,18 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
from tqdm import tqdm
|
||||
|
||||
elapsed = self.config.task_timeout
|
||||
tqdm.write(f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)")
|
||||
tqdm.write(
|
||||
f" [TIMEOUT] {task_name} (exceeded {elapsed}s wall-clock limit)"
|
||||
)
|
||||
logger.error("Task %s: wall-clock timeout after %ds", task_name, elapsed)
|
||||
out = {
|
||||
"passed": False, "reward": 0.0,
|
||||
"task_name": task_name, "category": category,
|
||||
**self._run_meta,
|
||||
"passed": False,
|
||||
"reward": 0.0,
|
||||
"task_name": task_name,
|
||||
"category": category,
|
||||
"error": f"timeout ({elapsed}s)",
|
||||
}
|
||||
self._save_result(out)
|
||||
@@ -807,23 +983,25 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
self.handleError(record)
|
||||
|
||||
handler = _TqdmHandler()
|
||||
handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
))
|
||||
handler.setFormatter(
|
||||
logging.Formatter(
|
||||
"%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
)
|
||||
root = logging.getLogger()
|
||||
root.handlers = [handler] # Replace any existing handlers
|
||||
root.setLevel(logging.INFO)
|
||||
|
||||
# Silence noisy third-party loggers that flood the output
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request
|
||||
logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries
|
||||
logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING) # Every HTTP request
|
||||
logging.getLogger("openai").setLevel(logging.WARNING) # OpenAI client retries
|
||||
logging.getLogger("rex-deploy").setLevel(logging.WARNING) # Swerex deployment
|
||||
logging.getLogger("rex_image_builder").setLevel(logging.WARNING) # Image builds
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Starting Terminal-Bench 2.0 Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f" Dataset: {self.config.dataset_name}")
|
||||
print(f" Total tasks: {len(self.all_eval_items)}")
|
||||
print(f" Max agent turns: {self.config.max_agent_turns}")
|
||||
@@ -831,9 +1009,11 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
print(f" Terminal backend: {self.config.terminal_backend}")
|
||||
print(f" Tool thread pool: {self.config.tool_pool_size}")
|
||||
print(f" Terminal timeout: {self.config.terminal_timeout}s/cmd")
|
||||
print(f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)")
|
||||
print(
|
||||
f" Terminal lifetime: {self.config.terminal_lifetime}s (auto: task_timeout + 120)"
|
||||
)
|
||||
print(f" Max concurrent tasks: {self.config.max_concurrent_tasks}")
|
||||
print(f"{'='*60}\n")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# Semaphore to limit concurrent Modal sandbox creations.
|
||||
# Without this, all 86 tasks fire simultaneously, each creating a Modal
|
||||
@@ -875,6 +1055,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
await asyncio.gather(*eval_tasks, return_exceptions=True)
|
||||
# Belt-and-suspenders: clean up any remaining sandboxes
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
|
||||
cleanup_all_environments()
|
||||
print("All sandboxes cleaned up.")
|
||||
return
|
||||
@@ -920,9 +1101,9 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
self.eval_metrics = [(k, v) for k, v in eval_metrics.items()]
|
||||
|
||||
# ---- Print summary ----
|
||||
print(f"\n{'='*60}")
|
||||
print(f"\n{'=' * 60}")
|
||||
print("Terminal-Bench 2.0 Evaluation Results")
|
||||
print(f"{'='*60}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"Overall Pass Rate: {overall_pass_rate:.4f} ({passed}/{total})")
|
||||
print(f"Evaluation Time: {end_time - start_time:.1f} seconds")
|
||||
|
||||
@@ -942,7 +1123,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
extra = f" (error: {error})" if error else ""
|
||||
print(f" [{status}] {r['task_name']} (turns={turns}){extra}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# Build sample records for evaluate_log (includes full conversations)
|
||||
samples = [
|
||||
@@ -967,6 +1148,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
"temperature": self.config.agent_temperature,
|
||||
"top_p": self.config.agent_top_p,
|
||||
"max_tokens": self.config.max_token_length,
|
||||
"max_agent_turns": self.config.max_agent_turns,
|
||||
"terminal_backend": self.config.terminal_backend,
|
||||
@@ -983,6 +1165,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# Kill all remaining sandboxes. Timed-out tasks leave orphaned thread
|
||||
# pool workers still executing commands -- cleanup_all stops them.
|
||||
from tools.terminal_tool import cleanup_all_environments
|
||||
|
||||
print("\nCleaning up all sandboxes...")
|
||||
cleanup_all_environments()
|
||||
|
||||
@@ -990,6 +1173,7 @@ class TerminalBench2EvalEnv(HermesAgentBaseEnv):
|
||||
# tasks are killed immediately instead of retrying against dead
|
||||
# sandboxes and spamming the console with TimeoutError warnings.
|
||||
from environments.agent_loop import _tool_executor
|
||||
|
||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
||||
print("Done.")
|
||||
|
||||
|
||||
@@ -549,6 +549,7 @@ class YCBenchEvalEnv(HermesAgentBaseEnv):
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
|
||||
@@ -62,6 +62,11 @@ from atroposlib.type_definitions import Item
|
||||
|
||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
||||
from environments.tool_context import ToolContext
|
||||
from tools.budget_config import (
|
||||
DEFAULT_RESULT_SIZE_CHARS,
|
||||
DEFAULT_TURN_BUDGET_CHARS,
|
||||
DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
)
|
||||
|
||||
# Import hermes-agent toolset infrastructure
|
||||
from model_tools import get_tool_definitions
|
||||
@@ -110,6 +115,10 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
default=1.0,
|
||||
description="Sampling temperature for agent generation during rollouts.",
|
||||
)
|
||||
agent_top_p: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Nucleus sampling top_p for agent generation. None = provider default.",
|
||||
)
|
||||
|
||||
# --- Terminal backend ---
|
||||
terminal_backend: str = Field(
|
||||
@@ -160,6 +169,32 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
||||
)
|
||||
|
||||
# --- Tool result budget ---
|
||||
# Defaults imported from tools.budget_config (single source of truth).
|
||||
default_result_size_chars: int = Field(
|
||||
default=DEFAULT_RESULT_SIZE_CHARS,
|
||||
description="Default per-tool threshold (chars) for persisting large results "
|
||||
"to sandbox. Results exceeding this are written to /tmp/hermes-results/ "
|
||||
"and replaced with a preview. Per-tool registry values take precedence "
|
||||
"unless overridden via tool_result_overrides.",
|
||||
)
|
||||
turn_budget_chars: int = Field(
|
||||
default=DEFAULT_TURN_BUDGET_CHARS,
|
||||
description="Aggregate char budget per assistant turn. If all tool results "
|
||||
"in a single turn exceed this, the largest are persisted to disk first.",
|
||||
)
|
||||
preview_size_chars: int = Field(
|
||||
default=DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
description="Size of the inline preview shown after a tool result is persisted.",
|
||||
)
|
||||
tool_result_overrides: Optional[Dict[str, int]] = Field(
|
||||
default=None,
|
||||
description="Per-tool threshold overrides (chars). Keys are tool names, "
|
||||
"values are char thresholds. Overrides both the default and registry "
|
||||
"per-tool values. Example: {'terminal': 10000, 'search_files': 5000}. "
|
||||
"Note: read_file is pinned to infinity and cannot be overridden.",
|
||||
)
|
||||
|
||||
# --- Provider-specific parameters ---
|
||||
# Passed as extra_body to the OpenAI client's chat.completions.create() call.
|
||||
# Useful for OpenRouter provider preferences, transforms, route settings, etc.
|
||||
@@ -176,6 +211,16 @@ class HermesAgentEnvConfig(BaseEnvConfig):
|
||||
"transforms, and other provider-specific settings.",
|
||||
)
|
||||
|
||||
def build_budget_config(self):
|
||||
"""Build a BudgetConfig from env config fields."""
|
||||
from tools.budget_config import BudgetConfig
|
||||
return BudgetConfig(
|
||||
default_result_size=self.default_result_size_chars,
|
||||
turn_budget=self.turn_budget_chars,
|
||||
preview_size=self.preview_size_chars,
|
||||
tool_overrides=dict(self.tool_result_overrides) if self.tool_result_overrides else {},
|
||||
)
|
||||
|
||||
|
||||
class HermesAgentBaseEnv(BaseEnv):
|
||||
"""
|
||||
@@ -488,8 +533,10 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
except NotImplementedError:
|
||||
@@ -505,8 +552,10 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
else:
|
||||
@@ -518,8 +567,10 @@ class HermesAgentBaseEnv(BaseEnv):
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
top_p=self.config.agent_top_p,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
|
||||
@@ -472,6 +472,7 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||
temperature=0.0, # Deterministic for eval
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
budget_config=self.config.build_budget_config(),
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
|
||||
@@ -712,6 +712,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
name=os.getenv("DISCORD_HOME_CHANNEL_NAME", "Home"),
|
||||
)
|
||||
|
||||
# Reply threading mode for Discord (off/first/all)
|
||||
discord_reply_mode = os.getenv("DISCORD_REPLY_TO_MODE", "").lower()
|
||||
if discord_reply_mode in ("off", "first", "all"):
|
||||
if Platform.DISCORD not in config.platforms:
|
||||
config.platforms[Platform.DISCORD] = PlatformConfig()
|
||||
config.platforms[Platform.DISCORD].reply_to_mode = discord_reply_mode
|
||||
|
||||
# WhatsApp (typically uses different auth mechanism)
|
||||
whatsapp_enabled = os.getenv("WHATSAPP_ENABLED", "").lower() in ("true", "1", "yes")
|
||||
if whatsapp_enabled:
|
||||
|
||||
@@ -455,6 +455,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
self._seen_messages: Dict[str, float] = {}
|
||||
self._SEEN_TTL = 300 # 5 minutes
|
||||
self._SEEN_MAX = 2000 # prune threshold
|
||||
# Reply threading mode: "off" (no replies), "first" (reply on first
|
||||
# chunk only, default), "all" (reply-reference on every chunk).
|
||||
self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first'
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to Discord and start receiving events."""
|
||||
@@ -774,7 +777,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
message_ids = []
|
||||
reference = None
|
||||
|
||||
if reply_to:
|
||||
if reply_to and self._reply_to_mode != "off":
|
||||
try:
|
||||
ref_msg = await channel.fetch_message(int(reply_to))
|
||||
reference = ref_msg
|
||||
@@ -782,7 +785,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
logger.debug("Could not fetch reply-to message: %s", e)
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
chunk_reference = reference if i == 0 else None
|
||||
if self._reply_to_mode == "all":
|
||||
chunk_reference = reference
|
||||
else: # "first" (default) or "off"
|
||||
chunk_reference = reference if i == 0 else None
|
||||
try:
|
||||
msg = await channel.send(
|
||||
content=chunk,
|
||||
|
||||
@@ -20,6 +20,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
@@ -1052,6 +1053,9 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
self._media_batch_state = FeishuBatchState()
|
||||
self._pending_media_batches = self._media_batch_state.events
|
||||
self._pending_media_batch_tasks = self._media_batch_state.tasks
|
||||
# Exec approval button state (approval_id → {session_key, message_id, chat_id})
|
||||
self._approval_state: Dict[int, Dict[str, str]] = {}
|
||||
self._approval_counter = itertools.count(1)
|
||||
self._load_seen_message_ids()
|
||||
|
||||
@staticmethod
|
||||
@@ -1394,6 +1398,104 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
logger.error("[Feishu] Failed to edit message %s: %s", message_id, exc, exc_info=True)
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
async def send_exec_approval(
|
||||
self, chat_id: str, command: str, session_key: str,
|
||||
description: str = "dangerous command",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send an interactive card with approval buttons.
|
||||
|
||||
The buttons carry ``hermes_action`` in their value dict so that
|
||||
``_handle_card_action_event`` can intercept them and call
|
||||
``resolve_gateway_approval()`` to unblock the waiting agent thread.
|
||||
"""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
||||
try:
|
||||
approval_id = next(self._approval_counter)
|
||||
cmd_preview = command[:3000] + "..." if len(command) > 3000 else command
|
||||
|
||||
def _btn(label: str, action_name: str, btn_type: str = "default") -> dict:
|
||||
return {
|
||||
"tag": "button",
|
||||
"text": {"tag": "plain_text", "content": label},
|
||||
"type": btn_type,
|
||||
"value": {"hermes_action": action_name, "approval_id": approval_id},
|
||||
}
|
||||
|
||||
card = {
|
||||
"config": {"wide_screen_mode": True},
|
||||
"header": {
|
||||
"title": {"content": "⚠️ Command Approval Required", "tag": "plain_text"},
|
||||
"template": "orange",
|
||||
},
|
||||
"elements": [
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": f"```\n{cmd_preview}\n```\n**Reason:** {description}",
|
||||
},
|
||||
{
|
||||
"tag": "action",
|
||||
"actions": [
|
||||
_btn("✅ Allow Once", "approve_once", "primary"),
|
||||
_btn("✅ Session", "approve_session"),
|
||||
_btn("✅ Always", "approve_always"),
|
||||
_btn("❌ Deny", "deny", "danger"),
|
||||
],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
payload = json.dumps(card, ensure_ascii=False)
|
||||
response = await self._feishu_send_with_retry(
|
||||
chat_id=chat_id,
|
||||
msg_type="interactive",
|
||||
payload=payload,
|
||||
reply_to=None,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
result = self._finalize_send_result(response, "send_exec_approval failed")
|
||||
if result.success:
|
||||
self._approval_state[approval_id] = {
|
||||
"session_key": session_key,
|
||||
"message_id": result.message_id or "",
|
||||
"chat_id": chat_id,
|
||||
}
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("[Feishu] send_exec_approval failed: %s", exc)
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
async def _update_approval_card(
|
||||
self, message_id: str, label: str, user_name: str, choice: str,
|
||||
) -> None:
|
||||
"""Replace the approval card with a resolved status card."""
|
||||
if not self._client or not message_id:
|
||||
return
|
||||
icon = "❌" if choice == "deny" else "✅"
|
||||
card = {
|
||||
"config": {"wide_screen_mode": True},
|
||||
"header": {
|
||||
"title": {"content": f"{icon} {label}", "tag": "plain_text"},
|
||||
"template": "red" if choice == "deny" else "green",
|
||||
},
|
||||
"elements": [
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": f"{icon} **{label}** by {user_name}",
|
||||
},
|
||||
],
|
||||
}
|
||||
try:
|
||||
payload = json.dumps(card, ensure_ascii=False)
|
||||
body = self._build_update_message_body(msg_type="interactive", content=payload)
|
||||
request = self._build_update_message_request(message_id=message_id, request_body=body)
|
||||
await asyncio.to_thread(self._client.im.v1.message.update, request)
|
||||
except Exception as exc:
|
||||
logger.warning("[Feishu] Failed to update approval card %s: %s", message_id, exc)
|
||||
|
||||
async def send_voice(
|
||||
self,
|
||||
chat_id: str,
|
||||
@@ -1820,6 +1922,52 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||
action = getattr(event, "action", None)
|
||||
action_tag = str(getattr(action, "tag", "") or "button")
|
||||
action_value = getattr(action, "value", {}) or {}
|
||||
|
||||
# --- Exec approval button intercept ---
|
||||
hermes_action = action_value.get("hermes_action") if isinstance(action_value, dict) else None
|
||||
if hermes_action:
|
||||
approval_id = action_value.get("approval_id")
|
||||
state = self._approval_state.pop(approval_id, None)
|
||||
if not state:
|
||||
logger.debug("[Feishu] Approval %s already resolved or unknown", approval_id)
|
||||
return
|
||||
|
||||
choice_map = {
|
||||
"approve_once": "once",
|
||||
"approve_session": "session",
|
||||
"approve_always": "always",
|
||||
"deny": "deny",
|
||||
}
|
||||
choice = choice_map.get(hermes_action, "deny")
|
||||
|
||||
label_map = {
|
||||
"once": "Approved once",
|
||||
"session": "Approved for session",
|
||||
"always": "Approved permanently",
|
||||
"deny": "Denied",
|
||||
}
|
||||
label = label_map.get(choice, "Resolved")
|
||||
|
||||
# Resolve sender name for the status card
|
||||
sender_id = SimpleNamespace(open_id=open_id, user_id=None, union_id=None)
|
||||
sender_profile = await self._resolve_sender_profile(sender_id)
|
||||
user_name = sender_profile.get("user_name") or open_id
|
||||
|
||||
# Resolve the approval — unblocks the agent thread
|
||||
try:
|
||||
from tools.approval import resolve_gateway_approval
|
||||
count = resolve_gateway_approval(state["session_key"], choice)
|
||||
logger.info(
|
||||
"Feishu button resolved %d approval(s) for session %s (choice=%s, user=%s)",
|
||||
count, state["session_key"], choice, user_name,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to resolve gateway approval from Feishu button: %s", exc)
|
||||
|
||||
# Update the card to show the decision
|
||||
await self._update_approval_card(state.get("message_id", ""), label, user_name, choice)
|
||||
return
|
||||
|
||||
synthetic_text = f"/card {action_tag}"
|
||||
if action_value:
|
||||
try:
|
||||
|
||||
@@ -647,7 +647,11 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
|
||||
if result is not None:
|
||||
self._track_sent_timestamp(result)
|
||||
return SendResult(success=True)
|
||||
# Use the timestamp from the RPC result as a pseudo message_id.
|
||||
# Signal doesn't have real message IDs, but the stream consumer
|
||||
# needs a truthy value to follow its edit→fallback path correctly.
|
||||
_msg_id = str(result.get("timestamp", "")) if isinstance(result, dict) else None
|
||||
return SendResult(success=True, message_id=_msg_id or None)
|
||||
return SendResult(success=False, error="RPC send failed")
|
||||
|
||||
def _track_sent_timestamp(self, rpc_result) -> None:
|
||||
@@ -837,6 +841,11 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def stop_typing(self, chat_id: str) -> None:
|
||||
"""Public interface for stopping typing — called by base adapter's
|
||||
_keep_typing finally block to clean up platform-level typing tasks."""
|
||||
await self._stop_typing_indicator(chat_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Chat Info
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
+41
-10
@@ -921,12 +921,11 @@ class GatewayRunner:
|
||||
|
||||
@staticmethod
|
||||
def _load_reasoning_config() -> dict | None:
|
||||
"""Load reasoning effort from config with env fallback.
|
||||
"""Load reasoning effort from config.yaml.
|
||||
|
||||
Checks agent.reasoning_effort in config.yaml first, then
|
||||
HERMES_REASONING_EFFORT as a fallback. Valid: "xhigh", "high",
|
||||
"medium", "low", "minimal", "none". Returns None to use default
|
||||
(medium).
|
||||
Reads agent.reasoning_effort from config.yaml. Valid: "xhigh",
|
||||
"high", "medium", "low", "minimal", "none". Returns None to use
|
||||
default (medium).
|
||||
"""
|
||||
from hermes_constants import parse_reasoning_effort
|
||||
effort = ""
|
||||
@@ -939,8 +938,6 @@ class GatewayRunner:
|
||||
effort = str(cfg.get("agent", {}).get("reasoning_effort", "") or "").strip()
|
||||
except Exception:
|
||||
pass
|
||||
if not effort:
|
||||
effort = os.getenv("HERMES_REASONING_EFFORT", "")
|
||||
result = parse_reasoning_effort(effort)
|
||||
if effort and effort.strip() and result is None:
|
||||
logger.warning("Unknown reasoning_effort '%s', using default (medium)", effort)
|
||||
@@ -1484,6 +1481,14 @@ class GatewayRunner:
|
||||
logger.debug("Interrupted running agent for session %s during shutdown", session_key[:20])
|
||||
except Exception as e:
|
||||
logger.debug("Failed interrupting agent during shutdown: %s", e)
|
||||
# Fire plugin on_session_finalize hook before memory shutdown
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_invoke_hook("on_session_finalize",
|
||||
session_id=getattr(agent, 'session_id', None),
|
||||
platform="gateway")
|
||||
except Exception:
|
||||
pass
|
||||
# Shut down memory provider at actual session boundary
|
||||
try:
|
||||
if hasattr(agent, 'shutdown_memory_provider'):
|
||||
@@ -3277,6 +3282,15 @@ class GatewayRunner:
|
||||
# the configured default instead of the previously switched model.
|
||||
self._session_model_overrides.pop(session_key, None)
|
||||
|
||||
# Fire plugin on_session_finalize hook (session boundary)
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_old_sid = old_entry.session_id if old_entry else None
|
||||
_invoke_hook("on_session_finalize", session_id=_old_sid,
|
||||
platform=source.platform.value if source.platform else "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Emit session:end hook (session is ending)
|
||||
await self.hooks.emit("session:end", {
|
||||
"platform": source.platform.value if source.platform else "",
|
||||
@@ -3290,7 +3304,7 @@ class GatewayRunner:
|
||||
"user_id": source.user_id,
|
||||
"session_key": session_key,
|
||||
})
|
||||
|
||||
|
||||
# Resolve session config info to surface to the user
|
||||
try:
|
||||
session_info = self._format_session_info()
|
||||
@@ -3301,9 +3315,18 @@ class GatewayRunner:
|
||||
header = "✨ Session reset! Starting fresh."
|
||||
else:
|
||||
# No existing session, just create one
|
||||
self.session_store.get_or_create_session(source, force_new=True)
|
||||
new_entry = self.session_store.get_or_create_session(source, force_new=True)
|
||||
header = "✨ New session started!"
|
||||
|
||||
# Fire plugin on_session_reset hook (new session guaranteed to exist)
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_new_sid = new_entry.session_id if new_entry else None
|
||||
_invoke_hook("on_session_reset", session_id=_new_sid,
|
||||
platform=source.platform.value if source.platform else "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if session_info:
|
||||
return f"{header}\n\n{session_info}"
|
||||
return header
|
||||
@@ -6285,7 +6308,15 @@ class GatewayRunner:
|
||||
# Falls back to env vars for backward compatibility.
|
||||
# YAML 1.1 parses bare `off` as boolean False — normalise before
|
||||
# the `or` chain so it doesn't silently fall through to "all".
|
||||
_raw_tp = user_config.get("display", {}).get("tool_progress")
|
||||
#
|
||||
# Per-platform overrides (display.tool_progress_overrides) take
|
||||
# priority over the global setting — e.g. Signal users can set
|
||||
# tool_progress to "off" while keeping Telegram on "all".
|
||||
_display_cfg = user_config.get("display", {})
|
||||
_overrides = _display_cfg.get("tool_progress_overrides", {})
|
||||
_raw_tp = _overrides.get(platform_key)
|
||||
if _raw_tp is None:
|
||||
_raw_tp = _display_cfg.get("tool_progress")
|
||||
if _raw_tp is False:
|
||||
_raw_tp = "off"
|
||||
progress_mode = (
|
||||
|
||||
+119
-7
@@ -74,6 +74,8 @@ class GatewayStreamConsumer:
|
||||
self._edit_supported = True # Disabled on first edit failure (Signal/Email/HA)
|
||||
self._last_edit_time = 0.0
|
||||
self._last_sent_text = "" # Track last-sent text to skip redundant edits
|
||||
self._fallback_final_send = False
|
||||
self._fallback_prefix = ""
|
||||
|
||||
@property
|
||||
def already_sent(self) -> bool:
|
||||
@@ -138,12 +140,19 @@ class GatewayStreamConsumer:
|
||||
while (
|
||||
len(self._accumulated) > _safe_limit
|
||||
and self._message_id is not None
|
||||
and self._edit_supported
|
||||
):
|
||||
split_at = self._accumulated.rfind("\n", 0, _safe_limit)
|
||||
if split_at < _safe_limit // 2:
|
||||
split_at = _safe_limit
|
||||
chunk = self._accumulated[:split_at]
|
||||
await self._send_or_edit(chunk)
|
||||
if self._fallback_final_send:
|
||||
# Edit failed while attempting to split an oversized
|
||||
# message. Keep the full accumulated text intact so
|
||||
# the fallback final-send path can deliver the
|
||||
# remaining continuation without dropping content.
|
||||
break
|
||||
self._accumulated = self._accumulated[split_at:].lstrip("\n")
|
||||
self._message_id = None
|
||||
self._last_sent_text = ""
|
||||
@@ -156,9 +165,17 @@ class GatewayStreamConsumer:
|
||||
self._last_edit_time = time.monotonic()
|
||||
|
||||
if got_done:
|
||||
# Final edit without cursor
|
||||
if self._accumulated and self._message_id:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
# Final edit without cursor. If progressive editing failed
|
||||
# mid-stream, send a single continuation/fallback message
|
||||
# here instead of letting the base gateway path send the
|
||||
# full response again.
|
||||
if self._accumulated:
|
||||
if self._fallback_final_send:
|
||||
await self._send_fallback_final(self._accumulated)
|
||||
elif self._message_id:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
elif not self._already_sent:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
return
|
||||
|
||||
# Tool boundary: the should_edit block above already flushed
|
||||
@@ -169,6 +186,8 @@ class GatewayStreamConsumer:
|
||||
self._message_id = None
|
||||
self._accumulated = ""
|
||||
self._last_sent_text = ""
|
||||
self._fallback_final_send = False
|
||||
self._fallback_prefix = ""
|
||||
|
||||
await asyncio.sleep(0.05) # Small yield to not busy-loop
|
||||
|
||||
@@ -207,6 +226,86 @@ class GatewayStreamConsumer:
|
||||
# Strip trailing whitespace/newlines but preserve leading content
|
||||
return cleaned.rstrip()
|
||||
|
||||
def _visible_prefix(self) -> str:
|
||||
"""Return the visible text already shown in the streamed message."""
|
||||
prefix = self._last_sent_text or ""
|
||||
if self.cfg.cursor and prefix.endswith(self.cfg.cursor):
|
||||
prefix = prefix[:-len(self.cfg.cursor)]
|
||||
return self._clean_for_display(prefix)
|
||||
|
||||
def _continuation_text(self, final_text: str) -> str:
|
||||
"""Return only the part of final_text the user has not already seen."""
|
||||
prefix = self._fallback_prefix or self._visible_prefix()
|
||||
if prefix and final_text.startswith(prefix):
|
||||
return final_text[len(prefix):].lstrip()
|
||||
return final_text
|
||||
|
||||
@staticmethod
|
||||
def _split_text_chunks(text: str, limit: int) -> list[str]:
|
||||
"""Split text into reasonably sized chunks for fallback sends."""
|
||||
if len(text) <= limit:
|
||||
return [text]
|
||||
chunks: list[str] = []
|
||||
remaining = text
|
||||
while len(remaining) > limit:
|
||||
split_at = remaining.rfind("\n", 0, limit)
|
||||
if split_at < limit // 2:
|
||||
split_at = limit
|
||||
chunks.append(remaining[:split_at])
|
||||
remaining = remaining[split_at:].lstrip("\n")
|
||||
if remaining:
|
||||
chunks.append(remaining)
|
||||
return chunks
|
||||
|
||||
async def _send_fallback_final(self, text: str) -> None:
|
||||
"""Send the final continuation after streaming edits stop working."""
|
||||
final_text = self._clean_for_display(text)
|
||||
continuation = self._continuation_text(final_text)
|
||||
self._fallback_final_send = False
|
||||
if not continuation.strip():
|
||||
# Nothing new to send — the visible partial already matches final text.
|
||||
self._already_sent = True
|
||||
return
|
||||
|
||||
raw_limit = getattr(self.adapter, "MAX_MESSAGE_LENGTH", 4096)
|
||||
safe_limit = max(500, raw_limit - 100)
|
||||
chunks = self._split_text_chunks(continuation, safe_limit)
|
||||
|
||||
last_message_id: Optional[str] = None
|
||||
last_successful_chunk = ""
|
||||
sent_any_chunk = False
|
||||
for chunk in chunks:
|
||||
result = await self.adapter.send(
|
||||
chat_id=self.chat_id,
|
||||
content=chunk,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
if not result.success:
|
||||
if sent_any_chunk:
|
||||
# Some continuation text already reached the user. Suppress
|
||||
# the base gateway final-send path so we don't resend the
|
||||
# full response and create another duplicate.
|
||||
self._already_sent = True
|
||||
self._message_id = last_message_id
|
||||
self._last_sent_text = last_successful_chunk
|
||||
self._fallback_prefix = ""
|
||||
return
|
||||
# No fallback chunk reached the user — allow the normal gateway
|
||||
# final-send path to try one more time.
|
||||
self._already_sent = False
|
||||
self._message_id = None
|
||||
self._last_sent_text = ""
|
||||
self._fallback_prefix = ""
|
||||
return
|
||||
sent_any_chunk = True
|
||||
last_successful_chunk = chunk
|
||||
last_message_id = result.message_id or last_message_id
|
||||
|
||||
self._message_id = last_message_id
|
||||
self._already_sent = True
|
||||
self._last_sent_text = chunks[-1]
|
||||
self._fallback_prefix = ""
|
||||
|
||||
async def _send_or_edit(self, text: str) -> None:
|
||||
"""Send or edit the streaming message."""
|
||||
# Strip MEDIA: directives so they don't appear as visible text.
|
||||
@@ -232,14 +331,16 @@ class GatewayStreamConsumer:
|
||||
self._last_sent_text = text
|
||||
else:
|
||||
# If an edit fails mid-stream (especially Telegram flood control),
|
||||
# stop progressive edits and let the normal final send path deliver
|
||||
# the complete answer instead of leaving the user with a partial.
|
||||
# stop progressive edits and send only the missing tail once the
|
||||
# final response is available.
|
||||
logger.debug("Edit failed, disabling streaming for this adapter")
|
||||
self._fallback_prefix = self._visible_prefix()
|
||||
self._fallback_final_send = True
|
||||
self._edit_supported = False
|
||||
self._already_sent = False
|
||||
self._already_sent = True
|
||||
else:
|
||||
# Editing not supported — skip intermediate updates.
|
||||
# The final response will be sent by the normal path.
|
||||
# The final response will be sent by the fallback path.
|
||||
pass
|
||||
else:
|
||||
# First message — send new
|
||||
@@ -252,6 +353,17 @@ class GatewayStreamConsumer:
|
||||
self._message_id = result.message_id
|
||||
self._already_sent = True
|
||||
self._last_sent_text = text
|
||||
elif result.success:
|
||||
# Platform accepted the message but returned no message_id
|
||||
# (e.g. Signal). Can't edit without an ID — switch to
|
||||
# fallback mode: suppress intermediate deltas, send only
|
||||
# the missing tail once the final response is ready.
|
||||
self._already_sent = True
|
||||
self._edit_supported = False
|
||||
self._fallback_prefix = self._clean_for_display(text)
|
||||
self._fallback_final_send = True
|
||||
# Sentinel prevents re-entering this branch on every delta
|
||||
self._message_id = "__no_edit__"
|
||||
else:
|
||||
# Initial send failed — disable streaming for this session
|
||||
self._edit_supported = False
|
||||
|
||||
@@ -11,5 +11,5 @@ Provides subcommands for:
|
||||
- hermes cron - Manage cron jobs
|
||||
"""
|
||||
|
||||
__version__ = "0.7.0"
|
||||
__release_date__ = "2026.4.3"
|
||||
__version__ = "0.8.0"
|
||||
__release_date__ = "2026.4.8"
|
||||
|
||||
@@ -67,12 +67,16 @@ DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
|
||||
DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s
|
||||
DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
DEFAULT_QWEN_BASE_URL = "https://portal.qwen.ai/v1"
|
||||
DEFAULT_GITHUB_MODELS_BASE_URL = "https://api.githubcopilot.com"
|
||||
DEFAULT_COPILOT_ACP_BASE_URL = "acp://copilot"
|
||||
DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
CODEX_OAUTH_CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
CODEX_OAUTH_TOKEN_URL = "https://auth.openai.com/oauth/token"
|
||||
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
QWEN_OAUTH_CLIENT_ID = "f0304373b74a44d2b584a3fb70ca9e56"
|
||||
QWEN_OAUTH_TOKEN_URL = "https://chat.qwen.ai/api/v1/oauth2/token"
|
||||
QWEN_ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -112,6 +116,12 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
auth_type="oauth_external",
|
||||
inference_base_url=DEFAULT_CODEX_BASE_URL,
|
||||
),
|
||||
"qwen-oauth": ProviderConfig(
|
||||
id="qwen-oauth",
|
||||
name="Qwen OAuth",
|
||||
auth_type="oauth_external",
|
||||
inference_base_url=DEFAULT_QWEN_BASE_URL,
|
||||
),
|
||||
"copilot": ProviderConfig(
|
||||
id="copilot",
|
||||
name="GitHub Copilot",
|
||||
@@ -817,6 +827,7 @@ def resolve_provider(
|
||||
"github-copilot-acp": "copilot-acp", "copilot-acp-agent": "copilot-acp",
|
||||
"aigateway": "ai-gateway", "vercel": "ai-gateway", "vercel-ai-gateway": "ai-gateway",
|
||||
"opencode": "opencode-zen", "zen": "opencode-zen",
|
||||
"qwen-portal": "qwen-oauth", "qwen-cli": "qwen-oauth", "qwen-oauth": "qwen-oauth",
|
||||
"hf": "huggingface", "hugging-face": "huggingface", "huggingface-hub": "huggingface",
|
||||
"go": "opencode-go", "opencode-go-sub": "opencode-go",
|
||||
"kilo": "kilocode", "kilo-code": "kilocode", "kilo-gateway": "kilocode",
|
||||
@@ -946,6 +957,176 @@ def _codex_access_token_is_expiring(access_token: Any, skew_seconds: int) -> boo
|
||||
return float(exp) <= (time.time() + max(0, int(skew_seconds)))
|
||||
|
||||
|
||||
def _qwen_cli_auth_path() -> Path:
|
||||
return Path.home() / ".qwen" / "oauth_creds.json"
|
||||
|
||||
|
||||
def _read_qwen_cli_tokens() -> Dict[str, Any]:
|
||||
auth_path = _qwen_cli_auth_path()
|
||||
if not auth_path.exists():
|
||||
raise AuthError(
|
||||
"Qwen CLI credentials not found. Run 'qwen auth qwen-oauth' first.",
|
||||
provider="qwen-oauth",
|
||||
code="qwen_auth_missing",
|
||||
)
|
||||
try:
|
||||
data = json.loads(auth_path.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
raise AuthError(
|
||||
f"Failed to read Qwen CLI credentials from {auth_path}: {exc}",
|
||||
provider="qwen-oauth",
|
||||
code="qwen_auth_read_failed",
|
||||
) from exc
|
||||
if not isinstance(data, dict):
|
||||
raise AuthError(
|
||||
f"Invalid Qwen CLI credentials in {auth_path}.",
|
||||
provider="qwen-oauth",
|
||||
code="qwen_auth_invalid",
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def _save_qwen_cli_tokens(tokens: Dict[str, Any]) -> Path:
|
||||
auth_path = _qwen_cli_auth_path()
|
||||
auth_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp_path = auth_path.with_suffix(".tmp")
|
||||
tmp_path.write_text(json.dumps(tokens, indent=2, sort_keys=True) + "\n", encoding="utf-8")
|
||||
os.chmod(tmp_path, stat.S_IRUSR | stat.S_IWUSR)
|
||||
tmp_path.replace(auth_path)
|
||||
return auth_path
|
||||
|
||||
|
||||
def _qwen_access_token_is_expiring(expiry_date_ms: Any, skew_seconds: int = QWEN_ACCESS_TOKEN_REFRESH_SKEW_SECONDS) -> bool:
|
||||
try:
|
||||
expiry_ms = int(expiry_date_ms)
|
||||
except Exception:
|
||||
return True
|
||||
return (time.time() + max(0, int(skew_seconds))) * 1000 >= expiry_ms
|
||||
|
||||
|
||||
def _refresh_qwen_cli_tokens(tokens: Dict[str, Any], timeout_seconds: float = 20.0) -> Dict[str, Any]:
|
||||
refresh_token = str(tokens.get("refresh_token", "") or "").strip()
|
||||
if not refresh_token:
|
||||
raise AuthError(
|
||||
"Qwen OAuth refresh token missing. Re-run 'qwen auth qwen-oauth'.",
|
||||
provider="qwen-oauth",
|
||||
code="qwen_refresh_token_missing",
|
||||
)
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
QWEN_OAUTH_TOKEN_URL,
|
||||
headers={
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": QWEN_OAUTH_CLIENT_ID,
|
||||
},
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise AuthError(
|
||||
f"Qwen OAuth refresh failed: {exc}",
|
||||
provider="qwen-oauth",
|
||||
code="qwen_refresh_failed",
|
||||
) from exc
|
||||
|
||||
if response.status_code >= 400:
|
||||
body = response.text.strip()
|
||||
raise AuthError(
|
||||
"Qwen OAuth refresh failed. Re-run 'qwen auth qwen-oauth'."
|
||||
+ (f" Response: {body}" if body else ""),
|
||||
provider="qwen-oauth",
|
||||
code="qwen_refresh_failed",
|
||||
)
|
||||
|
||||
try:
|
||||
payload = response.json()
|
||||
except Exception as exc:
|
||||
raise AuthError(
|
||||
f"Qwen OAuth refresh returned invalid JSON: {exc}",
|
||||
provider="qwen-oauth",
|
||||
code="qwen_refresh_invalid_json",
|
||||
) from exc
|
||||
|
||||
if not isinstance(payload, dict) or not str(payload.get("access_token", "") or "").strip():
|
||||
raise AuthError(
|
||||
"Qwen OAuth refresh response missing access_token.",
|
||||
provider="qwen-oauth",
|
||||
code="qwen_refresh_invalid_response",
|
||||
)
|
||||
|
||||
expires_in = payload.get("expires_in")
|
||||
try:
|
||||
expires_in_seconds = int(expires_in)
|
||||
except Exception:
|
||||
expires_in_seconds = 6 * 60 * 60
|
||||
|
||||
refreshed = {
|
||||
"access_token": str(payload.get("access_token", "") or "").strip(),
|
||||
"refresh_token": str(payload.get("refresh_token", refresh_token) or refresh_token).strip(),
|
||||
"token_type": str(payload.get("token_type", tokens.get("token_type", "Bearer")) or "Bearer").strip() or "Bearer",
|
||||
"resource_url": str(payload.get("resource_url", tokens.get("resource_url", "portal.qwen.ai")) or "portal.qwen.ai").strip(),
|
||||
"expiry_date": int(time.time() * 1000) + max(1, expires_in_seconds) * 1000,
|
||||
}
|
||||
_save_qwen_cli_tokens(refreshed)
|
||||
return refreshed
|
||||
|
||||
|
||||
def resolve_qwen_runtime_credentials(
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
refresh_if_expiring: bool = True,
|
||||
refresh_skew_seconds: int = QWEN_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
) -> Dict[str, Any]:
|
||||
tokens = _read_qwen_cli_tokens()
|
||||
access_token = str(tokens.get("access_token", "") or "").strip()
|
||||
should_refresh = bool(force_refresh)
|
||||
if not should_refresh and refresh_if_expiring:
|
||||
should_refresh = _qwen_access_token_is_expiring(tokens.get("expiry_date"), refresh_skew_seconds)
|
||||
if should_refresh:
|
||||
tokens = _refresh_qwen_cli_tokens(tokens)
|
||||
access_token = str(tokens.get("access_token", "") or "").strip()
|
||||
if not access_token:
|
||||
raise AuthError(
|
||||
"Qwen OAuth access token missing. Re-run 'qwen auth qwen-oauth'.",
|
||||
provider="qwen-oauth",
|
||||
code="qwen_access_token_missing",
|
||||
)
|
||||
|
||||
base_url = os.getenv("HERMES_QWEN_BASE_URL", "").strip().rstrip("/") or DEFAULT_QWEN_BASE_URL
|
||||
return {
|
||||
"provider": "qwen-oauth",
|
||||
"base_url": base_url,
|
||||
"api_key": access_token,
|
||||
"source": "qwen-cli",
|
||||
"expires_at_ms": tokens.get("expiry_date"),
|
||||
"auth_file": str(_qwen_cli_auth_path()),
|
||||
}
|
||||
|
||||
|
||||
def get_qwen_auth_status() -> Dict[str, Any]:
|
||||
auth_path = _qwen_cli_auth_path()
|
||||
try:
|
||||
creds = resolve_qwen_runtime_credentials(refresh_if_expiring=False)
|
||||
return {
|
||||
"logged_in": True,
|
||||
"auth_file": str(auth_path),
|
||||
"source": creds.get("source"),
|
||||
"api_key": creds.get("api_key"),
|
||||
"expires_at_ms": creds.get("expires_at_ms"),
|
||||
}
|
||||
except AuthError as exc:
|
||||
return {
|
||||
"logged_in": False,
|
||||
"auth_file": str(auth_path),
|
||||
"error": str(exc),
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SSH / remote session detection
|
||||
# =============================================================================
|
||||
@@ -2072,6 +2253,8 @@ def get_auth_status(provider_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
return get_nous_auth_status()
|
||||
if target == "openai-codex":
|
||||
return get_codex_auth_status()
|
||||
if target == "qwen-oauth":
|
||||
return get_qwen_auth_status()
|
||||
if target == "copilot-acp":
|
||||
return get_external_process_provider_status(target)
|
||||
# API-key providers
|
||||
|
||||
@@ -32,7 +32,7 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
|
||||
# Providers that support OAuth login in addition to API keys.
|
||||
_OAUTH_CAPABLE_PROVIDERS = {"anthropic", "nous", "openai-codex"}
|
||||
_OAUTH_CAPABLE_PROVIDERS = {"anthropic", "nous", "openai-codex", "qwen-oauth"}
|
||||
|
||||
|
||||
def _get_custom_provider_names() -> list:
|
||||
@@ -147,7 +147,7 @@ def auth_add_command(args) -> None:
|
||||
if provider.startswith(CUSTOM_POOL_PREFIX):
|
||||
requested_type = AUTH_TYPE_API_KEY
|
||||
else:
|
||||
requested_type = AUTH_TYPE_OAUTH if provider in {"anthropic", "nous", "openai-codex"} else AUTH_TYPE_API_KEY
|
||||
requested_type = AUTH_TYPE_OAUTH if provider in {"anthropic", "nous", "openai-codex", "qwen-oauth"} else AUTH_TYPE_API_KEY
|
||||
|
||||
pool = load_pool(provider)
|
||||
|
||||
@@ -250,6 +250,26 @@ def auth_add_command(args) -> None:
|
||||
print(f'Added {provider} OAuth credential #{len(pool.entries())}: "{entry.label}"')
|
||||
return
|
||||
|
||||
if provider == "qwen-oauth":
|
||||
creds = auth_mod.resolve_qwen_runtime_credentials(refresh_if_expiring=False)
|
||||
label = (getattr(args, "label", None) or "").strip() or label_from_token(
|
||||
creds["api_key"],
|
||||
_oauth_default_label(provider, len(pool.entries()) + 1),
|
||||
)
|
||||
entry = PooledCredential(
|
||||
provider=provider,
|
||||
id=uuid.uuid4().hex[:6],
|
||||
label=label,
|
||||
auth_type=AUTH_TYPE_OAUTH,
|
||||
priority=0,
|
||||
source=f"{SOURCE_MANUAL}:qwen_cli",
|
||||
access_token=creds["api_key"],
|
||||
base_url=creds.get("base_url"),
|
||||
)
|
||||
pool.add_entry(entry)
|
||||
print(f'Added {provider} OAuth credential #{len(pool.entries())}: "{entry.label}"')
|
||||
return
|
||||
|
||||
raise SystemExit(f"`hermes auth add {provider}` is not implemented for auth type {requested_type} yet.")
|
||||
|
||||
|
||||
|
||||
+35
-3
@@ -157,7 +157,14 @@ def get_project_root() -> Path:
|
||||
return Path(__file__).parent.parent.resolve()
|
||||
|
||||
def _secure_dir(path):
|
||||
"""Set directory to owner-only access (0700). No-op on Windows."""
|
||||
"""Set directory to owner-only access (0700). No-op on Windows.
|
||||
|
||||
Skipped in managed mode — the NixOS module sets group-readable
|
||||
permissions (0750) so interactive users in the hermes group can
|
||||
share state with the gateway service.
|
||||
"""
|
||||
if is_managed():
|
||||
return
|
||||
try:
|
||||
os.chmod(path, 0o700)
|
||||
except (OSError, NotImplementedError):
|
||||
@@ -165,7 +172,13 @@ def _secure_dir(path):
|
||||
|
||||
|
||||
def _secure_file(path):
|
||||
"""Set file to owner-only read/write (0600). No-op on Windows."""
|
||||
"""Set file to owner-only read/write (0600). No-op on Windows.
|
||||
|
||||
Skipped in managed mode — the NixOS activation script sets
|
||||
group-readable permissions (0640) on config files.
|
||||
"""
|
||||
if is_managed():
|
||||
return
|
||||
try:
|
||||
if os.path.exists(str(path)):
|
||||
os.chmod(path, 0o600)
|
||||
@@ -379,6 +392,7 @@ DEFAULT_CONFIG = {
|
||||
"show_cost": False, # Show $ cost in the status bar (off by default)
|
||||
"skin": "default",
|
||||
"tool_progress_command": False, # Enable /verbose command in messaging gateway
|
||||
"tool_progress_overrides": {}, # Per-platform overrides: {"signal": "off", "telegram": "all"}
|
||||
"tool_preview_length": 0, # Max chars for tool call previews (0 = no limit, show full paths/commands)
|
||||
},
|
||||
|
||||
@@ -413,7 +427,7 @@ DEFAULT_CONFIG = {
|
||||
|
||||
"stt": {
|
||||
"enabled": True,
|
||||
"provider": "local", # "local" (free, faster-whisper) | "groq" | "openai" (Whisper API)
|
||||
"provider": "local", # "local" (free, faster-whisper) | "groq" | "openai" (Whisper API) | "mistral" (Voxtral Transcribe)
|
||||
"local": {
|
||||
"model": "base", # tiny, base, small, medium, large-v3
|
||||
"language": "", # auto-detect by default; set to "en", "es", "fr", etc. to force
|
||||
@@ -421,6 +435,9 @@ DEFAULT_CONFIG = {
|
||||
"openai": {
|
||||
"model": "whisper-1", # whisper-1, gpt-4o-mini-transcribe, gpt-4o-transcribe
|
||||
},
|
||||
"mistral": {
|
||||
"model": "voxtral-mini-latest", # voxtral-mini-latest, voxtral-mini-2602
|
||||
},
|
||||
},
|
||||
|
||||
"voice": {
|
||||
@@ -724,6 +741,14 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"HERMES_QWEN_BASE_URL": {
|
||||
"description": "Qwen Portal base URL override (default: https://portal.qwen.ai/v1)",
|
||||
"prompt": "Qwen Portal base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"OPENCODE_ZEN_API_KEY": {
|
||||
"description": "OpenCode Zen API key (pay-as-you-go access to curated models)",
|
||||
"prompt": "OpenCode Zen API key",
|
||||
@@ -975,6 +1000,13 @@ OPTIONAL_ENV_VARS = {
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"DISCORD_REPLY_TO_MODE": {
|
||||
"description": "Discord reply threading mode: 'off' (no reply references), 'first' (reply on first message only, default), 'all' (reply on every chunk)",
|
||||
"prompt": "Discord reply mode (off/first/all)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "messaging",
|
||||
},
|
||||
"SLACK_BOT_TOKEN": {
|
||||
"description": "Slack bot token (xoxb-). Get from OAuth & Permissions after installing your app. "
|
||||
"Required scopes: chat:write, app_mentions:read, channels:history, groups:history, "
|
||||
|
||||
@@ -93,6 +93,21 @@ def cron_list(show_all: bool = False):
|
||||
script = job.get("script")
|
||||
if script:
|
||||
print(f" Script: {script}")
|
||||
|
||||
# Execution history
|
||||
last_status = job.get("last_status")
|
||||
if last_status:
|
||||
last_run = job.get("last_run_at", "?")
|
||||
if last_status == "ok":
|
||||
status_display = color("ok", Colors.GREEN)
|
||||
else:
|
||||
status_display = color(f"{last_status}: {job.get('last_error', '?')}", Colors.RED)
|
||||
print(f" Last run: {last_run} {status_display}")
|
||||
|
||||
delivery_err = job.get("last_delivery_error")
|
||||
if delivery_err:
|
||||
print(f" {color('⚠ Delivery failed:', Colors.YELLOW)} {delivery_err}")
|
||||
|
||||
print()
|
||||
|
||||
from hermes_cli.gateway import find_gateway_pids
|
||||
|
||||
+70
-56
@@ -812,69 +812,83 @@ def run_doctor(args):
|
||||
check_warn("No GITHUB_TOKEN", f"(60 req/hr rate limit — set in {_DHH}/.env for better rates)")
|
||||
|
||||
# =========================================================================
|
||||
# Honcho memory
|
||||
# Memory Provider (only check the active provider, if any)
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Honcho Memory", Colors.CYAN, Colors.BOLD))
|
||||
print(color("◆ Memory Provider", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
_active_memory_provider = ""
|
||||
try:
|
||||
from plugins.memory.honcho.client import HonchoClientConfig, resolve_config_path
|
||||
hcfg = HonchoClientConfig.from_global_config()
|
||||
_honcho_cfg_path = resolve_config_path()
|
||||
import yaml as _yaml
|
||||
_mem_cfg_path = HERMES_HOME / "config.yaml"
|
||||
if _mem_cfg_path.exists():
|
||||
with open(_mem_cfg_path) as _f:
|
||||
_raw_cfg = _yaml.safe_load(_f) or {}
|
||||
_active_memory_provider = (_raw_cfg.get("memory") or {}).get("provider", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not _honcho_cfg_path.exists():
|
||||
check_warn("Honcho config not found", "run: hermes memory setup")
|
||||
elif not hcfg.enabled:
|
||||
check_info(f"Honcho disabled (set enabled: true in {_honcho_cfg_path} to activate)")
|
||||
elif not (hcfg.api_key or hcfg.base_url):
|
||||
check_fail("Honcho API key or base URL not set", "run: hermes memory setup")
|
||||
issues.append("No Honcho API key — run 'hermes memory setup'")
|
||||
else:
|
||||
from plugins.memory.honcho.client import get_honcho_client, reset_honcho_client
|
||||
reset_honcho_client()
|
||||
try:
|
||||
get_honcho_client(hcfg)
|
||||
check_ok(
|
||||
"Honcho connected",
|
||||
f"workspace={hcfg.workspace_id} mode={hcfg.recall_mode} freq={hcfg.write_frequency}",
|
||||
)
|
||||
except Exception as _e:
|
||||
check_fail("Honcho connection failed", str(_e))
|
||||
issues.append(f"Honcho unreachable: {_e}")
|
||||
except ImportError:
|
||||
check_warn("honcho-ai not installed", "pip install honcho-ai")
|
||||
except Exception as _e:
|
||||
check_warn("Honcho check failed", str(_e))
|
||||
if not _active_memory_provider:
|
||||
check_ok("Built-in memory active", "(no external provider configured — this is fine)")
|
||||
elif _active_memory_provider == "honcho":
|
||||
try:
|
||||
from plugins.memory.honcho.client import HonchoClientConfig, resolve_config_path
|
||||
hcfg = HonchoClientConfig.from_global_config()
|
||||
_honcho_cfg_path = resolve_config_path()
|
||||
|
||||
# =========================================================================
|
||||
# Mem0 memory
|
||||
# =========================================================================
|
||||
print()
|
||||
print(color("◆ Mem0 Memory", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
try:
|
||||
from plugins.memory.mem0 import _load_config as _load_mem0_config
|
||||
mem0_cfg = _load_mem0_config()
|
||||
mem0_key = mem0_cfg.get("api_key", "")
|
||||
if mem0_key:
|
||||
check_ok("Mem0 API key configured")
|
||||
check_info(f"user_id={mem0_cfg.get('user_id', '?')} agent_id={mem0_cfg.get('agent_id', '?')}")
|
||||
# Check if mem0.json exists but is missing api_key (the bug we fixed)
|
||||
mem0_json = HERMES_HOME / "mem0.json"
|
||||
if mem0_json.exists():
|
||||
if not _honcho_cfg_path.exists():
|
||||
check_warn("Honcho config not found", "run: hermes memory setup")
|
||||
elif not hcfg.enabled:
|
||||
check_info(f"Honcho disabled (set enabled: true in {_honcho_cfg_path} to activate)")
|
||||
elif not (hcfg.api_key or hcfg.base_url):
|
||||
check_fail("Honcho API key or base URL not set", "run: hermes memory setup")
|
||||
issues.append("No Honcho API key — run 'hermes memory setup'")
|
||||
else:
|
||||
from plugins.memory.honcho.client import get_honcho_client, reset_honcho_client
|
||||
reset_honcho_client()
|
||||
try:
|
||||
import json as _json
|
||||
file_cfg = _json.loads(mem0_json.read_text())
|
||||
if not file_cfg.get("api_key") and mem0_key:
|
||||
check_info("api_key from .env (not in mem0.json) — this is fine")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
check_warn("Mem0 not configured", "(set MEM0_API_KEY in .env or run hermes memory setup)")
|
||||
except ImportError:
|
||||
check_warn("Mem0 plugin not loadable", "(optional)")
|
||||
except Exception as _e:
|
||||
check_warn("Mem0 check failed", str(_e))
|
||||
get_honcho_client(hcfg)
|
||||
check_ok(
|
||||
"Honcho connected",
|
||||
f"workspace={hcfg.workspace_id} mode={hcfg.recall_mode} freq={hcfg.write_frequency}",
|
||||
)
|
||||
except Exception as _e:
|
||||
check_fail("Honcho connection failed", str(_e))
|
||||
issues.append(f"Honcho unreachable: {_e}")
|
||||
except ImportError:
|
||||
check_fail("honcho-ai not installed", "pip install honcho-ai")
|
||||
issues.append("Honcho is set as memory provider but honcho-ai is not installed")
|
||||
except Exception as _e:
|
||||
check_warn("Honcho check failed", str(_e))
|
||||
elif _active_memory_provider == "mem0":
|
||||
try:
|
||||
from plugins.memory.mem0 import _load_config as _load_mem0_config
|
||||
mem0_cfg = _load_mem0_config()
|
||||
mem0_key = mem0_cfg.get("api_key", "")
|
||||
if mem0_key:
|
||||
check_ok("Mem0 API key configured")
|
||||
check_info(f"user_id={mem0_cfg.get('user_id', '?')} agent_id={mem0_cfg.get('agent_id', '?')}")
|
||||
else:
|
||||
check_fail("Mem0 API key not set", "(set MEM0_API_KEY in .env or run hermes memory setup)")
|
||||
issues.append("Mem0 is set as memory provider but API key is missing")
|
||||
except ImportError:
|
||||
check_fail("Mem0 plugin not loadable", "pip install mem0ai")
|
||||
issues.append("Mem0 is set as memory provider but mem0ai is not installed")
|
||||
except Exception as _e:
|
||||
check_warn("Mem0 check failed", str(_e))
|
||||
else:
|
||||
# Generic check for other memory providers (openviking, hindsight, etc.)
|
||||
try:
|
||||
from plugins.memory import load_memory_provider
|
||||
_provider = load_memory_provider(_active_memory_provider)
|
||||
if _provider and _provider.is_available():
|
||||
check_ok(f"{_active_memory_provider} provider active")
|
||||
elif _provider:
|
||||
check_warn(f"{_active_memory_provider} configured but not available", "run: hermes memory status")
|
||||
else:
|
||||
check_warn(f"{_active_memory_provider} plugin not found", "run: hermes memory setup")
|
||||
except Exception as _e:
|
||||
check_warn(f"{_active_memory_provider} check failed", str(_e))
|
||||
|
||||
# =========================================================================
|
||||
# Profiles
|
||||
|
||||
@@ -918,6 +918,7 @@ def select_provider_and_model(args=None):
|
||||
"openrouter": "OpenRouter",
|
||||
"nous": "Nous Portal",
|
||||
"openai-codex": "OpenAI Codex",
|
||||
"qwen-oauth": "Qwen OAuth",
|
||||
"copilot-acp": "GitHub Copilot ACP",
|
||||
"copilot": "GitHub Copilot",
|
||||
"anthropic": "Anthropic",
|
||||
@@ -947,6 +948,7 @@ def select_provider_and_model(args=None):
|
||||
("openrouter", "OpenRouter (100+ models, pay-per-use)"),
|
||||
("anthropic", "Anthropic (Claude models — API key or Claude Code)"),
|
||||
("openai-codex", "OpenAI Codex"),
|
||||
("qwen-oauth", "Qwen OAuth (reuses local Qwen CLI login)"),
|
||||
("copilot", "GitHub Copilot (uses GITHUB_TOKEN or gh auth token)"),
|
||||
("huggingface", "Hugging Face Inference Providers (20+ open models)"),
|
||||
]
|
||||
@@ -1043,6 +1045,8 @@ def select_provider_and_model(args=None):
|
||||
_model_flow_nous(config, current_model, args=args)
|
||||
elif selected_provider == "openai-codex":
|
||||
_model_flow_openai_codex(config, current_model)
|
||||
elif selected_provider == "qwen-oauth":
|
||||
_model_flow_qwen_oauth(config, current_model)
|
||||
elif selected_provider == "copilot-acp":
|
||||
_model_flow_copilot_acp(config, current_model)
|
||||
elif selected_provider == "copilot":
|
||||
@@ -1359,6 +1363,56 @@ def _model_flow_openai_codex(config, current_model=""):
|
||||
|
||||
|
||||
|
||||
_DEFAULT_QWEN_PORTAL_MODELS = [
|
||||
"qwen3-coder-plus",
|
||||
"qwen3-coder",
|
||||
]
|
||||
|
||||
|
||||
def _model_flow_qwen_oauth(_config, current_model=""):
|
||||
"""Qwen OAuth provider: reuse local Qwen CLI login, then pick model."""
|
||||
from hermes_cli.auth import (
|
||||
get_qwen_auth_status,
|
||||
resolve_qwen_runtime_credentials,
|
||||
_prompt_model_selection,
|
||||
_save_model_choice,
|
||||
_update_config_for_provider,
|
||||
DEFAULT_QWEN_BASE_URL,
|
||||
)
|
||||
from hermes_cli.models import fetch_api_models
|
||||
|
||||
status = get_qwen_auth_status()
|
||||
if not status.get("logged_in"):
|
||||
print("Not logged into Qwen CLI OAuth.")
|
||||
print("Run: qwen auth qwen-oauth")
|
||||
auth_file = status.get("auth_file")
|
||||
if auth_file:
|
||||
print(f"Expected credentials file: {auth_file}")
|
||||
if status.get("error"):
|
||||
print(f"Error: {status.get('error')}")
|
||||
return
|
||||
|
||||
# Try live model discovery, fall back to curated list.
|
||||
models = None
|
||||
try:
|
||||
creds = resolve_qwen_runtime_credentials(refresh_if_expiring=True)
|
||||
models = fetch_api_models(creds["api_key"], creds["base_url"])
|
||||
except Exception:
|
||||
pass
|
||||
if not models:
|
||||
models = list(_DEFAULT_QWEN_PORTAL_MODELS)
|
||||
|
||||
default = current_model or (models[0] if models else "qwen3-coder-plus")
|
||||
selected = _prompt_model_selection(models, current_model=default)
|
||||
if selected:
|
||||
_save_model_choice(selected)
|
||||
_update_config_for_provider("qwen-oauth", DEFAULT_QWEN_BASE_URL)
|
||||
print(f"Default model set to: {selected} (via Qwen OAuth)")
|
||||
else:
|
||||
print("No change.")
|
||||
|
||||
|
||||
|
||||
def _model_flow_custom(config):
|
||||
"""Custom endpoint: collect URL, API key, and model name.
|
||||
|
||||
|
||||
@@ -84,6 +84,7 @@ _PASSTHROUGH_PROVIDERS: frozenset[str] = frozenset({
|
||||
"minimax",
|
||||
"minimax-cn",
|
||||
"alibaba",
|
||||
"qwen-oauth",
|
||||
"huggingface",
|
||||
"openai-codex",
|
||||
"custom",
|
||||
|
||||
@@ -791,12 +791,12 @@ def list_authenticated_providers(
|
||||
if overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"):
|
||||
# These use auth stores, not env vars — check for auth.json entries
|
||||
try:
|
||||
from hermes_cli.auth import _read_auth_store
|
||||
store = _read_auth_store()
|
||||
if store and pid in store:
|
||||
from hermes_cli.auth import _load_auth_store
|
||||
store = _load_auth_store()
|
||||
if store and (pid in store.get("providers", {}) or pid in store.get("credential_pool", {})):
|
||||
has_creds = True
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.debug("Auth store check failed for %s: %s", pid, exc)
|
||||
if not has_creds:
|
||||
continue
|
||||
|
||||
|
||||
+15
-8
@@ -144,18 +144,22 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"kimi-k2-0905-preview",
|
||||
],
|
||||
"minimax": [
|
||||
"MiniMax-M2.7",
|
||||
"MiniMax-M2.7-highspeed",
|
||||
"MiniMax-M1",
|
||||
"MiniMax-M1-40k",
|
||||
"MiniMax-M1-80k",
|
||||
"MiniMax-M1-128k",
|
||||
"MiniMax-M1-256k",
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
"MiniMax-M2.7",
|
||||
],
|
||||
"minimax-cn": [
|
||||
"MiniMax-M2.7",
|
||||
"MiniMax-M2.7-highspeed",
|
||||
"MiniMax-M1",
|
||||
"MiniMax-M1-40k",
|
||||
"MiniMax-M1-80k",
|
||||
"MiniMax-M1-128k",
|
||||
"MiniMax-M1-256k",
|
||||
"MiniMax-M2.5",
|
||||
"MiniMax-M2.5-highspeed",
|
||||
"MiniMax-M2.1",
|
||||
"MiniMax-M2.7",
|
||||
],
|
||||
"anthropic": [
|
||||
"claude-opus-4-6",
|
||||
@@ -479,6 +483,7 @@ _PROVIDER_LABELS = {
|
||||
"ai-gateway": "AI Gateway",
|
||||
"kilocode": "Kilo Code",
|
||||
"alibaba": "Alibaba Cloud (DashScope)",
|
||||
"qwen-oauth": "Qwen OAuth (Portal)",
|
||||
"huggingface": "Hugging Face",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
@@ -518,6 +523,7 @@ _PROVIDER_ALIASES = {
|
||||
"aliyun": "alibaba",
|
||||
"qwen": "alibaba",
|
||||
"alibaba-cloud": "alibaba",
|
||||
"qwen-portal": "qwen-oauth",
|
||||
"hf": "huggingface",
|
||||
"hugging-face": "huggingface",
|
||||
"huggingface-hub": "huggingface",
|
||||
@@ -763,6 +769,7 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
||||
"gemini", "huggingface",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn", "kilocode", "anthropic", "alibaba",
|
||||
"qwen-oauth",
|
||||
"opencode-zen", "opencode-go",
|
||||
"ai-gateway", "deepseek", "custom",
|
||||
]
|
||||
|
||||
@@ -61,6 +61,8 @@ VALID_HOOKS: Set[str] = {
|
||||
"post_api_request",
|
||||
"on_session_start",
|
||||
"on_session_end",
|
||||
"on_session_finalize",
|
||||
"on_session_reset",
|
||||
}
|
||||
|
||||
ENTRY_POINTS_GROUP = "hermes_agent.plugins"
|
||||
|
||||
@@ -58,6 +58,12 @@ HERMES_OVERLAYS: Dict[str, HermesOverlay] = {
|
||||
auth_type="oauth_external",
|
||||
base_url_override="https://chatgpt.com/backend-api/codex",
|
||||
),
|
||||
"qwen-oauth": HermesOverlay(
|
||||
transport="openai_chat",
|
||||
auth_type="oauth_external",
|
||||
base_url_override="https://portal.qwen.ai/v1",
|
||||
base_url_env_var="HERMES_QWEN_BASE_URL",
|
||||
),
|
||||
"copilot-acp": HermesOverlay(
|
||||
transport="codex_responses",
|
||||
auth_type="external_process",
|
||||
|
||||
@@ -14,11 +14,13 @@ from agent.credential_pool import CredentialPool, PooledCredential, get_custom_p
|
||||
from hermes_cli.auth import (
|
||||
AuthError,
|
||||
DEFAULT_CODEX_BASE_URL,
|
||||
DEFAULT_QWEN_BASE_URL,
|
||||
PROVIDER_REGISTRY,
|
||||
format_auth_error,
|
||||
resolve_provider,
|
||||
resolve_nous_runtime_credentials,
|
||||
resolve_codex_runtime_credentials,
|
||||
resolve_qwen_runtime_credentials,
|
||||
resolve_api_key_provider_credentials,
|
||||
resolve_external_process_provider_credentials,
|
||||
has_usable_secret,
|
||||
@@ -148,6 +150,9 @@ def _resolve_runtime_from_pool_entry(
|
||||
if provider == "openai-codex":
|
||||
api_mode = "codex_responses"
|
||||
base_url = base_url or DEFAULT_CODEX_BASE_URL
|
||||
elif provider == "qwen-oauth":
|
||||
api_mode = "chat_completions"
|
||||
base_url = base_url or DEFAULT_QWEN_BASE_URL
|
||||
elif provider == "anthropic":
|
||||
api_mode = "anthropic_messages"
|
||||
cfg_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
@@ -163,6 +168,16 @@ def _resolve_runtime_from_pool_entry(
|
||||
api_mode = _copilot_runtime_api_mode(model_cfg, getattr(entry, "runtime_api_key", ""))
|
||||
else:
|
||||
configured_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
# Honour model.base_url from config.yaml when the configured provider
|
||||
# matches this provider — same pattern as the Anthropic branch above.
|
||||
# Only override when the pool entry has no explicit base_url (i.e. it
|
||||
# fell back to the hardcoded default). Env var overrides win (#6039).
|
||||
pconfig = PROVIDER_REGISTRY.get(provider)
|
||||
pool_url_is_default = pconfig and base_url.rstrip("/") == pconfig.inference_base_url.rstrip("/")
|
||||
if configured_provider == provider and pool_url_is_default:
|
||||
cfg_base_url = str(model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
if cfg_base_url:
|
||||
base_url = cfg_base_url
|
||||
configured_mode = _parse_api_mode(model_cfg.get("api_mode"))
|
||||
if configured_mode and _provider_supports_explicit_api_mode(provider, configured_provider):
|
||||
api_mode = configured_mode
|
||||
@@ -681,6 +696,24 @@ def resolve_runtime_provider(
|
||||
logger.info("Auto-detected Codex provider but credentials failed; "
|
||||
"falling through to next provider.")
|
||||
|
||||
if provider == "qwen-oauth":
|
||||
try:
|
||||
creds = resolve_qwen_runtime_credentials()
|
||||
return {
|
||||
"provider": "qwen-oauth",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": creds.get("base_url", "").rstrip("/"),
|
||||
"api_key": creds.get("api_key", ""),
|
||||
"source": creds.get("source", "qwen-cli"),
|
||||
"expires_at_ms": creds.get("expires_at_ms"),
|
||||
"requested_provider": requested_provider,
|
||||
}
|
||||
except AuthError:
|
||||
if requested_provider != "auto":
|
||||
raise
|
||||
logger.info("Qwen OAuth credentials failed; "
|
||||
"falling through to next provider.")
|
||||
|
||||
if provider == "copilot-acp":
|
||||
creds = resolve_external_process_provider_credentials(provider)
|
||||
return {
|
||||
@@ -724,7 +757,15 @@ def resolve_runtime_provider(
|
||||
pconfig = PROVIDER_REGISTRY.get(provider)
|
||||
if pconfig and pconfig.auth_type == "api_key":
|
||||
creds = resolve_api_key_provider_credentials(provider)
|
||||
base_url = creds.get("base_url", "").rstrip("/")
|
||||
# Honour model.base_url from config.yaml when the configured provider
|
||||
# matches this provider — mirrors the Anthropic path above. Without
|
||||
# this, users who set model.base_url to e.g. api.minimaxi.com/anthropic
|
||||
# (China endpoint) still get the hardcoded api.minimax.io default (#6039).
|
||||
cfg_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
cfg_base_url = ""
|
||||
if cfg_provider == provider:
|
||||
cfg_base_url = (model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
base_url = cfg_base_url or creds.get("base_url", "").rstrip("/")
|
||||
api_mode = "chat_completions"
|
||||
if provider == "copilot":
|
||||
api_mode = _copilot_runtime_api_mode(model_cfg, creds.get("api_key", ""))
|
||||
|
||||
+2
-2
@@ -105,8 +105,8 @@ _DEFAULT_PROVIDER_MODELS = {
|
||||
],
|
||||
"zai": ["glm-5", "glm-4.7", "glm-4.5", "glm-4.5-flash"],
|
||||
"kimi-coding": ["kimi-k2.5", "kimi-k2-thinking", "kimi-k2-turbo-preview"],
|
||||
"minimax": ["MiniMax-M2.7", "MiniMax-M2.7-highspeed", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1"],
|
||||
"minimax-cn": ["MiniMax-M2.7", "MiniMax-M2.7-highspeed", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1"],
|
||||
"minimax": ["MiniMax-M1", "MiniMax-M1-40k", "MiniMax-M1-80k", "MiniMax-M1-128k", "MiniMax-M1-256k", "MiniMax-M2.5", "MiniMax-M2.7"],
|
||||
"minimax-cn": ["MiniMax-M1", "MiniMax-M1-40k", "MiniMax-M1-80k", "MiniMax-M1-128k", "MiniMax-M1-256k", "MiniMax-M2.5", "MiniMax-M2.7"],
|
||||
"ai-gateway": ["anthropic/claude-opus-4.6", "anthropic/claude-sonnet-4.6", "openai/gpt-5", "google/gemini-3-flash"],
|
||||
"kilocode": ["anthropic/claude-opus-4.6", "anthropic/claude-sonnet-4.6", "openai/gpt-5.4", "google/gemini-3-pro-preview", "google/gemini-3-flash-preview"],
|
||||
"opencode-zen": ["gpt-5.4", "gpt-5.3-codex", "claude-sonnet-4-6", "gemini-3-flash", "glm-5", "kimi-k2.5", "minimax-m2.7"],
|
||||
|
||||
+18
-1
@@ -153,12 +153,14 @@ def show_status(args):
|
||||
print(color("◆ Auth Providers", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status
|
||||
from hermes_cli.auth import get_nous_auth_status, get_codex_auth_status, get_qwen_auth_status
|
||||
nous_status = get_nous_auth_status()
|
||||
codex_status = get_codex_auth_status()
|
||||
qwen_status = get_qwen_auth_status()
|
||||
except Exception:
|
||||
nous_status = {}
|
||||
codex_status = {}
|
||||
qwen_status = {}
|
||||
|
||||
nous_logged_in = bool(nous_status.get("logged_in"))
|
||||
print(
|
||||
@@ -189,6 +191,21 @@ def show_status(args):
|
||||
if codex_status.get("error") and not codex_logged_in:
|
||||
print(f" Error: {codex_status.get('error')}")
|
||||
|
||||
qwen_logged_in = bool(qwen_status.get("logged_in"))
|
||||
print(
|
||||
f" {'Qwen OAuth':<12} {check_mark(qwen_logged_in)} "
|
||||
f"{'logged in' if qwen_logged_in else 'not logged in (run: qwen auth qwen-oauth)'}"
|
||||
)
|
||||
qwen_auth_file = qwen_status.get("auth_file")
|
||||
if qwen_auth_file:
|
||||
print(f" Auth file: {qwen_auth_file}")
|
||||
qwen_exp = qwen_status.get("expires_at_ms")
|
||||
if qwen_exp:
|
||||
from datetime import datetime, timezone
|
||||
print(f" Access exp: {datetime.fromtimestamp(int(qwen_exp) / 1000, tz=timezone.utc).isoformat()}")
|
||||
if qwen_status.get("error") and not qwen_logged_in:
|
||||
print(f" Error: {qwen_status.get('error')}")
|
||||
|
||||
# =========================================================================
|
||||
# Nous Subscription Features
|
||||
# =========================================================================
|
||||
|
||||
+10
-2
@@ -464,7 +464,11 @@
|
||||
addToSystemPackages = mkOption {
|
||||
type = types.bool;
|
||||
default = false;
|
||||
description = "Add hermes CLI to environment.systemPackages.";
|
||||
description = ''
|
||||
Add the hermes CLI to environment.systemPackages and export
|
||||
HERMES_HOME system-wide (via environment.variables) so interactive
|
||||
shells share state with the gateway service.
|
||||
'';
|
||||
};
|
||||
|
||||
# ── OCI Container (opt-in) ──────────────────────────────────────────
|
||||
@@ -545,8 +549,12 @@
|
||||
})
|
||||
|
||||
# ── Host CLI ──────────────────────────────────────────────────────
|
||||
# Add the hermes CLI to system PATH and export HERMES_HOME system-wide
|
||||
# so interactive shells share state (sessions, skills, cron) with the
|
||||
# gateway service instead of creating a separate ~/.hermes/.
|
||||
(lib.mkIf cfg.addToSystemPackages {
|
||||
environment.systemPackages = [ cfg.package ];
|
||||
environment.variables.HERMES_HOME = "${cfg.stateDir}/.hermes";
|
||||
})
|
||||
|
||||
# ── Directories ───────────────────────────────────────────────────
|
||||
@@ -601,7 +609,7 @@
|
||||
# so this is the single source of truth for both native and container mode.
|
||||
${lib.optionalString (cfg.environment != {} || cfg.environmentFiles != []) ''
|
||||
ENV_FILE="${cfg.stateDir}/.hermes/.env"
|
||||
install -o ${cfg.user} -g ${cfg.group} -m 0600 /dev/null "$ENV_FILE"
|
||||
install -o ${cfg.user} -g ${cfg.group} -m 0640 /dev/null "$ENV_FILE"
|
||||
cat > "$ENV_FILE" <<'HERMES_NIX_ENV_EOF'
|
||||
${envFileContent}
|
||||
HERMES_NIX_ENV_EOF
|
||||
|
||||
@@ -6,14 +6,68 @@
|
||||
uv2nix,
|
||||
pyproject-nix,
|
||||
pyproject-build-systems,
|
||||
stdenv,
|
||||
}:
|
||||
let
|
||||
workspace = uv2nix.lib.workspace.loadWorkspace { workspaceRoot = ./..; };
|
||||
hacks = callPackage pyproject-nix.build.hacks { };
|
||||
|
||||
overlay = workspace.mkPyprojectOverlay {
|
||||
sourcePreference = "wheel";
|
||||
};
|
||||
|
||||
isAarch64Darwin = stdenv.hostPlatform.system == "aarch64-darwin";
|
||||
|
||||
# Keep the workspace locked through uv2nix, but supply the local voice stack
|
||||
# from nixpkgs so wheel-only transitive artifacts do not break evaluation.
|
||||
mkPrebuiltPassthru = dependencies: {
|
||||
inherit dependencies;
|
||||
optional-dependencies = { };
|
||||
dependency-groups = { };
|
||||
};
|
||||
|
||||
mkPrebuiltOverride = final: from: dependencies:
|
||||
hacks.nixpkgsPrebuilt {
|
||||
inherit from;
|
||||
prev = {
|
||||
nativeBuildInputs = [ final.pyprojectHook ];
|
||||
passthru = mkPrebuiltPassthru dependencies;
|
||||
};
|
||||
};
|
||||
|
||||
pythonPackageOverrides = final: _prev:
|
||||
if isAarch64Darwin then {
|
||||
numpy = mkPrebuiltOverride final python311.pkgs.numpy { };
|
||||
|
||||
av = mkPrebuiltOverride final python311.pkgs.av { };
|
||||
|
||||
humanfriendly = mkPrebuiltOverride final python311.pkgs.humanfriendly { };
|
||||
|
||||
coloredlogs = mkPrebuiltOverride final python311.pkgs.coloredlogs {
|
||||
humanfriendly = [ ];
|
||||
};
|
||||
|
||||
onnxruntime = mkPrebuiltOverride final python311.pkgs.onnxruntime {
|
||||
coloredlogs = [ ];
|
||||
numpy = [ ];
|
||||
packaging = [ ];
|
||||
};
|
||||
|
||||
ctranslate2 = mkPrebuiltOverride final python311.pkgs.ctranslate2 {
|
||||
numpy = [ ];
|
||||
pyyaml = [ ];
|
||||
};
|
||||
|
||||
faster-whisper = mkPrebuiltOverride final python311.pkgs.faster-whisper {
|
||||
av = [ ];
|
||||
ctranslate2 = [ ];
|
||||
huggingface-hub = [ ];
|
||||
onnxruntime = [ ];
|
||||
tokenizers = [ ];
|
||||
tqdm = [ ];
|
||||
};
|
||||
} else {};
|
||||
|
||||
pythonSet =
|
||||
(callPackage pyproject-nix.build.packages {
|
||||
python = python311;
|
||||
@@ -21,6 +75,7 @@ let
|
||||
(lib.composeManyExtensions [
|
||||
pyproject-build-systems.overlays.default
|
||||
overlay
|
||||
pythonPackageOverrides
|
||||
]);
|
||||
in
|
||||
pythonSet.mkVirtualEnv "hermes-agent-env" {
|
||||
|
||||
@@ -73,6 +73,7 @@ Config file: `~/.hermes/hindsight/config.json`
|
||||
|-----|---------|-------------|
|
||||
| `llm_provider` | `openai` | LLM provider: `openai`, `anthropic`, `gemini`, `groq`, `minimax`, `ollama` |
|
||||
| `llm_model` | per-provider | Model name (e.g. `gpt-4o-mini`, `openai/gpt-oss-120b`) |
|
||||
| `llm_base_url` | — | LLM Base URL override (e.g. `https://openrouter.ai/api/v1`) |
|
||||
|
||||
The LLM API key is stored in `~/.hermes/.env` as `HINDSIGHT_LLM_API_KEY`.
|
||||
|
||||
@@ -92,6 +93,7 @@ Available in `hybrid` and `tools` memory modes:
|
||||
|----------|-------------|
|
||||
| `HINDSIGHT_API_KEY` | API key for Hindsight Cloud |
|
||||
| `HINDSIGHT_LLM_API_KEY` | LLM API key for local mode |
|
||||
| `HINDSIGHT_API_LLM_BASE_URL` | LLM Base URL for local mode (e.g. OpenRouter) |
|
||||
| `HINDSIGHT_API_URL` | Override API endpoint |
|
||||
| `HINDSIGHT_BANK_ID` | Override bank name |
|
||||
| `HINDSIGHT_BUDGET` | Override recall budget |
|
||||
|
||||
@@ -23,6 +23,8 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
@@ -142,7 +144,6 @@ def _load_config() -> dict:
|
||||
3. Environment variables
|
||||
"""
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
# Profile-scoped path (preferred)
|
||||
profile_path = get_hermes_home() / "hindsight" / "config.json"
|
||||
@@ -234,6 +235,7 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
{"key": "api_key", "description": "Hindsight Cloud API key", "secret": True, "env_var": "HINDSIGHT_API_KEY", "url": "https://ui.hindsight.vectorize.io", "when": {"mode": "cloud"}},
|
||||
{"key": "llm_provider", "description": "LLM provider for local mode", "default": "openai", "choices": ["openai", "anthropic", "gemini", "groq", "minimax", "ollama"], "when": {"mode": "local"}},
|
||||
{"key": "llm_api_key", "description": "LLM API key for local Hindsight", "secret": True, "env_var": "HINDSIGHT_LLM_API_KEY", "when": {"mode": "local"}},
|
||||
{"key": "llm_base_url", "description": "LLM Base URL (e.g. for OpenRouter)", "default": "", "env_var": "HINDSIGHT_API_LLM_BASE_URL", "when": {"mode": "local"}},
|
||||
{"key": "llm_model", "description": "LLM model for local mode", "default": "gpt-4o-mini", "default_from": {"field": "llm_provider", "map": _PROVIDER_DEFAULT_MODELS}, "when": {"mode": "local"}},
|
||||
{"key": "bank_id", "description": "Memory bank name", "default": "hermes"},
|
||||
{"key": "budget", "description": "Recall thoroughness", "default": "mid", "choices": ["low", "mid", "high"]},
|
||||
@@ -250,12 +252,16 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
# different loop" errors during GC — we handle cleanup in
|
||||
# shutdown() instead.
|
||||
HindsightEmbedded.__del__ = lambda self: None
|
||||
self._client = HindsightEmbedded(
|
||||
kwargs = dict(
|
||||
profile=self._config.get("profile", "hermes"),
|
||||
llm_provider=self._config.get("llm_provider", ""),
|
||||
llm_api_key=self._config.get("llmApiKey") or os.environ.get("HINDSIGHT_LLM_API_KEY", ""),
|
||||
llm_api_key=self._config.get("llm_api_key") or os.environ.get("HINDSIGHT_LLM_API_KEY", ""),
|
||||
llm_model=self._config.get("llm_model", ""),
|
||||
)
|
||||
base_url = self._config.get("llm_base_url") or os.environ.get("HINDSIGHT_API_LLM_BASE_URL", "")
|
||||
if base_url:
|
||||
kwargs["llm_base_url"] = base_url
|
||||
self._client = HindsightEmbedded(**kwargs)
|
||||
else:
|
||||
from hindsight_client import Hindsight
|
||||
kwargs = {"base_url": self._api_url, "timeout": 30.0}
|
||||
@@ -310,9 +316,10 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
# If the config changed and the daemon is running, stop it.
|
||||
from pathlib import Path as _Path
|
||||
profile_env = _Path.home() / ".hindsight" / "profiles" / f"{profile}.env"
|
||||
current_key = self._config.get("llmApiKey") or os.environ.get("HINDSIGHT_LLM_API_KEY", "")
|
||||
current_key = self._config.get("llm_api_key") or os.environ.get("HINDSIGHT_LLM_API_KEY", "")
|
||||
current_provider = self._config.get("llm_provider", "")
|
||||
current_model = self._config.get("llm_model", "")
|
||||
current_base_url = self._config.get("llm_base_url") or os.environ.get("HINDSIGHT_API_LLM_BASE_URL", "")
|
||||
|
||||
# Read saved profile config
|
||||
saved = {}
|
||||
@@ -325,18 +332,22 @@ class HindsightMemoryProvider(MemoryProvider):
|
||||
config_changed = (
|
||||
saved.get("HINDSIGHT_API_LLM_PROVIDER") != current_provider or
|
||||
saved.get("HINDSIGHT_API_LLM_MODEL") != current_model or
|
||||
saved.get("HINDSIGHT_API_LLM_API_KEY") != current_key
|
||||
saved.get("HINDSIGHT_API_LLM_API_KEY") != current_key or
|
||||
saved.get("HINDSIGHT_API_LLM_BASE_URL", "") != current_base_url
|
||||
)
|
||||
|
||||
if config_changed:
|
||||
# Write updated profile .env
|
||||
profile_env.parent.mkdir(parents=True, exist_ok=True)
|
||||
profile_env.write_text(
|
||||
env_lines = (
|
||||
f"HINDSIGHT_API_LLM_PROVIDER={current_provider}\n"
|
||||
f"HINDSIGHT_API_LLM_API_KEY={current_key}\n"
|
||||
f"HINDSIGHT_API_LLM_MODEL={current_model}\n"
|
||||
f"HINDSIGHT_API_LOG_LEVEL=info\n"
|
||||
)
|
||||
if current_base_url:
|
||||
env_lines += f"HINDSIGHT_API_LLM_BASE_URL={current_base_url}\n"
|
||||
profile_env.write_text(env_lines)
|
||||
if client._manager.is_running(profile):
|
||||
with open(log_path, "a") as f:
|
||||
f.write("\n=== Config changed, restarting daemon ===\n")
|
||||
|
||||
+3
-1
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "hermes-agent"
|
||||
version = "0.7.0"
|
||||
version = "0.8.0"
|
||||
description = "The self-improving AI agent — creates skills from experience, improves them during use, and runs anywhere"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
@@ -62,6 +62,7 @@ mcp = ["mcp>=1.2.0,<2"]
|
||||
homeassistant = ["aiohttp>=3.9.0,<4"]
|
||||
sms = ["aiohttp>=3.9.0,<4"]
|
||||
acp = ["agent-client-protocol>=0.9.0,<1.0"]
|
||||
mistral = ["mistralai>=2.3.0,<3"]
|
||||
dingtalk = ["dingtalk-stream>=0.1.0,<1"]
|
||||
feishu = ["lark-oapi>=1.5.3,<2"]
|
||||
rl = [
|
||||
@@ -94,6 +95,7 @@ all = [
|
||||
"hermes-agent[voice]",
|
||||
"hermes-agent[dingtalk]",
|
||||
"hermes-agent[feishu]",
|
||||
"hermes-agent[mistral]",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
+207
-65
@@ -66,7 +66,8 @@ from model_tools import (
|
||||
handle_function_call,
|
||||
check_toolset_requirements,
|
||||
)
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.terminal_tool import cleanup_vm, get_active_env
|
||||
from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget
|
||||
from tools.interrupt import set_interrupt as _set_interrupt
|
||||
from tools.browser_tool import cleanup_browser
|
||||
|
||||
@@ -75,6 +76,7 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
# Agent internals extracted to agent/ package for modularity
|
||||
from agent.memory_manager import build_memory_context_block
|
||||
from agent.retry_utils import jittered_backoff
|
||||
from agent.prompt_builder import (
|
||||
DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS,
|
||||
MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE,
|
||||
@@ -85,6 +87,7 @@ from agent.model_metadata import (
|
||||
estimate_tokens_rough, estimate_messages_tokens_rough, estimate_request_tokens_rough,
|
||||
get_next_probe_tier, parse_context_limit_from_error,
|
||||
save_context_length, is_local_endpoint,
|
||||
query_ollama_num_ctx,
|
||||
)
|
||||
from agent.context_compressor import ContextCompressor
|
||||
from agent.subdirectory_hints import SubdirectoryHintTracker
|
||||
@@ -409,62 +412,26 @@ def _strip_budget_warnings_from_history(messages: list) -> None:
|
||||
# Large tool result handler — save oversized output to temp file
|
||||
# =========================================================================
|
||||
|
||||
# Threshold at which tool results are saved to a file instead of kept inline.
|
||||
# 100K chars ≈ 25K tokens — generous for any reasonable output but prevents
|
||||
# catastrophic context explosions.
|
||||
_LARGE_RESULT_CHARS = 100_000
|
||||
|
||||
# How many characters of the original result to include as an inline preview
|
||||
# so the model has immediate context about what the tool returned.
|
||||
_LARGE_RESULT_PREVIEW_CHARS = 1_500
|
||||
# =========================================================================
|
||||
# Qwen Portal headers — mimics QwenCode CLI for portal.qwen.ai compatibility.
|
||||
# Extracted as a module-level helper so both __init__ and
|
||||
# _apply_client_headers_for_base_url can share it.
|
||||
# =========================================================================
|
||||
_QWEN_CODE_VERSION = "0.14.1"
|
||||
|
||||
|
||||
def _save_oversized_tool_result(function_name: str, function_result: str) -> str:
|
||||
"""Replace oversized tool results with a file reference + preview.
|
||||
def _qwen_portal_headers() -> dict:
|
||||
"""Return default HTTP headers required by Qwen Portal API."""
|
||||
import platform as _plat
|
||||
|
||||
When a tool returns more than ``_LARGE_RESULT_CHARS`` characters, the full
|
||||
content is written to a temporary file under ``HERMES_HOME/cache/tool_responses/``
|
||||
and the result sent to the model is replaced with:
|
||||
• a brief head preview (first ``_LARGE_RESULT_PREVIEW_CHARS`` chars)
|
||||
• the file path so the model can use ``read_file`` / ``search_files``
|
||||
|
||||
Falls back to destructive truncation if the file write fails.
|
||||
"""
|
||||
original_len = len(function_result)
|
||||
if original_len <= _LARGE_RESULT_CHARS:
|
||||
return function_result
|
||||
|
||||
# Build the target directory
|
||||
try:
|
||||
response_dir = os.path.join(get_hermes_home(), "cache", "tool_responses")
|
||||
os.makedirs(response_dir, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
# Sanitize tool name for use in filename
|
||||
safe_name = re.sub(r"[^\w\-]", "_", function_name)[:40]
|
||||
filename = f"{safe_name}_{timestamp}.txt"
|
||||
filepath = os.path.join(response_dir, filename)
|
||||
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write(function_result)
|
||||
|
||||
preview = function_result[:_LARGE_RESULT_PREVIEW_CHARS]
|
||||
return (
|
||||
f"{preview}\n\n"
|
||||
f"[Large tool response: {original_len:,} characters total — "
|
||||
f"only the first {_LARGE_RESULT_PREVIEW_CHARS:,} shown above. "
|
||||
f"Full output saved to: {filepath}\n"
|
||||
f"Use read_file or search_files on that path to access the rest.]"
|
||||
)
|
||||
except Exception as exc:
|
||||
# Fall back to destructive truncation if file write fails
|
||||
logger.warning("Failed to save large tool result to file: %s", exc)
|
||||
return (
|
||||
function_result[:_LARGE_RESULT_CHARS]
|
||||
+ f"\n\n[Truncated: tool response was {original_len:,} chars, "
|
||||
f"exceeding the {_LARGE_RESULT_CHARS:,} char limit. "
|
||||
f"File save failed: {exc}]"
|
||||
)
|
||||
_ua = f"QwenCode/{_QWEN_CODE_VERSION} ({_plat.system().lower()}; {_plat.machine()})"
|
||||
return {
|
||||
"User-Agent": _ua,
|
||||
"X-DashScope-CacheControl": "enable",
|
||||
"X-DashScope-UserAgent": _ua,
|
||||
"X-DashScope-AuthType": "qwen-oauth",
|
||||
}
|
||||
|
||||
|
||||
class AIAgent:
|
||||
@@ -810,6 +777,8 @@ class AIAgent:
|
||||
client_kwargs["default_headers"] = {
|
||||
"User-Agent": "KimiCLI/1.3",
|
||||
}
|
||||
elif "portal.qwen.ai" in effective_base.lower():
|
||||
client_kwargs["default_headers"] = _qwen_portal_headers()
|
||||
else:
|
||||
# No explicit creds — use the centralized provider router
|
||||
from agent.auxiliary_client import resolve_provider_client
|
||||
@@ -1216,6 +1185,33 @@ class AIAgent:
|
||||
self.session_cost_status = "unknown"
|
||||
self.session_cost_source = "none"
|
||||
|
||||
# ── Ollama num_ctx injection ──
|
||||
# Ollama defaults to 2048 context regardless of the model's capabilities.
|
||||
# When running against an Ollama server, detect the model's max context
|
||||
# and pass num_ctx on every chat request so the full window is used.
|
||||
# User override: set model.ollama_num_ctx in config.yaml to cap VRAM use.
|
||||
self._ollama_num_ctx: int | None = None
|
||||
_ollama_num_ctx_override = None
|
||||
if isinstance(_model_cfg, dict):
|
||||
_ollama_num_ctx_override = _model_cfg.get("ollama_num_ctx")
|
||||
if _ollama_num_ctx_override is not None:
|
||||
try:
|
||||
self._ollama_num_ctx = int(_ollama_num_ctx_override)
|
||||
except (TypeError, ValueError):
|
||||
logger.debug("Invalid ollama_num_ctx config value: %r", _ollama_num_ctx_override)
|
||||
if self._ollama_num_ctx is None and self.base_url and is_local_endpoint(self.base_url):
|
||||
try:
|
||||
_detected = query_ollama_num_ctx(self.model, self.base_url)
|
||||
if _detected and _detected > 0:
|
||||
self._ollama_num_ctx = _detected
|
||||
except Exception as exc:
|
||||
logger.debug("Ollama num_ctx detection failed: %s", exc)
|
||||
if self._ollama_num_ctx and not self.quiet_mode:
|
||||
logger.info(
|
||||
"Ollama num_ctx: will request %d tokens (model max from /api/show)",
|
||||
self._ollama_num_ctx,
|
||||
)
|
||||
|
||||
if not self.quiet_mode:
|
||||
if compression_enabled:
|
||||
print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (compress at {int(compression_threshold*100)}% = {self.context_compressor.threshold_tokens:,})")
|
||||
@@ -4107,6 +4103,8 @@ class AIAgent:
|
||||
self._client_kwargs["default_headers"] = copilot_default_headers()
|
||||
elif "api.kimi.com" in normalized:
|
||||
self._client_kwargs["default_headers"] = {"User-Agent": "KimiCLI/1.3"}
|
||||
elif "portal.qwen.ai" in normalized:
|
||||
self._client_kwargs["default_headers"] = _qwen_portal_headers()
|
||||
else:
|
||||
self._client_kwargs.pop("default_headers", None)
|
||||
|
||||
@@ -4897,7 +4895,7 @@ class AIAgent:
|
||||
effective_key = (fb_client.api_key or resolve_anthropic_token() or "") if fb_provider == "anthropic" else (fb_client.api_key or "")
|
||||
self.api_key = effective_key
|
||||
self._anthropic_api_key = effective_key
|
||||
self._anthropic_base_url = getattr(fb_client, "base_url", None)
|
||||
self._anthropic_base_url = fb_base_url
|
||||
self._anthropic_client = build_anthropic_client(effective_key, self._anthropic_base_url)
|
||||
self._is_anthropic_oauth = _is_oauth_token(effective_key)
|
||||
self.client = None
|
||||
@@ -5253,6 +5251,71 @@ class AIAgent:
|
||||
base = (getattr(self, "base_url", "") or "").lower()
|
||||
return "dashscope" in base or "aliyuncs" in base or "opencode.ai/zen/go" in base
|
||||
|
||||
def _is_qwen_portal(self) -> bool:
|
||||
"""Return True when the base URL targets Qwen Portal."""
|
||||
return "portal.qwen.ai" in self._base_url_lower
|
||||
|
||||
def _qwen_prepare_chat_messages(self, api_messages: list) -> list:
|
||||
prepared = copy.deepcopy(api_messages)
|
||||
if not prepared:
|
||||
return prepared
|
||||
|
||||
for msg in prepared:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
msg["content"] = [{"type": "text", "text": content}]
|
||||
elif isinstance(content, list):
|
||||
# Normalize: convert bare strings to text dicts, keep dicts as-is.
|
||||
# deepcopy already created independent copies, no need for dict().
|
||||
normalized_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
normalized_parts.append({"type": "text", "text": part})
|
||||
elif isinstance(part, dict):
|
||||
normalized_parts.append(part)
|
||||
if normalized_parts:
|
||||
msg["content"] = normalized_parts
|
||||
|
||||
# Inject cache_control on the last part of the system message.
|
||||
for msg in prepared:
|
||||
if isinstance(msg, dict) and msg.get("role") == "system":
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list) and content and isinstance(content[-1], dict):
|
||||
content[-1]["cache_control"] = {"type": "ephemeral"}
|
||||
break
|
||||
|
||||
return prepared
|
||||
|
||||
def _qwen_prepare_chat_messages_inplace(self, messages: list) -> None:
|
||||
"""In-place variant — mutates an already-copied message list."""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
msg["content"] = [{"type": "text", "text": content}]
|
||||
elif isinstance(content, list):
|
||||
normalized_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, str):
|
||||
normalized_parts.append({"type": "text", "text": part})
|
||||
elif isinstance(part, dict):
|
||||
normalized_parts.append(part)
|
||||
if normalized_parts:
|
||||
msg["content"] = normalized_parts
|
||||
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict) and msg.get("role") == "system":
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list) and content and isinstance(content[-1], dict):
|
||||
content[-1]["cache_control"] = {"type": "ephemeral"}
|
||||
break
|
||||
|
||||
def _build_api_kwargs(self, api_messages: list) -> dict:
|
||||
"""Build the keyword arguments dict for the active API mode."""
|
||||
if self.api_mode == "anthropic_messages":
|
||||
@@ -5271,6 +5334,7 @@ class AIAgent:
|
||||
is_oauth=self._is_anthropic_oauth,
|
||||
preserve_dots=self._anthropic_preserve_dots(),
|
||||
context_length=ctx_len,
|
||||
base_url=getattr(self, "_anthropic_base_url", None),
|
||||
)
|
||||
|
||||
if self.api_mode == "codex_responses":
|
||||
@@ -5364,6 +5428,17 @@ class AIAgent:
|
||||
tool_call.pop("call_id", None)
|
||||
tool_call.pop("response_item_id", None)
|
||||
|
||||
# Qwen portal: normalize content to list-of-dicts, inject cache_control.
|
||||
# Must run AFTER codex sanitization so we transform the final messages.
|
||||
# If sanitization already deepcopied, reuse that copy (in-place).
|
||||
if self._is_qwen_portal():
|
||||
if sanitized_messages is api_messages:
|
||||
# No sanitization was done — we need our own copy.
|
||||
sanitized_messages = self._qwen_prepare_chat_messages(sanitized_messages)
|
||||
else:
|
||||
# Already a deepcopy — transform in place to avoid a second deepcopy.
|
||||
self._qwen_prepare_chat_messages_inplace(sanitized_messages)
|
||||
|
||||
# GPT-5 and Codex models respond better to 'developer' than 'system'
|
||||
# for instruction-following. Swap the role at the API boundary so
|
||||
# internal message representation stays uniform ("system").
|
||||
@@ -5396,11 +5471,17 @@ class AIAgent:
|
||||
"messages": sanitized_messages,
|
||||
"timeout": float(os.getenv("HERMES_API_TIMEOUT", 1800.0)),
|
||||
}
|
||||
if self._is_qwen_portal():
|
||||
api_kwargs["metadata"] = {
|
||||
"sessionId": self.session_id or "hermes",
|
||||
"promptId": str(uuid.uuid4()),
|
||||
}
|
||||
if self.tools:
|
||||
api_kwargs["tools"] = self.tools
|
||||
|
||||
if self.max_tokens is not None:
|
||||
api_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
if not self._is_qwen_portal():
|
||||
api_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
elif self._is_openrouter_url() and "claude" in (self.model or "").lower():
|
||||
# OpenRouter translates requests to Anthropic's Messages API,
|
||||
# which requires max_tokens as a mandatory field. When we omit
|
||||
@@ -5456,6 +5537,18 @@ class AIAgent:
|
||||
if _is_nous:
|
||||
extra_body["tags"] = ["product=hermes-agent"]
|
||||
|
||||
# Ollama num_ctx: override the 2048 default so the model actually
|
||||
# uses the context window it was trained for. Passed via the OpenAI
|
||||
# SDK's extra_body → options.num_ctx, which Ollama's OpenAI-compat
|
||||
# endpoint forwards to the runner as --ctx-size.
|
||||
if self._ollama_num_ctx:
|
||||
options = extra_body.get("options", {})
|
||||
options["num_ctx"] = self._ollama_num_ctx
|
||||
extra_body["options"] = options
|
||||
|
||||
if self._is_qwen_portal():
|
||||
extra_body["vl_high_resolution_images"] = True
|
||||
|
||||
if extra_body:
|
||||
api_kwargs["extra_body"] = extra_body
|
||||
|
||||
@@ -6224,15 +6317,17 @@ class AIAgent:
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool complete callback error: {cb_err}")
|
||||
|
||||
# Save oversized results to file instead of destructive truncation
|
||||
function_result = _save_oversized_tool_result(name, function_result)
|
||||
function_result = maybe_persist_tool_result(
|
||||
content=function_result,
|
||||
tool_name=name,
|
||||
tool_use_id=tc.id,
|
||||
env=get_active_env(effective_task_id),
|
||||
)
|
||||
|
||||
# Discover subdirectory context files from tool arguments
|
||||
subdir_hints = self._subdirectory_hints.check_tool_call(name, args)
|
||||
if subdir_hints:
|
||||
function_result += subdir_hints
|
||||
|
||||
# Append tool result message in order
|
||||
tool_msg = {
|
||||
"role": "tool",
|
||||
"content": function_result,
|
||||
@@ -6240,6 +6335,12 @@ class AIAgent:
|
||||
}
|
||||
messages.append(tool_msg)
|
||||
|
||||
# ── Per-turn aggregate budget enforcement ─────────────────────────
|
||||
num_tools = len(parsed_calls)
|
||||
if num_tools > 0:
|
||||
turn_tool_msgs = messages[-num_tools:]
|
||||
enforce_turn_budget(turn_tool_msgs, env=get_active_env(effective_task_id))
|
||||
|
||||
# ── Budget pressure injection ────────────────────────────────────
|
||||
budget_warning = self._get_budget_warning(api_call_count)
|
||||
if budget_warning and messages and messages[-1].get("role") == "tool":
|
||||
@@ -6524,8 +6625,12 @@ class AIAgent:
|
||||
except Exception as cb_err:
|
||||
logging.debug(f"Tool complete callback error: {cb_err}")
|
||||
|
||||
# Save oversized results to file instead of destructive truncation
|
||||
function_result = _save_oversized_tool_result(function_name, function_result)
|
||||
function_result = maybe_persist_tool_result(
|
||||
content=function_result,
|
||||
tool_name=function_name,
|
||||
tool_use_id=tool_call.id,
|
||||
env=get_active_env(effective_task_id),
|
||||
)
|
||||
|
||||
# Discover subdirectory context files from tool arguments
|
||||
subdir_hints = self._subdirectory_hints.check_tool_call(function_name, function_args)
|
||||
@@ -6563,6 +6668,11 @@ class AIAgent:
|
||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
||||
time.sleep(self.tool_delay)
|
||||
|
||||
# ── Per-turn aggregate budget enforcement ─────────────────────────
|
||||
num_tools_seq = len(assistant_message.tool_calls)
|
||||
if num_tools_seq > 0:
|
||||
enforce_turn_budget(messages[-num_tools_seq:], env=get_active_env(effective_task_id))
|
||||
|
||||
# ── Budget pressure injection ─────────────────────────────────
|
||||
# After all tool calls in this turn are processed, check if we're
|
||||
# approaching max_iterations. If so, inject a warning into the LAST
|
||||
@@ -7289,6 +7399,7 @@ class AIAgent:
|
||||
codex_auth_retry_attempted=False
|
||||
anthropic_auth_retry_attempted=False
|
||||
nous_auth_retry_attempted=False
|
||||
thinking_sig_retry_attempted = False
|
||||
has_retried_429 = False
|
||||
restart_with_compressed_messages = False
|
||||
restart_with_length_continuation = False
|
||||
@@ -7504,7 +7615,8 @@ class AIAgent:
|
||||
}
|
||||
|
||||
# Longer backoff for rate limiting (likely cause of None choices)
|
||||
wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s
|
||||
# Jittered exponential: 5s base, 120s cap + random jitter
|
||||
wait_time = jittered_backoff(retry_count, base_delay=5.0, max_delay=120.0)
|
||||
self._vprint(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...", force=True)
|
||||
logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}")
|
||||
|
||||
@@ -7877,8 +7989,38 @@ class AIAgent:
|
||||
print(f"{self.log_prefix} • Check ANTHROPIC_API_KEY in {_dhh}/.env for API keys or legacy token values")
|
||||
print(f"{self.log_prefix} • For API keys: verify at https://console.anthropic.com/settings/keys")
|
||||
print(f"{self.log_prefix} • For Claude Code: run 'claude /login' to refresh, then retry")
|
||||
print(f"{self.log_prefix} • Clear stale keys: hermes config set ANTHROPIC_TOKEN \"\"")
|
||||
print(f"{self.log_prefix} • Legacy cleanup: hermes config set ANTHROPIC_API_KEY \"\"")
|
||||
print(f"{self.log_prefix} • Legacy cleanup: hermes config set ANTHROPIC_TOKEN \"\"")
|
||||
print(f"{self.log_prefix} • Clear stale keys: hermes config set ANTHROPIC_API_KEY \"\"")
|
||||
|
||||
# ── Thinking block signature recovery ─────────────────
|
||||
# Anthropic signs thinking blocks against the full turn
|
||||
# content. Any upstream mutation (context compression,
|
||||
# session truncation, message merging) invalidates the
|
||||
# signature → HTTP 400. Recovery: strip reasoning_details
|
||||
# from all messages so the next retry sends no thinking
|
||||
# blocks at all. One-shot — don't retry infinitely.
|
||||
if (
|
||||
self.api_mode == "anthropic_messages"
|
||||
and status_code == 400
|
||||
and not thinking_sig_retry_attempted
|
||||
):
|
||||
_err_msg_lower = str(api_error).lower()
|
||||
if "signature" in _err_msg_lower and "thinking" in _err_msg_lower:
|
||||
thinking_sig_retry_attempted = True
|
||||
for _m in messages:
|
||||
if isinstance(_m, dict):
|
||||
_m.pop("reasoning_details", None)
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Thinking block signature invalid — "
|
||||
f"stripped all thinking blocks, retrying...",
|
||||
force=True,
|
||||
)
|
||||
logging.warning(
|
||||
"%sThinking block signature recovery: stripped "
|
||||
"reasoning_details from %d messages",
|
||||
self.log_prefix, len(messages),
|
||||
)
|
||||
continue
|
||||
|
||||
retry_count += 1
|
||||
elapsed_time = time.time() - api_start_time
|
||||
@@ -8361,7 +8503,7 @@ class AIAgent:
|
||||
_retry_after = min(int(_ra_raw), 120) # Cap at 2 minutes
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
wait_time = _retry_after if _retry_after else min(2 ** retry_count, 60)
|
||||
wait_time = _retry_after if _retry_after else jittered_backoff(retry_count, base_delay=2.0, max_delay=60.0)
|
||||
if is_rate_limited:
|
||||
self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...")
|
||||
else:
|
||||
|
||||
@@ -1276,6 +1276,258 @@ class TestRoleAlternation:
|
||||
assert [m["role"] for m in result] == ["user", "assistant", "user"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thinking block signature management
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestThinkingBlockSignatureManagement:
|
||||
"""Tests for the thinking block handling strategy:
|
||||
strip from old turns, preserve latest signed, downgrade unsigned."""
|
||||
|
||||
def test_thinking_stripped_from_non_last_assistant(self):
|
||||
"""Thinking blocks are removed from all assistant messages except the last."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "tc_1", "function": {"name": "tool1", "arguments": "{}"}},
|
||||
],
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Old reasoning.", "signature": "sig_old"},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "result 1"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "tc_2", "function": {"name": "tool2", "arguments": "{}"}},
|
||||
],
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Latest reasoning.", "signature": "sig_new"},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_2", "content": "result 2"},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
|
||||
# Find both assistant messages
|
||||
assistants = [m for m in result if m["role"] == "assistant"]
|
||||
assert len(assistants) == 2
|
||||
|
||||
# First (non-last) assistant: no thinking blocks
|
||||
first_types = [b.get("type") for b in assistants[0]["content"]]
|
||||
assert "thinking" not in first_types
|
||||
assert "redacted_thinking" not in first_types
|
||||
assert "tool_use" in first_types # tool_use should survive
|
||||
|
||||
# Last assistant: thinking block preserved with signature
|
||||
last_blocks = assistants[1]["content"]
|
||||
thinking_blocks = [b for b in last_blocks if b.get("type") == "thinking"]
|
||||
assert len(thinking_blocks) == 1
|
||||
assert thinking_blocks[0]["thinking"] == "Latest reasoning."
|
||||
assert thinking_blocks[0]["signature"] == "sig_new"
|
||||
|
||||
def test_signed_thinking_preserved_on_last_turn(self):
|
||||
"""A signed thinking block on the last assistant message is kept."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The answer is 42.",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Deep thought.", "signature": "sig_valid"},
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
blocks = result[0]["content"]
|
||||
thinking = [b for b in blocks if b.get("type") == "thinking"]
|
||||
assert len(thinking) == 1
|
||||
assert thinking[0]["signature"] == "sig_valid"
|
||||
|
||||
def test_unsigned_thinking_downgraded_to_text_on_last_turn(self):
|
||||
"""Unsigned thinking blocks on the last turn become text blocks."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Response text.",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Unsigned reasoning."},
|
||||
# No 'signature' field
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
blocks = result[0]["content"]
|
||||
|
||||
# No thinking blocks should remain
|
||||
assert not any(b.get("type") == "thinking" for b in blocks)
|
||||
# The reasoning text should be preserved as a text block
|
||||
text_contents = [b.get("text", "") for b in blocks if b.get("type") == "text"]
|
||||
assert "Unsigned reasoning." in text_contents
|
||||
|
||||
def test_redacted_thinking_with_data_preserved(self):
|
||||
"""Redacted thinking with 'data' field is kept on last turn."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Response.",
|
||||
"reasoning_details": [
|
||||
{"type": "redacted_thinking", "data": "opaque_signature_data"},
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
blocks = result[0]["content"]
|
||||
redacted = [b for b in blocks if b.get("type") == "redacted_thinking"]
|
||||
assert len(redacted) == 1
|
||||
assert redacted[0]["data"] == "opaque_signature_data"
|
||||
|
||||
def test_redacted_thinking_without_data_dropped(self):
|
||||
"""Redacted thinking without 'data' is dropped — can't be validated."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Response.",
|
||||
"reasoning_details": [
|
||||
{"type": "redacted_thinking"},
|
||||
# No 'data' field
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
blocks = result[0]["content"]
|
||||
assert not any(b.get("type") == "redacted_thinking" for b in blocks)
|
||||
|
||||
def test_cache_control_stripped_from_thinking_blocks(self):
|
||||
"""cache_control markers are removed from thinking/redacted_thinking blocks."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "tc_1", "function": {"name": "t", "arguments": "{}"}},
|
||||
],
|
||||
"reasoning_details": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "Reasoning.",
|
||||
"signature": "sig_1",
|
||||
"cache_control": {"type": "ephemeral"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tc_1", "content": "result"},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assistant = next(m for m in result if m["role"] == "assistant")
|
||||
for block in assistant["content"]:
|
||||
if block.get("type") in ("thinking", "redacted_thinking"):
|
||||
assert "cache_control" not in block
|
||||
|
||||
def test_thinking_stripped_from_merged_consecutive_assistants(self):
|
||||
"""When consecutive assistants are merged, second one's thinking is dropped."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "First response.",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "First thought.", "signature": "sig_1"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Second response.",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Second thought.", "signature": "sig_2"},
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
|
||||
# Should be merged into one assistant message
|
||||
assistants = [m for m in result if m["role"] == "assistant"]
|
||||
assert len(assistants) == 1
|
||||
|
||||
# Only the first thinking block should remain (signed, on the last/only assistant)
|
||||
blocks = assistants[0]["content"]
|
||||
thinking = [b for b in blocks if b.get("type") == "thinking"]
|
||||
assert len(thinking) == 1
|
||||
assert thinking[0]["thinking"] == "First thought."
|
||||
|
||||
def test_empty_content_after_strip_gets_placeholder(self):
|
||||
"""If stripping thinking leaves an empty message, a placeholder is added."""
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Only thinking, no text."},
|
||||
# Unsigned — will be downgraded, but content was empty string
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Next message."},
|
||||
{"role": "assistant", "content": "Final."},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
# First assistant is non-last, so thinking is stripped completely.
|
||||
# The original content was empty and thinking was unsigned → placeholder
|
||||
first_assistant = result[0]
|
||||
assert first_assistant["role"] == "assistant"
|
||||
assert len(first_assistant["content"]) >= 1
|
||||
|
||||
def test_multi_turn_conversation_preserves_only_last(self):
|
||||
"""Full multi-turn conversation: only last assistant keeps thinking."""
|
||||
messages = [
|
||||
{"role": "user", "content": "Question 1"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Answer 1",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Thought 1", "signature": "sig_1"},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Question 2"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Answer 2",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Thought 2", "signature": "sig_2"},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Question 3"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Answer 3",
|
||||
"reasoning_details": [
|
||||
{"type": "thinking", "thinking": "Thought 3", "signature": "sig_3"},
|
||||
],
|
||||
},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
|
||||
assistants = [m for m in result if m["role"] == "assistant"]
|
||||
assert len(assistants) == 3
|
||||
|
||||
# First two: no thinking blocks
|
||||
for a in assistants[:2]:
|
||||
assert not any(
|
||||
b.get("type") in ("thinking", "redacted_thinking")
|
||||
for b in a["content"]
|
||||
if isinstance(b, dict)
|
||||
)
|
||||
|
||||
# Last one: thinking preserved
|
||||
last_thinking = [
|
||||
b for b in assistants[2]["content"]
|
||||
if isinstance(b, dict) and b.get("type") == "thinking"
|
||||
]
|
||||
assert len(last_thinking) == 1
|
||||
assert last_thinking[0]["signature"] == "sig_3"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool choice
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -471,6 +471,23 @@ class TestExplicitProviderRouting:
|
||||
client, model = resolve_provider_client("zai")
|
||||
assert client is not None
|
||||
|
||||
def test_explicit_google_alias_uses_gemini_credentials(self):
|
||||
"""provider='google' should route through the gemini API-key provider."""
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "gemini-key",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("google", model="gemini-3.1-pro-preview")
|
||||
|
||||
assert client is not None
|
||||
assert model == "gemini-3.1-pro-preview"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
def test_explicit_unknown_returns_none(self, monkeypatch):
|
||||
"""Unknown provider should return None."""
|
||||
client, model = resolve_provider_client("nonexistent-provider")
|
||||
@@ -624,12 +641,15 @@ class TestVisionClientFallback:
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
def test_vision_auto_includes_anthropic_when_configured(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
def test_vision_auto_includes_active_provider_when_configured(self, monkeypatch):
|
||||
"""Active provider appears in available backends when credentials exist."""
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
|
||||
):
|
||||
backends = get_available_vision_backends()
|
||||
|
||||
@@ -702,88 +722,51 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert call_kwargs["base_url"] == "https://api.githubcopilot.com"
|
||||
assert call_kwargs["default_headers"]["Editor-Version"]
|
||||
|
||||
def test_vision_auto_uses_anthropic_when_no_higher_priority_backend(self, monkeypatch):
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
def test_vision_auto_uses_active_provider_as_fallback(self, monkeypatch):
|
||||
"""When no OpenRouter/Nous available, vision auto falls back to active provider."""
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
||||
assert client is not None
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
|
||||
def test_selected_anthropic_provider_is_preferred_for_vision_auto(self, monkeypatch):
|
||||
def test_vision_auto_prefers_active_provider_over_openrouter(self, monkeypatch):
|
||||
"""Active provider is tried before OpenRouter in vision auto."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-api03-key")
|
||||
|
||||
def fake_load_config():
|
||||
return {"model": {"provider": "anthropic", "default": "claude-sonnet-4-6"}}
|
||||
monkeypatch.setenv("ANTHROPIC_API_KEY", "***")
|
||||
|
||||
with (
|
||||
patch("agent.auxiliary_client._read_nous_auth", return_value=None),
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="anthropic"),
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="claude-sonnet-4"),
|
||||
patch("agent.anthropic_adapter.build_anthropic_client", return_value=MagicMock()),
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="sk-ant-api03-key"),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
patch("hermes_cli.config.load_config", fake_load_config),
|
||||
):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
||||
assert client is not None
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
|
||||
def test_selected_codex_provider_short_circuits_vision_auto(self, monkeypatch):
|
||||
def fake_load_config():
|
||||
return {"model": {"provider": "openai-codex", "default": "gpt-5.2-codex"}}
|
||||
|
||||
codex_client = MagicMock()
|
||||
with (
|
||||
patch("hermes_cli.config.load_config", fake_load_config),
|
||||
patch("agent.auxiliary_client._try_codex", return_value=(codex_client, "gpt-5.2-codex")) as mock_codex,
|
||||
patch("agent.auxiliary_client._try_openrouter") as mock_openrouter,
|
||||
patch("agent.auxiliary_client._try_nous") as mock_nous,
|
||||
patch("agent.auxiliary_client._try_anthropic") as mock_anthropic,
|
||||
patch("agent.auxiliary_client._try_custom_endpoint") as mock_custom,
|
||||
patch("agent.anthropic_adapter.resolve_anthropic_token", return_value="***"),
|
||||
):
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "openai-codex"
|
||||
assert client is codex_client
|
||||
assert model == "gpt-5.2-codex"
|
||||
mock_codex.assert_called_once()
|
||||
mock_openrouter.assert_not_called()
|
||||
mock_nous.assert_not_called()
|
||||
mock_anthropic.assert_not_called()
|
||||
mock_custom.assert_not_called()
|
||||
# Active provider should win over OpenRouter
|
||||
assert provider == "anthropic"
|
||||
|
||||
def test_vision_auto_includes_codex(self, codex_auth_dir):
|
||||
"""Codex supports vision (gpt-5.3-codex), so auto mode should use it."""
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client.OpenAI"):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
from agent.auxiliary_client import CodexAuxiliaryClient
|
||||
assert isinstance(client, CodexAuxiliaryClient)
|
||||
assert model == "gpt-5.2-codex"
|
||||
|
||||
def test_vision_auto_falls_back_to_custom_endpoint(self, monkeypatch):
|
||||
"""Custom endpoint is used as fallback in vision auto mode.
|
||||
|
||||
Many local models (Qwen-VL, LLaVA, etc.) support vision.
|
||||
When no OpenRouter/Nous/Codex is available, try the custom endpoint.
|
||||
"""
|
||||
def test_vision_auto_uses_named_custom_as_active_provider(self, monkeypatch):
|
||||
"""Named custom provider works as active provider fallback in vision auto."""
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_custom_runtime",
|
||||
return_value=("http://localhost:1234/v1", "local-key")), \
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||
client, model = get_vision_auxiliary_client()
|
||||
assert client is not None # Custom endpoint picked up as fallback
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value="custom:local"), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value="my-local-model"), \
|
||||
patch("agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(MagicMock(), "my-local-model")) as mock_resolve:
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
assert client is not None
|
||||
assert provider == "custom:local"
|
||||
|
||||
def test_vision_direct_endpoint_override(self, monkeypatch):
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
@@ -822,6 +805,31 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_vision_config_google_provider_uses_gemini_credentials(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"vision": {
|
||||
"provider": "google",
|
||||
"model": "gemini-3.1-pro-preview",
|
||||
}
|
||||
}
|
||||
}
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "gemini-key",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
resolved_provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert resolved_provider == "gemini"
|
||||
assert client is not None
|
||||
assert model == "gemini-3.1-pro-preview"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
def test_vision_forced_main_uses_custom_endpoint(self, monkeypatch):
|
||||
"""When explicitly forced to 'main', vision CAN use custom endpoint."""
|
||||
config = {
|
||||
@@ -846,7 +854,14 @@ class TestAuxiliaryPoolAwareness:
|
||||
monkeypatch.setenv("AUXILIARY_VISION_PROVIDER", "main")
|
||||
monkeypatch.delenv("OPENAI_BASE_URL", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
# Clear client cache to avoid stale entries from previous tests
|
||||
from agent.auxiliary_client import _client_cache
|
||||
_client_cache.clear()
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
patch("agent.auxiliary_client._read_main_provider", return_value=""), \
|
||||
patch("agent.auxiliary_client._read_main_model", return_value=""), \
|
||||
patch("agent.auxiliary_client._select_pool_entry", return_value=(False, None)), \
|
||||
patch("agent.auxiliary_client._resolve_custom_runtime", return_value=(None, None)), \
|
||||
patch("agent.auxiliary_client._read_codex_access_token", return_value=None), \
|
||||
patch("agent.auxiliary_client._resolve_api_key_provider", return_value=(None, None)):
|
||||
client, model = get_vision_auxiliary_client()
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Tests for MiniMax auxiliary client URL normalization.
|
||||
|
||||
MiniMax and MiniMax-CN set inference_base_url to the /anthropic path.
|
||||
The auxiliary client uses the OpenAI SDK, which needs /v1 instead.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from agent.auxiliary_client import _to_openai_base_url
|
||||
|
||||
|
||||
class TestToOpenaiBaseUrl:
|
||||
def test_minimax_global_anthropic_suffix_replaced(self):
|
||||
assert _to_openai_base_url("https://api.minimax.io/anthropic") == "https://api.minimax.io/v1"
|
||||
|
||||
def test_minimax_cn_anthropic_suffix_replaced(self):
|
||||
assert _to_openai_base_url("https://api.minimaxi.com/anthropic") == "https://api.minimaxi.com/v1"
|
||||
|
||||
def test_trailing_slash_stripped_before_replace(self):
|
||||
assert _to_openai_base_url("https://api.minimax.io/anthropic/") == "https://api.minimax.io/v1"
|
||||
|
||||
def test_v1_url_unchanged(self):
|
||||
assert _to_openai_base_url("https://api.openai.com/v1") == "https://api.openai.com/v1"
|
||||
|
||||
def test_openrouter_url_unchanged(self):
|
||||
assert _to_openai_base_url("https://openrouter.ai/api/v1") == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_anthropic_domain_unchanged(self):
|
||||
"""api.anthropic.com doesn't end with /anthropic — should be untouched."""
|
||||
assert _to_openai_base_url("https://api.anthropic.com") == "https://api.anthropic.com"
|
||||
|
||||
def test_anthropic_in_subpath_unchanged(self):
|
||||
assert _to_openai_base_url("https://example.com/anthropic/extra") == "https://example.com/anthropic/extra"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _to_openai_base_url("") == ""
|
||||
|
||||
def test_none(self):
|
||||
assert _to_openai_base_url(None) == ""
|
||||
@@ -0,0 +1,105 @@
|
||||
"""Tests for MiniMax provider hardening — context lengths, thinking guard, catalog."""
|
||||
|
||||
|
||||
class TestMinimaxContextLengths:
|
||||
"""Verify per-model context length entries for MiniMax models."""
|
||||
|
||||
def test_m1_variants_have_1m_context(self):
|
||||
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
|
||||
# Keys are lowercase because the lookup lowercases model names
|
||||
for model in ("minimax-m1", "minimax-m1-40k", "minimax-m1-80k",
|
||||
"minimax-m1-128k", "minimax-m1-256k"):
|
||||
assert model in DEFAULT_CONTEXT_LENGTHS, f"{model} missing from context lengths"
|
||||
assert DEFAULT_CONTEXT_LENGTHS[model] == 1_000_000, f"{model} expected 1M"
|
||||
|
||||
def test_m2_variants_have_1m_context(self):
|
||||
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
|
||||
# Keys are lowercase because the lookup lowercases model names
|
||||
for model in ("minimax-m2.5", "minimax-m2.7"):
|
||||
assert model in DEFAULT_CONTEXT_LENGTHS, f"{model} missing from context lengths"
|
||||
assert DEFAULT_CONTEXT_LENGTHS[model] == 1_048_576, f"{model} expected 1048576"
|
||||
|
||||
def test_minimax_prefix_fallback(self):
|
||||
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
|
||||
# The generic "minimax" prefix entry should be 1M for unknown models
|
||||
assert DEFAULT_CONTEXT_LENGTHS["minimax"] == 1_048_576
|
||||
|
||||
|
||||
|
||||
class TestMinimaxThinkingGuard:
|
||||
"""Verify that build_anthropic_kwargs does NOT add thinking params for MiniMax models."""
|
||||
|
||||
def test_no_thinking_for_minimax_m27(self):
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="MiniMax-M2.7",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
)
|
||||
assert "thinking" not in kwargs
|
||||
assert "output_config" not in kwargs
|
||||
|
||||
def test_no_thinking_for_minimax_m1(self):
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="MiniMax-M1-128k",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "high"},
|
||||
)
|
||||
assert "thinking" not in kwargs
|
||||
|
||||
def test_thinking_still_works_for_claude(self):
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-sonnet-4-20250514",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config={"enabled": True, "effort": "medium"},
|
||||
)
|
||||
assert "thinking" in kwargs
|
||||
|
||||
|
||||
class TestMinimaxAuxModel:
|
||||
"""Verify auxiliary model is standard (not highspeed)."""
|
||||
|
||||
def test_minimax_aux_is_standard(self):
|
||||
from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS
|
||||
assert _API_KEY_PROVIDER_AUX_MODELS["minimax"] == "MiniMax-M2.7"
|
||||
assert _API_KEY_PROVIDER_AUX_MODELS["minimax-cn"] == "MiniMax-M2.7"
|
||||
|
||||
def test_minimax_aux_not_highspeed(self):
|
||||
from agent.auxiliary_client import _API_KEY_PROVIDER_AUX_MODELS
|
||||
assert "highspeed" not in _API_KEY_PROVIDER_AUX_MODELS["minimax"]
|
||||
assert "highspeed" not in _API_KEY_PROVIDER_AUX_MODELS["minimax-cn"]
|
||||
|
||||
|
||||
class TestMinimaxModelCatalog:
|
||||
"""Verify the model catalog includes M1 family and excludes deprecated models."""
|
||||
|
||||
def test_catalog_includes_m1_family(self):
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
for provider in ("minimax", "minimax-cn"):
|
||||
models = _PROVIDER_MODELS[provider]
|
||||
assert "MiniMax-M1" in models
|
||||
assert "MiniMax-M1-40k" in models
|
||||
assert "MiniMax-M1-80k" in models
|
||||
assert "MiniMax-M1-128k" in models
|
||||
assert "MiniMax-M1-256k" in models
|
||||
|
||||
def test_catalog_excludes_deprecated(self):
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
for provider in ("minimax", "minimax-cn"):
|
||||
models = _PROVIDER_MODELS[provider]
|
||||
assert "MiniMax-M2.1" not in models
|
||||
|
||||
def test_catalog_excludes_highspeed(self):
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
for provider in ("minimax", "minimax-cn"):
|
||||
models = _PROVIDER_MODELS[provider]
|
||||
assert "MiniMax-M2.7-highspeed" not in models
|
||||
assert "MiniMax-M2.5-highspeed" not in models
|
||||
@@ -0,0 +1,66 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from hermes_cli.plugins import VALID_HOOKS, PluginManager
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
def test_session_hooks_in_valid_hooks():
|
||||
"""Verify on_session_finalize and on_session_reset are registered as valid hooks."""
|
||||
assert "on_session_finalize" in VALID_HOOKS
|
||||
assert "on_session_reset" in VALID_HOOKS
|
||||
|
||||
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
def test_session_finalize_on_reset(mock_invoke_hook):
|
||||
"""Verify on_session_finalize fires when /new or /reset is used."""
|
||||
cli = HermesCLI()
|
||||
cli.agent = MagicMock()
|
||||
cli.agent.session_id = "test-session-id"
|
||||
|
||||
# Simulate /new command which triggers on_session_finalize for the old session
|
||||
cli.new_session(silent=True)
|
||||
|
||||
# Check if on_session_finalize was called for the old session
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_finalize", session_id="test-session-id", platform="cli"
|
||||
)
|
||||
# Check if on_session_reset was called for the new session
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_reset", session_id=cli.session_id, platform="cli"
|
||||
)
|
||||
|
||||
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
def test_session_finalize_on_cleanup(mock_invoke_hook):
|
||||
"""Verify on_session_finalize fires during CLI exit cleanup."""
|
||||
import cli as cli_mod
|
||||
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.session_id = "cleanup-session-id"
|
||||
cli_mod._active_agent_ref = mock_agent
|
||||
cli_mod._cleanup_done = False
|
||||
|
||||
cli_mod._run_cleanup()
|
||||
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_finalize", session_id="cleanup-session-id", platform="cli"
|
||||
)
|
||||
|
||||
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
def test_hook_errors_are_caught(mock_invoke_hook):
|
||||
"""Verify hook exceptions are caught and don't crash the agent."""
|
||||
mgr = PluginManager()
|
||||
|
||||
# Register a hook that raises
|
||||
def bad_callback(**kwargs):
|
||||
raise Exception("Hook failed")
|
||||
|
||||
mgr._hooks["on_session_finalize"] = [bad_callback]
|
||||
|
||||
# This should not raise
|
||||
results = mgr.invoke_hook("on_session_finalize", session_id="test", platform="cli")
|
||||
assert results == []
|
||||
+231
-24
@@ -33,6 +33,13 @@ def git_repo(tmp_path):
|
||||
["git", "commit", "-m", "Initial commit"],
|
||||
cwd=repo, capture_output=True,
|
||||
)
|
||||
# Add a fake remote ref so cleanup logic sees the initial commit as
|
||||
# "pushed". Without this, `git log HEAD --not --remotes` treats every
|
||||
# commit as unpushed and cleanup refuses to delete worktrees.
|
||||
subprocess.run(
|
||||
["git", "update-ref", "refs/remotes/origin/main", "HEAD"],
|
||||
cwd=repo, capture_output=True,
|
||||
)
|
||||
return repo
|
||||
|
||||
|
||||
@@ -81,7 +88,11 @@ def _setup_worktree(repo_root):
|
||||
|
||||
|
||||
def _cleanup_worktree(info):
|
||||
"""Test version of _cleanup_worktree."""
|
||||
"""Test version of _cleanup_worktree.
|
||||
|
||||
Preserves the worktree only if it has unpushed commits.
|
||||
Dirty working tree alone is not enough to keep it.
|
||||
"""
|
||||
wt_path = info["path"]
|
||||
branch = info["branch"]
|
||||
repo_root = info["repo_root"]
|
||||
@@ -89,15 +100,15 @@ def _cleanup_worktree(info):
|
||||
if not Path(wt_path).exists():
|
||||
return
|
||||
|
||||
# Check for uncommitted changes
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
# Check for unpushed commits
|
||||
result = subprocess.run(
|
||||
["git", "log", "--oneline", "HEAD", "--not", "--remotes"],
|
||||
capture_output=True, text=True, timeout=10, cwd=wt_path,
|
||||
)
|
||||
has_changes = bool(status.stdout.strip())
|
||||
has_unpushed = bool(result.stdout.strip())
|
||||
|
||||
if has_changes:
|
||||
return False # Did not clean up
|
||||
if has_unpushed:
|
||||
return False # Did not clean up — has unpushed commits
|
||||
|
||||
subprocess.run(
|
||||
["git", "worktree", "remove", wt_path, "--force"],
|
||||
@@ -204,20 +215,45 @@ class TestWorktreeCleanup:
|
||||
assert result is True
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
def test_dirty_worktree_kept(self, git_repo):
|
||||
def test_dirty_worktree_cleaned_when_no_unpushed(self, git_repo):
|
||||
"""Dirty working tree without unpushed commits is cleaned up.
|
||||
|
||||
Agent sessions typically leave untracked files / artifacts behind.
|
||||
Since all real work is in pushed commits, these don't warrant
|
||||
keeping the worktree.
|
||||
"""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make uncommitted changes
|
||||
# Make uncommitted changes (untracked file)
|
||||
(Path(info["path"]) / "new-file.txt").write_text("uncommitted")
|
||||
subprocess.run(
|
||||
["git", "add", "new-file.txt"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
# The git_repo fixture already has a fake remote ref so the initial
|
||||
# commit is seen as "pushed". No unpushed commits → cleanup proceeds.
|
||||
result = _cleanup_worktree(info)
|
||||
assert result is False
|
||||
assert Path(info["path"]).exists() # Still there
|
||||
assert result is True # Cleaned up despite dirty working tree
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
def test_worktree_with_unpushed_commits_kept(self, git_repo):
|
||||
"""Worktree with unpushed commits is preserved."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make a commit that is NOT on any remote
|
||||
(Path(info["path"]) / "work.txt").write_text("real work")
|
||||
subprocess.run(["git", "add", "work.txt"], cwd=info["path"], capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "agent work"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
result = _cleanup_worktree(info)
|
||||
assert result is False # Kept — has unpushed commits
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
def test_branch_deleted_on_cleanup(self, git_repo):
|
||||
info = _setup_worktree(str(git_repo))
|
||||
@@ -367,7 +403,7 @@ class TestMultipleWorktrees:
|
||||
lines = [l for l in result.stdout.strip().splitlines() if l.strip()]
|
||||
assert len(lines) == 11
|
||||
|
||||
# Cleanup all
|
||||
# Cleanup all (git_repo fixture has a fake remote ref so cleanup works)
|
||||
for info in worktrees:
|
||||
# Discard changes first so cleanup works
|
||||
subprocess.run(
|
||||
@@ -492,33 +528,77 @@ class TestStaleWorktreePruning:
|
||||
assert not pruned
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
def test_keeps_dirty_old_worktree(self, git_repo):
|
||||
"""Old worktrees with uncommitted changes should NOT be pruned."""
|
||||
def test_keeps_old_worktree_with_unpushed_commits(self, git_repo):
|
||||
"""Old worktrees (24-72h) with unpushed commits should NOT be pruned."""
|
||||
import time
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make it dirty
|
||||
(Path(info["path"]) / "dirty.txt").write_text("uncommitted")
|
||||
# Make an unpushed commit
|
||||
(Path(info["path"]) / "work.txt").write_text("real work")
|
||||
subprocess.run(["git", "add", "work.txt"], cwd=info["path"], capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "add", "dirty.txt"],
|
||||
["git", "commit", "-m", "agent work"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
# Make it old
|
||||
# Make it old (25h — in the 24-72h soft tier)
|
||||
old_time = time.time() - (25 * 3600)
|
||||
os.utime(info["path"], (old_time, old_time))
|
||||
|
||||
# Check if it would be pruned
|
||||
status = subprocess.run(
|
||||
["git", "status", "--porcelain"],
|
||||
# Check for unpushed commits (simulates prune logic)
|
||||
result = subprocess.run(
|
||||
["git", "log", "--oneline", "HEAD", "--not", "--remotes"],
|
||||
capture_output=True, text=True, cwd=info["path"],
|
||||
)
|
||||
has_changes = bool(status.stdout.strip())
|
||||
assert has_changes # Should be dirty → not pruned
|
||||
has_unpushed = bool(result.stdout.strip())
|
||||
assert has_unpushed # Has unpushed commits → not pruned in soft tier
|
||||
assert Path(info["path"]).exists()
|
||||
|
||||
def test_force_prunes_very_old_worktree(self, git_repo):
|
||||
"""Worktrees older than 72h should be force-pruned regardless."""
|
||||
import time
|
||||
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
# Make an unpushed commit (would normally protect it)
|
||||
(Path(info["path"]) / "work.txt").write_text("stale work")
|
||||
subprocess.run(["git", "add", "work.txt"], cwd=info["path"], capture_output=True)
|
||||
subprocess.run(
|
||||
["git", "commit", "-m", "old agent work"],
|
||||
cwd=info["path"], capture_output=True,
|
||||
)
|
||||
|
||||
# Make it very old (73h — beyond the 72h hard threshold)
|
||||
old_time = time.time() - (73 * 3600)
|
||||
os.utime(info["path"], (old_time, old_time))
|
||||
|
||||
# Simulate the force-prune tier check
|
||||
hard_cutoff = time.time() - (72 * 3600)
|
||||
mtime = Path(info["path"]).stat().st_mtime
|
||||
assert mtime <= hard_cutoff # Should qualify for force removal
|
||||
|
||||
# Actually remove it (simulates _prune_stale_worktrees force path)
|
||||
branch_result = subprocess.run(
|
||||
["git", "branch", "--show-current"],
|
||||
capture_output=True, text=True, timeout=5, cwd=info["path"],
|
||||
)
|
||||
branch = branch_result.stdout.strip()
|
||||
|
||||
subprocess.run(
|
||||
["git", "worktree", "remove", info["path"], "--force"],
|
||||
capture_output=True, text=True, timeout=15, cwd=str(git_repo),
|
||||
)
|
||||
if branch:
|
||||
subprocess.run(
|
||||
["git", "branch", "-D", branch],
|
||||
capture_output=True, text=True, timeout=10, cwd=str(git_repo),
|
||||
)
|
||||
|
||||
assert not Path(info["path"]).exists()
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases for robustness."""
|
||||
@@ -611,6 +691,133 @@ class TestTerminalCWDIntegration:
|
||||
assert result.stdout.strip() == "true"
|
||||
|
||||
|
||||
class TestOrphanedBranchPruning:
|
||||
"""Test cleanup of orphaned hermes/* and pr-* branches."""
|
||||
|
||||
def test_prunes_orphaned_hermes_branch(self, git_repo):
|
||||
"""hermes/hermes-* branches with no worktree should be deleted."""
|
||||
# Create a branch that looks like a worktree branch but has no worktree
|
||||
subprocess.run(
|
||||
["git", "branch", "hermes/hermes-deadbeef", "HEAD"],
|
||||
cwd=str(git_repo), capture_output=True,
|
||||
)
|
||||
|
||||
# Verify it exists
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--list", "hermes/hermes-deadbeef"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
assert "hermes/hermes-deadbeef" in result.stdout
|
||||
|
||||
# Simulate _prune_orphaned_branches logic
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--format=%(refname:short)"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
all_branches = [b.strip() for b in result.stdout.strip().split("\n") if b.strip()]
|
||||
|
||||
wt_result = subprocess.run(
|
||||
["git", "worktree", "list", "--porcelain"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
active_branches = {"main"}
|
||||
for line in wt_result.stdout.split("\n"):
|
||||
if line.startswith("branch refs/heads/"):
|
||||
active_branches.add(line.split("branch refs/heads/", 1)[-1].strip())
|
||||
|
||||
orphaned = [
|
||||
b for b in all_branches
|
||||
if b not in active_branches
|
||||
and (b.startswith("hermes/hermes-") or b.startswith("pr-"))
|
||||
]
|
||||
assert "hermes/hermes-deadbeef" in orphaned
|
||||
|
||||
# Delete them
|
||||
if orphaned:
|
||||
subprocess.run(
|
||||
["git", "branch", "-D"] + orphaned,
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
|
||||
# Verify gone
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--list", "hermes/hermes-deadbeef"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
assert "hermes/hermes-deadbeef" not in result.stdout
|
||||
|
||||
def test_prunes_orphaned_pr_branch(self, git_repo):
|
||||
"""pr-* branches should be deleted during pruning."""
|
||||
subprocess.run(
|
||||
["git", "branch", "pr-1234", "HEAD"],
|
||||
cwd=str(git_repo), capture_output=True,
|
||||
)
|
||||
subprocess.run(
|
||||
["git", "branch", "pr-5678", "HEAD"],
|
||||
cwd=str(git_repo), capture_output=True,
|
||||
)
|
||||
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--format=%(refname:short)"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
all_branches = [b.strip() for b in result.stdout.strip().split("\n") if b.strip()]
|
||||
|
||||
active_branches = {"main"}
|
||||
orphaned = [
|
||||
b for b in all_branches
|
||||
if b not in active_branches and b.startswith("pr-")
|
||||
]
|
||||
assert "pr-1234" in orphaned
|
||||
assert "pr-5678" in orphaned
|
||||
|
||||
subprocess.run(
|
||||
["git", "branch", "-D"] + orphaned,
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
|
||||
# Verify gone
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--format=%(refname:short)"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
remaining = result.stdout.strip()
|
||||
assert "pr-1234" not in remaining
|
||||
assert "pr-5678" not in remaining
|
||||
|
||||
def test_preserves_active_worktree_branch(self, git_repo):
|
||||
"""Branches with active worktrees should NOT be pruned."""
|
||||
info = _setup_worktree(str(git_repo))
|
||||
assert info is not None
|
||||
|
||||
result = subprocess.run(
|
||||
["git", "worktree", "list", "--porcelain"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
active_branches = set()
|
||||
for line in result.stdout.split("\n"):
|
||||
if line.startswith("branch refs/heads/"):
|
||||
active_branches.add(line.split("branch refs/heads/", 1)[-1].strip())
|
||||
|
||||
assert info["branch"] in active_branches # Protected
|
||||
|
||||
def test_preserves_main_branch(self, git_repo):
|
||||
"""main branch should never be pruned."""
|
||||
result = subprocess.run(
|
||||
["git", "branch", "--format=%(refname:short)"],
|
||||
capture_output=True, text=True, cwd=str(git_repo),
|
||||
)
|
||||
all_branches = [b.strip() for b in result.stdout.strip().split("\n") if b.strip()]
|
||||
active_branches = {"main"}
|
||||
|
||||
orphaned = [
|
||||
b for b in all_branches
|
||||
if b not in active_branches
|
||||
and (b.startswith("hermes/hermes-") or b.startswith("pr-"))
|
||||
]
|
||||
assert "main" not in orphaned
|
||||
|
||||
|
||||
class TestSystemPromptInjection:
|
||||
"""Test that the agent gets worktree context in its system prompt."""
|
||||
|
||||
@@ -625,7 +832,7 @@ class TestSystemPromptInjection:
|
||||
f"{info['path']}. Your branch is `{info['branch']}`. "
|
||||
f"Changes here do not affect the main working tree or other agents. "
|
||||
f"Remember to commit and push your changes, and create a PR if appropriate. "
|
||||
f"The original repo is at {info['repo_root']}.]"
|
||||
f"The original repo is at {info['repo_root']}.]\n"
|
||||
)
|
||||
|
||||
assert info["path"] in wt_note
|
||||
|
||||
@@ -339,6 +339,36 @@ class TestMarkJobRun:
|
||||
assert updated["last_status"] == "error"
|
||||
assert updated["last_error"] == "timeout"
|
||||
|
||||
def test_delivery_error_tracked_separately(self, tmp_cron_dir):
|
||||
"""Agent succeeds but delivery fails — both tracked independently."""
|
||||
job = create_job(prompt="Report", schedule="every 1h")
|
||||
mark_job_run(job["id"], success=True, delivery_error="platform 'telegram' not configured")
|
||||
updated = get_job(job["id"])
|
||||
assert updated["last_status"] == "ok"
|
||||
assert updated["last_error"] is None
|
||||
assert updated["last_delivery_error"] == "platform 'telegram' not configured"
|
||||
|
||||
def test_delivery_error_cleared_on_success(self, tmp_cron_dir):
|
||||
"""Successful delivery clears the previous delivery error."""
|
||||
job = create_job(prompt="Report", schedule="every 1h")
|
||||
mark_job_run(job["id"], success=True, delivery_error="network timeout")
|
||||
updated = get_job(job["id"])
|
||||
assert updated["last_delivery_error"] == "network timeout"
|
||||
# Next run delivers successfully
|
||||
mark_job_run(job["id"], success=True, delivery_error=None)
|
||||
updated = get_job(job["id"])
|
||||
assert updated["last_delivery_error"] is None
|
||||
|
||||
def test_both_agent_and_delivery_error(self, tmp_cron_dir):
|
||||
"""Agent fails AND delivery fails — both errors recorded."""
|
||||
job = create_job(prompt="Report", schedule="every 1h")
|
||||
mark_job_run(job["id"], success=False, error="model timeout",
|
||||
delivery_error="platform 'discord' not enabled")
|
||||
updated = get_job(job["id"])
|
||||
assert updated["last_status"] == "error"
|
||||
assert updated["last_error"] == "model timeout"
|
||||
assert updated["last_delivery_error"] == "platform 'discord' not enabled"
|
||||
|
||||
|
||||
class TestAdvanceNextRun:
|
||||
"""Tests for advance_next_run() — crash-safety for recurring jobs."""
|
||||
|
||||
@@ -508,6 +508,90 @@ class TestDeliverResultWrapping:
|
||||
assert send_mock.call_args.kwargs["thread_id"] == "17585"
|
||||
|
||||
|
||||
class TestDeliverResultErrorReturns:
|
||||
"""Verify _deliver_result returns error strings on failure, None on success."""
|
||||
|
||||
def test_returns_none_on_successful_delivery(self):
|
||||
from gateway.config import Platform
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"success": True})):
|
||||
job = {
|
||||
"id": "ok-job",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "123"},
|
||||
}
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_local_delivery(self):
|
||||
"""local-only jobs don't deliver — not a failure."""
|
||||
job = {"id": "local-job", "deliver": "local"}
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is None
|
||||
|
||||
def test_returns_error_for_unknown_platform(self):
|
||||
job = {
|
||||
"id": "bad-platform",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "fax", "chat_id": "123"},
|
||||
}
|
||||
with patch("gateway.config.load_gateway_config"):
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is not None
|
||||
assert "unknown platform" in result
|
||||
|
||||
def test_returns_error_when_platform_disabled(self):
|
||||
from gateway.config import Platform
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = False
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg):
|
||||
job = {
|
||||
"id": "disabled",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "123"},
|
||||
}
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is not None
|
||||
assert "not configured" in result
|
||||
|
||||
def test_returns_error_on_send_failure(self):
|
||||
from gateway.config import Platform
|
||||
|
||||
pconfig = MagicMock()
|
||||
pconfig.enabled = True
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.platforms = {Platform.TELEGRAM: pconfig}
|
||||
|
||||
with patch("gateway.config.load_gateway_config", return_value=mock_cfg), \
|
||||
patch("tools.send_message_tool._send_to_platform", new=AsyncMock(return_value={"error": "rate limited"})):
|
||||
job = {
|
||||
"id": "rate-limited",
|
||||
"deliver": "origin",
|
||||
"origin": {"platform": "telegram", "chat_id": "123"},
|
||||
}
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is not None
|
||||
assert "rate limited" in result
|
||||
|
||||
def test_returns_error_for_unresolved_target(self, monkeypatch):
|
||||
"""Non-local delivery with no resolvable target should return an error."""
|
||||
monkeypatch.delenv("TELEGRAM_HOME_CHANNEL", raising=False)
|
||||
job = {"id": "no-target", "deliver": "telegram"}
|
||||
result = _deliver_result(job, "Output.")
|
||||
assert result is not None
|
||||
assert "no delivery target" in result
|
||||
|
||||
|
||||
class TestRunJobSessionPersistence:
|
||||
def test_run_job_passes_session_db_and_cron_platform(self, tmp_path):
|
||||
job = {
|
||||
|
||||
@@ -0,0 +1,277 @@
|
||||
"""Tests for Discord reply_to_mode functionality.
|
||||
|
||||
Covers the threading behavior control for multi-chunk replies:
|
||||
- "off": Never reply-reference to original message
|
||||
- "first": Only first chunk uses reply reference (default)
|
||||
- "all": All chunks reply-reference the original message
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig, GatewayConfig, Platform, _apply_env_overrides
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def adapter_factory():
|
||||
"""Factory to create DiscordAdapter with custom reply_to_mode."""
|
||||
def create(reply_to_mode: str = "first"):
|
||||
config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode)
|
||||
return DiscordAdapter(config)
|
||||
return create
|
||||
|
||||
|
||||
class TestReplyToModeConfig:
|
||||
"""Tests for reply_to_mode configuration loading."""
|
||||
|
||||
def test_default_mode_is_first(self, adapter_factory):
|
||||
adapter = adapter_factory()
|
||||
assert adapter._reply_to_mode == "first"
|
||||
|
||||
def test_off_mode(self, adapter_factory):
|
||||
adapter = adapter_factory(reply_to_mode="off")
|
||||
assert adapter._reply_to_mode == "off"
|
||||
|
||||
def test_first_mode(self, adapter_factory):
|
||||
adapter = adapter_factory(reply_to_mode="first")
|
||||
assert adapter._reply_to_mode == "first"
|
||||
|
||||
def test_all_mode(self, adapter_factory):
|
||||
adapter = adapter_factory(reply_to_mode="all")
|
||||
assert adapter._reply_to_mode == "all"
|
||||
|
||||
def test_invalid_mode_stored_as_is(self, adapter_factory):
|
||||
"""Invalid modes are stored but send() handles them gracefully."""
|
||||
adapter = adapter_factory(reply_to_mode="invalid")
|
||||
assert adapter._reply_to_mode == "invalid"
|
||||
|
||||
def test_none_mode_defaults_to_first(self):
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
adapter = DiscordAdapter(config)
|
||||
assert adapter._reply_to_mode == "first"
|
||||
|
||||
def test_empty_string_mode_defaults_to_first(self):
|
||||
config = PlatformConfig(enabled=True, token="test-token", reply_to_mode="")
|
||||
adapter = DiscordAdapter(config)
|
||||
assert adapter._reply_to_mode == "first"
|
||||
|
||||
|
||||
def _make_discord_adapter(reply_to_mode: str = "first"):
|
||||
"""Create a DiscordAdapter with mocked client and channel for send() tests."""
|
||||
config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode)
|
||||
adapter = DiscordAdapter(config)
|
||||
|
||||
# Mock the Discord client and channel
|
||||
mock_channel = AsyncMock()
|
||||
ref_message = MagicMock()
|
||||
mock_channel.fetch_message = AsyncMock(return_value=ref_message)
|
||||
|
||||
sent_msg = MagicMock()
|
||||
sent_msg.id = 42
|
||||
mock_channel.send = AsyncMock(return_value=sent_msg)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
||||
adapter._client = mock_client
|
||||
return adapter, mock_channel, ref_message
|
||||
|
||||
|
||||
class TestSendWithReplyToMode:
|
||||
"""Tests for send() method respecting reply_to_mode."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_off_mode_no_reply_reference(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("off")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
# Should never try to fetch the reference message
|
||||
channel.fetch_message.assert_not_called()
|
||||
# All chunks sent without reference
|
||||
for call in channel.send.call_args_list:
|
||||
assert call.kwargs.get("reference") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_mode_only_first_chunk_references(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("first")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
# Should fetch the reference message
|
||||
channel.fetch_message.assert_called_once_with(999)
|
||||
calls = channel.send.call_args_list
|
||||
assert len(calls) == 3
|
||||
assert calls[0].kwargs.get("reference") is ref_msg
|
||||
assert calls[1].kwargs.get("reference") is None
|
||||
assert calls[2].kwargs.get("reference") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_mode_all_chunks_reference(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("all")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to="999")
|
||||
|
||||
channel.fetch_message.assert_called_once_with(999)
|
||||
calls = channel.send.call_args_list
|
||||
assert len(calls) == 3
|
||||
for call in calls:
|
||||
assert call.kwargs.get("reference") is ref_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_reply_to_param_no_reference(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("all")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
|
||||
|
||||
await adapter.send("12345", "test content", reply_to=None)
|
||||
|
||||
channel.fetch_message.assert_not_called()
|
||||
for call in channel.send.call_args_list:
|
||||
assert call.kwargs.get("reference") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chunk_respects_first_mode(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("first")
|
||||
adapter.truncate_message = lambda content, max_len: ["single chunk"]
|
||||
|
||||
await adapter.send("12345", "test", reply_to="999")
|
||||
|
||||
calls = channel.send.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert calls[0].kwargs.get("reference") is ref_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_chunk_off_mode(self):
|
||||
adapter, channel, ref_msg = _make_discord_adapter("off")
|
||||
adapter.truncate_message = lambda content, max_len: ["single chunk"]
|
||||
|
||||
await adapter.send("12345", "test", reply_to="999")
|
||||
|
||||
channel.fetch_message.assert_not_called()
|
||||
calls = channel.send.call_args_list
|
||||
assert len(calls) == 1
|
||||
assert calls[0].kwargs.get("reference") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_mode_falls_back_to_first_behavior(self):
|
||||
"""Invalid mode behaves like 'first' — only first chunk gets reference."""
|
||||
adapter, channel, ref_msg = _make_discord_adapter("banana")
|
||||
adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"]
|
||||
|
||||
await adapter.send("12345", "test", reply_to="999")
|
||||
|
||||
calls = channel.send.call_args_list
|
||||
assert len(calls) == 2
|
||||
assert calls[0].kwargs.get("reference") is ref_msg
|
||||
assert calls[1].kwargs.get("reference") is None
|
||||
|
||||
|
||||
class TestConfigSerialization:
|
||||
"""Tests for reply_to_mode serialization (shared with Telegram)."""
|
||||
|
||||
def test_to_dict_includes_reply_to_mode(self):
|
||||
config = PlatformConfig(enabled=True, token="test", reply_to_mode="all")
|
||||
result = config.to_dict()
|
||||
assert result["reply_to_mode"] == "all"
|
||||
|
||||
def test_from_dict_loads_reply_to_mode(self):
|
||||
data = {"enabled": True, "token": "***", "reply_to_mode": "off"}
|
||||
config = PlatformConfig.from_dict(data)
|
||||
assert config.reply_to_mode == "off"
|
||||
|
||||
def test_from_dict_defaults_to_first(self):
|
||||
data = {"enabled": True, "token": "***"}
|
||||
config = PlatformConfig.from_dict(data)
|
||||
assert config.reply_to_mode == "first"
|
||||
|
||||
|
||||
class TestEnvVarOverride:
|
||||
"""Tests for DISCORD_REPLY_TO_MODE environment variable override."""
|
||||
|
||||
def _make_config(self):
|
||||
config = GatewayConfig()
|
||||
config.platforms[Platform.DISCORD] = PlatformConfig(enabled=True, token="test")
|
||||
return config
|
||||
|
||||
def test_env_var_sets_off_mode(self):
|
||||
config = self._make_config()
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "off"}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "off"
|
||||
|
||||
def test_env_var_sets_all_mode(self):
|
||||
config = self._make_config()
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "all"}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "all"
|
||||
|
||||
def test_env_var_case_insensitive(self):
|
||||
config = self._make_config()
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "ALL"}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "all"
|
||||
|
||||
def test_env_var_invalid_value_ignored(self):
|
||||
config = self._make_config()
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "banana"}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "first"
|
||||
|
||||
def test_env_var_empty_value_ignored(self):
|
||||
config = self._make_config()
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": ""}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "first"
|
||||
|
||||
def test_env_var_creates_platform_config_if_missing(self):
|
||||
"""DISCORD_REPLY_TO_MODE creates PlatformConfig even without DISCORD_BOT_TOKEN."""
|
||||
config = GatewayConfig()
|
||||
assert Platform.DISCORD not in config.platforms
|
||||
with patch.dict(os.environ, {"DISCORD_REPLY_TO_MODE": "off"}, clear=False):
|
||||
_apply_env_overrides(config)
|
||||
assert Platform.DISCORD in config.platforms
|
||||
assert config.platforms[Platform.DISCORD].reply_to_mode == "off"
|
||||
@@ -0,0 +1,432 @@
|
||||
"""Tests for Feishu interactive card approval buttons."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ensure the repo root is importable
|
||||
# ---------------------------------------------------------------------------
|
||||
_repo = str(Path(__file__).resolve().parents[2])
|
||||
if _repo not in sys.path:
|
||||
sys.path.insert(0, _repo)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal Feishu mock so FeishuAdapter can be imported without lark-oapi
|
||||
# ---------------------------------------------------------------------------
|
||||
def _ensure_feishu_mocks():
|
||||
"""Provide stubs for lark-oapi / aiohttp.web so the import succeeds."""
|
||||
if "lark_oapi" not in sys.modules:
|
||||
mod = MagicMock()
|
||||
for name in (
|
||||
"lark_oapi", "lark_oapi.api.im.v1",
|
||||
"lark_oapi.event", "lark_oapi.event.callback_type",
|
||||
):
|
||||
sys.modules.setdefault(name, mod)
|
||||
if "aiohttp" not in sys.modules:
|
||||
aio = MagicMock()
|
||||
sys.modules.setdefault("aiohttp", aio)
|
||||
sys.modules.setdefault("aiohttp.web", aio.web)
|
||||
|
||||
|
||||
_ensure_feishu_mocks()
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.feishu import FeishuAdapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter() -> FeishuAdapter:
|
||||
"""Create a FeishuAdapter with mocked internals."""
|
||||
config = PlatformConfig(enabled=True)
|
||||
adapter = FeishuAdapter(config)
|
||||
adapter._client = MagicMock()
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_card_action_data(
|
||||
action_value: dict,
|
||||
chat_id: str = "oc_12345",
|
||||
open_id: str = "ou_user1",
|
||||
token: str = "tok_abc",
|
||||
) -> SimpleNamespace:
|
||||
"""Create a mock Feishu card action callback data object."""
|
||||
return SimpleNamespace(
|
||||
event=SimpleNamespace(
|
||||
token=token,
|
||||
context=SimpleNamespace(open_chat_id=chat_id),
|
||||
operator=SimpleNamespace(open_id=open_id),
|
||||
action=SimpleNamespace(
|
||||
tag="button",
|
||||
value=action_value,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# send_exec_approval — interactive card with buttons
|
||||
# ===========================================================================
|
||||
|
||||
class TestFeishuExecApproval:
|
||||
"""Test send_exec_approval sends an interactive card."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_interactive_card(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_response = SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(message_id="msg_001"),
|
||||
)
|
||||
with patch.object(
|
||||
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_send:
|
||||
result = await adapter.send_exec_approval(
|
||||
chat_id="oc_12345",
|
||||
command="rm -rf /important",
|
||||
session_key="agent:main:feishu:group:oc_12345",
|
||||
description="dangerous deletion",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "msg_001"
|
||||
|
||||
mock_send.assert_called_once()
|
||||
kwargs = mock_send.call_args[1]
|
||||
assert kwargs["chat_id"] == "oc_12345"
|
||||
assert kwargs["msg_type"] == "interactive"
|
||||
|
||||
# Verify card payload contains the command and buttons
|
||||
card = json.loads(kwargs["payload"])
|
||||
assert card["header"]["template"] == "orange"
|
||||
assert "rm -rf /important" in card["elements"][0]["content"]
|
||||
assert "dangerous deletion" in card["elements"][0]["content"]
|
||||
|
||||
# Check buttons
|
||||
actions = card["elements"][1]["actions"]
|
||||
assert len(actions) == 4
|
||||
action_names = [a["value"]["hermes_action"] for a in actions]
|
||||
assert action_names == [
|
||||
"approve_once", "approve_session", "approve_always", "deny"
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stores_approval_state(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_response = SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(message_id="msg_002"),
|
||||
)
|
||||
with patch.object(
|
||||
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="oc_12345",
|
||||
command="echo test",
|
||||
session_key="my-session-key",
|
||||
)
|
||||
|
||||
assert len(adapter._approval_state) == 1
|
||||
approval_id = list(adapter._approval_state.keys())[0]
|
||||
state = adapter._approval_state[approval_id]
|
||||
assert state["session_key"] == "my-session-key"
|
||||
assert state["message_id"] == "msg_002"
|
||||
assert state["chat_id"] == "oc_12345"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_connected(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._client = None
|
||||
result = await adapter.send_exec_approval(
|
||||
chat_id="oc_12345", command="ls", session_key="s"
|
||||
)
|
||||
assert result.success is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_long_command(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_response = SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(message_id="msg_003"),
|
||||
)
|
||||
with patch.object(
|
||||
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_send:
|
||||
long_cmd = "x" * 5000
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="oc_12345", command=long_cmd, session_key="s"
|
||||
)
|
||||
|
||||
card = json.loads(mock_send.call_args[1]["payload"])
|
||||
content = card["elements"][0]["content"]
|
||||
assert "..." in content
|
||||
assert len(content) < 5000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_approvals_get_unique_ids(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_response = SimpleNamespace(
|
||||
success=lambda: True,
|
||||
data=SimpleNamespace(message_id="msg_x"),
|
||||
)
|
||||
with patch.object(
|
||||
adapter, "_feishu_send_with_retry", new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="oc_1", command="cmd1", session_key="s1"
|
||||
)
|
||||
await adapter.send_exec_approval(
|
||||
chat_id="oc_2", command="cmd2", session_key="s2"
|
||||
)
|
||||
|
||||
assert len(adapter._approval_state) == 2
|
||||
ids = list(adapter._approval_state.keys())
|
||||
assert ids[0] != ids[1]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _handle_card_action_event — approval button clicks
|
||||
# ===========================================================================
|
||||
|
||||
class TestFeishuApprovalCallback:
|
||||
"""Test the approval intercept in _handle_card_action_event."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolves_approval_on_click(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_state[1] = {
|
||||
"session_key": "agent:main:feishu:group:oc_12345",
|
||||
"message_id": "msg_001",
|
||||
"chat_id": "oc_12345",
|
||||
}
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"hermes_action": "approve_once", "approval_id": 1},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
|
||||
return_value={"user_id": "ou_user1", "user_name": "Norbert", "user_id_alt": None},
|
||||
),
|
||||
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock) as mock_update,
|
||||
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
|
||||
):
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
mock_resolve.assert_called_once_with("agent:main:feishu:group:oc_12345", "once")
|
||||
mock_update.assert_called_once_with("msg_001", "Approved once", "Norbert", "once")
|
||||
|
||||
# State should be cleaned up
|
||||
assert 1 not in adapter._approval_state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deny_button(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_state[2] = {
|
||||
"session_key": "some-session",
|
||||
"message_id": "msg_002",
|
||||
"chat_id": "oc_12345",
|
||||
}
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"hermes_action": "deny", "approval_id": 2},
|
||||
token="tok_deny",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
|
||||
return_value={"user_id": "ou_alice", "user_name": "Alice", "user_id_alt": None},
|
||||
),
|
||||
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock) as mock_update,
|
||||
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
|
||||
):
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
mock_resolve.assert_called_once_with("some-session", "deny")
|
||||
mock_update.assert_called_once_with("msg_002", "Denied", "Alice", "deny")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_approval(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_state[3] = {
|
||||
"session_key": "sess-3",
|
||||
"message_id": "msg_003",
|
||||
"chat_id": "oc_99",
|
||||
}
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"hermes_action": "approve_session", "approval_id": 3},
|
||||
token="tok_ses",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
|
||||
return_value={"user_id": "ou_u", "user_name": "Bob", "user_id_alt": None},
|
||||
),
|
||||
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock) as mock_update,
|
||||
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
|
||||
):
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
mock_resolve.assert_called_once_with("sess-3", "session")
|
||||
mock_update.assert_called_once_with("msg_003", "Approved for session", "Bob", "session")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_always_approval(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._approval_state[4] = {
|
||||
"session_key": "sess-4",
|
||||
"message_id": "msg_004",
|
||||
"chat_id": "oc_55",
|
||||
}
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"hermes_action": "approve_always", "approval_id": 4},
|
||||
token="tok_alw",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
|
||||
return_value={"user_id": "ou_u", "user_name": "Carol", "user_id_alt": None},
|
||||
),
|
||||
patch.object(adapter, "_update_approval_card", new_callable=AsyncMock),
|
||||
patch("tools.approval.resolve_gateway_approval", return_value=1) as mock_resolve,
|
||||
):
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
mock_resolve.assert_called_once_with("sess-4", "always")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_already_resolved_drops_silently(self):
|
||||
adapter = _make_adapter()
|
||||
# No state for approval_id 99 — already resolved
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"hermes_action": "approve_once", "approval_id": 99},
|
||||
token="tok_gone",
|
||||
)
|
||||
|
||||
with patch("tools.approval.resolve_gateway_approval") as mock_resolve:
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
# Should NOT resolve — already handled
|
||||
mock_resolve.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_approval_actions_route_normally(self):
|
||||
"""Non-approval card actions should still become synthetic commands."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
data = _make_card_action_data(
|
||||
action_value={"custom_action": "something_else"},
|
||||
token="tok_normal",
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter, "_resolve_sender_profile", new_callable=AsyncMock,
|
||||
return_value={"user_id": "ou_u", "user_name": "Dave", "user_id_alt": None},
|
||||
),
|
||||
patch.object(adapter, "get_chat_info", new_callable=AsyncMock, return_value={"name": "Test Chat"}),
|
||||
patch.object(adapter, "_handle_message_with_guards", new_callable=AsyncMock) as mock_handle,
|
||||
patch("tools.approval.resolve_gateway_approval") as mock_resolve,
|
||||
):
|
||||
await adapter._handle_card_action_event(data)
|
||||
|
||||
# Should NOT resolve any approval
|
||||
mock_resolve.assert_not_called()
|
||||
# Should have routed as synthetic command
|
||||
mock_handle.assert_called_once()
|
||||
event = mock_handle.call_args[0][0]
|
||||
assert "/card button" in event.text
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# _update_approval_card — card replacement after resolution
|
||||
# ===========================================================================
|
||||
|
||||
class TestFeishuUpdateApprovalCard:
|
||||
"""Test the card update after approval resolution."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_card_on_approve(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
mock_update = AsyncMock()
|
||||
adapter._client.im.v1.message.update = MagicMock()
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
|
||||
await adapter._update_approval_card(
|
||||
"msg_001", "Approved once", "Norbert", "once"
|
||||
)
|
||||
|
||||
mock_thread.assert_called_once()
|
||||
# Verify the update request was built
|
||||
call_args = mock_thread.call_args
|
||||
assert call_args[0][0] == adapter._client.im.v1.message.update
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_card_on_deny(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
|
||||
await adapter._update_approval_card(
|
||||
"msg_002", "Denied", "Alice", "deny"
|
||||
)
|
||||
|
||||
mock_thread.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_update_when_not_connected(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._client = None
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
|
||||
await adapter._update_approval_card(
|
||||
"msg_001", "Approved", "Bob", "once"
|
||||
)
|
||||
|
||||
mock_thread.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_update_when_no_message_id(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock) as mock_thread:
|
||||
await adapter._update_approval_card(
|
||||
"", "Approved", "Bob", "once"
|
||||
)
|
||||
|
||||
mock_thread.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_update_errors(self):
|
||||
adapter = _make_adapter()
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock, side_effect=Exception("API error")):
|
||||
# Should not raise
|
||||
await adapter._update_approval_card(
|
||||
"msg_001", "Approved", "Bob", "once"
|
||||
)
|
||||
@@ -87,7 +87,6 @@ class TestReasoningCommand:
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
|
||||
|
||||
runner = _make_runner()
|
||||
runner._reasoning_config = {"enabled": True, "effort": "xhigh"}
|
||||
@@ -108,7 +107,6 @@ class TestReasoningCommand:
|
||||
config_path.write_text("agent:\n reasoning_effort: medium\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
|
||||
|
||||
runner = _make_runner()
|
||||
runner._reasoning_config = {"enabled": True, "effort": "medium"}
|
||||
@@ -138,7 +136,6 @@ class TestReasoningCommand:
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.delenv("HERMES_REASONING_EFFORT", raising=False)
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
@@ -170,55 +167,6 @@ class TestReasoningCommand:
|
||||
assert _CapturingAgent.last_init is not None
|
||||
assert _CapturingAgent.last_init["reasoning_config"] == {"enabled": True, "effort": "low"}
|
||||
|
||||
def test_run_agent_prefers_config_over_stale_reasoning_env(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
(hermes_home / "config.yaml").write_text("agent:\n reasoning_effort: none\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.setattr(gateway_run, "_env_path", hermes_home / ".env")
|
||||
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
"_resolve_runtime_agent_kwargs",
|
||||
lambda: {
|
||||
"provider": "openrouter",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.setenv("HERMES_REASONING_EFFORT", "low")
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
_CapturingAgent.last_init = None
|
||||
runner = _make_runner()
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI",
|
||||
chat_type="dm",
|
||||
user_id="user-1",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
runner._run_agent(
|
||||
message="ping",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="session-1",
|
||||
session_key="agent:main:local:dm",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["final_response"] == "ok"
|
||||
assert _CapturingAgent.last_init is not None
|
||||
assert _CapturingAgent.last_init["reasoning_config"] == {"enabled": False}
|
||||
|
||||
def test_run_agent_includes_enabled_mcp_servers_in_gateway_toolsets(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
"""Tests that on_session_finalize and on_session_reset plugin hooks fire in the gateway."""
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(text=text, source=_make_source(), message_id="m1")
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner._session_model_overrides = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._background_tasks = set()
|
||||
|
||||
session_key = build_session_key(_make_source())
|
||||
session_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id="sess-old",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
new_session_entry = SessionEntry(
|
||||
session_key=session_key,
|
||||
session_id="sess-new",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = new_session_entry
|
||||
runner.session_store.reset_session.return_value = new_session_entry
|
||||
runner.session_store._entries = {session_key: session_entry}
|
||||
runner.session_store._generate_session_key.return_value = session_key
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._agent_cache_lock = None
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._format_session_info = lambda: ""
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
async def test_reset_fires_finalize_hook(mock_invoke_hook):
|
||||
"""/new must fire on_session_finalize with the OLD session id."""
|
||||
runner = _make_runner()
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_finalize", session_id="sess-old", platform="telegram"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
async def test_reset_fires_reset_hook(mock_invoke_hook):
|
||||
"""/new must fire on_session_reset with the NEW session id."""
|
||||
runner = _make_runner()
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
mock_invoke_hook.assert_any_call(
|
||||
"on_session_reset", session_id="sess-new", platform="telegram"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
async def test_finalize_before_reset(mock_invoke_hook):
|
||||
"""on_session_finalize must fire before on_session_reset."""
|
||||
runner = _make_runner()
|
||||
|
||||
await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
calls = [c for c in mock_invoke_hook.call_args_list
|
||||
if c[0][0] in ("on_session_finalize", "on_session_reset")]
|
||||
hook_names = [c[0][0] for c in calls]
|
||||
assert hook_names == ["on_session_finalize", "on_session_reset"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("hermes_cli.plugins.invoke_hook")
|
||||
async def test_shutdown_fires_finalize_for_active_agents(mock_invoke_hook):
|
||||
"""Gateway stop() must fire on_session_finalize for each active agent."""
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner._running = True
|
||||
runner._background_tasks = set()
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._shutdown_event = MagicMock()
|
||||
runner.adapters = {}
|
||||
runner._exit_reason = "test"
|
||||
|
||||
agent1 = MagicMock()
|
||||
agent1.session_id = "sess-a"
|
||||
agent2 = MagicMock()
|
||||
agent2.session_id = "sess-b"
|
||||
runner._running_agents = {"key-a": agent1, "key-b": agent2}
|
||||
|
||||
with patch("gateway.status.remove_pid_file"), \
|
||||
patch("gateway.status.write_runtime_status"):
|
||||
await runner.stop()
|
||||
|
||||
finalize_calls = [
|
||||
c for c in mock_invoke_hook.call_args_list
|
||||
if c[0][0] == "on_session_finalize"
|
||||
]
|
||||
session_ids = {c[1]["session_id"] for c in finalize_calls}
|
||||
assert session_ids == {"sess-a", "sess-b"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("hermes_cli.plugins.invoke_hook", side_effect=Exception("boom"))
|
||||
async def test_hook_error_does_not_break_reset(mock_invoke_hook):
|
||||
"""Plugin hook errors must not prevent /new from completing."""
|
||||
runner = _make_runner()
|
||||
|
||||
result = await runner._handle_reset_command(_make_event("/new"))
|
||||
|
||||
# Should still return a success message despite hook errors
|
||||
assert "Session reset" in result or "New session" in result
|
||||
@@ -707,3 +707,66 @@ class TestSignalSendDocumentViaHelper:
|
||||
|
||||
assert result.success is False
|
||||
assert "/nonexistent.pdf" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send() returns message_id from timestamp (#4647)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalSendReturnsMessageId:
|
||||
"""Signal send() must return a timestamp-based message_id so the stream
|
||||
consumer can follow its edit→fallback path correctly."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_timestamp_as_message_id(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
mock_rpc, _ = _stub_rpc({"timestamp": 1712345678000})
|
||||
adapter._rpc = mock_rpc
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
|
||||
result = await adapter.send(chat_id="+155****4567", content="hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "1712345678000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_none_message_id_when_no_timestamp(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
mock_rpc, _ = _stub_rpc({}) # No timestamp key
|
||||
adapter._rpc = mock_rpc
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
|
||||
result = await adapter.send(chat_id="+155****4567", content="hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_none_message_id_for_non_dict(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
mock_rpc, _ = _stub_rpc("ok") # Non-dict result
|
||||
adapter._rpc = mock_rpc
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
|
||||
result = await adapter.send(chat_id="+155****4567", content="hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stop_typing() delegates to _stop_typing_indicator (#4647)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSignalStopTyping:
|
||||
"""Signal must expose a public stop_typing() so base adapter's
|
||||
_keep_typing finally block can clean up platform-level typing tasks."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_typing_calls_private_method(self, monkeypatch):
|
||||
adapter = _make_signal_adapter(monkeypatch)
|
||||
adapter._stop_typing_indicator = AsyncMock()
|
||||
|
||||
await adapter.stop_typing("+155****4567")
|
||||
|
||||
adapter._stop_typing_indicator.assert_awaited_once_with("+155****4567")
|
||||
|
||||
@@ -324,3 +324,145 @@ class TestSegmentBreakOnToolBoundary:
|
||||
await consumer.run()
|
||||
|
||||
assert consumer.already_sent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_failure_sends_only_unsent_tail_at_finish(self):
|
||||
"""If an edit fails mid-stream, send only the missing tail once at finish."""
|
||||
adapter = MagicMock()
|
||||
send_results = [
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
]
|
||||
adapter.send = AsyncMock(side_effect=send_results)
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉")
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
assert adapter.send.call_count == 2
|
||||
first_text = adapter.send.call_args_list[0][1]["content"]
|
||||
second_text = adapter.send.call_args_list[1][1]["content"]
|
||||
assert "Hello" in first_text
|
||||
assert second_text.strip() == "world"
|
||||
assert consumer.already_sent
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_segment_break_clears_failed_edit_fallback_state(self):
|
||||
"""A tool boundary after edit failure must not duplicate the next segment."""
|
||||
adapter = MagicMock()
|
||||
send_results = [
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
]
|
||||
adapter.send = AsyncMock(side_effect=send_results)
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉")
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(None)
|
||||
consumer.on_delta("Next segment")
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert sent_texts == ["Hello ▉", "Next segment"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_enters_fallback_mode(self):
|
||||
"""Platform returns success but no message_id (Signal) — must not
|
||||
re-send on every delta. Should enter fallback mode and send only
|
||||
the continuation at finish."""
|
||||
adapter = MagicMock()
|
||||
# First send succeeds but returns no message_id (Signal behavior)
|
||||
send_result_no_id = SimpleNamespace(success=True, message_id=None)
|
||||
# Fallback final send succeeds
|
||||
send_result_final = SimpleNamespace(success=True, message_id="msg_final")
|
||||
adapter.send = AsyncMock(side_effect=[send_result_no_id, send_result_final])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True))
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Hello")
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(" world, this is a longer response.")
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
# Should send exactly 2 messages: initial chunk + fallback continuation
|
||||
# NOT one message per delta
|
||||
assert adapter.send.call_count == 2
|
||||
assert consumer.already_sent
|
||||
# edit_message should NOT have been called (no valid message_id to edit)
|
||||
adapter.edit_message.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_message_id_single_delta_marks_already_sent(self):
|
||||
"""When the entire response fits in one delta and platform returns no
|
||||
message_id, already_sent must still be True to prevent the gateway
|
||||
from re-sending the full response."""
|
||||
adapter = MagicMock()
|
||||
send_result = SimpleNamespace(success=True, message_id=None)
|
||||
adapter.send = AsyncMock(return_value=send_result)
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5)
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
consumer.on_delta("Short response.")
|
||||
consumer.finish()
|
||||
|
||||
await consumer.run()
|
||||
|
||||
assert consumer.already_sent
|
||||
# Only one send call (the initial message)
|
||||
assert adapter.send.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_final_splits_long_continuation_without_dropping_text(self):
|
||||
"""Long continuation tails should be chunked when fallback final-send runs."""
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock(side_effect=[
|
||||
SimpleNamespace(success=True, message_id="msg_1"),
|
||||
SimpleNamespace(success=True, message_id="msg_2"),
|
||||
SimpleNamespace(success=True, message_id="msg_3"),
|
||||
])
|
||||
adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=False, error="flood_control:6"))
|
||||
adapter.MAX_MESSAGE_LENGTH = 610
|
||||
|
||||
config = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor=" ▉")
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", config)
|
||||
|
||||
prefix = "abc"
|
||||
tail = "x" * 620
|
||||
consumer.on_delta(prefix)
|
||||
task = asyncio.create_task(consumer.run())
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.on_delta(tail)
|
||||
await asyncio.sleep(0.08)
|
||||
consumer.finish()
|
||||
await task
|
||||
|
||||
sent_texts = [call[1]["content"] for call in adapter.send.call_args_list]
|
||||
assert len(sent_texts) == 3
|
||||
assert sent_texts[0].startswith(prefix)
|
||||
assert sum(len(t) for t in sent_texts[1:]) == len(tail)
|
||||
|
||||
@@ -0,0 +1,399 @@
|
||||
"""Tests for Qwen OAuth provider authentication (hermes_cli/auth.py).
|
||||
|
||||
Covers: _qwen_cli_auth_path, _read_qwen_cli_tokens, _save_qwen_cli_tokens,
|
||||
_qwen_access_token_is_expiring, _refresh_qwen_cli_tokens,
|
||||
resolve_qwen_runtime_credentials, get_qwen_auth_status.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli.auth import (
|
||||
AuthError,
|
||||
DEFAULT_QWEN_BASE_URL,
|
||||
QWEN_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
_qwen_cli_auth_path,
|
||||
_read_qwen_cli_tokens,
|
||||
_save_qwen_cli_tokens,
|
||||
_qwen_access_token_is_expiring,
|
||||
_refresh_qwen_cli_tokens,
|
||||
resolve_qwen_runtime_credentials,
|
||||
get_qwen_auth_status,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_qwen_tokens(
|
||||
access_token="test-access-token",
|
||||
refresh_token="test-refresh-token",
|
||||
expiry_date=None,
|
||||
**extra,
|
||||
):
|
||||
"""Create a minimal Qwen CLI OAuth credential dict."""
|
||||
if expiry_date is None:
|
||||
# 1 hour from now in milliseconds
|
||||
expiry_date = int((time.time() + 3600) * 1000)
|
||||
data = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"token_type": "Bearer",
|
||||
"expiry_date": expiry_date,
|
||||
"resource_url": "portal.qwen.ai",
|
||||
}
|
||||
data.update(extra)
|
||||
return data
|
||||
|
||||
|
||||
def _write_qwen_creds(tmp_path, tokens=None):
|
||||
"""Write tokens to the Qwen CLI credentials file and return the path."""
|
||||
qwen_dir = tmp_path / ".qwen"
|
||||
qwen_dir.mkdir(parents=True, exist_ok=True)
|
||||
creds_path = qwen_dir / "oauth_creds.json"
|
||||
if tokens is None:
|
||||
tokens = _make_qwen_tokens()
|
||||
creds_path.write_text(json.dumps(tokens), encoding="utf-8")
|
||||
return creds_path
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def qwen_env(tmp_path, monkeypatch):
|
||||
"""Redirect _qwen_cli_auth_path to tmp_path/.qwen/oauth_creds.json."""
|
||||
creds_path = tmp_path / ".qwen" / "oauth_creds.json"
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._qwen_cli_auth_path", lambda: creds_path
|
||||
)
|
||||
return tmp_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _qwen_cli_auth_path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_qwen_cli_auth_path_returns_expected_location():
|
||||
path = _qwen_cli_auth_path()
|
||||
assert path == Path.home() / ".qwen" / "oauth_creds.json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_qwen_cli_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_read_qwen_cli_tokens_success(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="my-access")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
result = _read_qwen_cli_tokens()
|
||||
assert result["access_token"] == "my-access"
|
||||
assert result["refresh_token"] == "test-refresh-token"
|
||||
|
||||
|
||||
def test_read_qwen_cli_tokens_missing_file(qwen_env):
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_read_qwen_cli_tokens()
|
||||
assert exc.value.code == "qwen_auth_missing"
|
||||
|
||||
|
||||
def test_read_qwen_cli_tokens_invalid_json(qwen_env):
|
||||
creds_path = qwen_env / ".qwen" / "oauth_creds.json"
|
||||
creds_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
creds_path.write_text("not json{{{", encoding="utf-8")
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_read_qwen_cli_tokens()
|
||||
assert exc.value.code == "qwen_auth_read_failed"
|
||||
|
||||
|
||||
def test_read_qwen_cli_tokens_non_dict(qwen_env):
|
||||
creds_path = qwen_env / ".qwen" / "oauth_creds.json"
|
||||
creds_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
creds_path.write_text(json.dumps(["a", "b"]), encoding="utf-8")
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_read_qwen_cli_tokens()
|
||||
assert exc.value.code == "qwen_auth_invalid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _save_qwen_cli_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_save_qwen_cli_tokens_roundtrip(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="saved-token")
|
||||
saved_path = _save_qwen_cli_tokens(tokens)
|
||||
assert saved_path.exists()
|
||||
loaded = json.loads(saved_path.read_text(encoding="utf-8"))
|
||||
assert loaded["access_token"] == "saved-token"
|
||||
|
||||
|
||||
def test_save_qwen_cli_tokens_creates_parent(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
saved_path = _save_qwen_cli_tokens(tokens)
|
||||
assert saved_path.parent.exists()
|
||||
|
||||
|
||||
def test_save_qwen_cli_tokens_permissions(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
saved_path = _save_qwen_cli_tokens(tokens)
|
||||
mode = saved_path.stat().st_mode
|
||||
assert mode & stat.S_IRUSR # owner read
|
||||
assert mode & stat.S_IWUSR # owner write
|
||||
assert not (mode & stat.S_IRGRP) # no group read
|
||||
assert not (mode & stat.S_IROTH) # no other read
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _qwen_access_token_is_expiring
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_expiring_token_not_expired():
|
||||
# 1 hour from now in milliseconds
|
||||
future_ms = int((time.time() + 3600) * 1000)
|
||||
assert not _qwen_access_token_is_expiring(future_ms)
|
||||
|
||||
|
||||
def test_expiring_token_already_expired():
|
||||
# 1 hour ago in milliseconds
|
||||
past_ms = int((time.time() - 3600) * 1000)
|
||||
assert _qwen_access_token_is_expiring(past_ms)
|
||||
|
||||
|
||||
def test_expiring_token_within_skew():
|
||||
# Just inside the default skew window
|
||||
near_ms = int((time.time() + QWEN_ACCESS_TOKEN_REFRESH_SKEW_SECONDS - 5) * 1000)
|
||||
assert _qwen_access_token_is_expiring(near_ms)
|
||||
|
||||
|
||||
def test_expiring_token_none_returns_true():
|
||||
assert _qwen_access_token_is_expiring(None)
|
||||
|
||||
|
||||
def test_expiring_token_non_numeric_returns_true():
|
||||
assert _qwen_access_token_is_expiring("not-a-number")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _refresh_qwen_cli_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_refresh_qwen_cli_tokens_success(qwen_env):
|
||||
tokens = _make_qwen_tokens(refresh_token="old-refresh")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {
|
||||
"access_token": "new-access",
|
||||
"refresh_token": "new-refresh",
|
||||
"expires_in": 7200,
|
||||
}
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
result = _refresh_qwen_cli_tokens(tokens)
|
||||
|
||||
assert result["access_token"] == "new-access"
|
||||
assert result["refresh_token"] == "new-refresh"
|
||||
assert "expiry_date" in result
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_preserves_old_refresh_if_not_in_response(qwen_env):
|
||||
tokens = _make_qwen_tokens(refresh_token="keep-me")
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {
|
||||
"access_token": "new-access",
|
||||
# No refresh_token in response — should keep old one
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
result = _refresh_qwen_cli_tokens(tokens)
|
||||
|
||||
assert result["refresh_token"] == "keep-me"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_missing_refresh_token():
|
||||
tokens = {"access_token": "at", "refresh_token": ""}
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
assert exc.value.code == "qwen_refresh_token_missing"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_http_error(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 401
|
||||
resp.text = "unauthorized"
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
assert exc.value.code == "qwen_refresh_failed"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_network_error(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.side_effect = ConnectionError("timeout")
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
assert exc.value.code == "qwen_refresh_failed"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_invalid_json_response(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.side_effect = ValueError("bad json")
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
assert exc.value.code == "qwen_refresh_invalid_json"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_missing_access_token_in_response(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"something": "but no access_token"}
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
with pytest.raises(AuthError) as exc:
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
assert exc.value.code == "qwen_refresh_invalid_response"
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_default_expires_in(qwen_env):
|
||||
"""When expires_in is missing, default to 6 hours."""
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {"access_token": "new"}
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
result = _refresh_qwen_cli_tokens(tokens)
|
||||
|
||||
# Verify expiry_date is roughly now + 6h (within 60s tolerance)
|
||||
expected_ms = int(time.time() * 1000) + 6 * 60 * 60 * 1000
|
||||
assert abs(result["expiry_date"] - expected_ms) < 60_000
|
||||
|
||||
|
||||
def test_refresh_qwen_cli_tokens_saves_to_disk(qwen_env):
|
||||
tokens = _make_qwen_tokens()
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {
|
||||
"access_token": "disk-check",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
with patch("hermes_cli.auth.httpx") as mock_httpx:
|
||||
mock_httpx.post.return_value = resp
|
||||
_refresh_qwen_cli_tokens(tokens)
|
||||
|
||||
# Verify it was persisted
|
||||
creds_path = qwen_env / ".qwen" / "oauth_creds.json"
|
||||
assert creds_path.exists()
|
||||
saved = json.loads(creds_path.read_text(encoding="utf-8"))
|
||||
assert saved["access_token"] == "disk-check"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_qwen_runtime_credentials
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_resolve_qwen_runtime_credentials_fresh_token(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="fresh-at")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
creds = resolve_qwen_runtime_credentials(refresh_if_expiring=False)
|
||||
assert creds["provider"] == "qwen-oauth"
|
||||
assert creds["api_key"] == "fresh-at"
|
||||
assert creds["base_url"] == DEFAULT_QWEN_BASE_URL
|
||||
assert creds["source"] == "qwen-cli"
|
||||
|
||||
|
||||
def test_resolve_qwen_runtime_credentials_triggers_refresh(qwen_env):
|
||||
# Write an expired token
|
||||
expired_ms = int((time.time() - 3600) * 1000)
|
||||
tokens = _make_qwen_tokens(access_token="old", expiry_date=expired_ms)
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
refreshed = _make_qwen_tokens(access_token="refreshed-at")
|
||||
|
||||
with patch(
|
||||
"hermes_cli.auth._refresh_qwen_cli_tokens", return_value=refreshed
|
||||
) as mock_refresh:
|
||||
creds = resolve_qwen_runtime_credentials()
|
||||
mock_refresh.assert_called_once()
|
||||
assert creds["api_key"] == "refreshed-at"
|
||||
|
||||
|
||||
def test_resolve_qwen_runtime_credentials_force_refresh(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="old-at")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
refreshed = _make_qwen_tokens(access_token="force-refreshed")
|
||||
|
||||
with patch(
|
||||
"hermes_cli.auth._refresh_qwen_cli_tokens", return_value=refreshed
|
||||
) as mock_refresh:
|
||||
creds = resolve_qwen_runtime_credentials(force_refresh=True)
|
||||
mock_refresh.assert_called_once()
|
||||
assert creds["api_key"] == "force-refreshed"
|
||||
|
||||
|
||||
def test_resolve_qwen_runtime_credentials_missing_access_token(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
with pytest.raises(AuthError) as exc:
|
||||
resolve_qwen_runtime_credentials(refresh_if_expiring=False)
|
||||
assert exc.value.code == "qwen_access_token_missing"
|
||||
|
||||
|
||||
def test_resolve_qwen_runtime_credentials_base_url_env_override(qwen_env, monkeypatch):
|
||||
tokens = _make_qwen_tokens(access_token="at")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
monkeypatch.setenv("HERMES_QWEN_BASE_URL", "https://custom.qwen.ai/v1")
|
||||
|
||||
creds = resolve_qwen_runtime_credentials(refresh_if_expiring=False)
|
||||
assert creds["base_url"] == "https://custom.qwen.ai/v1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_qwen_auth_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_get_qwen_auth_status_logged_in(qwen_env):
|
||||
tokens = _make_qwen_tokens(access_token="status-at")
|
||||
_write_qwen_creds(qwen_env, tokens)
|
||||
|
||||
status = get_qwen_auth_status()
|
||||
assert status["logged_in"] is True
|
||||
assert status["api_key"] == "status-at"
|
||||
|
||||
|
||||
def test_get_qwen_auth_status_not_logged_in(qwen_env):
|
||||
# No credentials file
|
||||
status = get_qwen_auth_status()
|
||||
assert status["logged_in"] is False
|
||||
assert "error" in status
|
||||
@@ -136,3 +136,73 @@ def test_check_gateway_service_linger_skips_when_service_not_installed(monkeypat
|
||||
out = capsys.readouterr().out
|
||||
assert out == ""
|
||||
assert issues == []
|
||||
|
||||
|
||||
# ── Memory provider section (doctor should only check the *active* provider) ──
|
||||
|
||||
|
||||
class TestDoctorMemoryProviderSection:
|
||||
"""The ◆ Memory Provider section should respect memory.provider config."""
|
||||
|
||||
def _make_hermes_home(self, tmp_path, provider=""):
|
||||
"""Create a minimal HERMES_HOME with config.yaml."""
|
||||
home = tmp_path / ".hermes"
|
||||
home.mkdir(parents=True, exist_ok=True)
|
||||
import yaml
|
||||
config = {"memory": {"provider": provider}} if provider else {"memory": {}}
|
||||
(home / "config.yaml").write_text(yaml.dump(config))
|
||||
return home
|
||||
|
||||
def _run_doctor_and_capture(self, monkeypatch, tmp_path, provider=""):
|
||||
"""Run doctor and capture stdout."""
|
||||
home = self._make_hermes_home(tmp_path, provider)
|
||||
monkeypatch.setattr(doctor_mod, "HERMES_HOME", home)
|
||||
monkeypatch.setattr(doctor_mod, "PROJECT_ROOT", tmp_path / "project")
|
||||
monkeypatch.setattr(doctor_mod, "_DHH", str(home))
|
||||
(tmp_path / "project").mkdir(exist_ok=True)
|
||||
|
||||
# Stub tool availability (returns empty) so doctor runs past it
|
||||
fake_model_tools = types.SimpleNamespace(
|
||||
check_tool_availability=lambda *a, **kw: ([], []),
|
||||
TOOLSET_REQUIREMENTS={},
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "model_tools", fake_model_tools)
|
||||
|
||||
# Stub auth checks to avoid real API calls
|
||||
try:
|
||||
from hermes_cli import auth as _auth_mod
|
||||
monkeypatch.setattr(_auth_mod, "get_nous_auth_status", lambda: {})
|
||||
monkeypatch.setattr(_auth_mod, "get_codex_auth_status", lambda: {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
import io, contextlib
|
||||
buf = io.StringIO()
|
||||
with contextlib.redirect_stdout(buf):
|
||||
doctor_mod.run_doctor(Namespace(fix=False))
|
||||
return buf.getvalue()
|
||||
|
||||
def test_no_provider_shows_builtin_ok(self, monkeypatch, tmp_path):
|
||||
out = self._run_doctor_and_capture(monkeypatch, tmp_path, provider="")
|
||||
assert "Memory Provider" in out
|
||||
assert "Built-in memory active" in out
|
||||
# Should NOT mention Honcho or Mem0 errors
|
||||
assert "Honcho API key" not in out
|
||||
assert "Mem0" not in out
|
||||
|
||||
def test_honcho_provider_not_installed_shows_fail(self, monkeypatch, tmp_path):
|
||||
# Make honcho import fail
|
||||
monkeypatch.setitem(
|
||||
sys.modules, "plugins.memory.honcho.client", None
|
||||
)
|
||||
out = self._run_doctor_and_capture(monkeypatch, tmp_path, provider="honcho")
|
||||
assert "Memory Provider" in out
|
||||
# Should show failure since honcho is set but not importable
|
||||
assert "Built-in memory active" not in out
|
||||
|
||||
def test_mem0_provider_not_installed_shows_fail(self, monkeypatch, tmp_path):
|
||||
# Make mem0 import fail
|
||||
monkeypatch.setitem(sys.modules, "plugins.memory.mem0", None)
|
||||
out = self._run_doctor_and_capture(monkeypatch, tmp_path, provider="mem0")
|
||||
assert "Memory Provider" in out
|
||||
assert "Built-in memory active" not in out
|
||||
|
||||
@@ -143,6 +143,82 @@ def test_resolve_runtime_provider_codex(monkeypatch):
|
||||
assert resolved["requested_provider"] == "openai-codex"
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_qwen_oauth(monkeypatch):
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "qwen-oauth")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"resolve_qwen_runtime_credentials",
|
||||
lambda: {
|
||||
"provider": "qwen-oauth",
|
||||
"base_url": "https://portal.qwen.ai/v1",
|
||||
"api_key": "qwen-token",
|
||||
"source": "qwen-cli",
|
||||
"expires_at_ms": 1775640710946,
|
||||
},
|
||||
)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="qwen-oauth")
|
||||
|
||||
assert resolved["provider"] == "qwen-oauth"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["base_url"] == "https://portal.qwen.ai/v1"
|
||||
assert resolved["api_key"] == "qwen-token"
|
||||
assert resolved["requested_provider"] == "qwen-oauth"
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_uses_qwen_pool_entry(monkeypatch):
|
||||
class _Entry:
|
||||
access_token = "pool-qwen-token"
|
||||
source = "manual:qwen_cli"
|
||||
base_url = "https://portal.qwen.ai/v1"
|
||||
|
||||
class _Pool:
|
||||
def has_credentials(self):
|
||||
return True
|
||||
|
||||
def select(self):
|
||||
return _Entry()
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "qwen-oauth")
|
||||
monkeypatch.setattr(rp, "load_pool", lambda provider: _Pool())
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {"provider": "qwen-oauth", "default": "coder-model"})
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="qwen-oauth")
|
||||
|
||||
assert resolved["provider"] == "qwen-oauth"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["base_url"] == "https://portal.qwen.ai/v1"
|
||||
assert resolved["api_key"] == "pool-qwen-token"
|
||||
assert resolved["source"] == "manual:qwen_cli"
|
||||
|
||||
|
||||
def test_resolve_provider_alias_qwen(monkeypatch):
|
||||
monkeypatch.setattr(rp.auth_mod, "_load_auth_store", lambda: {})
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
assert rp.resolve_provider("qwen-portal") == "qwen-oauth"
|
||||
assert rp.resolve_provider("qwen-cli") == "qwen-oauth"
|
||||
|
||||
|
||||
def test_qwen_oauth_auto_fallthrough_on_auth_failure(monkeypatch):
|
||||
"""When requested_provider is 'auto' and Qwen creds fail, fall through."""
|
||||
from hermes_cli.auth import AuthError
|
||||
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "qwen-oauth")
|
||||
monkeypatch.setattr(
|
||||
rp,
|
||||
"resolve_qwen_runtime_credentials",
|
||||
lambda **kw: (_ for _ in ()).throw(AuthError("stale", provider="qwen-oauth", code="qwen_auth_missing")),
|
||||
)
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "test-or-key")
|
||||
|
||||
# Should NOT raise — falls through to OpenRouter
|
||||
resolved = rp.resolve_runtime_provider(requested="auto")
|
||||
# The fallthrough means it won't be qwen-oauth
|
||||
assert resolved["provider"] != "qwen-oauth"
|
||||
|
||||
|
||||
def test_resolve_runtime_provider_ai_gateway(monkeypatch):
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "ai-gateway")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
@@ -808,6 +884,55 @@ def test_minimax_explicit_api_mode_respected(monkeypatch):
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
|
||||
|
||||
def test_minimax_config_base_url_overrides_hardcoded_default(monkeypatch):
|
||||
"""model.base_url in config.yaml should override the hardcoded default (#6039)."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {
|
||||
"provider": "minimax",
|
||||
"base_url": "https://api.minimaxi.com/anthropic",
|
||||
})
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-minimax-key")
|
||||
monkeypatch.delenv("MINIMAX_BASE_URL", raising=False)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax")
|
||||
|
||||
assert resolved["provider"] == "minimax"
|
||||
assert resolved["base_url"] == "https://api.minimaxi.com/anthropic"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
|
||||
|
||||
def test_minimax_env_base_url_still_wins_over_config(monkeypatch):
|
||||
"""MINIMAX_BASE_URL env var should take priority over config.yaml model.base_url."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {
|
||||
"provider": "minimax",
|
||||
"base_url": "https://api.minimaxi.com/anthropic",
|
||||
})
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-minimax-key")
|
||||
monkeypatch.setenv("MINIMAX_BASE_URL", "https://custom.example.com/v1")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax")
|
||||
|
||||
# Env var wins because resolve_api_key_provider_credentials prefers it
|
||||
assert resolved["base_url"] == "https://custom.example.com/v1"
|
||||
|
||||
|
||||
def test_minimax_config_base_url_ignored_for_different_provider(monkeypatch):
|
||||
"""model.base_url should NOT be used when model.provider doesn't match."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {
|
||||
"provider": "openrouter",
|
||||
"base_url": "https://some-other-endpoint.com/v1",
|
||||
})
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-minimax-key")
|
||||
monkeypatch.delenv("MINIMAX_BASE_URL", raising=False)
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax")
|
||||
|
||||
# Should use the default, NOT the config base_url from a different provider
|
||||
assert resolved["base_url"] == "https://api.minimax.io/anthropic"
|
||||
|
||||
|
||||
def test_alibaba_default_coding_intl_endpoint_uses_chat_completions(monkeypatch):
|
||||
"""Alibaba default coding-intl /v1 URL should use chat_completions mode."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "alibaba")
|
||||
|
||||
@@ -34,8 +34,8 @@ class TestSetupProviderModelSelection:
|
||||
@pytest.mark.parametrize("provider_id,expected_defaults", [
|
||||
("zai", ["glm-5", "glm-4.7", "glm-4.5", "glm-4.5-flash"]),
|
||||
("kimi-coding", ["kimi-k2.5", "kimi-k2-thinking", "kimi-k2-turbo-preview"]),
|
||||
("minimax", ["MiniMax-M2.7", "MiniMax-M2.7-highspeed", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1"]),
|
||||
("minimax-cn", ["MiniMax-M2.7", "MiniMax-M2.7-highspeed", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1"]),
|
||||
("minimax", ["MiniMax-M1", "MiniMax-M1-40k", "MiniMax-M1-80k", "MiniMax-M1-128k", "MiniMax-M1-256k", "MiniMax-M2.5", "MiniMax-M2.7"]),
|
||||
("minimax-cn", ["MiniMax-M1", "MiniMax-M1-40k", "MiniMax-M1-80k", "MiniMax-M1-128k", "MiniMax-M1-256k", "MiniMax-M2.5", "MiniMax-M2.7"]),
|
||||
("opencode-zen", ["gpt-5.4", "gpt-5.3-codex", "claude-sonnet-4-6", "gemini-3-flash"]),
|
||||
("opencode-go", ["glm-5", "kimi-k2.5", "minimax-m2.5", "minimax-m2.7"]),
|
||||
])
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
"""Tests for _save_oversized_tool_result() — the large tool response handler.
|
||||
|
||||
When a tool returns more than _LARGE_RESULT_CHARS characters, the full content
|
||||
is saved to a file and the model receives a preview + file path instead.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from run_agent import (
|
||||
_save_oversized_tool_result,
|
||||
_LARGE_RESULT_CHARS,
|
||||
_LARGE_RESULT_PREVIEW_CHARS,
|
||||
)
|
||||
|
||||
|
||||
class TestSaveOversizedToolResult:
|
||||
"""Unit tests for the large tool result handler."""
|
||||
|
||||
def test_small_result_returned_unchanged(self):
|
||||
"""Results under the threshold pass through untouched."""
|
||||
small = "x" * 1000
|
||||
assert _save_oversized_tool_result("terminal", small) is small
|
||||
|
||||
def test_exactly_at_threshold_returned_unchanged(self):
|
||||
"""Results exactly at the threshold pass through."""
|
||||
exact = "y" * _LARGE_RESULT_CHARS
|
||||
assert _save_oversized_tool_result("terminal", exact) is exact
|
||||
|
||||
def test_oversized_result_saved_to_file(self, tmp_path, monkeypatch):
|
||||
"""Results over the threshold are written to a file."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
big = "A" * (_LARGE_RESULT_CHARS + 500)
|
||||
result = _save_oversized_tool_result("terminal", big)
|
||||
|
||||
# Should contain the preview
|
||||
assert result.startswith("A" * _LARGE_RESULT_PREVIEW_CHARS)
|
||||
# Should mention the file path
|
||||
assert "Full output saved to:" in result
|
||||
# Should mention original size
|
||||
assert f"{len(big):,}" in result
|
||||
|
||||
# Extract the file path and verify the file exists with full content
|
||||
match = re.search(r"Full output saved to: (.+?)\n", result)
|
||||
assert match, f"No file path found in result: {result[:300]}"
|
||||
filepath = match.group(1)
|
||||
assert os.path.isfile(filepath)
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
saved = f.read()
|
||||
assert saved == big
|
||||
assert len(saved) == _LARGE_RESULT_CHARS + 500
|
||||
|
||||
def test_file_placed_in_cache_tool_responses(self, tmp_path, monkeypatch):
|
||||
"""Saved file lives under HERMES_HOME/cache/tool_responses/."""
|
||||
hermes_home = str(tmp_path / ".hermes")
|
||||
monkeypatch.setenv("HERMES_HOME", hermes_home)
|
||||
os.makedirs(hermes_home, exist_ok=True)
|
||||
|
||||
big = "B" * (_LARGE_RESULT_CHARS + 1)
|
||||
result = _save_oversized_tool_result("web_search", big)
|
||||
|
||||
match = re.search(r"Full output saved to: (.+?)\n", result)
|
||||
filepath = match.group(1)
|
||||
expected_dir = os.path.join(hermes_home, "cache", "tool_responses")
|
||||
assert filepath.startswith(expected_dir)
|
||||
|
||||
def test_filename_contains_tool_name(self, tmp_path, monkeypatch):
|
||||
"""The saved filename includes a sanitized version of the tool name."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
big = "C" * (_LARGE_RESULT_CHARS + 1)
|
||||
result = _save_oversized_tool_result("browser_navigate", big)
|
||||
|
||||
match = re.search(r"Full output saved to: (.+?)\n", result)
|
||||
filename = os.path.basename(match.group(1))
|
||||
assert filename.startswith("browser_navigate_")
|
||||
assert filename.endswith(".txt")
|
||||
|
||||
def test_tool_name_sanitized(self, tmp_path, monkeypatch):
|
||||
"""Special characters in tool names are replaced in the filename."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
big = "D" * (_LARGE_RESULT_CHARS + 1)
|
||||
result = _save_oversized_tool_result("mcp:some/weird tool", big)
|
||||
|
||||
match = re.search(r"Full output saved to: (.+?)\n", result)
|
||||
filename = os.path.basename(match.group(1))
|
||||
# No slashes or colons in filename
|
||||
assert "/" not in filename
|
||||
assert ":" not in filename
|
||||
|
||||
def test_fallback_on_write_failure(self, tmp_path, monkeypatch):
|
||||
"""When file write fails, falls back to destructive truncation."""
|
||||
# Point HERMES_HOME to a path that will fail (file, not directory)
|
||||
bad_path = str(tmp_path / "not_a_dir.txt")
|
||||
with open(bad_path, "w") as f:
|
||||
f.write("I'm a file, not a directory")
|
||||
monkeypatch.setenv("HERMES_HOME", bad_path)
|
||||
|
||||
big = "E" * (_LARGE_RESULT_CHARS + 50_000)
|
||||
result = _save_oversized_tool_result("terminal", big)
|
||||
|
||||
# Should still contain data (fallback truncation)
|
||||
assert len(result) > 0
|
||||
assert result.startswith("E" * 1000)
|
||||
# Should mention the failure
|
||||
assert "File save failed" in result
|
||||
# Should be truncated to approximately _LARGE_RESULT_CHARS + error msg
|
||||
assert len(result) < len(big)
|
||||
|
||||
def test_preview_length_capped(self, tmp_path, monkeypatch):
|
||||
"""The inline preview is capped at _LARGE_RESULT_PREVIEW_CHARS."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
# Use distinct chars so we can measure the preview
|
||||
big = "Z" * (_LARGE_RESULT_CHARS + 5000)
|
||||
result = _save_oversized_tool_result("terminal", big)
|
||||
|
||||
# The preview section is the content before the "[Large tool response:" marker
|
||||
marker_pos = result.index("[Large tool response:")
|
||||
preview_section = result[:marker_pos].rstrip()
|
||||
assert len(preview_section) == _LARGE_RESULT_PREVIEW_CHARS
|
||||
|
||||
def test_guidance_message_mentions_tools(self, tmp_path, monkeypatch):
|
||||
"""The replacement message tells the model how to access the file."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
big = "F" * (_LARGE_RESULT_CHARS + 1)
|
||||
result = _save_oversized_tool_result("terminal", big)
|
||||
|
||||
assert "read_file" in result
|
||||
assert "search_files" in result
|
||||
|
||||
def test_empty_result_passes_through(self):
|
||||
"""Empty strings are not oversized."""
|
||||
assert _save_oversized_tool_result("terminal", "") == ""
|
||||
|
||||
def test_unicode_content_preserved(self, tmp_path, monkeypatch):
|
||||
"""Unicode content is fully preserved in the saved file."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes"))
|
||||
os.makedirs(tmp_path / ".hermes", exist_ok=True)
|
||||
|
||||
# Mix of ASCII and multi-byte unicode to exceed threshold
|
||||
unit = "Hello 世界! 🎉 " * 100 # ~1400 chars per repeat
|
||||
big = unit * ((_LARGE_RESULT_CHARS // len(unit)) + 1)
|
||||
assert len(big) > _LARGE_RESULT_CHARS
|
||||
|
||||
result = _save_oversized_tool_result("terminal", big)
|
||||
match = re.search(r"Full output saved to: (.+?)\n", result)
|
||||
filepath = match.group(1)
|
||||
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
saved = f.read()
|
||||
assert saved == big
|
||||
@@ -872,6 +872,52 @@ class TestBuildApiKwargs:
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["max_tokens"] == 4096
|
||||
|
||||
def test_qwen_portal_formats_messages_and_metadata(self, agent):
|
||||
agent.base_url = "https://portal.qwen.ai/v1"
|
||||
agent._base_url_lower = agent.base_url.lower()
|
||||
agent.session_id = "sess-123"
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "assistant", "content": "Got it"},
|
||||
{"role": "user", "content": "hi"},
|
||||
]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["metadata"]["sessionId"] == "sess-123"
|
||||
assert kwargs["extra_body"]["vl_high_resolution_images"] is True
|
||||
assert isinstance(kwargs["messages"][0]["content"], list)
|
||||
assert kwargs["messages"][0]["content"][0]["cache_control"] == {"type": "ephemeral"}
|
||||
assert kwargs["messages"][2]["content"][0]["text"] == "hi"
|
||||
|
||||
def test_qwen_portal_normalizes_bare_string_content_parts(self, agent):
|
||||
agent.base_url = "https://portal.qwen.ai/v1"
|
||||
agent._base_url_lower = agent.base_url.lower()
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": "system"}]},
|
||||
{"role": "user", "content": ["hello", {"type": "text", "text": "world"}]},
|
||||
]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
user_content = kwargs["messages"][1]["content"]
|
||||
assert user_content[0] == {"type": "text", "text": "hello"}
|
||||
assert user_content[1] == {"type": "text", "text": "world"}
|
||||
|
||||
def test_qwen_portal_no_system_message(self, agent):
|
||||
agent.base_url = "https://portal.qwen.ai/v1"
|
||||
agent._base_url_lower = agent.base_url.lower()
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
# Should not crash even without a system message
|
||||
assert kwargs["messages"][0]["content"][0]["text"] == "hi"
|
||||
assert "cache_control" not in kwargs["messages"][0]["content"][0]
|
||||
|
||||
def test_qwen_portal_omits_max_tokens(self, agent):
|
||||
agent.base_url = "https://portal.qwen.ai/v1"
|
||||
agent._base_url_lower = agent.base_url.lower()
|
||||
agent.max_tokens = 4096
|
||||
messages = [{"role": "system", "content": "sys"}, {"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert "max_tokens" not in kwargs
|
||||
assert "max_completion_tokens" not in kwargs
|
||||
|
||||
|
||||
class TestBuildAssistantMessage:
|
||||
def test_basic_message(self, agent):
|
||||
@@ -1011,10 +1057,9 @@ class TestExecuteToolCalls:
|
||||
big_result = "x" * 150_000
|
||||
with patch("run_agent.handle_function_call", return_value=big_result):
|
||||
agent._execute_tool_calls(mock_msg, messages, "task-1")
|
||||
# Content should be replaced with preview + file path
|
||||
# Content should be replaced with persisted-output or truncation
|
||||
assert len(messages[0]["content"]) < 150_000
|
||||
assert "Large tool response" in messages[0]["content"]
|
||||
assert "Full output saved to:" in messages[0]["content"]
|
||||
assert ("Truncated" in messages[0]["content"] or "<persisted-output>" in messages[0]["content"])
|
||||
|
||||
|
||||
class TestConcurrentToolExecution:
|
||||
@@ -1249,8 +1294,7 @@ class TestConcurrentToolExecution:
|
||||
assert len(messages) == 2
|
||||
for m in messages:
|
||||
assert len(m["content"]) < 150_000
|
||||
assert "Large tool response" in m["content"]
|
||||
assert "Full output saved to:" in m["content"]
|
||||
assert ("Truncated" in m["content"] or "<persisted-output>" in m["content"])
|
||||
|
||||
def test_invoke_tool_dispatches_to_handle_function_call(self, agent):
|
||||
"""_invoke_tool should route regular tools through handle_function_call."""
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Tests for Ollama num_ctx context length detection and injection.
|
||||
|
||||
Covers:
|
||||
agent/model_metadata.py — query_ollama_num_ctx()
|
||||
run_agent.py — _ollama_num_ctx detection + extra_body injection
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.model_metadata import query_ollama_num_ctx
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Level 1: query_ollama_num_ctx — Ollama API interaction
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _mock_httpx_client(show_response_data, status_code=200):
|
||||
"""Create a mock httpx.Client context manager that returns given /api/show data."""
|
||||
mock_resp = MagicMock(status_code=status_code)
|
||||
mock_resp.json.return_value = show_response_data
|
||||
mock_client = MagicMock()
|
||||
mock_client.post.return_value = mock_resp
|
||||
mock_ctx = MagicMock()
|
||||
mock_ctx.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_ctx.__exit__ = MagicMock(return_value=False)
|
||||
return mock_ctx, mock_client
|
||||
|
||||
|
||||
class TestQueryOllamaNumCtx:
|
||||
"""Test the Ollama /api/show context length query."""
|
||||
|
||||
def test_returns_context_from_model_info(self):
|
||||
"""Should extract context_length from GGUF model_info metadata."""
|
||||
show_data = {
|
||||
"model_info": {"llama.context_length": 131072},
|
||||
"parameters": "",
|
||||
}
|
||||
mock_ctx, _ = _mock_httpx_client(show_data)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
# httpx is imported inside the function — patch the module import
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("llama3.1:8b", "http://localhost:11434/v1")
|
||||
|
||||
assert result == 131072
|
||||
|
||||
def test_prefers_explicit_num_ctx_from_modelfile(self):
|
||||
"""If the Modelfile sets num_ctx explicitly, that should take priority."""
|
||||
show_data = {
|
||||
"model_info": {"llama.context_length": 131072},
|
||||
"parameters": "num_ctx 32768\ntemperature 0.7",
|
||||
}
|
||||
mock_ctx, _ = _mock_httpx_client(show_data)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("custom-model", "http://localhost:11434")
|
||||
|
||||
assert result == 32768
|
||||
|
||||
def test_returns_none_for_non_ollama_server(self):
|
||||
"""Should return None if the server is not Ollama."""
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"):
|
||||
result = query_ollama_num_ctx("model", "http://localhost:1234")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_connection_error(self):
|
||||
"""Should return None if the server is unreachable."""
|
||||
with patch("agent.model_metadata.detect_local_server_type", side_effect=Exception("timeout")):
|
||||
result = query_ollama_num_ctx("model", "http://localhost:11434")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_404(self):
|
||||
"""Should return None if the model is not found."""
|
||||
mock_ctx, _ = _mock_httpx_client({}, status_code=404)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("nonexistent", "http://localhost:11434")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_strips_provider_prefix(self):
|
||||
"""Should strip 'local:' prefix from model name before querying."""
|
||||
show_data = {
|
||||
"model_info": {"qwen2.context_length": 32768},
|
||||
"parameters": "",
|
||||
}
|
||||
mock_ctx, mock_client = _mock_httpx_client(show_data)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("local:qwen2.5:7b", "http://localhost:11434/v1")
|
||||
|
||||
# Verify the post was called with stripped name (no "local:" prefix)
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["name"] == "qwen2.5:7b" or call_args[0][1] is not None
|
||||
assert result == 32768
|
||||
|
||||
def test_handles_qwen2_architecture_key(self):
|
||||
"""Different model architectures use different key prefixes in model_info."""
|
||||
show_data = {
|
||||
"model_info": {"qwen2.context_length": 65536},
|
||||
"parameters": "",
|
||||
}
|
||||
mock_ctx, _ = _mock_httpx_client(show_data)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("qwen2.5:32b", "http://localhost:11434")
|
||||
|
||||
assert result == 65536
|
||||
|
||||
def test_returns_none_when_model_info_empty(self):
|
||||
"""Should return None if model_info has no context_length key."""
|
||||
show_data = {
|
||||
"model_info": {"llama.embedding_length": 4096},
|
||||
"parameters": "",
|
||||
}
|
||||
mock_ctx, _ = _mock_httpx_client(show_data)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("model", "http://localhost:11434")
|
||||
|
||||
assert result is None
|
||||
@@ -0,0 +1,117 @@
|
||||
"""Tests for agent.retry_utils jittered backoff."""
|
||||
|
||||
import threading
|
||||
|
||||
import agent.retry_utils as retry_utils
|
||||
from agent.retry_utils import jittered_backoff
|
||||
|
||||
|
||||
def test_backoff_is_exponential():
|
||||
"""Base delay should double each attempt (before jitter)."""
|
||||
for attempt in (1, 2, 3, 4):
|
||||
delays = [jittered_backoff(attempt, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0) for _ in range(100)]
|
||||
expected = min(5.0 * (2 ** (attempt - 1)), 120.0)
|
||||
mean = sum(delays) / len(delays)
|
||||
assert abs(mean - expected) < 0.01, f"attempt {attempt}: expected {expected}, got {mean}"
|
||||
|
||||
|
||||
def test_backoff_respects_max_delay():
|
||||
"""Even with high attempt numbers, delay should not exceed max_delay."""
|
||||
for attempt in (10, 20, 100):
|
||||
delay = jittered_backoff(attempt, base_delay=5.0, max_delay=60.0, jitter_ratio=0.0)
|
||||
assert delay <= 60.0, f"attempt {attempt}: delay {delay} exceeds max 60s"
|
||||
|
||||
|
||||
def test_backoff_adds_jitter():
|
||||
"""With jitter enabled, delays should vary across calls."""
|
||||
delays = [jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5) for _ in range(50)]
|
||||
assert min(delays) != max(delays), "jitter should produce varying delays"
|
||||
assert all(d >= 10.0 for d in delays), "jittered delay should be >= base delay"
|
||||
assert all(d <= 15.0 for d in delays), "jittered delay should be bounded"
|
||||
|
||||
|
||||
def test_backoff_attempt_1_is_base():
|
||||
"""First attempt delay should equal base_delay (with no jitter)."""
|
||||
delay = jittered_backoff(1, base_delay=3.0, max_delay=120.0, jitter_ratio=0.0)
|
||||
assert delay == 3.0
|
||||
|
||||
|
||||
def test_backoff_with_zero_base_delay_returns_max():
|
||||
"""base_delay=0 should return max_delay (guard against busy-wait)."""
|
||||
delay = jittered_backoff(1, base_delay=0.0, max_delay=60.0, jitter_ratio=0.0)
|
||||
assert delay == 60.0
|
||||
|
||||
|
||||
def test_backoff_with_extreme_attempt_returns_max():
|
||||
"""Very large attempt numbers should not overflow and should return max_delay."""
|
||||
delay = jittered_backoff(999, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0)
|
||||
assert delay == 120.0
|
||||
|
||||
|
||||
def test_backoff_negative_attempt_treated_as_one():
|
||||
"""Negative attempt should not crash and behaves like attempt=1."""
|
||||
delay = jittered_backoff(-5, base_delay=10.0, max_delay=120.0, jitter_ratio=0.0)
|
||||
assert delay == 10.0
|
||||
|
||||
|
||||
def test_backoff_thread_safety():
|
||||
"""Concurrent calls should generally produce different delays."""
|
||||
results = []
|
||||
barrier = threading.Barrier(8)
|
||||
|
||||
def _call_backoff():
|
||||
barrier.wait()
|
||||
results.append(jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5))
|
||||
|
||||
threads = [threading.Thread(target=_call_backoff) for _ in range(8)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
assert len(results) == 8
|
||||
unique = len(set(results))
|
||||
assert unique >= 6, f"Expected mostly unique delays, got {unique}/8 unique"
|
||||
|
||||
|
||||
def test_backoff_uses_locked_tick_for_seed(monkeypatch):
|
||||
"""Seed derivation should use per-call tick captured under lock."""
|
||||
import time
|
||||
|
||||
monkeypatch.setattr(retry_utils, "_jitter_counter", 0)
|
||||
|
||||
recorded_seeds = []
|
||||
|
||||
class _RecordingRandom:
|
||||
def __init__(self, seed):
|
||||
recorded_seeds.append(seed)
|
||||
|
||||
def uniform(self, a, b):
|
||||
return 0.0
|
||||
|
||||
monkeypatch.setattr(retry_utils.random, "Random", _RecordingRandom)
|
||||
|
||||
fixed_time_ns = 123456789
|
||||
|
||||
def _time_ns_wait_for_two_ticks():
|
||||
deadline = time.time() + 2.0
|
||||
while retry_utils._jitter_counter < 2 and time.time() < deadline:
|
||||
time.sleep(0.001)
|
||||
return fixed_time_ns
|
||||
|
||||
monkeypatch.setattr(retry_utils.time, "time_ns", _time_ns_wait_for_two_ticks)
|
||||
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
def _call():
|
||||
barrier.wait()
|
||||
jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5)
|
||||
|
||||
threads = [threading.Thread(target=_call) for _ in range(2)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join(timeout=5)
|
||||
|
||||
assert len(recorded_seeds) == 2
|
||||
assert len(set(recorded_seeds)) == 2, f"Expected unique seeds, got {recorded_seeds}"
|
||||
@@ -0,0 +1,174 @@
|
||||
"""Tests for BaseEnvironment unified execution model.
|
||||
|
||||
Tests _wrap_command(), _extract_cwd_from_output(), _embed_stdin_heredoc(),
|
||||
init_session() failure handling, and the CWD marker contract.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tools.environments.base import BaseEnvironment, _cwd_marker
|
||||
|
||||
|
||||
class _TestableEnv(BaseEnvironment):
|
||||
"""Concrete subclass for testing base class methods."""
|
||||
|
||||
def __init__(self, cwd="/tmp", timeout=10):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
def _run_bash(self, cmd_string, *, login=False, timeout=120, stdin_data=None):
|
||||
raise NotImplementedError("Use mock")
|
||||
|
||||
def cleanup(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestWrapCommand:
|
||||
def test_basic_shape(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("echo hello", "/tmp")
|
||||
|
||||
assert "source" in wrapped
|
||||
assert "cd /tmp" in wrapped or "cd '/tmp'" in wrapped
|
||||
assert "eval 'echo hello'" in wrapped
|
||||
assert "__hermes_ec=$?" in wrapped
|
||||
assert "export -p >" in wrapped
|
||||
assert "pwd -P >" in wrapped
|
||||
assert env._cwd_marker in wrapped
|
||||
assert "exit $__hermes_ec" in wrapped
|
||||
|
||||
def test_no_snapshot_skips_source(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = False
|
||||
wrapped = env._wrap_command("echo hello", "/tmp")
|
||||
|
||||
assert "source" not in wrapped
|
||||
|
||||
def test_single_quote_escaping(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("echo 'hello world'", "/tmp")
|
||||
|
||||
assert "eval 'echo '\\''hello world'\\'''" in wrapped
|
||||
|
||||
def test_tilde_not_quoted(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("ls", "~")
|
||||
|
||||
assert "cd ~" in wrapped
|
||||
assert "cd '~'" not in wrapped
|
||||
|
||||
def test_cd_failure_exit_126(self):
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = True
|
||||
wrapped = env._wrap_command("ls", "/nonexistent")
|
||||
|
||||
assert "exit 126" in wrapped
|
||||
|
||||
|
||||
class TestExtractCwdFromOutput:
|
||||
def test_happy_path(self):
|
||||
env = _TestableEnv()
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"hello\n{marker}/home/user{marker}\n",
|
||||
}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == "/home/user"
|
||||
assert marker not in result["output"]
|
||||
|
||||
def test_missing_marker(self):
|
||||
env = _TestableEnv()
|
||||
result = {"output": "hello world\n"}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == "/tmp" # unchanged
|
||||
|
||||
def test_marker_in_command_output(self):
|
||||
"""If the marker appears in command output AND as the real marker,
|
||||
rfind grabs the last (real) one."""
|
||||
env = _TestableEnv()
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"user typed {marker} in their output\nreal output\n{marker}/correct/path{marker}\n",
|
||||
}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert env.cwd == "/correct/path"
|
||||
|
||||
def test_output_cleaned(self):
|
||||
env = _TestableEnv()
|
||||
marker = env._cwd_marker
|
||||
result = {
|
||||
"output": f"hello\n{marker}/tmp{marker}\n",
|
||||
}
|
||||
env._extract_cwd_from_output(result)
|
||||
|
||||
assert "hello" in result["output"]
|
||||
assert marker not in result["output"]
|
||||
|
||||
|
||||
class TestEmbedStdinHeredoc:
|
||||
def test_heredoc_format(self):
|
||||
result = BaseEnvironment._embed_stdin_heredoc("cat", "hello world")
|
||||
|
||||
assert result.startswith("cat << '")
|
||||
assert "hello world" in result
|
||||
assert "HERMES_STDIN_" in result
|
||||
|
||||
def test_unique_delimiter_each_call(self):
|
||||
r1 = BaseEnvironment._embed_stdin_heredoc("cat", "data")
|
||||
r2 = BaseEnvironment._embed_stdin_heredoc("cat", "data")
|
||||
|
||||
# Extract delimiters
|
||||
d1 = r1.split("'")[1]
|
||||
d2 = r2.split("'")[1]
|
||||
assert d1 != d2 # UUID-based, should be unique
|
||||
|
||||
|
||||
class TestInitSessionFailure:
|
||||
def test_snapshot_ready_false_on_failure(self):
|
||||
env = _TestableEnv()
|
||||
|
||||
def failing_run_bash(*args, **kwargs):
|
||||
raise RuntimeError("bash not found")
|
||||
|
||||
env._run_bash = failing_run_bash
|
||||
env.init_session()
|
||||
|
||||
assert env._snapshot_ready is False
|
||||
|
||||
def test_login_flag_when_snapshot_not_ready(self):
|
||||
"""When _snapshot_ready=False, execute() should pass login=True to _run_bash."""
|
||||
env = _TestableEnv()
|
||||
env._snapshot_ready = False
|
||||
|
||||
calls = []
|
||||
def mock_run_bash(cmd, *, login=False, timeout=120, stdin_data=None):
|
||||
calls.append({"login": login})
|
||||
# Return a mock process handle
|
||||
mock = MagicMock()
|
||||
mock.poll.return_value = 0
|
||||
mock.returncode = 0
|
||||
mock.stdout = iter([])
|
||||
return mock
|
||||
|
||||
env._run_bash = mock_run_bash
|
||||
env.execute("echo test")
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["login"] is True
|
||||
|
||||
|
||||
class TestCwdMarker:
|
||||
def test_marker_contains_session_id(self):
|
||||
env = _TestableEnv()
|
||||
assert env._session_id in env._cwd_marker
|
||||
|
||||
def test_unique_per_instance(self):
|
||||
env1 = _TestableEnv()
|
||||
env2 = _TestableEnv()
|
||||
assert env1._cwd_marker != env2._cwd_marker
|
||||
@@ -16,6 +16,7 @@ from tools.browser_camofox import (
|
||||
_managed_persistence_enabled,
|
||||
camofox_close,
|
||||
camofox_navigate,
|
||||
camofox_soft_cleanup,
|
||||
check_camofox_available,
|
||||
cleanup_all_camofox_sessions,
|
||||
get_vnc_url,
|
||||
@@ -240,3 +241,50 @@ class TestVncUrlDiscovery:
|
||||
|
||||
assert result["vnc_url"] == "http://localhost:6080"
|
||||
assert "vnc_hint" in result
|
||||
|
||||
|
||||
class TestCamofoxSoftCleanup:
|
||||
"""camofox_soft_cleanup drops local state only when managed persistence is on."""
|
||||
|
||||
def test_returns_true_and_drops_session_when_enabled(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
with _enable_persistence():
|
||||
_get_session("task-1")
|
||||
result = camofox_soft_cleanup("task-1")
|
||||
|
||||
assert result is True
|
||||
# Session should have been dropped from in-memory store
|
||||
import tools.browser_camofox as mod
|
||||
with mod._sessions_lock:
|
||||
assert "task-1" not in mod._sessions
|
||||
|
||||
def test_returns_false_when_disabled(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
_get_session("task-1")
|
||||
config = {"browser": {"camofox": {"managed_persistence": False}}}
|
||||
with patch("tools.browser_camofox.load_config", return_value=config):
|
||||
result = camofox_soft_cleanup("task-1")
|
||||
|
||||
assert result is False
|
||||
# Session should still be present — not dropped
|
||||
import tools.browser_camofox as mod
|
||||
with mod._sessions_lock:
|
||||
assert "task-1" in mod._sessions
|
||||
|
||||
def test_does_not_call_server_delete(self, tmp_path, monkeypatch):
|
||||
"""Soft cleanup must never hit the Camofox /sessions DELETE endpoint."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("CAMOFOX_URL", "http://localhost:9377")
|
||||
|
||||
with (
|
||||
_enable_persistence(),
|
||||
patch("tools.browser_camofox.requests.delete") as mock_delete,
|
||||
):
|
||||
_get_session("task-1")
|
||||
camofox_soft_cleanup("task-1")
|
||||
|
||||
mock_delete.assert_not_called()
|
||||
|
||||
@@ -65,6 +65,62 @@ class TestBrowserCleanup:
|
||||
mock_stop.assert_called_once_with("task-1")
|
||||
mock_run.assert_called_once_with("task-1", "close", [], timeout=10)
|
||||
|
||||
def test_cleanup_camofox_managed_persistence_skips_close(self):
|
||||
"""When camofox mode + managed persistence, soft_cleanup fires instead of close."""
|
||||
browser_tool = self.browser_tool
|
||||
browser_tool._active_sessions["task-1"] = {
|
||||
"session_name": "sess-1",
|
||||
"bb_session_id": None,
|
||||
}
|
||||
browser_tool._session_last_activity["task-1"] = 123.0
|
||||
|
||||
with (
|
||||
patch("tools.browser_tool._is_camofox_mode", return_value=True),
|
||||
patch("tools.browser_tool._maybe_stop_recording") as mock_stop,
|
||||
patch(
|
||||
"tools.browser_tool._run_browser_command",
|
||||
return_value={"success": True},
|
||||
),
|
||||
patch("tools.browser_tool.os.path.exists", return_value=False),
|
||||
patch(
|
||||
"tools.browser_camofox.camofox_soft_cleanup",
|
||||
return_value=True,
|
||||
) as mock_soft,
|
||||
patch("tools.browser_camofox.camofox_close") as mock_close,
|
||||
):
|
||||
browser_tool.cleanup_browser("task-1")
|
||||
|
||||
mock_soft.assert_called_once_with("task-1")
|
||||
mock_close.assert_not_called()
|
||||
|
||||
def test_cleanup_camofox_no_persistence_calls_close(self):
|
||||
"""When camofox mode but managed persistence is off, camofox_close fires."""
|
||||
browser_tool = self.browser_tool
|
||||
browser_tool._active_sessions["task-1"] = {
|
||||
"session_name": "sess-1",
|
||||
"bb_session_id": None,
|
||||
}
|
||||
browser_tool._session_last_activity["task-1"] = 123.0
|
||||
|
||||
with (
|
||||
patch("tools.browser_tool._is_camofox_mode", return_value=True),
|
||||
patch("tools.browser_tool._maybe_stop_recording") as mock_stop,
|
||||
patch(
|
||||
"tools.browser_tool._run_browser_command",
|
||||
return_value={"success": True},
|
||||
),
|
||||
patch("tools.browser_tool.os.path.exists", return_value=False),
|
||||
patch(
|
||||
"tools.browser_camofox.camofox_soft_cleanup",
|
||||
return_value=False,
|
||||
) as mock_soft,
|
||||
patch("tools.browser_camofox.camofox_close") as mock_close,
|
||||
):
|
||||
browser_tool.cleanup_browser("task-1")
|
||||
|
||||
mock_soft.assert_called_once_with("task-1")
|
||||
mock_close.assert_called_once_with("task-1")
|
||||
|
||||
def test_emergency_cleanup_clears_all_tracking_state(self):
|
||||
browser_tool = self.browser_tool
|
||||
browser_tool._cleanup_done = False
|
||||
|
||||
@@ -152,6 +152,109 @@ class TestFindAgentBrowser:
|
||||
class TestRunBrowserCommandPathConstruction:
|
||||
"""Verify _run_browser_command() includes Homebrew node dirs in subprocess PATH."""
|
||||
|
||||
def test_subprocess_preserves_executable_path_with_spaces(self, tmp_path):
|
||||
"""A local agent-browser path containing spaces must stay one argv entry."""
|
||||
captured_cmd = None
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait.return_value = 0
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
nonlocal captured_cmd
|
||||
captured_cmd = cmd
|
||||
return mock_proc
|
||||
|
||||
fake_session = {
|
||||
"session_name": "test-session",
|
||||
"session_id": "test-id",
|
||||
"cdp_url": None,
|
||||
}
|
||||
fake_json = json.dumps({"success": True})
|
||||
browser_path = "/Users/test/Library/Application Support/hermes/node_modules/.bin/agent-browser"
|
||||
hermes_home = str(tmp_path / "hermes-home")
|
||||
|
||||
with patch("tools.browser_tool._find_agent_browser", return_value=browser_path), \
|
||||
patch("tools.browser_tool._get_session_info", return_value=fake_session), \
|
||||
patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(tmp_path)), \
|
||||
patch("tools.browser_tool._discover_homebrew_node_dirs", return_value=[]), \
|
||||
patch("hermes_constants.Path.home", return_value=tmp_path), \
|
||||
patch("subprocess.Popen", side_effect=capture_popen), \
|
||||
patch("os.open", return_value=99), \
|
||||
patch("os.close"), \
|
||||
patch("tools.interrupt.is_interrupted", return_value=False), \
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"PATH": "/usr/bin:/bin",
|
||||
"HOME": "/home/test",
|
||||
"HERMES_HOME": hermes_home,
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
with patch("builtins.open", mock_open(read_data=fake_json)):
|
||||
_run_browser_command("test-task", "navigate", ["https://example.com"])
|
||||
|
||||
assert captured_cmd is not None
|
||||
assert captured_cmd[0] == browser_path
|
||||
assert captured_cmd[1:5] == [
|
||||
"--session",
|
||||
"test-session",
|
||||
"--json",
|
||||
"navigate",
|
||||
]
|
||||
|
||||
def test_subprocess_splits_npx_fallback_into_command_and_package(self, tmp_path):
|
||||
"""The synthetic npx fallback should still expand into separate argv items."""
|
||||
captured_cmd = None
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait.return_value = 0
|
||||
|
||||
def capture_popen(cmd, **kwargs):
|
||||
nonlocal captured_cmd
|
||||
captured_cmd = cmd
|
||||
return mock_proc
|
||||
|
||||
fake_session = {
|
||||
"session_name": "test-session",
|
||||
"session_id": "test-id",
|
||||
"cdp_url": None,
|
||||
}
|
||||
fake_json = json.dumps({"success": True})
|
||||
hermes_home = str(tmp_path / "hermes-home")
|
||||
|
||||
with patch("tools.browser_tool._find_agent_browser", return_value="npx agent-browser"), \
|
||||
patch("tools.browser_tool._get_session_info", return_value=fake_session), \
|
||||
patch("tools.browser_tool._socket_safe_tmpdir", return_value=str(tmp_path)), \
|
||||
patch("tools.browser_tool._discover_homebrew_node_dirs", return_value=[]), \
|
||||
patch("hermes_constants.Path.home", return_value=tmp_path), \
|
||||
patch("subprocess.Popen", side_effect=capture_popen), \
|
||||
patch("os.open", return_value=99), \
|
||||
patch("os.close"), \
|
||||
patch("tools.interrupt.is_interrupted", return_value=False), \
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"PATH": "/usr/bin:/bin",
|
||||
"HOME": "/home/test",
|
||||
"HERMES_HOME": hermes_home,
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
with patch("builtins.open", mock_open(read_data=fake_json)):
|
||||
_run_browser_command("test-task", "navigate", ["https://example.com"])
|
||||
|
||||
assert captured_cmd is not None
|
||||
assert captured_cmd[:2] == ["npx", "agent-browser"]
|
||||
assert captured_cmd[2:6] == [
|
||||
"--session",
|
||||
"test-session",
|
||||
"--json",
|
||||
"navigate",
|
||||
]
|
||||
|
||||
def test_subprocess_path_includes_homebrew_node_dirs(self, tmp_path):
|
||||
"""When _discover_homebrew_node_dirs returns dirs, they should appear
|
||||
in the subprocess env PATH passed to Popen."""
|
||||
|
||||
@@ -59,8 +59,8 @@ def daytona_sdk(monkeypatch):
|
||||
@pytest.fixture()
|
||||
def make_env(daytona_sdk, monkeypatch):
|
||||
"""Factory that creates a DaytonaEnvironment with a mocked SDK."""
|
||||
# Prevent is_interrupted from interfering
|
||||
monkeypatch.setattr("tools.interrupt.is_interrupted", lambda: False)
|
||||
# Prevent is_interrupted from interfering — patch where it's used (base.py)
|
||||
monkeypatch.setattr("tools.environments.base.is_interrupted", lambda: False)
|
||||
# Prevent skills/credential sync from consuming mock exec calls
|
||||
monkeypatch.setattr("tools.credential_files.get_credential_file_mounts", lambda: [])
|
||||
monkeypatch.setattr("tools.credential_files.get_skills_directory_mount", lambda **kw: None)
|
||||
@@ -221,41 +221,45 @@ class TestCleanup:
|
||||
class TestExecute:
|
||||
def test_basic_command(self, make_env):
|
||||
sb = _make_sandbox()
|
||||
# First call: $HOME detection; subsequent calls: actual commands
|
||||
# Calls: (1) $HOME detection, (2) init_session bootstrap, (3) actual command
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"), # $HOME
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="hello", exit_code=0), # actual cmd
|
||||
]
|
||||
sb.state = "started"
|
||||
env = make_env(sandbox=sb)
|
||||
|
||||
result = env.execute("echo hello")
|
||||
assert result["output"] == "hello"
|
||||
assert "hello" in result["output"]
|
||||
assert result["returncode"] == 0
|
||||
|
||||
def test_command_wrapped_with_shell_timeout(self, make_env):
|
||||
def test_sdk_timeout_passed_to_exec(self, make_env):
|
||||
"""SDK native timeout is passed to sandbox.process.exec()."""
|
||||
sb = _make_sandbox()
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"),
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="ok", exit_code=0),
|
||||
]
|
||||
sb.state = "started"
|
||||
env = make_env(sandbox=sb, timeout=42)
|
||||
|
||||
env.execute("echo hello")
|
||||
# The command sent to exec should be wrapped with `timeout N sh -c '...'`
|
||||
# The exec call should receive timeout= kwarg (SDK native timeout)
|
||||
call_args = sb.process.exec.call_args_list[-1]
|
||||
assert call_args[1]["timeout"] == 42
|
||||
# The command should NOT have a shell `timeout` prefix
|
||||
cmd = call_args[0][0]
|
||||
assert cmd.startswith("timeout 42 sh -c ")
|
||||
# SDK timeout param should NOT be passed
|
||||
assert "timeout" not in call_args[1]
|
||||
assert not cmd.startswith("timeout ")
|
||||
|
||||
def test_timeout_returns_exit_code_124(self, make_env):
|
||||
"""Shell timeout utility returns exit code 124."""
|
||||
"""SDK-level timeout surfaces as exit code 124 via _wait_for_process."""
|
||||
sb = _make_sandbox()
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"),
|
||||
_make_exec_response(result="", exit_code=124),
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="", exit_code=124), # actual cmd
|
||||
]
|
||||
sb.state = "started"
|
||||
env = make_env(sandbox=sb)
|
||||
@@ -267,6 +271,7 @@ class TestExecute:
|
||||
sb = _make_sandbox()
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"),
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="not found", exit_code=127),
|
||||
]
|
||||
sb.state = "started"
|
||||
@@ -279,6 +284,7 @@ class TestExecute:
|
||||
sb = _make_sandbox()
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"),
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="ok", exit_code=0),
|
||||
]
|
||||
sb.state = "started"
|
||||
@@ -286,39 +292,47 @@ class TestExecute:
|
||||
|
||||
env.execute("python3", stdin_data="print('hi')")
|
||||
# Check that the command passed to exec contains heredoc markers
|
||||
# (single quotes get shell-escaped by shlex.quote, so check components)
|
||||
# Base class uses HERMES_STDIN_ prefix for heredoc delimiters
|
||||
call_args = sb.process.exec.call_args_list[-1]
|
||||
cmd = call_args[0][0]
|
||||
assert "HERMES_EOF_" in cmd
|
||||
assert "HERMES_STDIN_" in cmd
|
||||
assert "print" in cmd
|
||||
assert "hi" in cmd
|
||||
|
||||
def test_custom_cwd_passed_through(self, make_env):
|
||||
def test_custom_cwd_in_command_wrapper(self, make_env):
|
||||
"""CWD is handled by _wrap_command() in the command string, not as a kwarg."""
|
||||
sb = _make_sandbox()
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"),
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
_make_exec_response(result="/tmp", exit_code=0),
|
||||
]
|
||||
sb.state = "started"
|
||||
env = make_env(sandbox=sb)
|
||||
|
||||
env.execute("pwd", cwd="/tmp")
|
||||
call_kwargs = sb.process.exec.call_args_list[-1][1]
|
||||
assert call_kwargs["cwd"] == "/tmp"
|
||||
# CWD should be embedded in the command string via _wrap_command
|
||||
call_args = sb.process.exec.call_args_list[-1]
|
||||
cmd = call_args[0][0]
|
||||
assert "cd /tmp" in cmd
|
||||
# CWD should NOT be passed as a kwarg to exec
|
||||
assert "cwd" not in call_args[1]
|
||||
|
||||
def test_daytona_error_triggers_retry(self, make_env, daytona_sdk):
|
||||
sb = _make_sandbox()
|
||||
sb.state = "started"
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"), # $HOME
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
daytona_sdk.DaytonaError("transient"), # first attempt fails
|
||||
_make_exec_response(result="ok", exit_code=0), # retry succeeds
|
||||
]
|
||||
env = make_env(sandbox=sb)
|
||||
|
||||
result = env.execute("echo retry")
|
||||
assert result["output"] == "ok"
|
||||
assert result["returncode"] == 0
|
||||
# DaytonaError now surfaces directly through _ThreadedProcessHandle
|
||||
# (no retry logic) — the error becomes returncode=1
|
||||
assert result["returncode"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -359,14 +373,18 @@ class TestInterrupt:
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
return _make_exec_response(result="/root") # $HOME detection
|
||||
if calls["n"] == 2:
|
||||
return _make_exec_response(result="", exit_code=0) # init_session
|
||||
event.wait(timeout=5) # simulate long-running command
|
||||
return _make_exec_response(result="done", exit_code=0)
|
||||
|
||||
sb.process.exec.side_effect = exec_side_effect
|
||||
env = make_env(sandbox=sb)
|
||||
|
||||
# is_interrupted is checked by base.py's _wait_for_process,
|
||||
# patch where it's actually referenced (base.py's local binding)
|
||||
monkeypatch.setattr(
|
||||
"tools.environments.daytona.is_interrupted", lambda: True
|
||||
"tools.environments.base.is_interrupted", lambda: True
|
||||
)
|
||||
try:
|
||||
result = env.execute("sleep 10")
|
||||
@@ -377,23 +395,24 @@ class TestInterrupt:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Retry exhaustion
|
||||
# DaytonaError surfaces directly (no retry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRetryExhausted:
|
||||
def test_both_attempts_fail(self, make_env, daytona_sdk):
|
||||
"""DaytonaError surfaces directly as rc=1 (retry logic was removed)."""
|
||||
sb = _make_sandbox()
|
||||
sb.state = "started"
|
||||
sb.process.exec.side_effect = [
|
||||
_make_exec_response(result="/root"), # $HOME
|
||||
daytona_sdk.DaytonaError("fail1"), # first attempt
|
||||
daytona_sdk.DaytonaError("fail2"), # retry
|
||||
_make_exec_response(result="", exit_code=0), # init_session
|
||||
daytona_sdk.DaytonaError("fail1"), # actual command fails
|
||||
]
|
||||
env = make_env(sandbox=sb)
|
||||
|
||||
result = env.execute("echo x")
|
||||
# Error surfaces directly through _ThreadedProcessHandle (rc=1)
|
||||
assert result["returncode"] == 1
|
||||
assert "Daytona execution error" in result["output"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -245,43 +245,42 @@ def _make_execute_only_env(forward_env=None):
|
||||
env._timeout_result = lambda timeout: {"output": f"timed out after {timeout}", "returncode": 124}
|
||||
env._container_id = "test-container"
|
||||
env._docker_exe = "/usr/bin/docker"
|
||||
# Base class attributes needed by unified execute()
|
||||
env._session_id = "test123"
|
||||
env._snapshot_path = "/tmp/hermes-snap-test123.sh"
|
||||
env._cwd_file = "/tmp/hermes-cwd-test123.txt"
|
||||
env._cwd_marker = "__HERMES_CWD_test123__"
|
||||
env._snapshot_ready = True
|
||||
env._last_sync_time = None
|
||||
env._init_env_args = []
|
||||
return env
|
||||
|
||||
|
||||
def test_execute_uses_hermes_dotenv_for_allowlisted_env(monkeypatch):
|
||||
def test_init_env_args_uses_hermes_dotenv_for_allowlisted_env(monkeypatch):
|
||||
"""_build_init_env_args picks up forwarded env vars from .env file at init time."""
|
||||
env = _make_execute_only_env(["GITHUB_TOKEN"])
|
||||
popen_calls = []
|
||||
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
|
||||
monkeypatch.delenv("GITHUB_TOKEN", raising=False)
|
||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"})
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
result = env.execute("echo hi")
|
||||
args = env._build_init_env_args()
|
||||
args_str = " ".join(args)
|
||||
|
||||
assert result["returncode"] == 0
|
||||
assert "GITHUB_TOKEN=value_from_dotenv" in popen_calls[0]
|
||||
assert "GITHUB_TOKEN=value_from_dotenv" in args_str
|
||||
|
||||
|
||||
def test_execute_prefers_shell_env_over_hermes_dotenv(monkeypatch):
|
||||
def test_init_env_args_prefers_shell_env_over_hermes_dotenv(monkeypatch):
|
||||
"""Shell env vars take priority over .env file values in init env args."""
|
||||
env = _make_execute_only_env(["GITHUB_TOKEN"])
|
||||
popen_calls = []
|
||||
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
|
||||
monkeypatch.setenv("GITHUB_TOKEN", "value_from_shell")
|
||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"})
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
env.execute("echo hi")
|
||||
args = env._build_init_env_args()
|
||||
args_str = " ".join(args)
|
||||
|
||||
assert "GITHUB_TOKEN=value_from_shell" in popen_calls[0]
|
||||
assert "GITHUB_TOKEN=value_from_dotenv" not in popen_calls[0]
|
||||
assert "GITHUB_TOKEN=value_from_shell" in args_str
|
||||
assert "value_from_dotenv" not in args_str
|
||||
|
||||
|
||||
# ── docker_env tests ──────────────────────────────────────────────
|
||||
@@ -302,64 +301,46 @@ def test_docker_env_appears_in_run_command(monkeypatch):
|
||||
assert "GNUPGHOME=/root/.gnupg" in run_args_str
|
||||
|
||||
|
||||
def test_docker_env_appears_in_exec_command(monkeypatch):
|
||||
"""Explicit docker_env values should also be passed via -e at docker exec time."""
|
||||
def test_docker_env_appears_in_init_env_args(monkeypatch):
|
||||
"""Explicit docker_env values should appear in _build_init_env_args."""
|
||||
env = _make_execute_only_env()
|
||||
env._env = {"MY_VAR": "my_value"}
|
||||
popen_calls = []
|
||||
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
args = env._build_init_env_args()
|
||||
args_str = " ".join(args)
|
||||
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
env.execute("echo hi")
|
||||
|
||||
assert popen_calls, "Popen should have been called"
|
||||
assert "MY_VAR=my_value" in popen_calls[0]
|
||||
assert "MY_VAR=my_value" in args_str
|
||||
|
||||
|
||||
def test_forward_env_overrides_docker_env(monkeypatch):
|
||||
def test_forward_env_overrides_docker_env_in_init_args(monkeypatch):
|
||||
"""docker_forward_env should override docker_env for the same key."""
|
||||
env = _make_execute_only_env(forward_env=["MY_KEY"])
|
||||
env._env = {"MY_KEY": "static_value"}
|
||||
popen_calls = []
|
||||
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
|
||||
monkeypatch.setenv("MY_KEY", "dynamic_value")
|
||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {})
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
env.execute("echo hi")
|
||||
args = env._build_init_env_args()
|
||||
args_str = " ".join(args)
|
||||
|
||||
cmd_str = " ".join(popen_calls[0])
|
||||
assert "MY_KEY=dynamic_value" in cmd_str
|
||||
assert "MY_KEY=static_value" not in cmd_str
|
||||
assert "MY_KEY=dynamic_value" in args_str
|
||||
assert "MY_KEY=static_value" not in args_str
|
||||
|
||||
|
||||
def test_docker_env_and_forward_env_merge(monkeypatch):
|
||||
def test_docker_env_and_forward_env_merge_in_init_args(monkeypatch):
|
||||
"""docker_env and docker_forward_env with different keys should both appear."""
|
||||
env = _make_execute_only_env(forward_env=["TOKEN"])
|
||||
env._env = {"SSH_AUTH_SOCK": "/run/user/1000/agent.sock"}
|
||||
popen_calls = []
|
||||
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
|
||||
monkeypatch.setenv("TOKEN", "secret123")
|
||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {})
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
env.execute("echo hi")
|
||||
args = env._build_init_env_args()
|
||||
args_str = " ".join(args)
|
||||
|
||||
assert "SSH_AUTH_SOCK=/run/user/1000/agent.sock" in args_str
|
||||
assert "TOKEN=secret123" in args_str
|
||||
|
||||
cmd_str = " ".join(popen_calls[0])
|
||||
assert "SSH_AUTH_SOCK=/run/user/1000/agent.sock" in cmd_str
|
||||
assert "TOKEN=secret123" in cmd_str
|
||||
|
||||
|
||||
def test_normalize_env_dict_filters_invalid_keys():
|
||||
|
||||
@@ -22,21 +22,19 @@ import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
|
||||
from tools.environments.local import (
|
||||
LocalEnvironment,
|
||||
_clean_shell_noise,
|
||||
_extract_fenced_output,
|
||||
_OUTPUT_FENCE,
|
||||
_SHELL_NOISE_SUBSTRINGS,
|
||||
)
|
||||
from tools.environments.local import LocalEnvironment
|
||||
from tools.file_operations import ShellFileOperations
|
||||
|
||||
|
||||
# ── Shared noise detection ───────────────────────────────────────────────
|
||||
# Every known shell noise pattern. If ANY of these appear in output that
|
||||
# isn't explicitly expected, the test fails with a clear message.
|
||||
# Known shell noise patterns that should never appear in command output.
|
||||
|
||||
_ALL_NOISE_PATTERNS = list(_SHELL_NOISE_SUBSTRINGS) + [
|
||||
_ALL_NOISE_PATTERNS = [
|
||||
"bash: cannot set terminal process group",
|
||||
"bash: no job control in this shell",
|
||||
"no job control in this shell",
|
||||
"cannot set terminal process group",
|
||||
"tcsetattr: Inappropriate ioctl for device",
|
||||
"bash: ",
|
||||
"Inappropriate ioctl",
|
||||
"Auto-suggestions:",
|
||||
@@ -88,134 +86,6 @@ def populated_dir(tmp_path):
|
||||
return tmp_path
|
||||
|
||||
|
||||
# ── _clean_shell_noise unit tests ────────────────────────────────────────
|
||||
|
||||
class TestCleanShellNoise:
|
||||
def test_single_noise_line(self):
|
||||
output = "bash: no job control in this shell\nhello world\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello world\n"
|
||||
|
||||
def test_double_noise_lines(self):
|
||||
output = (
|
||||
"bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n"
|
||||
"bash: no job control in this shell\n"
|
||||
"actual output here\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "actual output here\n"
|
||||
_assert_clean(result)
|
||||
|
||||
def test_tcsetattr_noise(self):
|
||||
output = (
|
||||
"bash: [12345: 2 (255)] tcsetattr: Inappropriate ioctl for device\n"
|
||||
"real content\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "real content\n"
|
||||
_assert_clean(result)
|
||||
|
||||
def test_triple_noise_lines(self):
|
||||
output = (
|
||||
"bash: cannot set terminal process group (-1): Inappropriate ioctl for device\n"
|
||||
"bash: no job control in this shell\n"
|
||||
"bash: [999: 2 (255)] tcsetattr: Inappropriate ioctl for device\n"
|
||||
"clean\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "clean\n"
|
||||
|
||||
def test_no_noise_untouched(self):
|
||||
assert _clean_shell_noise("hello\nworld\n") == "hello\nworld\n"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _clean_shell_noise("") == ""
|
||||
|
||||
def test_only_noise_produces_empty(self):
|
||||
output = "bash: no job control in this shell\n"
|
||||
result = _clean_shell_noise(output)
|
||||
_assert_clean(result)
|
||||
|
||||
def test_noise_in_middle_not_stripped(self):
|
||||
"""Noise in the middle is real output and should be preserved."""
|
||||
output = "real\nbash: no job control in this shell\nmore real\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == output
|
||||
|
||||
def test_zsh_restored_session(self):
|
||||
output = "Restored session: Mon Mar 2 22:16:54 +03 2026\nhello\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_zsh_saving_session_trailing(self):
|
||||
output = "hello\nSaving session...completed.\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_zsh_oh_my_zsh_banner(self):
|
||||
output = "Oh My Zsh on! | Auto-suggestions: press right\nhello\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_zsh_full_noise_sandwich(self):
|
||||
"""Both leading and trailing zsh noise stripped."""
|
||||
output = (
|
||||
"Restored session: Mon Mar 2\n"
|
||||
"command not found: docker\n"
|
||||
"Oh My Zsh on!\n"
|
||||
"actual output\n"
|
||||
"Saving session...completed.\n"
|
||||
)
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "actual output\n"
|
||||
|
||||
def test_last_login_stripped(self):
|
||||
output = "Last login: Mon Mar 2 22:00:00 on ttys001\nhello\n"
|
||||
result = _clean_shell_noise(output)
|
||||
assert result == "hello\n"
|
||||
|
||||
|
||||
# ── _extract_fenced_output unit tests ────────────────────────────────────
|
||||
|
||||
class TestExtractFencedOutput:
|
||||
def test_normal_fenced_output(self):
|
||||
raw = f"noise\n{_OUTPUT_FENCE}hello world\n{_OUTPUT_FENCE}more noise\n"
|
||||
assert _extract_fenced_output(raw) == "hello world\n"
|
||||
|
||||
def test_no_trailing_newline(self):
|
||||
"""printf output with no trailing newline is preserved."""
|
||||
raw = f"noise{_OUTPUT_FENCE}exact{_OUTPUT_FENCE}noise"
|
||||
assert _extract_fenced_output(raw) == "exact"
|
||||
|
||||
def test_no_fences_falls_back(self):
|
||||
"""Without fences, falls back to pattern-based cleaning."""
|
||||
raw = "bash: no job control in this shell\nhello\n"
|
||||
result = _extract_fenced_output(raw)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_only_start_fence(self):
|
||||
"""Only start fence (e.g. user command called exit)."""
|
||||
raw = f"noise{_OUTPUT_FENCE}hello\nSaving session...\n"
|
||||
result = _extract_fenced_output(raw)
|
||||
assert result == "hello\n"
|
||||
|
||||
def test_user_outputs_fence_string(self):
|
||||
"""If user command outputs the fence marker, it is preserved."""
|
||||
raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}real\n{_OUTPUT_FENCE}noise"
|
||||
result = _extract_fenced_output(raw)
|
||||
# first fence -> last fence captures the middle including user's fence
|
||||
assert _OUTPUT_FENCE in result
|
||||
assert "real\n" in result
|
||||
|
||||
def test_empty_command_output(self):
|
||||
raw = f"noise{_OUTPUT_FENCE}{_OUTPUT_FENCE}noise"
|
||||
assert _extract_fenced_output(raw) == ""
|
||||
|
||||
def test_multiline_output(self):
|
||||
raw = f"noise\n{_OUTPUT_FENCE}line1\nline2\nline3\n{_OUTPUT_FENCE}noise\n"
|
||||
assert _extract_fenced_output(raw) == "line1\nline2\nline3\n"
|
||||
|
||||
|
||||
# ── LocalEnvironment.execute() ───────────────────────────────────────────
|
||||
|
||||
class TestLocalEnvironmentExecute:
|
||||
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Tests for the local persistent shell backend."""
|
||||
|
||||
import glob as glob_mod
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.local import LocalEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
|
||||
|
||||
class TestLocalConfig:
|
||||
def test_local_persistent_default_false(self, monkeypatch):
|
||||
monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False)
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is False
|
||||
|
||||
def test_local_persistent_true(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is True
|
||||
|
||||
def test_local_persistent_yes(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is True
|
||||
|
||||
|
||||
class TestMergeOutput:
|
||||
def test_stdout_only(self):
|
||||
assert PersistentShellMixin._merge_output("out", "") == "out"
|
||||
|
||||
def test_stderr_only(self):
|
||||
assert PersistentShellMixin._merge_output("", "err") == "err"
|
||||
|
||||
def test_both(self):
|
||||
assert PersistentShellMixin._merge_output("out", "err") == "out\nerr"
|
||||
|
||||
def test_empty(self):
|
||||
assert PersistentShellMixin._merge_output("", "") == ""
|
||||
|
||||
def test_strips_trailing_newlines(self):
|
||||
assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr"
|
||||
|
||||
|
||||
class TestLocalOneShotRegression:
|
||||
def test_echo(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
r = env.execute("echo hello")
|
||||
assert r["returncode"] == 0
|
||||
assert "hello" in r["output"]
|
||||
env.cleanup()
|
||||
|
||||
def test_exit_code(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
r = env.execute("exit 42")
|
||||
assert r["returncode"] == 42
|
||||
env.cleanup()
|
||||
|
||||
def test_state_does_not_persist(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
env.execute("export HERMES_ONESHOT_LOCAL=yes")
|
||||
r = env.execute("echo $HERMES_ONESHOT_LOCAL")
|
||||
assert r["output"].strip() == ""
|
||||
env.cleanup()
|
||||
|
||||
def test_oneshot_heredoc_does_not_leak_fence_wrapper(self):
|
||||
"""Heredoc closing line must not be merged with the fence wrapper tail."""
|
||||
env = LocalEnvironment(persistent=False)
|
||||
cmd = "cat <<'H_EOF'\nheredoc body line\nH_EOF"
|
||||
r = env.execute(cmd)
|
||||
env.cleanup()
|
||||
assert r["returncode"] == 0
|
||||
assert "heredoc body line" in r["output"]
|
||||
assert "__hermes_rc" not in r["output"]
|
||||
assert "printf '" not in r["output"]
|
||||
assert "exit $" not in r["output"]
|
||||
|
||||
|
||||
class TestLocalPersistent:
|
||||
@pytest.fixture
|
||||
def env(self):
|
||||
e = LocalEnvironment(persistent=True)
|
||||
yield e
|
||||
e.cleanup()
|
||||
|
||||
def test_echo(self, env):
|
||||
r = env.execute("echo hello-persistent")
|
||||
assert r["returncode"] == 0
|
||||
assert "hello-persistent" in r["output"]
|
||||
|
||||
def test_env_var_persists(self, env):
|
||||
env.execute("export HERMES_LOCAL_PERSIST_TEST=works")
|
||||
r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST")
|
||||
assert r["output"].strip() == "works"
|
||||
|
||||
def test_cwd_persists(self, env):
|
||||
env.execute("cd /tmp")
|
||||
r = env.execute("pwd")
|
||||
assert r["output"].strip() == "/tmp"
|
||||
|
||||
def test_exit_code(self, env):
|
||||
r = env.execute("(exit 42)")
|
||||
assert r["returncode"] == 42
|
||||
|
||||
def test_stderr(self, env):
|
||||
r = env.execute("echo oops >&2")
|
||||
assert r["returncode"] == 0
|
||||
assert "oops" in r["output"]
|
||||
|
||||
def test_multiline_output(self, env):
|
||||
r = env.execute("echo a; echo b; echo c")
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert lines == ["a", "b", "c"]
|
||||
|
||||
def test_timeout_then_recovery(self, env):
|
||||
r = env.execute("sleep 999", timeout=2)
|
||||
assert r["returncode"] in (124, 130)
|
||||
r = env.execute("echo alive")
|
||||
assert r["returncode"] == 0
|
||||
assert "alive" in r["output"]
|
||||
|
||||
def test_large_output(self, env):
|
||||
r = env.execute("seq 1 1000")
|
||||
assert r["returncode"] == 0
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert len(lines) == 1000
|
||||
assert lines[0] == "1"
|
||||
assert lines[-1] == "1000"
|
||||
|
||||
def test_shell_variable_persists(self, env):
|
||||
env.execute("MY_LOCAL_VAR=hello123")
|
||||
r = env.execute("echo $MY_LOCAL_VAR")
|
||||
assert r["output"].strip() == "hello123"
|
||||
|
||||
def test_cleanup_removes_temp_files(self, env):
|
||||
env.execute("echo warmup")
|
||||
prefix = env._temp_prefix
|
||||
assert len(glob_mod.glob(f"{prefix}-*")) > 0
|
||||
env.cleanup()
|
||||
remaining = glob_mod.glob(f"{prefix}-*")
|
||||
assert remaining == []
|
||||
|
||||
def test_state_does_not_leak_between_instances(self):
|
||||
env1 = LocalEnvironment(persistent=True)
|
||||
env2 = LocalEnvironment(persistent=True)
|
||||
try:
|
||||
env1.execute("export LEAK_TEST=from_env1")
|
||||
r = env2.execute("echo $LEAK_TEST")
|
||||
assert r["output"].strip() == ""
|
||||
finally:
|
||||
env1.cleanup()
|
||||
env2.cleanup()
|
||||
|
||||
def test_special_characters_in_command(self, env):
|
||||
r = env.execute("echo 'hello world'")
|
||||
assert r["output"].strip() == "hello world"
|
||||
|
||||
def test_pipe_command(self, env):
|
||||
r = env.execute("echo hello | tr 'h' 'H'")
|
||||
assert r["output"].strip() == "Hello"
|
||||
|
||||
def test_multiple_commands_semicolon(self, env):
|
||||
r = env.execute("X=42; echo $X")
|
||||
assert r["output"].strip() == "42"
|
||||
@@ -110,7 +110,7 @@ class _FakeResponse:
|
||||
def test_managed_modal_execute_polls_until_completed(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
poll_count = {"value": 0}
|
||||
@@ -173,7 +173,7 @@ def test_managed_modal_create_sends_a_stable_idempotency_key(monkeypatch):
|
||||
def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
||||
interrupt_event = _install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
|
||||
@@ -215,7 +215,7 @@ def test_managed_modal_execute_cancels_on_interrupt(monkeypatch):
|
||||
def test_managed_modal_execute_returns_descriptive_error_on_missing_exec(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
@@ -293,7 +293,7 @@ def test_managed_modal_rejects_host_credential_passthrough():
|
||||
def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
|
||||
_install_fake_tools_package()
|
||||
managed_modal = _load_tool_module("tools.environments.managed_modal", "environments/managed_modal.py")
|
||||
modal_common = sys.modules["tools.environments.modal_common"]
|
||||
modal_common = sys.modules["tools.environments.modal_utils"]
|
||||
|
||||
calls = []
|
||||
monotonic_values = iter([0.0, 12.5])
|
||||
|
||||
@@ -231,20 +231,20 @@ class TestEnsurepipFix:
|
||||
"""Verify the pip fix is applied in the ModalEnvironment init."""
|
||||
|
||||
def test_modal_environment_creates_image_with_setup_commands(self):
|
||||
"""ModalEnvironment.__init__ should create a modal.Image with pip fix."""
|
||||
"""_resolve_modal_image should create a modal.Image with pip fix."""
|
||||
try:
|
||||
from tools.environments.modal import ModalEnvironment
|
||||
from tools.environments.modal import _resolve_modal_image
|
||||
except ImportError:
|
||||
pytest.skip("tools.environments.modal not importable")
|
||||
|
||||
import inspect
|
||||
source = inspect.getsource(ModalEnvironment.__init__)
|
||||
source = inspect.getsource(_resolve_modal_image)
|
||||
assert "ensurepip" in source, (
|
||||
"ModalEnvironment should include ensurepip fix "
|
||||
"_resolve_modal_image should include ensurepip fix "
|
||||
"for Modal's legacy image builder"
|
||||
)
|
||||
assert "setup_dockerfile_commands" in source, (
|
||||
"ModalEnvironment should use setup_dockerfile_commands "
|
||||
"_resolve_modal_image should use setup_dockerfile_commands "
|
||||
"to fix pip before Modal's bootstrap"
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,401 @@
|
||||
"""Tests verifying the Modal sandbox timeout bug fix.
|
||||
|
||||
Bug: `lifetime_seconds` from container_config was never passed through to
|
||||
`sandbox_kwargs["timeout"]`, so Modal always used its default of 3600s.
|
||||
|
||||
Fix applied to:
|
||||
- tools/terminal_tool.py: `_create_environment()` now sets
|
||||
`sandbox_kwargs["timeout"]` from `cc.get("lifetime_seconds", 3600)`
|
||||
- tools/terminal_tool.py: `container_config` dict now includes
|
||||
`"lifetime_seconds"` from config
|
||||
- tools/environments/managed_modal.py: `_create_sandbox()` reads timeout
|
||||
from `self._sandbox_kwargs` instead of hardcoding 3_600_000
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Repo root on sys.path
|
||||
# ---------------------------------------------------------------------------
|
||||
_repo_root = Path(__file__).resolve().parents[2]
|
||||
if str(_repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(_repo_root))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Load terminal_tool (may be skipped if deps are missing)
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
import tools.terminal_tool as _tt_mod
|
||||
except ImportError:
|
||||
pytest.skip("tools.terminal_tool not importable (missing deps)", allow_module_level=True)
|
||||
|
||||
TOOLS_DIR = _repo_root / "tools"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers shared with test_managed_modal_environment.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _reset_modules(prefixes: tuple):
|
||||
for name in list(sys.modules):
|
||||
if name.startswith(prefixes):
|
||||
sys.modules.pop(name, None)
|
||||
|
||||
|
||||
def _install_fake_tools_package(*, credential_mounts=None):
|
||||
"""Install a minimal fake tools package so managed_modal.py can be loaded
|
||||
without network access or real Modal credentials."""
|
||||
_reset_modules(("tools", "agent", "hermes_cli"))
|
||||
|
||||
hermes_cli = types.ModuleType("hermes_cli")
|
||||
hermes_cli.__path__ = [] # type: ignore[attr-defined]
|
||||
sys.modules["hermes_cli"] = hermes_cli
|
||||
sys.modules["hermes_cli.config"] = types.SimpleNamespace(
|
||||
get_hermes_home=lambda: Path(tempfile.gettempdir()) / "hermes-home",
|
||||
)
|
||||
|
||||
tools_package = types.ModuleType("tools")
|
||||
tools_package.__path__ = [str(TOOLS_DIR)] # type: ignore[attr-defined]
|
||||
sys.modules["tools"] = tools_package
|
||||
|
||||
env_package = types.ModuleType("tools.environments")
|
||||
env_package.__path__ = [str(TOOLS_DIR / "environments")] # type: ignore[attr-defined]
|
||||
sys.modules["tools.environments"] = env_package
|
||||
|
||||
interrupt_event = threading.Event()
|
||||
sys.modules["tools.interrupt"] = types.SimpleNamespace(
|
||||
set_interrupt=lambda value=True: interrupt_event.set() if value else interrupt_event.clear(),
|
||||
is_interrupted=lambda: interrupt_event.is_set(),
|
||||
_interrupt_event=interrupt_event,
|
||||
)
|
||||
|
||||
class _DummyBaseEnvironment:
|
||||
def __init__(self, cwd: str = "/root", timeout: int = 60, env=None):
|
||||
self.cwd = cwd
|
||||
self.timeout = timeout
|
||||
self.env = env or {}
|
||||
|
||||
sys.modules["tools.environments.base"] = types.SimpleNamespace(
|
||||
BaseEnvironment=_DummyBaseEnvironment,
|
||||
)
|
||||
sys.modules["tools.managed_tool_gateway"] = types.SimpleNamespace(
|
||||
resolve_managed_tool_gateway=lambda vendor: types.SimpleNamespace(
|
||||
vendor=vendor,
|
||||
gateway_origin="https://modal-gateway.example.com",
|
||||
nous_user_token="user-token",
|
||||
managed_mode=True,
|
||||
)
|
||||
)
|
||||
sys.modules["tools.credential_files"] = types.SimpleNamespace(
|
||||
get_credential_file_mounts=lambda: list(credential_mounts or []),
|
||||
)
|
||||
|
||||
return interrupt_event
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
"""Minimal requests.Response substitute."""
|
||||
|
||||
def __init__(self, status_code: int, payload=None):
|
||||
self.status_code = status_code
|
||||
self._payload = payload
|
||||
self.text = ""
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests: _create_environment (direct modal path)
|
||||
# ===========================================================================
|
||||
|
||||
class TestCreateEnvironmentTimeoutPassthrough:
|
||||
"""_create_environment() must set sandbox_kwargs['timeout'] from lifetime_seconds."""
|
||||
|
||||
def test_lifetime_seconds_7200_reaches_modal_environment(self, monkeypatch):
|
||||
"""When container_config has lifetime_seconds=7200, ModalEnvironment gets timeout=7200."""
|
||||
captured_kwargs = {}
|
||||
sentinel = object()
|
||||
|
||||
def _fake_modal_env(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return sentinel
|
||||
|
||||
# Force the direct backend so we hit ModalEnvironment, not ManagedModalEnvironment
|
||||
monkeypatch.setattr(
|
||||
_tt_mod,
|
||||
"_get_modal_backend_state",
|
||||
lambda _: {"selected_backend": "direct"},
|
||||
)
|
||||
monkeypatch.setattr(_tt_mod, "_ModalEnvironment", _fake_modal_env)
|
||||
|
||||
result = _tt_mod._create_environment(
|
||||
env_type="modal",
|
||||
image="python:3.11",
|
||||
cwd="/root",
|
||||
timeout=60,
|
||||
container_config={"lifetime_seconds": 7200},
|
||||
)
|
||||
|
||||
assert result is sentinel, "Should have used our fake ModalEnvironment"
|
||||
modal_sandbox_kwargs = captured_kwargs.get("modal_sandbox_kwargs", {})
|
||||
assert modal_sandbox_kwargs.get("timeout") == 7200, (
|
||||
f"Expected timeout=7200 in modal_sandbox_kwargs, got: {modal_sandbox_kwargs}"
|
||||
)
|
||||
|
||||
def test_lifetime_seconds_defaults_to_3600_when_absent(self, monkeypatch):
|
||||
"""When lifetime_seconds is not in container_config, timeout defaults to 3600."""
|
||||
captured_kwargs = {}
|
||||
sentinel = object()
|
||||
|
||||
def _fake_modal_env(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr(
|
||||
_tt_mod,
|
||||
"_get_modal_backend_state",
|
||||
lambda _: {"selected_backend": "direct"},
|
||||
)
|
||||
monkeypatch.setattr(_tt_mod, "_ModalEnvironment", _fake_modal_env)
|
||||
|
||||
result = _tt_mod._create_environment(
|
||||
env_type="modal",
|
||||
image="python:3.11",
|
||||
cwd="/root",
|
||||
timeout=60,
|
||||
container_config={}, # no lifetime_seconds
|
||||
)
|
||||
|
||||
assert result is sentinel
|
||||
modal_sandbox_kwargs = captured_kwargs.get("modal_sandbox_kwargs", {})
|
||||
assert modal_sandbox_kwargs.get("timeout") == 3600, (
|
||||
f"Expected default timeout=3600, got: {modal_sandbox_kwargs}"
|
||||
)
|
||||
|
||||
def test_lifetime_seconds_none_container_config_defaults_to_3600(self, monkeypatch):
|
||||
"""When container_config is None, timeout defaults to 3600."""
|
||||
captured_kwargs = {}
|
||||
sentinel = object()
|
||||
|
||||
def _fake_modal_env(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr(
|
||||
_tt_mod,
|
||||
"_get_modal_backend_state",
|
||||
lambda _: {"selected_backend": "direct"},
|
||||
)
|
||||
monkeypatch.setattr(_tt_mod, "_ModalEnvironment", _fake_modal_env)
|
||||
|
||||
result = _tt_mod._create_environment(
|
||||
env_type="modal",
|
||||
image="python:3.11",
|
||||
cwd="/root",
|
||||
timeout=60,
|
||||
container_config=None, # None container_config
|
||||
)
|
||||
|
||||
assert result is sentinel
|
||||
modal_sandbox_kwargs = captured_kwargs.get("modal_sandbox_kwargs", {})
|
||||
assert modal_sandbox_kwargs.get("timeout") == 3600, (
|
||||
f"Expected default timeout=3600, got: {modal_sandbox_kwargs}"
|
||||
)
|
||||
|
||||
def test_lifetime_seconds_7200_reaches_managed_modal_environment(self, monkeypatch):
|
||||
"""When managed backend is selected, ManagedModalEnvironment also gets timeout=7200."""
|
||||
captured_kwargs = {}
|
||||
sentinel = object()
|
||||
|
||||
def _fake_managed_env(**kwargs):
|
||||
captured_kwargs.update(kwargs)
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr(
|
||||
_tt_mod,
|
||||
"_get_modal_backend_state",
|
||||
lambda _: {"selected_backend": "managed"},
|
||||
)
|
||||
monkeypatch.setattr(_tt_mod, "_ManagedModalEnvironment", _fake_managed_env)
|
||||
|
||||
result = _tt_mod._create_environment(
|
||||
env_type="modal",
|
||||
image="python:3.11",
|
||||
cwd="/root",
|
||||
timeout=60,
|
||||
container_config={"lifetime_seconds": 7200},
|
||||
)
|
||||
|
||||
assert result is sentinel
|
||||
modal_sandbox_kwargs = captured_kwargs.get("modal_sandbox_kwargs", {})
|
||||
assert modal_sandbox_kwargs.get("timeout") == 7200, (
|
||||
f"Expected timeout=7200 in modal_sandbox_kwargs for managed env, got: {modal_sandbox_kwargs}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests: container_config includes lifetime_seconds from _get_env_config
|
||||
# ===========================================================================
|
||||
|
||||
class TestContainerConfigLifetimeSeconds:
|
||||
"""container_config dict built in terminal_tool must include lifetime_seconds."""
|
||||
|
||||
def test_container_config_includes_lifetime_seconds_from_env(self, monkeypatch):
|
||||
"""TERMINAL_LIFETIME_SECONDS env var flows into container_config."""
|
||||
monkeypatch.setenv("TERMINAL_ENV", "modal")
|
||||
monkeypatch.setenv("TERMINAL_LIFETIME_SECONDS", "7200")
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config.get("lifetime_seconds") == 7200, (
|
||||
f"Expected lifetime_seconds=7200 in config, got: {config.get('lifetime_seconds')}"
|
||||
)
|
||||
|
||||
def test_container_config_lifetime_seconds_default_is_300(self, monkeypatch):
|
||||
"""Without TERMINAL_LIFETIME_SECONDS, the default should be 300 (cleanup thread default)."""
|
||||
monkeypatch.setenv("TERMINAL_ENV", "modal")
|
||||
monkeypatch.delenv("TERMINAL_LIFETIME_SECONDS", raising=False)
|
||||
config = _tt_mod._get_env_config()
|
||||
assert "lifetime_seconds" in config, "lifetime_seconds must be present in config"
|
||||
# Default from code is 300
|
||||
assert config["lifetime_seconds"] == 300, (
|
||||
f"Expected default lifetime_seconds=300, got: {config['lifetime_seconds']}"
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Tests: ManagedModalEnvironment._create_sandbox uses sandbox_kwargs timeout
|
||||
# ===========================================================================
|
||||
|
||||
class TestManagedModalTimeoutPassthrough:
|
||||
"""ManagedModalEnvironment must read timeout from sandbox_kwargs, not hardcode 3_600_000."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_modules(self):
|
||||
"""Save and restore sys.modules so fake package doesn't leak."""
|
||||
saved = {
|
||||
name: mod for name, mod in sys.modules.items()
|
||||
if name.startswith(("tools", "hermes_cli"))
|
||||
}
|
||||
yield
|
||||
_reset_modules(("tools", "hermes_cli"))
|
||||
sys.modules.update(saved)
|
||||
|
||||
def test_sandbox_created_with_7200_timeout(self, monkeypatch):
|
||||
"""ManagedModalEnvironment with lifetime_seconds=7200 sends timeoutMs=7_200_000."""
|
||||
_install_fake_tools_package()
|
||||
|
||||
# Load managed_modal fresh after installing fake package
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
spec = spec_from_file_location(
|
||||
"tools.environments.managed_modal",
|
||||
TOOLS_DIR / "environments" / "managed_modal.py",
|
||||
)
|
||||
managed_modal = module_from_spec(spec)
|
||||
sys.modules["tools.environments.managed_modal"] = managed_modal
|
||||
spec.loader.exec_module(managed_modal)
|
||||
|
||||
create_payloads = []
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
create_payloads.append(json)
|
||||
return _FakeResponse(200, {"id": "sandbox-1"})
|
||||
if method == "POST" and url.endswith("/terminate"):
|
||||
return _FakeResponse(200, {"status": "terminated"})
|
||||
raise AssertionError(f"Unexpected request: {method} {url}")
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(
|
||||
image="python:3.11",
|
||||
modal_sandbox_kwargs={"timeout": 7200},
|
||||
)
|
||||
env.cleanup()
|
||||
|
||||
assert len(create_payloads) == 1
|
||||
payload = create_payloads[0]
|
||||
assert payload["timeoutMs"] == 7_200_000, (
|
||||
f"Expected timeoutMs=7_200_000 (7200s * 1000), got: {payload['timeoutMs']}. "
|
||||
"ManagedModalEnvironment must read timeout from sandbox_kwargs, not hardcode 3600."
|
||||
)
|
||||
|
||||
def test_sandbox_created_with_default_3600_timeout(self, monkeypatch):
|
||||
"""ManagedModalEnvironment with no explicit timeout sends timeoutMs=3_600_000."""
|
||||
_install_fake_tools_package()
|
||||
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
spec = spec_from_file_location(
|
||||
"tools.environments.managed_modal",
|
||||
TOOLS_DIR / "environments" / "managed_modal.py",
|
||||
)
|
||||
managed_modal = module_from_spec(spec)
|
||||
sys.modules["tools.environments.managed_modal"] = managed_modal
|
||||
spec.loader.exec_module(managed_modal)
|
||||
|
||||
create_payloads = []
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
create_payloads.append(json)
|
||||
return _FakeResponse(200, {"id": "sandbox-1"})
|
||||
if method == "POST" and url.endswith("/terminate"):
|
||||
return _FakeResponse(200, {"status": "terminated"})
|
||||
raise AssertionError(f"Unexpected request: {method} {url}")
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(
|
||||
image="python:3.11",
|
||||
modal_sandbox_kwargs={}, # no timeout key — should default to 3600
|
||||
)
|
||||
env.cleanup()
|
||||
|
||||
assert len(create_payloads) == 1
|
||||
payload = create_payloads[0]
|
||||
assert payload["timeoutMs"] == 3_600_000, (
|
||||
f"Expected default timeoutMs=3_600_000, got: {payload['timeoutMs']}"
|
||||
)
|
||||
|
||||
def test_sandbox_created_with_none_kwargs_defaults_to_3600(self, monkeypatch):
|
||||
"""ManagedModalEnvironment with modal_sandbox_kwargs=None defaults to 3600."""
|
||||
_install_fake_tools_package()
|
||||
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
spec = spec_from_file_location(
|
||||
"tools.environments.managed_modal",
|
||||
TOOLS_DIR / "environments" / "managed_modal.py",
|
||||
)
|
||||
managed_modal = module_from_spec(spec)
|
||||
sys.modules["tools.environments.managed_modal"] = managed_modal
|
||||
spec.loader.exec_module(managed_modal)
|
||||
|
||||
create_payloads = []
|
||||
|
||||
def fake_request(method, url, headers=None, json=None, timeout=None):
|
||||
if method == "POST" and url.endswith("/v1/sandboxes"):
|
||||
create_payloads.append(json)
|
||||
return _FakeResponse(200, {"id": "sandbox-1"})
|
||||
if method == "POST" and url.endswith("/terminate"):
|
||||
return _FakeResponse(200, {"status": "terminated"})
|
||||
raise AssertionError(f"Unexpected request: {method} {url}")
|
||||
|
||||
monkeypatch.setattr(managed_modal.requests, "request", fake_request)
|
||||
|
||||
env = managed_modal.ManagedModalEnvironment(
|
||||
image="python:3.11",
|
||||
modal_sandbox_kwargs=None,
|
||||
)
|
||||
env.cleanup()
|
||||
|
||||
assert len(create_payloads) == 1
|
||||
payload = create_payloads[0]
|
||||
assert payload["timeoutMs"] == 3_600_000, (
|
||||
f"Expected default timeoutMs=3_600_000, got: {payload['timeoutMs']}"
|
||||
)
|
||||
@@ -85,11 +85,47 @@ def _install_modal_test_modules(
|
||||
def _prepare_command(self, command: str):
|
||||
return command, None
|
||||
|
||||
sys.modules["tools.environments.base"] = types.SimpleNamespace(BaseEnvironment=_DummyBaseEnvironment)
|
||||
def init_session(self):
|
||||
pass
|
||||
|
||||
# Stub _ThreadedProcessHandle: modal.py imports it but only uses it at
|
||||
# runtime inside _run_bash; the snapshot-isolation tests never call _run_bash,
|
||||
# so a class placeholder is sufficient.
|
||||
class _DummyThreadedProcessHandle:
|
||||
def __init__(self, exec_fn, cancel_fn=None):
|
||||
pass
|
||||
|
||||
def _load_json_store(path):
|
||||
if path.exists():
|
||||
try:
|
||||
return json.loads(path.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
def _save_json_store(path, data):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(data, indent=2))
|
||||
|
||||
def _file_mtime_key(host_path):
|
||||
try:
|
||||
st = Path(host_path).stat()
|
||||
return (st.st_mtime, st.st_size)
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
sys.modules["tools.environments.base"] = types.SimpleNamespace(
|
||||
BaseEnvironment=_DummyBaseEnvironment,
|
||||
_ThreadedProcessHandle=_DummyThreadedProcessHandle,
|
||||
_load_json_store=_load_json_store,
|
||||
_save_json_store=_save_json_store,
|
||||
_file_mtime_key=_file_mtime_key,
|
||||
)
|
||||
sys.modules["tools.interrupt"] = types.SimpleNamespace(is_interrupted=lambda: False)
|
||||
sys.modules["tools.credential_files"] = types.SimpleNamespace(
|
||||
get_credential_file_mounts=lambda: [],
|
||||
iter_skills_files=lambda: [],
|
||||
iter_cache_files=lambda: [],
|
||||
)
|
||||
|
||||
from_id_calls: list[str] = []
|
||||
|
||||
@@ -197,6 +197,26 @@ class TestCheckpointNotify:
|
||||
s = registry.get("proc_live")
|
||||
assert s.notify_on_complete is True
|
||||
|
||||
def test_recover_requeues_notify_watchers(self, registry, tmp_path):
|
||||
checkpoint = tmp_path / "procs.json"
|
||||
checkpoint.write_text(json.dumps([{
|
||||
"session_id": "proc_live",
|
||||
"command": "sleep 999",
|
||||
"pid": os.getpid(),
|
||||
"task_id": "t1",
|
||||
"session_key": "sk1",
|
||||
"watcher_platform": "telegram",
|
||||
"watcher_chat_id": "123",
|
||||
"watcher_thread_id": "42",
|
||||
"watcher_interval": 5,
|
||||
"notify_on_complete": True,
|
||||
}]))
|
||||
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
|
||||
recovered = registry.recover_from_checkpoint()
|
||||
assert recovered == 1
|
||||
assert len(registry.pending_watchers) == 1
|
||||
assert registry.pending_watchers[0]["notify_on_complete"] is True
|
||||
|
||||
def test_recover_defaults_false(self, registry, tmp_path):
|
||||
"""Old checkpoint entries without the field default to False."""
|
||||
checkpoint = tmp_path / "procs.json"
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
@@ -45,6 +48,23 @@ def _make_session(
|
||||
return s
|
||||
|
||||
|
||||
def _spawn_python_sleep(seconds: float) -> subprocess.Popen:
|
||||
"""Spawn a portable short-lived Python sleep process."""
|
||||
return subprocess.Popen(
|
||||
[sys.executable, "-c", f"import time; time.sleep({seconds})"],
|
||||
)
|
||||
|
||||
|
||||
def _wait_until(predicate, timeout: float = 5.0, interval: float = 0.05) -> bool:
|
||||
"""Poll a predicate until it returns truthy or the timeout elapses."""
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
if predicate():
|
||||
return True
|
||||
time.sleep(interval)
|
||||
return False
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Get / Poll
|
||||
# =========================================================================
|
||||
@@ -349,6 +369,88 @@ class TestCheckpoint:
|
||||
assert recovered == 1
|
||||
assert len(registry.pending_watchers) == 0
|
||||
|
||||
def test_recovery_keeps_live_checkpoint_entries(self, registry, tmp_path):
|
||||
checkpoint = tmp_path / "procs.json"
|
||||
checkpoint.write_text(json.dumps([{
|
||||
"session_id": "proc_live",
|
||||
"command": "sleep 999",
|
||||
"pid": os.getpid(),
|
||||
"task_id": "t1",
|
||||
"session_key": "sk1",
|
||||
}]))
|
||||
|
||||
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
|
||||
recovered = registry.recover_from_checkpoint()
|
||||
assert recovered == 1
|
||||
assert registry.get("proc_live") is not None
|
||||
|
||||
data = json.loads(checkpoint.read_text())
|
||||
assert len(data) == 1
|
||||
assert data[0]["session_id"] == "proc_live"
|
||||
assert data[0]["pid"] == os.getpid()
|
||||
assert data != []
|
||||
|
||||
def test_recovery_skips_explicit_sandbox_backed_entries(self, registry, tmp_path):
|
||||
checkpoint = tmp_path / "procs.json"
|
||||
original = [{
|
||||
"session_id": "proc_remote",
|
||||
"command": "sleep 999",
|
||||
"pid": os.getpid(),
|
||||
"task_id": "t1",
|
||||
"pid_scope": "sandbox",
|
||||
}]
|
||||
checkpoint.write_text(json.dumps(original))
|
||||
|
||||
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
|
||||
recovered = registry.recover_from_checkpoint()
|
||||
assert recovered == 0
|
||||
assert registry.get("proc_remote") is None
|
||||
|
||||
data = json.loads(checkpoint.read_text())
|
||||
assert data == []
|
||||
|
||||
def test_detached_recovered_process_eventually_exits(self, registry, tmp_path):
|
||||
proc = _spawn_python_sleep(0.4)
|
||||
checkpoint = tmp_path / "procs.json"
|
||||
checkpoint.write_text(json.dumps([{
|
||||
"session_id": "proc_live",
|
||||
"command": "python -c 'import time; time.sleep(0.4)'",
|
||||
"pid": proc.pid,
|
||||
"task_id": "t1",
|
||||
"session_key": "sk1",
|
||||
}]))
|
||||
|
||||
try:
|
||||
with patch("tools.process_registry.CHECKPOINT_PATH", checkpoint):
|
||||
recovered = registry.recover_from_checkpoint()
|
||||
assert recovered == 1
|
||||
|
||||
session = registry.get("proc_live")
|
||||
assert session is not None
|
||||
assert session.detached is True
|
||||
|
||||
proc.wait(timeout=5)
|
||||
|
||||
assert _wait_until(
|
||||
lambda: registry.get("proc_live") is not None
|
||||
and registry.get("proc_live").exited,
|
||||
timeout=5,
|
||||
)
|
||||
|
||||
poll_result = registry.poll("proc_live")
|
||||
assert poll_result["status"] == "exited"
|
||||
|
||||
wait_result = registry.wait("proc_live", timeout=1)
|
||||
assert wait_result["status"] == "exited"
|
||||
finally:
|
||||
if proc.poll() is None:
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
except Exception:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Kill process
|
||||
@@ -365,6 +467,27 @@ class TestKillProcess:
|
||||
result = registry.kill_process(s.id)
|
||||
assert result["status"] == "already_exited"
|
||||
|
||||
def test_kill_detached_session_uses_host_pid(self, registry):
|
||||
s = _make_session(sid="proc_detached", command="sleep 999")
|
||||
s.pid = 424242
|
||||
s.detached = True
|
||||
registry._running[s.id] = s
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_kill(pid, sig):
|
||||
calls.append((pid, sig))
|
||||
|
||||
try:
|
||||
with patch("tools.process_registry.os.kill", side_effect=fake_kill):
|
||||
result = registry.kill_process(s.id)
|
||||
|
||||
assert result["status"] == "killed"
|
||||
assert (424242, 0) in calls
|
||||
assert (424242, signal.SIGTERM) in calls
|
||||
finally:
|
||||
registry._running.pop(s.id, None)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tool handler
|
||||
|
||||
@@ -43,7 +43,7 @@ class TestBuildSSHCommand:
|
||||
lambda *a, **k: MagicMock(stdout=iter([]),
|
||||
stderr=iter([]),
|
||||
stdin=MagicMock()))
|
||||
monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None)
|
||||
monkeypatch.setattr("tools.environments.base.time.sleep", lambda _: None)
|
||||
|
||||
def test_base_flags(self):
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
"""Tests for _ThreadedProcessHandle — the adapter for SDK backends."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tools.environments.base import _ThreadedProcessHandle
|
||||
|
||||
|
||||
class TestBasicExecution:
|
||||
def test_successful_execution(self):
|
||||
def exec_fn():
|
||||
return ("hello world", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 0
|
||||
output = handle.stdout.read()
|
||||
assert "hello world" in output
|
||||
|
||||
def test_nonzero_exit_code(self):
|
||||
def exec_fn():
|
||||
return ("error occurred", 42)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 42
|
||||
output = handle.stdout.read()
|
||||
assert "error occurred" in output
|
||||
|
||||
def test_exception_in_exec_fn(self):
|
||||
def exec_fn():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 1
|
||||
|
||||
def test_empty_output(self):
|
||||
def exec_fn():
|
||||
return ("", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
assert handle.returncode == 0
|
||||
output = handle.stdout.read()
|
||||
assert output == ""
|
||||
|
||||
|
||||
class TestPolling:
|
||||
def test_poll_returns_none_while_running(self):
|
||||
event = threading.Event()
|
||||
|
||||
def exec_fn():
|
||||
event.wait(timeout=5)
|
||||
return ("done", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
assert handle.poll() is None
|
||||
|
||||
event.set()
|
||||
handle.wait(timeout=5)
|
||||
assert handle.poll() == 0
|
||||
|
||||
def test_poll_returns_returncode_when_done(self):
|
||||
def exec_fn():
|
||||
return ("ok", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
assert handle.poll() == 0
|
||||
|
||||
|
||||
class TestCancelFn:
|
||||
def test_cancel_fn_called_on_kill(self):
|
||||
called = threading.Event()
|
||||
|
||||
def cancel():
|
||||
called.set()
|
||||
|
||||
def exec_fn():
|
||||
time.sleep(10)
|
||||
return ("", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
handle.kill()
|
||||
assert called.is_set()
|
||||
|
||||
def test_cancel_fn_none_is_safe(self):
|
||||
def exec_fn():
|
||||
return ("ok", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=None)
|
||||
handle.kill() # should not raise
|
||||
handle.wait(timeout=5)
|
||||
assert handle.returncode == 0
|
||||
|
||||
def test_cancel_fn_exception_swallowed(self):
|
||||
def cancel():
|
||||
raise RuntimeError("cancel failed")
|
||||
|
||||
def exec_fn():
|
||||
return ("ok", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
handle.kill() # should not raise despite cancel raising
|
||||
handle.wait(timeout=5)
|
||||
|
||||
|
||||
class TestStdoutPipe:
|
||||
def test_stdout_is_readable(self):
|
||||
def exec_fn():
|
||||
return ("line1\nline2\nline3\n", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
lines = handle.stdout.readlines()
|
||||
assert len(lines) == 3
|
||||
assert lines[0] == "line1\n"
|
||||
|
||||
def test_stdout_iterable(self):
|
||||
def exec_fn():
|
||||
return ("a\nb\nc\n", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
collected = list(handle.stdout)
|
||||
assert len(collected) == 3
|
||||
|
||||
def test_unicode_output(self):
|
||||
def exec_fn():
|
||||
return ("hello 世界 🌍\n", 0)
|
||||
|
||||
handle = _ThreadedProcessHandle(exec_fn)
|
||||
handle.wait(timeout=5)
|
||||
|
||||
output = handle.stdout.read()
|
||||
assert "世界" in output
|
||||
assert "🌍" in output
|
||||
@@ -0,0 +1,472 @@
|
||||
"""Tests for tools/tool_result_storage.py -- 3-layer tool result persistence."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tools.budget_config import (
|
||||
DEFAULT_RESULT_SIZE_CHARS,
|
||||
DEFAULT_TURN_BUDGET_CHARS,
|
||||
DEFAULT_PREVIEW_SIZE_CHARS,
|
||||
BudgetConfig,
|
||||
)
|
||||
from tools.tool_result_storage import (
|
||||
HEREDOC_MARKER,
|
||||
PERSISTED_OUTPUT_TAG,
|
||||
PERSISTED_OUTPUT_CLOSING_TAG,
|
||||
STORAGE_DIR,
|
||||
_build_persisted_message,
|
||||
_heredoc_marker,
|
||||
_write_to_sandbox,
|
||||
enforce_turn_budget,
|
||||
generate_preview,
|
||||
maybe_persist_tool_result,
|
||||
)
|
||||
|
||||
|
||||
# ── generate_preview ──────────────────────────────────────────────────
|
||||
|
||||
class TestGeneratePreview:
|
||||
def test_short_content_unchanged(self):
|
||||
text = "short result"
|
||||
preview, has_more = generate_preview(text)
|
||||
assert preview == text
|
||||
assert has_more is False
|
||||
|
||||
def test_long_content_truncated(self):
|
||||
text = "x" * 5000
|
||||
preview, has_more = generate_preview(text, max_chars=2000)
|
||||
assert len(preview) <= 2000
|
||||
assert has_more is True
|
||||
|
||||
def test_truncates_at_newline_boundary(self):
|
||||
# 1500 chars + newline + 600 chars (past halfway)
|
||||
text = "a" * 1500 + "\n" + "b" * 600
|
||||
preview, has_more = generate_preview(text, max_chars=2000)
|
||||
assert preview == "a" * 1500 + "\n"
|
||||
assert has_more is True
|
||||
|
||||
def test_ignores_early_newline(self):
|
||||
# Newline at position 100, well before halfway of 2000
|
||||
text = "a" * 100 + "\n" + "b" * 3000
|
||||
preview, has_more = generate_preview(text, max_chars=2000)
|
||||
assert len(preview) == 2000
|
||||
assert has_more is True
|
||||
|
||||
def test_empty_content(self):
|
||||
preview, has_more = generate_preview("")
|
||||
assert preview == ""
|
||||
assert has_more is False
|
||||
|
||||
def test_exact_boundary(self):
|
||||
text = "x" * DEFAULT_PREVIEW_SIZE_CHARS
|
||||
preview, has_more = generate_preview(text)
|
||||
assert preview == text
|
||||
assert has_more is False
|
||||
|
||||
|
||||
# ── _heredoc_marker ───────────────────────────────────────────────────
|
||||
|
||||
class TestHeredocMarker:
|
||||
def test_default_marker_when_no_collision(self):
|
||||
assert _heredoc_marker("normal content") == HEREDOC_MARKER
|
||||
|
||||
def test_uuid_marker_on_collision(self):
|
||||
content = f"some text with {HEREDOC_MARKER} embedded"
|
||||
marker = _heredoc_marker(content)
|
||||
assert marker != HEREDOC_MARKER
|
||||
assert marker.startswith("HERMES_PERSIST_")
|
||||
assert marker not in content
|
||||
|
||||
|
||||
# ── _write_to_sandbox ─────────────────────────────────────────────────
|
||||
|
||||
class TestWriteToSandbox:
|
||||
def test_success(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
result = _write_to_sandbox("hello world", "/tmp/hermes-results/abc.txt", env)
|
||||
assert result is True
|
||||
env.execute.assert_called_once()
|
||||
cmd = env.execute.call_args[0][0]
|
||||
assert "mkdir -p" in cmd
|
||||
assert "hello world" in cmd
|
||||
assert HEREDOC_MARKER in cmd
|
||||
|
||||
def test_failure_returns_false(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "error", "returncode": 1}
|
||||
result = _write_to_sandbox("content", "/tmp/hermes-results/abc.txt", env)
|
||||
assert result is False
|
||||
|
||||
def test_heredoc_collision_uses_uuid_marker(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = f"text with {HEREDOC_MARKER} inside"
|
||||
_write_to_sandbox(content, "/tmp/hermes-results/abc.txt", env)
|
||||
cmd = env.execute.call_args[0][0]
|
||||
# The default marker should NOT be used as the delimiter
|
||||
lines = cmd.split("\n")
|
||||
# The first and last lines contain the actual delimiter
|
||||
assert HEREDOC_MARKER not in lines[0].split("<<")[1]
|
||||
|
||||
def test_timeout_passed(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
_write_to_sandbox("content", "/tmp/hermes-results/abc.txt", env)
|
||||
assert env.execute.call_args[1]["timeout"] == 30
|
||||
|
||||
|
||||
# ── _build_persisted_message ──────────────────────────────────────────
|
||||
|
||||
class TestBuildPersistedMessage:
|
||||
def test_structure(self):
|
||||
msg = _build_persisted_message(
|
||||
preview="first 100 chars...",
|
||||
has_more=True,
|
||||
original_size=50_000,
|
||||
file_path="/tmp/hermes-results/test123.txt",
|
||||
)
|
||||
assert msg.startswith(PERSISTED_OUTPUT_TAG)
|
||||
assert msg.endswith(PERSISTED_OUTPUT_CLOSING_TAG)
|
||||
assert "50,000 characters" in msg
|
||||
assert "/tmp/hermes-results/test123.txt" in msg
|
||||
assert "read_file" in msg
|
||||
assert "first 100 chars..." in msg
|
||||
assert "..." in msg # has_more indicator
|
||||
|
||||
def test_no_ellipsis_when_complete(self):
|
||||
msg = _build_persisted_message(
|
||||
preview="complete content",
|
||||
has_more=False,
|
||||
original_size=16,
|
||||
file_path="/tmp/hermes-results/x.txt",
|
||||
)
|
||||
# Should not have the trailing "..." indicator before closing tag
|
||||
lines = msg.strip().split("\n")
|
||||
assert lines[-2] != "..."
|
||||
|
||||
def test_large_size_shows_mb(self):
|
||||
msg = _build_persisted_message(
|
||||
preview="x",
|
||||
has_more=True,
|
||||
original_size=2_000_000,
|
||||
file_path="/tmp/hermes-results/big.txt",
|
||||
)
|
||||
assert "MB" in msg
|
||||
|
||||
|
||||
# ── maybe_persist_tool_result ─────────────────────────────────────────
|
||||
|
||||
class TestMaybePersistToolResult:
|
||||
def test_below_threshold_returns_unchanged(self):
|
||||
content = "small result"
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_123",
|
||||
env=None,
|
||||
threshold=50_000,
|
||||
)
|
||||
assert result == content
|
||||
|
||||
def test_above_threshold_with_env_persists(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_456",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert PERSISTED_OUTPUT_TAG in result
|
||||
assert "tc_456.txt" in result
|
||||
assert len(result) < len(content)
|
||||
env.execute.assert_called_once()
|
||||
|
||||
def test_persists_full_content_as_is(self):
|
||||
"""Content is persisted verbatim — no JSON extraction."""
|
||||
import json
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
raw = "line1\nline2\n" * 5_000
|
||||
content = json.dumps({"output": raw, "exit_code": 0, "error": None})
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_json",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert PERSISTED_OUTPUT_TAG in result
|
||||
# The heredoc written to sandbox should contain the full JSON blob
|
||||
cmd = env.execute.call_args[0][0]
|
||||
assert '"exit_code"' in cmd
|
||||
|
||||
def test_above_threshold_no_env_truncates_inline(self):
|
||||
content = "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_789",
|
||||
env=None,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert PERSISTED_OUTPUT_TAG not in result
|
||||
assert "Truncated" in result
|
||||
assert len(result) < len(content)
|
||||
|
||||
def test_env_write_failure_falls_back_to_truncation(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "disk full", "returncode": 1}
|
||||
content = "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_fail",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert PERSISTED_OUTPUT_TAG not in result
|
||||
assert "Truncated" in result
|
||||
|
||||
def test_env_execute_exception_falls_back(self):
|
||||
env = MagicMock()
|
||||
env.execute.side_effect = RuntimeError("connection lost")
|
||||
content = "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_exc",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert "Truncated" in result
|
||||
|
||||
def test_read_file_never_persisted(self):
|
||||
"""read_file has threshold=inf, should never be persisted."""
|
||||
env = MagicMock()
|
||||
content = "x" * 200_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="read_file",
|
||||
tool_use_id="tc_rf",
|
||||
env=env,
|
||||
threshold=float("inf"),
|
||||
)
|
||||
assert result == content
|
||||
env.execute.assert_not_called()
|
||||
|
||||
def test_uses_registry_threshold_when_not_provided(self):
|
||||
"""When threshold=None, looks up from registry."""
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = "x" * 60_000
|
||||
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_max_result_size.return_value = 30_000
|
||||
|
||||
with patch("tools.registry.registry", mock_registry):
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_reg",
|
||||
env=env,
|
||||
threshold=None,
|
||||
)
|
||||
# Should have persisted since 60K > 30K
|
||||
assert PERSISTED_OUTPUT_TAG in result or "Truncated" in result
|
||||
|
||||
def test_unicode_content_survives(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = "日本語テスト " * 10_000 # ~60K chars of unicode
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_uni",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert PERSISTED_OUTPUT_TAG in result
|
||||
# Preview should contain unicode
|
||||
assert "日本語テスト" in result
|
||||
|
||||
def test_empty_content_returns_unchanged(self):
|
||||
result = maybe_persist_tool_result(
|
||||
content="",
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_empty",
|
||||
env=None,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert result == ""
|
||||
|
||||
def test_whitespace_only_below_threshold(self):
|
||||
content = " " * 100
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_ws",
|
||||
env=None,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert result == content
|
||||
|
||||
def test_file_path_uses_tool_use_id(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="unique_id_abc",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert "unique_id_abc.txt" in result
|
||||
|
||||
def test_preview_included_in_persisted_output(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
# Create content with a distinctive start
|
||||
content = "DISTINCTIVE_START_MARKER" + "x" * 60_000
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_prev",
|
||||
env=env,
|
||||
threshold=30_000,
|
||||
)
|
||||
assert "DISTINCTIVE_START_MARKER" in result
|
||||
|
||||
def test_threshold_zero_forces_persist(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
content = "even short content"
|
||||
result = maybe_persist_tool_result(
|
||||
content=content,
|
||||
tool_name="terminal",
|
||||
tool_use_id="tc_zero",
|
||||
env=env,
|
||||
threshold=0,
|
||||
)
|
||||
# Any non-empty content with threshold=0 should be persisted
|
||||
assert PERSISTED_OUTPUT_TAG in result
|
||||
|
||||
|
||||
# ── enforce_turn_budget ───────────────────────────────────────────────
|
||||
|
||||
class TestEnforceTurnBudget:
|
||||
def test_under_budget_no_changes(self):
|
||||
msgs = [
|
||||
{"role": "tool", "tool_call_id": "t1", "content": "small"},
|
||||
{"role": "tool", "tool_call_id": "t2", "content": "also small"},
|
||||
]
|
||||
result = enforce_turn_budget(msgs, env=None, config=BudgetConfig(turn_budget=200_000))
|
||||
assert result[0]["content"] == "small"
|
||||
assert result[1]["content"] == "also small"
|
||||
|
||||
def test_over_budget_largest_persisted_first(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
msgs = [
|
||||
{"role": "tool", "tool_call_id": "t1", "content": "a" * 80_000},
|
||||
{"role": "tool", "tool_call_id": "t2", "content": "b" * 130_000},
|
||||
]
|
||||
# Total 210K > 200K budget
|
||||
enforce_turn_budget(msgs, env=env, config=BudgetConfig(turn_budget=200_000))
|
||||
# The larger one (130K) should be persisted first
|
||||
assert PERSISTED_OUTPUT_TAG in msgs[1]["content"]
|
||||
|
||||
def test_already_persisted_results_skipped(self):
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
msgs = [
|
||||
{"role": "tool", "tool_call_id": "t1",
|
||||
"content": f"{PERSISTED_OUTPUT_TAG}\nalready persisted\n{PERSISTED_OUTPUT_CLOSING_TAG}"},
|
||||
{"role": "tool", "tool_call_id": "t2", "content": "x" * 250_000},
|
||||
]
|
||||
enforce_turn_budget(msgs, env=env, config=BudgetConfig(turn_budget=200_000))
|
||||
# t1 should be untouched (already persisted)
|
||||
assert msgs[0]["content"].startswith(PERSISTED_OUTPUT_TAG)
|
||||
# t2 should be persisted
|
||||
assert PERSISTED_OUTPUT_TAG in msgs[1]["content"]
|
||||
|
||||
def test_medium_result_regression(self):
|
||||
"""6 results of 42K chars each (252K total) — each under 100K default
|
||||
threshold but aggregate exceeds 200K budget. L3 should persist."""
|
||||
env = MagicMock()
|
||||
env.execute.return_value = {"output": "", "returncode": 0}
|
||||
msgs = [
|
||||
{"role": "tool", "tool_call_id": f"t{i}", "content": "x" * 42_000}
|
||||
for i in range(6)
|
||||
]
|
||||
enforce_turn_budget(msgs, env=env, config=BudgetConfig(turn_budget=200_000))
|
||||
# At least some results should be persisted to get under 200K
|
||||
persisted_count = sum(
|
||||
1 for m in msgs if PERSISTED_OUTPUT_TAG in m["content"]
|
||||
)
|
||||
assert persisted_count >= 2 # Need to shed at least ~52K
|
||||
|
||||
def test_no_env_falls_back_to_truncation(self):
|
||||
msgs = [
|
||||
{"role": "tool", "tool_call_id": "t1", "content": "x" * 250_000},
|
||||
]
|
||||
enforce_turn_budget(msgs, env=None, config=BudgetConfig(turn_budget=200_000))
|
||||
# Should be truncated (no sandbox available)
|
||||
assert "Truncated" in msgs[0]["content"] or PERSISTED_OUTPUT_TAG in msgs[0]["content"]
|
||||
|
||||
def test_returns_same_list(self):
|
||||
msgs = [{"role": "tool", "tool_call_id": "t1", "content": "ok"}]
|
||||
result = enforce_turn_budget(msgs, env=None, config=BudgetConfig(turn_budget=200_000))
|
||||
assert result is msgs
|
||||
|
||||
def test_empty_messages(self):
|
||||
result = enforce_turn_budget([], env=None, config=BudgetConfig(turn_budget=200_000))
|
||||
assert result == []
|
||||
|
||||
|
||||
# ── Per-tool threshold integration ────────────────────────────────────
|
||||
|
||||
class TestPerToolThresholds:
|
||||
"""Verify registry wiring for per-tool thresholds."""
|
||||
|
||||
def test_registry_has_get_max_result_size(self):
|
||||
from tools.registry import registry
|
||||
assert hasattr(registry, "get_max_result_size")
|
||||
|
||||
def test_default_threshold(self):
|
||||
from tools.registry import registry
|
||||
# Unknown tool should return the default
|
||||
val = registry.get_max_result_size("nonexistent_tool_xyz")
|
||||
assert val == DEFAULT_RESULT_SIZE_CHARS
|
||||
|
||||
def test_terminal_threshold(self):
|
||||
from tools.registry import registry
|
||||
# Trigger import of terminal_tool to register the tool
|
||||
try:
|
||||
import tools.terminal_tool # noqa: F401
|
||||
val = registry.get_max_result_size("terminal")
|
||||
assert val == 100_000
|
||||
except ImportError:
|
||||
pytest.skip("terminal_tool not importable in test env")
|
||||
|
||||
def test_read_file_never_persisted(self):
|
||||
from tools.registry import registry
|
||||
try:
|
||||
import tools.file_tools # noqa: F401
|
||||
val = registry.get_max_result_size("read_file")
|
||||
assert val == float("inf")
|
||||
except ImportError:
|
||||
pytest.skip("file_tools not importable in test env")
|
||||
|
||||
def test_search_files_threshold(self):
|
||||
from tools.registry import registry
|
||||
try:
|
||||
import tools.file_tools # noqa: F401
|
||||
val = registry.get_max_result_size("search_files")
|
||||
assert val == 100_000
|
||||
except ImportError:
|
||||
pytest.skip("file_tools not importable in test env")
|
||||
@@ -48,6 +48,7 @@ def clean_env(monkeypatch):
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
monkeypatch.delenv("HERMES_LOCAL_STT_COMMAND", raising=False)
|
||||
monkeypatch.delenv("HERMES_LOCAL_STT_LANGUAGE", raising=False)
|
||||
|
||||
@@ -858,3 +859,183 @@ class TestGetSttModelFromConfig:
|
||||
|
||||
from tools.transcription_tools import get_stt_model_from_config
|
||||
assert get_stt_model_from_config() is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_mistral
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mistral_module():
|
||||
"""Inject a fake mistralai module into sys.modules for testing."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_client.__exit__ = MagicMock(return_value=False)
|
||||
mock_mistral_cls = MagicMock(return_value=mock_client)
|
||||
fake_module = MagicMock()
|
||||
fake_module.Mistral = mock_mistral_cls
|
||||
with patch.dict("sys.modules", {"mistralai": fake_module, "mistralai.client": fake_module}):
|
||||
yield mock_client
|
||||
|
||||
|
||||
class TestTranscribeMistral:
|
||||
def test_no_key(self, monkeypatch):
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
from tools.transcription_tools import _transcribe_mistral
|
||||
result = _transcribe_mistral("/tmp/test.ogg", "voxtral-mini-latest")
|
||||
assert result["success"] is False
|
||||
assert "MISTRAL_API_KEY" in result["error"]
|
||||
|
||||
def test_successful_transcription(self, monkeypatch, sample_ogg, mock_mistral_module):
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.text = "hello from mistral"
|
||||
mock_mistral_module.audio.transcriptions.complete.return_value = mock_result
|
||||
|
||||
from tools.transcription_tools import _transcribe_mistral
|
||||
result = _transcribe_mistral(sample_ogg, "voxtral-mini-latest")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hello from mistral"
|
||||
assert result["provider"] == "mistral"
|
||||
mock_mistral_module.audio.transcriptions.complete.assert_called_once()
|
||||
mock_mistral_module.__exit__.assert_called_once()
|
||||
|
||||
def test_api_error_returns_failure(self, monkeypatch, sample_ogg, mock_mistral_module):
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
mock_mistral_module.audio.transcriptions.complete.side_effect = RuntimeError("secret-key-leaked")
|
||||
|
||||
from tools.transcription_tools import _transcribe_mistral
|
||||
result = _transcribe_mistral(sample_ogg, "voxtral-mini-latest")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "RuntimeError" in result["error"]
|
||||
assert "secret-key-leaked" not in result["error"]
|
||||
|
||||
def test_permission_error(self, monkeypatch, sample_ogg, mock_mistral_module):
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
mock_mistral_module.audio.transcriptions.complete.side_effect = PermissionError("denied")
|
||||
|
||||
from tools.transcription_tools import _transcribe_mistral
|
||||
result = _transcribe_mistral(sample_ogg, "voxtral-mini-latest")
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Permission denied" in result["error"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _get_provider — Mistral
|
||||
# ============================================================================
|
||||
|
||||
class TestGetProviderMistral:
|
||||
"""Mistral-specific provider selection tests."""
|
||||
|
||||
def test_mistral_when_key_and_sdk_available(self, monkeypatch):
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
with patch("tools.transcription_tools._HAS_MISTRAL", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "mistral"}) == "mistral"
|
||||
|
||||
def test_mistral_explicit_no_key_returns_none(self, monkeypatch):
|
||||
"""Explicit mistral with no key returns none — no cross-provider fallback."""
|
||||
monkeypatch.delenv("MISTRAL_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_MISTRAL", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "mistral"}) == "none"
|
||||
|
||||
def test_mistral_explicit_no_sdk_returns_none(self, monkeypatch):
|
||||
"""Explicit mistral with key but no SDK returns none."""
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
with patch("tools.transcription_tools._HAS_MISTRAL", False):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "mistral"}) == "none"
|
||||
|
||||
def test_auto_detect_mistral_after_openai(self, monkeypatch):
|
||||
"""Auto-detect: mistral is tried after openai when both are unavailable."""
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", False), \
|
||||
patch("tools.transcription_tools._HAS_MISTRAL", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "mistral"
|
||||
|
||||
def test_auto_detect_openai_preferred_over_mistral(self, monkeypatch):
|
||||
"""Auto-detect: openai is preferred over mistral (both paid, openai more common)."""
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools._HAS_MISTRAL", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "openai"
|
||||
|
||||
def test_auto_detect_groq_preferred_over_mistral(self, monkeypatch):
|
||||
"""Auto-detect: groq (free) is preferred over mistral (paid)."""
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools._HAS_MISTRAL", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "groq"
|
||||
|
||||
def test_auto_detect_skips_mistral_without_sdk(self, monkeypatch):
|
||||
"""Auto-detect: mistral skipped when key is set but SDK is not installed."""
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv("MISTRAL_API_KEY", "test-key")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", False), \
|
||||
patch("tools.transcription_tools._HAS_MISTRAL", False):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "none"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# transcribe_audio — Mistral dispatch
|
||||
# ============================================================================
|
||||
|
||||
class TestTranscribeAudioMistralDispatch:
|
||||
def test_dispatches_to_mistral(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "mistral"}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="mistral"), \
|
||||
patch("tools.transcription_tools._transcribe_mistral",
|
||||
return_value={"success": True, "transcript": "hi", "provider": "mistral"}) as mock_mistral:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["provider"] == "mistral"
|
||||
mock_mistral.assert_called_once()
|
||||
|
||||
def test_config_mistral_model_used(self, sample_ogg):
|
||||
config = {"provider": "mistral", "mistral": {"model": "voxtral-mini-2602"}}
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value=config), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="mistral"), \
|
||||
patch("tools.transcription_tools._transcribe_mistral",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_mistral:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model=None)
|
||||
|
||||
assert mock_mistral.call_args[0][1] == "voxtral-mini-2602"
|
||||
|
||||
def test_model_override_passed_to_mistral(self, sample_ogg):
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={}), \
|
||||
patch("tools.transcription_tools._get_provider", return_value="mistral"), \
|
||||
patch("tools.transcription_tools._transcribe_mistral",
|
||||
return_value={"success": True, "transcript": "hi"}) as mock_mistral:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
transcribe_audio(sample_ogg, model="voxtral-mini-2602")
|
||||
|
||||
assert mock_mistral.call_args[0][1] == "voxtral-mini-2602"
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Binary file extensions to skip for text-based operations.
|
||||
|
||||
These files can't be meaningfully compared as text and are often large.
|
||||
Ported from free-code src/constants/files.ts.
|
||||
"""
|
||||
|
||||
BINARY_EXTENSIONS = frozenset({
|
||||
# Images
|
||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".ico", ".webp", ".tiff", ".tif",
|
||||
# Videos
|
||||
".mp4", ".mov", ".avi", ".mkv", ".webm", ".wmv", ".flv", ".m4v", ".mpeg", ".mpg",
|
||||
# Audio
|
||||
".mp3", ".wav", ".ogg", ".flac", ".aac", ".m4a", ".wma", ".aiff", ".opus",
|
||||
# Archives
|
||||
".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz", ".z", ".tgz", ".iso",
|
||||
# Executables/binaries
|
||||
".exe", ".dll", ".so", ".dylib", ".bin", ".o", ".a", ".obj", ".lib",
|
||||
".app", ".msi", ".deb", ".rpm",
|
||||
# Documents (exclude .pdf — text-based, agents may want to inspect)
|
||||
".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx",
|
||||
".odt", ".ods", ".odp",
|
||||
# Fonts
|
||||
".ttf", ".otf", ".woff", ".woff2", ".eot",
|
||||
# Bytecode / VM artifacts
|
||||
".pyc", ".pyo", ".class", ".jar", ".war", ".ear", ".node", ".wasm", ".rlib",
|
||||
# Database files
|
||||
".sqlite", ".sqlite3", ".db", ".mdb", ".idx",
|
||||
# Design / 3D
|
||||
".psd", ".ai", ".eps", ".sketch", ".fig", ".xd", ".blend", ".3ds", ".max",
|
||||
# Flash
|
||||
".swf", ".fla",
|
||||
# Lock/profiling data
|
||||
".lockb", ".dat", ".data",
|
||||
})
|
||||
|
||||
|
||||
def has_binary_extension(path: str) -> bool:
|
||||
"""Check if a file path has a binary extension. Pure string check, no I/O."""
|
||||
dot = path.rfind(".")
|
||||
if dot == -1:
|
||||
return False
|
||||
return path[dot:].lower() in BINARY_EXTENSIONS
|
||||
@@ -101,7 +101,8 @@ def _managed_persistence_enabled() -> bool:
|
||||
"""
|
||||
try:
|
||||
camofox_cfg = load_config().get("browser", {}).get("camofox", {})
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.warning("managed_persistence check failed, defaulting to disabled: %s", exc)
|
||||
return False
|
||||
return bool(camofox_cfg.get("managed_persistence"))
|
||||
|
||||
@@ -172,6 +173,22 @@ def _drop_session(task_id: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
return _sessions.pop(task_id, None)
|
||||
|
||||
|
||||
def camofox_soft_cleanup(task_id: Optional[str] = None) -> bool:
|
||||
"""Release the in-memory session without destroying the server-side context.
|
||||
|
||||
When managed persistence is enabled the browser profile (and its cookies)
|
||||
must survive across agent tasks. This helper drops only the local tracking
|
||||
entry and returns ``True``. When managed persistence is *not* enabled it
|
||||
does nothing and returns ``False`` so the caller can fall back to
|
||||
:func:`camofox_close`.
|
||||
"""
|
||||
if _managed_persistence_enabled():
|
||||
_drop_session(task_id)
|
||||
logger.debug("Camofox soft cleanup for task %s (managed persistence)", task_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
+12
-4
@@ -877,7 +877,11 @@ def _run_browser_command(
|
||||
# Local mode — launch a headless Chromium instance
|
||||
backend_args = ["--session", session_info["session_name"]]
|
||||
|
||||
cmd_parts = browser_cmd.split() + backend_args + [
|
||||
# Keep concrete executable paths intact, even when they contain spaces.
|
||||
# Only the synthetic npx fallback needs to expand into multiple argv items.
|
||||
cmd_prefix = ["npx", "agent-browser"] if browser_cmd == "npx agent-browser" else [browser_cmd]
|
||||
|
||||
cmd_parts = cmd_prefix + backend_args + [
|
||||
"--json",
|
||||
command
|
||||
] + args
|
||||
@@ -1931,11 +1935,15 @@ def cleanup_browser(task_id: Optional[str] = None) -> None:
|
||||
if task_id is None:
|
||||
task_id = "default"
|
||||
|
||||
# Also clean up Camofox session if running in Camofox mode
|
||||
# Also clean up Camofox session if running in Camofox mode.
|
||||
# Skip full close when managed persistence is enabled — the browser
|
||||
# profile (and its session cookies) must survive across agent tasks.
|
||||
# The inactivity reaper still frees idle resources.
|
||||
if _is_camofox_mode():
|
||||
try:
|
||||
from tools.browser_camofox import camofox_close
|
||||
camofox_close(task_id)
|
||||
from tools.browser_camofox import camofox_close, camofox_soft_cleanup
|
||||
if not camofox_soft_cleanup(task_id):
|
||||
camofox_close(task_id)
|
||||
except Exception as e:
|
||||
logger.debug("Camofox cleanup for task %s: %s", task_id, e)
|
||||
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
"""Configurable budget constants for tool result persistence.
|
||||
|
||||
Overridable at the RL environment level via HermesAgentEnvConfig fields.
|
||||
Per-tool resolution: pinned > config overrides > registry > default.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict
|
||||
|
||||
# Tools whose thresholds must never be overridden.
|
||||
# read_file=inf prevents infinite persist->read->persist loops.
|
||||
PINNED_THRESHOLDS: Dict[str, float] = {
|
||||
"read_file": float("inf"),
|
||||
}
|
||||
|
||||
# Defaults matching the current hardcoded values in tool_result_storage.py.
|
||||
# Kept here as the single source of truth; tool_result_storage.py imports these.
|
||||
DEFAULT_RESULT_SIZE_CHARS: int = 100_000
|
||||
DEFAULT_TURN_BUDGET_CHARS: int = 200_000
|
||||
DEFAULT_PREVIEW_SIZE_CHARS: int = 1_500
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BudgetConfig:
|
||||
"""Immutable budget constants for the 3-layer tool result persistence system.
|
||||
|
||||
Layer 2 (per-result): resolve_threshold(tool_name) -> threshold in chars.
|
||||
Layer 3 (per-turn): turn_budget -> aggregate char budget across all tool
|
||||
results in a single assistant turn.
|
||||
Preview: preview_size -> inline snippet size after persistence.
|
||||
"""
|
||||
|
||||
default_result_size: int = DEFAULT_RESULT_SIZE_CHARS
|
||||
turn_budget: int = DEFAULT_TURN_BUDGET_CHARS
|
||||
preview_size: int = DEFAULT_PREVIEW_SIZE_CHARS
|
||||
tool_overrides: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def resolve_threshold(self, tool_name: str) -> int | float:
|
||||
"""Resolve the persistence threshold for a tool.
|
||||
|
||||
Priority: pinned -> tool_overrides -> registry per-tool -> default.
|
||||
"""
|
||||
if tool_name in PINNED_THRESHOLDS:
|
||||
return PINNED_THRESHOLDS[tool_name]
|
||||
if tool_name in self.tool_overrides:
|
||||
return self.tool_overrides[tool_name]
|
||||
from tools.registry import registry
|
||||
return registry.get_max_result_size(tool_name, default=self.default_result_size)
|
||||
|
||||
|
||||
# Default config -- matches current hardcoded behavior exactly.
|
||||
DEFAULT_BUDGET = BudgetConfig()
|
||||
@@ -18,7 +18,7 @@ Architecture (two transports):
|
||||
2. Parent ships both files to the remote environment
|
||||
3. Script runs inside the terminal backend (Docker/SSH/Modal/Daytona/etc.)
|
||||
4. Tool calls are written as request files; a polling thread on the parent
|
||||
reads them via execute_oneshot(), dispatches, and writes response files
|
||||
reads them via env.execute(), dispatches, and writes response files
|
||||
5. The script polls for response files and continues
|
||||
|
||||
In both cases, only the script's stdout is returned to the LLM; intermediate
|
||||
@@ -536,7 +536,7 @@ def _ship_file_to_remote(env, remote_path: str, content: str) -> None:
|
||||
quotes are fine.
|
||||
"""
|
||||
encoded = base64.b64encode(content.encode("utf-8")).decode("ascii")
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"echo '{encoded}' | base64 -d > {remote_path}",
|
||||
cwd="/",
|
||||
timeout=30,
|
||||
@@ -555,9 +555,9 @@ def _rpc_poll_loop(
|
||||
):
|
||||
"""Poll the remote filesystem for tool call requests and dispatch them.
|
||||
|
||||
Runs in a background thread. Uses ``env.execute_oneshot()`` so it can
|
||||
operate concurrently with the script-execution thread that holds
|
||||
``env.execute()`` (important for persistent-shell backends like SSH).
|
||||
Runs in a background thread. Each ``env.execute()`` spawns an
|
||||
independent process, so these calls run safely concurrent with the
|
||||
script-execution thread.
|
||||
"""
|
||||
from model_tools import handle_function_call
|
||||
|
||||
@@ -566,7 +566,7 @@ def _rpc_poll_loop(
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
# List pending request files (skip .tmp partials)
|
||||
ls_result = env.execute_oneshot(
|
||||
ls_result = env.execute(
|
||||
f"ls -1 {rpc_dir}/req_* 2>/dev/null || true",
|
||||
cwd="/",
|
||||
timeout=10,
|
||||
@@ -590,7 +590,7 @@ def _rpc_poll_loop(
|
||||
call_start = time.monotonic()
|
||||
|
||||
# Read request
|
||||
read_result = env.execute_oneshot(
|
||||
read_result = env.execute(
|
||||
f"cat {req_file}",
|
||||
cwd="/",
|
||||
timeout=10,
|
||||
@@ -600,7 +600,7 @@ def _rpc_poll_loop(
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
logger.debug("Malformed RPC request in %s", req_file)
|
||||
# Remove bad request to avoid infinite retry
|
||||
env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
env.execute(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
continue
|
||||
|
||||
tool_name = request.get("tool", "")
|
||||
@@ -664,7 +664,7 @@ def _rpc_poll_loop(
|
||||
encoded_result = base64.b64encode(
|
||||
tool_result.encode("utf-8")
|
||||
).decode("ascii")
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"echo '{encoded_result}' | base64 -d > {res_file}.tmp"
|
||||
f" && mv {res_file}.tmp {res_file}",
|
||||
cwd="/",
|
||||
@@ -672,7 +672,7 @@ def _rpc_poll_loop(
|
||||
)
|
||||
|
||||
# Remove the request file
|
||||
env.execute_oneshot(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
env.execute(f"rm -f {req_file}", cwd="/", timeout=5)
|
||||
|
||||
except Exception as e:
|
||||
if not stop_event.is_set():
|
||||
@@ -717,7 +717,7 @@ def _execute_remote(
|
||||
|
||||
try:
|
||||
# Verify Python is available on the remote
|
||||
py_check = env.execute_oneshot(
|
||||
py_check = env.execute(
|
||||
"command -v python3 >/dev/null 2>&1 && echo OK",
|
||||
cwd="/", timeout=15,
|
||||
)
|
||||
@@ -734,7 +734,7 @@ def _execute_remote(
|
||||
})
|
||||
|
||||
# Create sandbox directory on remote
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"mkdir -p {sandbox_dir}/rpc", cwd="/", timeout=10,
|
||||
)
|
||||
|
||||
@@ -806,7 +806,7 @@ def _execute_remote(
|
||||
|
||||
# Clean up remote sandbox dir
|
||||
try:
|
||||
env.execute_oneshot(
|
||||
env.execute(
|
||||
f"rm -rf {sandbox_dir}", cwd="/", timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
@@ -1343,4 +1343,5 @@ registry.register(
|
||||
enabled_tools=kw.get("enabled_tools")),
|
||||
check_fn=check_sandbox_requirements,
|
||||
emoji="🐍",
|
||||
max_result_size_chars=100_000,
|
||||
)
|
||||
|
||||
@@ -195,6 +195,7 @@ def _format_job(job: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"next_run_at": job.get("next_run_at"),
|
||||
"last_run_at": job.get("last_run_at"),
|
||||
"last_status": job.get("last_status"),
|
||||
"last_delivery_error": job.get("last_delivery_error"),
|
||||
"enabled": job.get("enabled", True),
|
||||
"state": job.get("state", "scheduled" if job.get("enabled", True) else "paused"),
|
||||
"paused_at": job.get("paused_at"),
|
||||
|
||||
+501
-55
@@ -1,11 +1,27 @@
|
||||
"""Base class for all Hermes execution environment backends."""
|
||||
"""Base class for all Hermes execution environment backends.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
Unified spawn-per-call model: every command spawns a fresh ``bash -c`` process.
|
||||
A session snapshot (env vars, functions, aliases) is captured once at init and
|
||||
re-sourced before each command. CWD persists via in-band stdout markers (remote)
|
||||
or a temp file (local).
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import IO, Callable, Protocol
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_sandbox_dir() -> Path:
|
||||
@@ -23,30 +39,501 @@ def get_sandbox_dir() -> Path:
|
||||
return p
|
||||
|
||||
|
||||
class BaseEnvironment(ABC):
|
||||
"""Common interface for all Hermes execution backends.
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared constants and utilities
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Subclasses implement execute() and cleanup(). Shared helpers eliminate
|
||||
duplicated subprocess boilerplate across backends.
|
||||
_SYNC_INTERVAL_SECONDS = 5.0
|
||||
|
||||
|
||||
def _pipe_stdin(proc: subprocess.Popen, data: str) -> None:
|
||||
"""Write *data* to proc.stdin on a daemon thread to avoid pipe-buffer deadlocks."""
|
||||
|
||||
def _write():
|
||||
try:
|
||||
proc.stdin.write(data)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
|
||||
threading.Thread(target=_write, daemon=True).start()
|
||||
|
||||
|
||||
def _popen_bash(
|
||||
cmd: list[str], stdin_data: str | None = None, **kwargs
|
||||
) -> subprocess.Popen:
|
||||
"""Spawn a subprocess with standard stdout/stderr/stdin setup.
|
||||
|
||||
If *stdin_data* is provided, writes it asynchronously via :func:`_pipe_stdin`.
|
||||
Backends with special Popen needs (e.g. local's ``preexec_fn``) can bypass
|
||||
this and call :func:`_pipe_stdin` directly.
|
||||
"""
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL,
|
||||
text=True,
|
||||
**kwargs,
|
||||
)
|
||||
if stdin_data is not None:
|
||||
_pipe_stdin(proc, stdin_data)
|
||||
return proc
|
||||
|
||||
|
||||
def _load_json_store(path: Path) -> dict:
|
||||
"""Load a JSON file as a dict, returning ``{}`` on any error."""
|
||||
if path.exists():
|
||||
try:
|
||||
return json.loads(path.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def _save_json_store(path: Path, data: dict) -> None:
|
||||
"""Write *data* as pretty-printed JSON to *path*."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(data, indent=2))
|
||||
|
||||
|
||||
def _file_mtime_key(host_path: str) -> tuple[float, int] | None:
|
||||
"""Return ``(mtime, size)`` for cache comparison, or ``None`` if unreadable."""
|
||||
try:
|
||||
st = Path(host_path).stat()
|
||||
return (st.st_mtime, st.st_size)
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ProcessHandle protocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ProcessHandle(Protocol):
|
||||
"""Duck type that every backend's _run_bash() must return.
|
||||
|
||||
subprocess.Popen satisfies this natively. SDK backends (Modal, Daytona)
|
||||
return _ThreadedProcessHandle which adapts their blocking calls.
|
||||
"""
|
||||
|
||||
def poll(self) -> int | None: ...
|
||||
def kill(self) -> None: ...
|
||||
def wait(self, timeout: float | None = None) -> int: ...
|
||||
|
||||
@property
|
||||
def stdout(self) -> IO[str] | None: ...
|
||||
|
||||
@property
|
||||
def returncode(self) -> int | None: ...
|
||||
|
||||
|
||||
class _ThreadedProcessHandle:
|
||||
"""Adapter for SDK backends (Modal, Daytona) that have no real subprocess.
|
||||
|
||||
Wraps a blocking ``exec_fn() -> (output_str, exit_code)`` in a background
|
||||
thread and exposes a ProcessHandle-compatible interface. An optional
|
||||
``cancel_fn`` is invoked on ``kill()`` for backend-specific cancellation
|
||||
(e.g. Modal sandbox.terminate, Daytona sandbox.stop).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exec_fn: Callable[[], tuple[str, int]],
|
||||
cancel_fn: Callable[[], None] | None = None,
|
||||
):
|
||||
self._cancel_fn = cancel_fn
|
||||
self._done = threading.Event()
|
||||
self._returncode: int | None = None
|
||||
self._error: Exception | None = None
|
||||
|
||||
# Pipe for stdout — drain thread in _wait_for_process reads the read end.
|
||||
read_fd, write_fd = os.pipe()
|
||||
self._stdout = os.fdopen(read_fd, "r", encoding="utf-8", errors="replace")
|
||||
self._write_fd = write_fd
|
||||
|
||||
def _worker():
|
||||
try:
|
||||
output, exit_code = exec_fn()
|
||||
self._returncode = exit_code
|
||||
# Write output into the pipe so drain thread picks it up.
|
||||
try:
|
||||
os.write(self._write_fd, output.encode("utf-8", errors="replace"))
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as exc:
|
||||
self._error = exc
|
||||
self._returncode = 1
|
||||
finally:
|
||||
try:
|
||||
os.close(self._write_fd)
|
||||
except OSError:
|
||||
pass
|
||||
self._done.set()
|
||||
|
||||
t = threading.Thread(target=_worker, daemon=True)
|
||||
t.start()
|
||||
|
||||
@property
|
||||
def stdout(self):
|
||||
return self._stdout
|
||||
|
||||
@property
|
||||
def returncode(self) -> int | None:
|
||||
return self._returncode
|
||||
|
||||
def poll(self) -> int | None:
|
||||
return self._returncode if self._done.is_set() else None
|
||||
|
||||
def kill(self):
|
||||
if self._cancel_fn:
|
||||
try:
|
||||
self._cancel_fn()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def wait(self, timeout: float | None = None) -> int:
|
||||
self._done.wait(timeout=timeout)
|
||||
return self._returncode
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CWD marker for remote backends
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _cwd_marker(session_id: str) -> str:
|
||||
return f"__HERMES_CWD_{session_id}__"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BaseEnvironment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class BaseEnvironment(ABC):
|
||||
"""Common interface and unified execution flow for all Hermes backends.
|
||||
|
||||
Subclasses implement ``_run_bash()`` and ``cleanup()``. The base class
|
||||
provides ``execute()`` with session snapshot sourcing, CWD tracking,
|
||||
interrupt handling, and timeout enforcement.
|
||||
"""
|
||||
|
||||
# Subclasses that embed stdin as a heredoc (Modal, Daytona) set this.
|
||||
_stdin_mode: str = "pipe" # "pipe" or "heredoc"
|
||||
|
||||
# Snapshot creation timeout (override for slow cold-starts).
|
||||
_snapshot_timeout: int = 30
|
||||
|
||||
def __init__(self, cwd: str, timeout: int, env: dict = None):
|
||||
self.cwd = cwd
|
||||
self.timeout = timeout
|
||||
self.env = env or {}
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||
...
|
||||
self._session_id = uuid.uuid4().hex[:12]
|
||||
self._snapshot_path = f"/tmp/hermes-snap-{self._session_id}.sh"
|
||||
self._cwd_file = f"/tmp/hermes-cwd-{self._session_id}.txt"
|
||||
self._cwd_marker = _cwd_marker(self._session_id)
|
||||
self._snapshot_ready = False
|
||||
self._last_sync_time: float | None = (
|
||||
None # set to 0 by backends that need file sync
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Abstract methods
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_bash(
|
||||
self,
|
||||
cmd_string: str,
|
||||
*,
|
||||
login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None,
|
||||
) -> ProcessHandle:
|
||||
"""Spawn a bash process to run *cmd_string*.
|
||||
|
||||
Returns a ProcessHandle (subprocess.Popen or _ThreadedProcessHandle).
|
||||
Must be overridden by every backend.
|
||||
"""
|
||||
raise NotImplementedError(f"{type(self).__name__} must implement _run_bash()")
|
||||
|
||||
@abstractmethod
|
||||
def cleanup(self):
|
||||
"""Release backend resources (container, instance, connection)."""
|
||||
...
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session snapshot (init_session)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def init_session(self):
|
||||
"""Capture login shell environment into a snapshot file.
|
||||
|
||||
Called once after backend construction. On success, sets
|
||||
``_snapshot_ready = True`` so subsequent commands source the snapshot
|
||||
instead of running with ``bash -l``.
|
||||
"""
|
||||
# Full capture: env vars, functions (filtered), aliases, shell options.
|
||||
bootstrap = (
|
||||
f"export -p > {self._snapshot_path}\n"
|
||||
f"declare -f | grep -vE '^_[^_]' >> {self._snapshot_path}\n"
|
||||
f"alias -p >> {self._snapshot_path}\n"
|
||||
f"echo 'shopt -s expand_aliases' >> {self._snapshot_path}\n"
|
||||
f"echo 'set +e' >> {self._snapshot_path}\n"
|
||||
f"echo 'set +u' >> {self._snapshot_path}\n"
|
||||
f"pwd -P > {self._cwd_file} 2>/dev/null || true\n"
|
||||
f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\"\n"
|
||||
)
|
||||
try:
|
||||
proc = self._run_bash(bootstrap, login=True, timeout=self._snapshot_timeout)
|
||||
result = self._wait_for_process(proc, timeout=self._snapshot_timeout)
|
||||
self._snapshot_ready = True
|
||||
self._update_cwd(result)
|
||||
logger.info(
|
||||
"Session snapshot created (session=%s, cwd=%s)",
|
||||
self._session_id,
|
||||
self.cwd,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"init_session failed (session=%s): %s — "
|
||||
"falling back to bash -l per command",
|
||||
self._session_id,
|
||||
exc,
|
||||
)
|
||||
self._snapshot_ready = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Command wrapping
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _wrap_command(self, command: str, cwd: str) -> str:
|
||||
"""Build the full bash script that sources snapshot, cd's, runs command,
|
||||
re-dumps env vars, and emits CWD markers."""
|
||||
escaped = command.replace("'", "'\\''")
|
||||
|
||||
parts = []
|
||||
|
||||
# Source snapshot (env vars from previous commands)
|
||||
if self._snapshot_ready:
|
||||
parts.append(f"source {self._snapshot_path} 2>/dev/null || true")
|
||||
|
||||
# cd to working directory — let bash expand ~ natively
|
||||
quoted_cwd = (
|
||||
shlex.quote(cwd) if cwd != "~" and not cwd.startswith("~/") else cwd
|
||||
)
|
||||
parts.append(f"cd {quoted_cwd} || exit 126")
|
||||
|
||||
# Run the actual command
|
||||
parts.append(f"eval '{escaped}'")
|
||||
parts.append("__hermes_ec=$?")
|
||||
|
||||
# Re-dump env vars to snapshot (last-writer-wins for concurrent calls)
|
||||
if self._snapshot_ready:
|
||||
parts.append(f"export -p > {self._snapshot_path} 2>/dev/null || true")
|
||||
|
||||
# Write CWD to file (local reads this) and stdout marker (remote parses this)
|
||||
parts.append(f"pwd -P > {self._cwd_file} 2>/dev/null || true")
|
||||
# Use a distinct line for the marker. The leading \n ensures
|
||||
# the marker starts on its own line even if the command doesn't
|
||||
# end with a newline (e.g. printf 'exact'). We'll strip this
|
||||
# injected newline in _extract_cwd_from_output.
|
||||
parts.append(
|
||||
f"printf '\\n{self._cwd_marker}%s{self._cwd_marker}\\n' \"$(pwd -P)\""
|
||||
)
|
||||
parts.append("exit $__hermes_ec")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stdin heredoc embedding (for SDK backends)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _embed_stdin_heredoc(command: str, stdin_data: str) -> str:
|
||||
"""Append stdin_data as a shell heredoc to the command string."""
|
||||
delimiter = f"HERMES_STDIN_{uuid.uuid4().hex[:12]}"
|
||||
return f"{command} << '{delimiter}'\n{stdin_data}\n{delimiter}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Process lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _wait_for_process(self, proc: ProcessHandle, timeout: int = 120) -> dict:
|
||||
"""Poll-based wait with interrupt checking and stdout draining.
|
||||
|
||||
Shared across all backends — not overridden.
|
||||
"""
|
||||
output_chunks: list[str] = []
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
output_chunks.append(line)
|
||||
except UnicodeDecodeError:
|
||||
output_chunks.clear()
|
||||
output_chunks.append(
|
||||
"[binary output detected — raw bytes not displayable]"
|
||||
)
|
||||
except (ValueError, OSError):
|
||||
pass
|
||||
|
||||
drain_thread = threading.Thread(target=_drain, daemon=True)
|
||||
drain_thread.start()
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
self._kill_process(proc)
|
||||
drain_thread.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
self._kill_process(proc)
|
||||
drain_thread.join(timeout=2)
|
||||
partial = "".join(output_chunks)
|
||||
timeout_msg = f"\n[Command timed out after {timeout}s]"
|
||||
return {
|
||||
"output": partial + timeout_msg
|
||||
if partial
|
||||
else timeout_msg.lstrip(),
|
||||
"returncode": 124,
|
||||
}
|
||||
time.sleep(0.2)
|
||||
|
||||
drain_thread.join(timeout=5)
|
||||
|
||||
try:
|
||||
proc.stdout.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"output": "".join(output_chunks), "returncode": proc.returncode}
|
||||
|
||||
def _kill_process(self, proc: ProcessHandle):
|
||||
"""Terminate a process. Subclasses may override for process-group kill."""
|
||||
try:
|
||||
proc.kill()
|
||||
except (ProcessLookupError, PermissionError, OSError):
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# CWD extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _update_cwd(self, result: dict):
|
||||
"""Extract CWD from command output. Override for local file-based read."""
|
||||
self._extract_cwd_from_output(result)
|
||||
|
||||
def _extract_cwd_from_output(self, result: dict):
|
||||
"""Parse the __HERMES_CWD_{session}__ marker from stdout output.
|
||||
|
||||
Updates self.cwd and strips the marker from result["output"].
|
||||
Used by remote backends (Docker, SSH, Modal, Daytona, Singularity).
|
||||
"""
|
||||
output = result.get("output", "")
|
||||
marker = self._cwd_marker
|
||||
last = output.rfind(marker)
|
||||
if last == -1:
|
||||
return
|
||||
|
||||
# Find the opening marker before this closing one
|
||||
search_start = max(0, last - 4096) # CWD path won't be >4KB
|
||||
first = output.rfind(marker, search_start, last)
|
||||
if first == -1 or first == last:
|
||||
return
|
||||
|
||||
cwd_path = output[first + len(marker) : last].strip()
|
||||
if cwd_path:
|
||||
self.cwd = cwd_path
|
||||
|
||||
# Strip the marker line AND the \n we injected before it.
|
||||
# The wrapper emits: printf '\n__MARKER__%s__MARKER__\n'
|
||||
# So the output looks like: <cmd output>\n__MARKER__path__MARKER__\n
|
||||
# We want to remove everything from the injected \n onwards.
|
||||
line_start = output.rfind("\n", 0, first)
|
||||
if line_start == -1:
|
||||
line_start = first
|
||||
line_end = output.find("\n", last + len(marker))
|
||||
line_end = line_end + 1 if line_end != -1 else len(output)
|
||||
|
||||
result["output"] = output[:line_start] + output[line_end:]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Hooks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _before_execute(self):
|
||||
"""Rate-limited file sync before each command.
|
||||
|
||||
Backends that need pre-command sync set ``self._last_sync_time = 0``
|
||||
in ``__init__`` and override :meth:`_sync_files`. Backends needing
|
||||
extra pre-exec logic (e.g. Daytona sandbox restart check) override
|
||||
this method and call ``super()._before_execute()``.
|
||||
"""
|
||||
if self._last_sync_time is not None:
|
||||
now = time.monotonic()
|
||||
if now - self._last_sync_time >= _SYNC_INTERVAL_SECONDS:
|
||||
self._sync_files()
|
||||
self._last_sync_time = now
|
||||
|
||||
def _sync_files(self):
|
||||
"""Push files to remote environment. Called rate-limited by _before_execute."""
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Unified execute()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def execute(
|
||||
self,
|
||||
command: str,
|
||||
cwd: str = "",
|
||||
*,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None,
|
||||
) -> dict:
|
||||
"""Execute a command, return {"output": str, "returncode": int}."""
|
||||
self._before_execute()
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
effective_timeout = timeout or self.timeout
|
||||
effective_cwd = cwd or self.cwd
|
||||
|
||||
# Merge sudo stdin with caller stdin
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
# Embed stdin as heredoc for backends that need it
|
||||
if effective_stdin and self._stdin_mode == "heredoc":
|
||||
exec_command = self._embed_stdin_heredoc(exec_command, effective_stdin)
|
||||
effective_stdin = None
|
||||
|
||||
wrapped = self._wrap_command(exec_command, effective_cwd)
|
||||
|
||||
# Use login shell if snapshot failed (so user's profile still loads)
|
||||
login = not self._snapshot_ready
|
||||
|
||||
proc = self._run_bash(
|
||||
wrapped, login=login, timeout=effective_timeout, stdin_data=effective_stdin
|
||||
)
|
||||
result = self._wait_for_process(proc, timeout=effective_timeout)
|
||||
self._update_cwd(result)
|
||||
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def stop(self):
|
||||
"""Alias for cleanup (compat with older callers)."""
|
||||
self.cleanup()
|
||||
@@ -57,53 +544,12 @@ class BaseEnvironment(ABC):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Shared helpers (eliminate duplication across backends)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _prepare_command(self, command: str) -> tuple[str, str | None]:
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available.
|
||||
|
||||
Returns:
|
||||
(transformed_command, sudo_stdin) — see _transform_sudo_command
|
||||
for the full contract. Callers that drive a subprocess directly
|
||||
should prepend sudo_stdin (when not None) to any stdin_data they
|
||||
pass to Popen. Callers that embed stdin via heredoc (modal,
|
||||
daytona) handle sudo_stdin in their own execute() method.
|
||||
"""
|
||||
"""Transform sudo commands if SUDO_PASSWORD is available."""
|
||||
from tools.terminal_tool import _transform_sudo_command
|
||||
|
||||
return _transform_sudo_command(command)
|
||||
|
||||
def _build_run_kwargs(self, timeout: int | None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Build common subprocess.run kwargs for non-interactive execution."""
|
||||
kw = {
|
||||
"text": True,
|
||||
"timeout": timeout or self.timeout,
|
||||
"encoding": "utf-8",
|
||||
"errors": "replace",
|
||||
"stdout": subprocess.PIPE,
|
||||
"stderr": subprocess.STDOUT,
|
||||
}
|
||||
if stdin_data is not None:
|
||||
kw["input"] = stdin_data
|
||||
else:
|
||||
kw["stdin"] = subprocess.DEVNULL
|
||||
return kw
|
||||
|
||||
def execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
"""Execute a command bypassing any persistent shell.
|
||||
|
||||
Safe for concurrent use alongside a long-running execute() call.
|
||||
Backends that maintain a persistent shell (SSH, Local) override this
|
||||
to route through their oneshot path, avoiding the shell lock.
|
||||
Non-persistent backends delegate to execute().
|
||||
"""
|
||||
return self.execute(command, cwd=cwd, timeout=timeout,
|
||||
stdin_data=stdin_data)
|
||||
|
||||
def _timeout_result(self, timeout: int | None) -> dict:
|
||||
"""Standard return dict when a command times out."""
|
||||
return {
|
||||
|
||||
+48
-133
@@ -6,17 +6,18 @@ and resumed on next creation, preserving the filesystem across sessions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import math
|
||||
import shlex
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.interrupt import is_interrupted
|
||||
from tools.environments.base import (
|
||||
BaseEnvironment,
|
||||
_ThreadedProcessHandle,
|
||||
_file_mtime_key,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,22 +25,25 @@ logger = logging.getLogger(__name__)
|
||||
class DaytonaEnvironment(BaseEnvironment):
|
||||
"""Daytona cloud sandbox execution backend.
|
||||
|
||||
Uses stopped/started sandbox lifecycle for filesystem persistence
|
||||
instead of snapshots, making it faster and stateless on the host.
|
||||
Spawn-per-call via _ThreadedProcessHandle wrapping blocking SDK calls.
|
||||
cancel_fn wired to sandbox.stop() for interrupt support.
|
||||
Shell timeout wrapper preserved (SDK timeout unreliable).
|
||||
"""
|
||||
|
||||
_stdin_mode = "heredoc"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str,
|
||||
cwd: str = "/home/daytona",
|
||||
timeout: int = 60,
|
||||
cpu: int = 1,
|
||||
memory: int = 5120, # MB (hermes convention)
|
||||
disk: int = 10240, # MB (Daytona platform max is 10GB)
|
||||
memory: int = 5120,
|
||||
disk: int = 10240,
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
):
|
||||
self._requested_cwd = cwd
|
||||
requested_cwd = cwd
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
from daytona import (
|
||||
@@ -53,16 +57,18 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._SandboxState = SandboxState
|
||||
self._DaytonaError = DaytonaError
|
||||
self._daytona = Daytona()
|
||||
self._sandbox = None
|
||||
self._lock = threading.Lock()
|
||||
self._last_sync_time: float = 0
|
||||
|
||||
memory_gib = max(1, math.ceil(memory / 1024))
|
||||
disk_gib = max(1, math.ceil(disk / 1024))
|
||||
if disk_gib > 10:
|
||||
warnings.warn(
|
||||
f"Daytona: requested disk ({disk_gib}GB) exceeds platform limit (10GB). "
|
||||
f"Capping to 10GB. Set container_disk: 10240 in config to silence this.",
|
||||
f"Capping to 10GB.",
|
||||
stacklevel=2,
|
||||
)
|
||||
disk_gib = 10
|
||||
@@ -71,9 +77,7 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
labels = {"hermes_task_id": task_id}
|
||||
sandbox_name = f"hermes-{task_id}"
|
||||
|
||||
# Try to resume an existing sandbox for this task
|
||||
if self._persistent:
|
||||
# 1. Try name-based lookup (new path)
|
||||
try:
|
||||
self._sandbox = self._daytona.get(sandbox_name)
|
||||
self._sandbox.start()
|
||||
@@ -86,7 +90,6 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# 2. Legacy fallback: find sandbox created before the naming migration
|
||||
if self._sandbox is None:
|
||||
try:
|
||||
page = self._daytona.list(labels=labels, page=1, limit=1)
|
||||
@@ -100,7 +103,6 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
task_id, e)
|
||||
self._sandbox = None
|
||||
|
||||
# Create a fresh sandbox if we don't have one
|
||||
if self._sandbox is None:
|
||||
self._sandbox = self._daytona.create(
|
||||
CreateSandboxFromImageParams(
|
||||
@@ -114,32 +116,25 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
logger.info("Daytona: created sandbox %s for task %s",
|
||||
self._sandbox.id, task_id)
|
||||
|
||||
# Detect remote home dir first so mounts go to the right place.
|
||||
# Detect remote home dir
|
||||
self._remote_home = "/root"
|
||||
try:
|
||||
home = self._sandbox.process.exec("echo $HOME").result.strip()
|
||||
if home:
|
||||
self._remote_home = home
|
||||
if self._requested_cwd in ("~", "/home/daytona"):
|
||||
if requested_cwd in ("~", "/home/daytona"):
|
||||
self.cwd = home
|
||||
except Exception:
|
||||
pass
|
||||
logger.info("Daytona: resolved home to %s, cwd to %s", self._remote_home, self.cwd)
|
||||
|
||||
# Track synced files to avoid redundant uploads.
|
||||
# Key: remote_path, Value: (mtime, size)
|
||||
self._synced_files: Dict[str, tuple] = {}
|
||||
|
||||
# Upload credential files and skills directory into the sandbox.
|
||||
self._sync_skills_and_credentials()
|
||||
self._sync_files()
|
||||
self.init_session()
|
||||
|
||||
def _upload_if_changed(self, host_path: str, remote_path: str) -> bool:
|
||||
"""Upload a file if its mtime/size changed since last sync."""
|
||||
hp = Path(host_path)
|
||||
try:
|
||||
stat = hp.stat()
|
||||
file_key = (stat.st_mtime, stat.st_size)
|
||||
except OSError:
|
||||
file_key = _file_mtime_key(host_path)
|
||||
if file_key is None:
|
||||
return False
|
||||
if self._synced_files.get(remote_path) == file_key:
|
||||
return False
|
||||
@@ -153,20 +148,15 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
logger.debug("Daytona: upload failed %s: %s", host_path, e)
|
||||
return False
|
||||
|
||||
def _sync_skills_and_credentials(self) -> None:
|
||||
"""Upload changed credential files and skill files into the sandbox."""
|
||||
def _sync_files(self) -> None:
|
||||
container_base = f"{self._remote_home}/.hermes"
|
||||
try:
|
||||
from tools.credential_files import get_credential_file_mounts, iter_skills_files
|
||||
|
||||
for mount_entry in get_credential_file_mounts():
|
||||
remote_path = mount_entry["container_path"].replace("/root/.hermes", container_base, 1)
|
||||
if self._upload_if_changed(mount_entry["host_path"], remote_path):
|
||||
logger.debug("Daytona: synced credential %s", remote_path)
|
||||
|
||||
self._upload_if_changed(mount_entry["host_path"], remote_path)
|
||||
for entry in iter_skills_files(container_base=container_base):
|
||||
if self._upload_if_changed(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Daytona: synced skill %s", entry["container_path"])
|
||||
self._upload_if_changed(entry["host_path"], entry["container_path"])
|
||||
except Exception as e:
|
||||
logger.debug("Daytona: could not sync skills/credentials: %s", e)
|
||||
|
||||
@@ -177,111 +167,36 @@ class DaytonaEnvironment(BaseEnvironment):
|
||||
self._sandbox.start()
|
||||
logger.info("Daytona: restarted sandbox %s", self._sandbox.id)
|
||||
|
||||
def _exec_in_thread(self, exec_command: str, cwd: Optional[str], timeout: int) -> dict:
|
||||
"""Run exec in a background thread with interrupt polling.
|
||||
|
||||
The Daytona SDK's exec(timeout=...) parameter is unreliable (the
|
||||
server-side timeout is not enforced and the SDK has no client-side
|
||||
fallback), so we wrap the command with the shell ``timeout`` utility
|
||||
which reliably kills the process and returns exit code 124.
|
||||
"""
|
||||
# Wrap with shell `timeout` to enforce the deadline reliably.
|
||||
# Add a small buffer so the shell timeout fires before any SDK-level
|
||||
# timeout would, giving us a clean exit code 124.
|
||||
timed_command = f"timeout {timeout} sh -c {shlex.quote(exec_command)}"
|
||||
|
||||
result_holder: dict = {"value": None, "error": None}
|
||||
|
||||
def _run():
|
||||
try:
|
||||
response = self._sandbox.process.exec(
|
||||
timed_command, cwd=cwd,
|
||||
)
|
||||
result_holder["value"] = {
|
||||
"output": response.result or "",
|
||||
"returncode": response.exit_code,
|
||||
}
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
# Wait for timeout + generous buffer for network/SDK overhead
|
||||
deadline = time.monotonic() + timeout + 10
|
||||
while t.is_alive():
|
||||
t.join(timeout=0.2)
|
||||
if is_interrupted():
|
||||
with self._lock:
|
||||
try:
|
||||
self._sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"output": "[Command interrupted - Daytona sandbox stopped]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
# Shell timeout didn't fire and SDK is hung — force stop
|
||||
with self._lock:
|
||||
try:
|
||||
self._sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
return self._timeout_result(timeout)
|
||||
|
||||
if result_holder["error"]:
|
||||
return {"error": result_holder["error"]}
|
||||
return result_holder["value"]
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: Optional[int] = None,
|
||||
stdin_data: Optional[str] = None) -> dict:
|
||||
def _before_execute(self):
|
||||
"""Ensure sandbox is ready, then rate-limited file sync via base class."""
|
||||
with self._lock:
|
||||
self._ensure_sandbox_ready()
|
||||
# Incremental sync before each command so mid-session credential
|
||||
# refreshes and skill updates are picked up.
|
||||
self._sync_skills_and_credentials()
|
||||
super()._before_execute()
|
||||
|
||||
if stdin_data is not None:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in stdin_data:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
command = f"{command} << '{marker}'\n{stdin_data}\n{marker}"
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None):
|
||||
"""Return a _ThreadedProcessHandle wrapping a blocking Daytona SDK call."""
|
||||
sandbox = self._sandbox
|
||||
lock = self._lock
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
def cancel():
|
||||
with lock:
|
||||
try:
|
||||
sandbox.stop()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Daytona sandboxes execute commands via the Daytona SDK and cannot
|
||||
# pipe subprocess stdin directly the way a local Popen can. When a
|
||||
# sudo password is present, use a shell-level pipe from printf so that
|
||||
# the password feeds sudo -S without appearing as an echo argument
|
||||
# embedded in the shell string. The password is still visible in the
|
||||
# remote sandbox's command line, but it is not exposed on the user's
|
||||
# local machine — which is the primary threat being mitigated.
|
||||
if sudo_stdin is not None:
|
||||
import shlex
|
||||
exec_command = (
|
||||
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
|
||||
)
|
||||
effective_cwd = cwd or self.cwd or None
|
||||
effective_timeout = timeout or self.timeout
|
||||
if login:
|
||||
shell_cmd = f"bash -l -c {shlex.quote(cmd_string)}"
|
||||
else:
|
||||
shell_cmd = f"bash -c {shlex.quote(cmd_string)}"
|
||||
|
||||
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
|
||||
def exec_fn() -> tuple[str, int]:
|
||||
response = sandbox.process.exec(shell_cmd, timeout=timeout)
|
||||
return (response.result or "", response.exit_code)
|
||||
|
||||
if "error" in result:
|
||||
from daytona import DaytonaError
|
||||
err = result["error"]
|
||||
if isinstance(err, DaytonaError):
|
||||
with self._lock:
|
||||
try:
|
||||
self._ensure_sandbox_ready()
|
||||
except Exception:
|
||||
return {"output": f"Daytona execution error: {err}", "returncode": 1}
|
||||
result = self._exec_in_thread(exec_command, effective_cwd, effective_timeout)
|
||||
if "error" not in result:
|
||||
return result
|
||||
return {"output": f"Daytona execution error: {err}", "returncode": 1}
|
||||
|
||||
return result
|
||||
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
|
||||
def cleanup(self):
|
||||
with self._lock:
|
||||
|
||||
+64
-108
@@ -8,18 +8,14 @@ persistence via bind mounts.
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.base import BaseEnvironment, _popen_bash
|
||||
from tools.environments.local import _HERMES_PROVIDER_ENV_BLOCKLIST
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -431,6 +427,69 @@ class DockerEnvironment(BaseEnvironment):
|
||||
self._container_id = result.stdout.strip()
|
||||
logger.info(f"Started container {container_name} ({self._container_id[:12]})")
|
||||
|
||||
# Build the init-time env forwarding args (used only by init_session
|
||||
# to inject host env vars into the snapshot; subsequent commands get
|
||||
# them from the snapshot file).
|
||||
self._init_env_args = self._build_init_env_args()
|
||||
|
||||
# Initialize session snapshot inside the container
|
||||
self.init_session()
|
||||
|
||||
def _build_init_env_args(self) -> list[str]:
|
||||
"""Build -e KEY=VALUE args for injecting host env vars into init_session.
|
||||
|
||||
These are used once during init_session() so that export -p captures
|
||||
them into the snapshot. Subsequent execute() calls don't need -e flags.
|
||||
"""
|
||||
exec_env: dict[str, str] = dict(self._env)
|
||||
|
||||
explicit_forward_keys = set(self._forward_env)
|
||||
passthrough_keys: set[str] = set()
|
||||
try:
|
||||
from tools.env_passthrough import get_all_passthrough
|
||||
passthrough_keys = set(get_all_passthrough())
|
||||
except Exception:
|
||||
pass
|
||||
# Explicit docker_forward_env entries are an intentional opt-in and must
|
||||
# win over the generic Hermes secret blocklist. Only implicit passthrough
|
||||
# keys are filtered.
|
||||
forward_keys = explicit_forward_keys | (passthrough_keys - _HERMES_PROVIDER_ENV_BLOCKLIST)
|
||||
hermes_env = _load_hermes_env_vars() if forward_keys else {}
|
||||
for key in sorted(forward_keys):
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
value = hermes_env.get(key)
|
||||
if value is not None:
|
||||
exec_env[key] = value
|
||||
|
||||
args = []
|
||||
for key in sorted(exec_env):
|
||||
args.extend(["-e", f"{key}={exec_env[key]}"])
|
||||
return args
|
||||
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None) -> subprocess.Popen:
|
||||
"""Spawn a bash process inside the Docker container."""
|
||||
assert self._container_id, "Container not started"
|
||||
cmd = [self._docker_exe, "exec"]
|
||||
if stdin_data is not None:
|
||||
cmd.append("-i")
|
||||
|
||||
# Only inject -e env args during init_session (login=True).
|
||||
# Subsequent commands get env vars from the snapshot.
|
||||
if login:
|
||||
cmd.extend(self._init_env_args)
|
||||
|
||||
cmd.extend([self._container_id])
|
||||
|
||||
if login:
|
||||
cmd.extend(["bash", "-l", "-c", cmd_string])
|
||||
else:
|
||||
cmd.extend(["bash", "-c", cmd_string])
|
||||
|
||||
return _popen_bash(cmd, stdin_data)
|
||||
|
||||
@staticmethod
|
||||
def _storage_opt_supported() -> bool:
|
||||
"""Check if Docker's storage driver supports --storage-opt size=.
|
||||
@@ -471,109 +530,6 @@ class DockerEnvironment(BaseEnvironment):
|
||||
logger.debug("Docker --storage-opt support: %s", _storage_opt_ok)
|
||||
return _storage_opt_ok
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
work_dir = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
# Merge sudo password (if any) with caller-supplied stdin_data.
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
# docker exec -w doesn't expand ~, so prepend a cd into the command.
|
||||
# Keep ~ unquoted (for shell expansion) and quote only the subpath.
|
||||
if work_dir == "~":
|
||||
exec_command = f"cd ~ && {exec_command}"
|
||||
work_dir = "/"
|
||||
elif work_dir.startswith("~/"):
|
||||
exec_command = f"cd ~/{shlex.quote(work_dir[2:])} && {exec_command}"
|
||||
work_dir = "/"
|
||||
|
||||
assert self._container_id, "Container not started"
|
||||
cmd = [self._docker_exe, "exec"]
|
||||
if effective_stdin is not None:
|
||||
cmd.append("-i")
|
||||
cmd.extend(["-w", work_dir])
|
||||
# Build the per-exec environment: start with explicit docker_env values
|
||||
# (static config), then overlay docker_forward_env / skill env_passthrough
|
||||
# (dynamic from host process). Forward values take precedence.
|
||||
exec_env: dict[str, str] = dict(self._env)
|
||||
|
||||
forward_keys = set(self._forward_env)
|
||||
try:
|
||||
from tools.env_passthrough import get_all_passthrough
|
||||
forward_keys |= get_all_passthrough()
|
||||
except Exception:
|
||||
pass
|
||||
# Strip Hermes-managed secrets so they never leak into the container.
|
||||
forward_keys -= _HERMES_PROVIDER_ENV_BLOCKLIST
|
||||
hermes_env = _load_hermes_env_vars() if forward_keys else {}
|
||||
for key in sorted(forward_keys):
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
value = hermes_env.get(key)
|
||||
if value is not None:
|
||||
exec_env[key] = value
|
||||
|
||||
for key in sorted(exec_env):
|
||||
cmd.extend(["-e", f"{key}={exec_env[key]}"])
|
||||
cmd.extend([self._container_id, "bash", "-lc", exec_command])
|
||||
|
||||
try:
|
||||
_output_chunks = []
|
||||
proc = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin else subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
if effective_stdin:
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _drain():
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=1)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return self._timeout_result(effective_timeout)
|
||||
time.sleep(0.2)
|
||||
|
||||
reader.join(timeout=5)
|
||||
return {"output": "".join(_output_chunks), "returncode": proc.returncode}
|
||||
except Exception as e:
|
||||
return {"output": f"Docker execution error: {e}", "returncode": 1}
|
||||
|
||||
def cleanup(self):
|
||||
"""Stop and remove the container. Bind-mount dirs persist if persistent=True."""
|
||||
if self._container_id:
|
||||
|
||||
+69
-283
@@ -1,42 +1,22 @@
|
||||
"""Local execution environment with interrupt support and non-blocking I/O."""
|
||||
"""Local execution environment — spawn-per-call with session snapshot."""
|
||||
|
||||
import glob
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
|
||||
from tools.environments.base import BaseEnvironment, _pipe_stdin
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
|
||||
from tools.environments.base import BaseEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
from tools.interrupt import is_interrupted
|
||||
|
||||
# Unique marker to isolate real command output from shell init/exit noise.
|
||||
# printf (no trailing newline) keeps the boundaries clean for splitting.
|
||||
_OUTPUT_FENCE = "__HERMES_FENCE_a9f7b3__"
|
||||
|
||||
# Hermes-internal env vars that should NOT leak into terminal subprocesses.
|
||||
# These are loaded from ~/.hermes/.env for Hermes' own LLM/provider calls
|
||||
# but can break external CLIs (e.g. codex) that also honor them.
|
||||
# See: https://github.com/NousResearch/hermes-agent/issues/1002
|
||||
#
|
||||
# Built dynamically from the provider registry so new providers are
|
||||
# automatically covered without manual blocklist maintenance.
|
||||
_HERMES_PROVIDER_ENV_FORCE_PREFIX = "_HERMES_FORCE_"
|
||||
|
||||
|
||||
def _build_provider_env_blocklist() -> frozenset:
|
||||
"""Derive the blocklist from provider, tool, and gateway config.
|
||||
|
||||
Automatically picks up api_key_env_vars and base_url_env_var from
|
||||
every registered provider, plus tool/messaging env vars from the
|
||||
optional config registry, so new Hermes-managed secrets are blocked
|
||||
in subprocesses without having to maintain multiple static lists.
|
||||
"""
|
||||
"""Derive the blocklist from provider, tool, and gateway config."""
|
||||
blocked: set[str] = set()
|
||||
|
||||
try:
|
||||
@@ -59,33 +39,30 @@ def _build_provider_env_blocklist() -> frozenset:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Vars not covered above but still Hermes-internal / conflict-prone.
|
||||
blocked.update({
|
||||
"OPENAI_BASE_URL",
|
||||
"OPENAI_API_KEY",
|
||||
"OPENAI_API_BASE", # legacy alias
|
||||
"OPENAI_API_BASE",
|
||||
"OPENAI_ORG_ID",
|
||||
"OPENAI_ORGANIZATION",
|
||||
"OPENROUTER_API_KEY",
|
||||
"ANTHROPIC_BASE_URL",
|
||||
"ANTHROPIC_TOKEN", # OAuth token (not in registry as env var)
|
||||
"ANTHROPIC_TOKEN",
|
||||
"CLAUDE_CODE_OAUTH_TOKEN",
|
||||
"LLM_MODEL",
|
||||
# Expanded isolation for other major providers (Issue #1002)
|
||||
"GOOGLE_API_KEY", # Gemini / Google AI Studio
|
||||
"DEEPSEEK_API_KEY", # DeepSeek
|
||||
"MISTRAL_API_KEY", # Mistral AI
|
||||
"GROQ_API_KEY", # Groq
|
||||
"TOGETHER_API_KEY", # Together AI
|
||||
"PERPLEXITY_API_KEY", # Perplexity
|
||||
"COHERE_API_KEY", # Cohere
|
||||
"FIREWORKS_API_KEY", # Fireworks AI
|
||||
"XAI_API_KEY", # xAI (Grok)
|
||||
"HELICONE_API_KEY", # LLM Observability proxy
|
||||
"GOOGLE_API_KEY",
|
||||
"DEEPSEEK_API_KEY",
|
||||
"MISTRAL_API_KEY",
|
||||
"GROQ_API_KEY",
|
||||
"TOGETHER_API_KEY",
|
||||
"PERPLEXITY_API_KEY",
|
||||
"COHERE_API_KEY",
|
||||
"FIREWORKS_API_KEY",
|
||||
"XAI_API_KEY",
|
||||
"HELICONE_API_KEY",
|
||||
"PARALLEL_API_KEY",
|
||||
"FIRECRAWL_API_KEY",
|
||||
"FIRECRAWL_API_URL",
|
||||
# Gateway/runtime config not represented in OPTIONAL_ENV_VARS.
|
||||
"TELEGRAM_HOME_CHANNEL",
|
||||
"TELEGRAM_HOME_CHANNEL_NAME",
|
||||
"DISCORD_HOME_CHANNEL",
|
||||
@@ -115,12 +92,10 @@ def _build_provider_env_blocklist() -> frozenset:
|
||||
"EMAIL_HOME_ADDRESS",
|
||||
"EMAIL_HOME_ADDRESS_NAME",
|
||||
"GATEWAY_ALLOWED_USERS",
|
||||
# Skills Hub / GitHub app auth paths and aliases.
|
||||
"GH_TOKEN",
|
||||
"GITHUB_APP_ID",
|
||||
"GITHUB_APP_PRIVATE_KEY_PATH",
|
||||
"GITHUB_APP_INSTALLATION_ID",
|
||||
# Remote sandbox backend credentials.
|
||||
"MODAL_TOKEN_ID",
|
||||
"MODAL_TOKEN_SECRET",
|
||||
"DAYTONA_API_KEY",
|
||||
@@ -132,13 +107,7 @@ _HERMES_PROVIDER_ENV_BLOCKLIST = _build_provider_env_blocklist()
|
||||
|
||||
|
||||
def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = None) -> dict:
|
||||
"""Filter Hermes-managed secrets from a subprocess environment.
|
||||
|
||||
`_HERMES_FORCE_<VAR>` entries in ``extra_env`` opt a blocked variable back in
|
||||
intentionally for callers that truly need it. Vars registered via
|
||||
:mod:`tools.env_passthrough` (skill-declared or user-configured) also
|
||||
bypass the blocklist.
|
||||
"""
|
||||
"""Filter Hermes-managed secrets from a subprocess environment."""
|
||||
try:
|
||||
from tools.env_passthrough import is_env_passthrough as _is_passthrough
|
||||
except Exception:
|
||||
@@ -163,33 +132,24 @@ def _sanitize_subprocess_env(base_env: dict | None, extra_env: dict | None = Non
|
||||
|
||||
|
||||
def _find_bash() -> str:
|
||||
"""Find bash for command execution.
|
||||
|
||||
The fence wrapper uses bash syntax (semicolons, $?, printf), so we
|
||||
must use bash — not the user's $SHELL which could be fish/zsh/etc.
|
||||
On Windows: uses Git Bash (bundled with Git for Windows).
|
||||
"""
|
||||
"""Find bash for command execution."""
|
||||
if not _IS_WINDOWS:
|
||||
return (
|
||||
shutil.which("bash")
|
||||
or ("/usr/bin/bash" if os.path.isfile("/usr/bin/bash") else None)
|
||||
or ("/bin/bash" if os.path.isfile("/bin/bash") else None)
|
||||
or os.environ.get("SHELL") # last resort: whatever they have
|
||||
or os.environ.get("SHELL")
|
||||
or "/bin/sh"
|
||||
)
|
||||
|
||||
# Windows: look for Git Bash (installed with Git for Windows).
|
||||
# Allow override via env var (same pattern as Claude Code).
|
||||
custom = os.environ.get("HERMES_GIT_BASH_PATH")
|
||||
if custom and os.path.isfile(custom):
|
||||
return custom
|
||||
|
||||
# shutil.which finds bash.exe if Git\bin is on PATH
|
||||
found = shutil.which("bash")
|
||||
if found:
|
||||
return found
|
||||
|
||||
# Check common Git for Windows install locations
|
||||
for candidate in (
|
||||
os.path.join(os.environ.get("ProgramFiles", r"C:\Program Files"), "Git", "bin", "bash.exe"),
|
||||
os.path.join(os.environ.get("ProgramFiles(x86)", r"C:\Program Files (x86)"), "Git", "bin", "bash.exe"),
|
||||
@@ -209,60 +169,7 @@ def _find_bash() -> str:
|
||||
_find_shell = _find_bash
|
||||
|
||||
|
||||
# Noise lines emitted by interactive shells when stdin is not a terminal.
|
||||
# Used as a fallback when output fence markers are missing.
|
||||
_SHELL_NOISE_SUBSTRINGS = (
|
||||
# bash
|
||||
"bash: cannot set terminal process group",
|
||||
"bash: no job control in this shell",
|
||||
"no job control in this shell",
|
||||
"cannot set terminal process group",
|
||||
"tcsetattr: Inappropriate ioctl for device",
|
||||
# zsh / oh-my-zsh / macOS terminal session
|
||||
"Restored session:",
|
||||
"Saving session...",
|
||||
"Last login:",
|
||||
"command not found:",
|
||||
"Oh My Zsh",
|
||||
"compinit:",
|
||||
)
|
||||
|
||||
|
||||
def _clean_shell_noise(output: str) -> str:
|
||||
"""Strip shell startup/exit warnings that leak when using -i without a TTY.
|
||||
|
||||
Removes lines matching known noise patterns from both the beginning
|
||||
and end of the output. Lines in the middle are left untouched.
|
||||
"""
|
||||
|
||||
def _is_noise(line: str) -> bool:
|
||||
return any(noise in line for noise in _SHELL_NOISE_SUBSTRINGS)
|
||||
|
||||
lines = output.split("\n")
|
||||
|
||||
# Strip leading noise
|
||||
while lines and _is_noise(lines[0]):
|
||||
lines.pop(0)
|
||||
|
||||
# Strip trailing noise (walk backwards, skip empty lines from split)
|
||||
end = len(lines) - 1
|
||||
while end >= 0 and (not lines[end] or _is_noise(lines[end])):
|
||||
end -= 1
|
||||
|
||||
if end < 0:
|
||||
return ""
|
||||
|
||||
cleaned = lines[: end + 1]
|
||||
result = "\n".join(cleaned)
|
||||
|
||||
# Preserve trailing newline if original had one
|
||||
if output.endswith("\n") and result and not result.endswith("\n"):
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
|
||||
# Standard PATH entries for environments with minimal PATH (e.g. systemd services).
|
||||
# Includes macOS Homebrew paths (/opt/homebrew/* for Apple Silicon).
|
||||
# Standard PATH entries for environments with minimal PATH.
|
||||
_SANE_PATH = (
|
||||
"/opt/homebrew/bin:/opt/homebrew/sbin:"
|
||||
"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
@@ -290,197 +197,76 @@ def _make_run_env(env: dict) -> dict:
|
||||
return run_env
|
||||
|
||||
|
||||
def _extract_fenced_output(raw: str) -> str:
|
||||
"""Extract real command output from between fence markers.
|
||||
|
||||
The execute() method wraps each command with printf(FENCE) markers.
|
||||
This function finds the first and last fence and returns only the
|
||||
content between them, which is the actual command output free of
|
||||
any shell init/exit noise.
|
||||
|
||||
Falls back to pattern-based _clean_shell_noise if fences are missing.
|
||||
"""
|
||||
first = raw.find(_OUTPUT_FENCE)
|
||||
if first == -1:
|
||||
return _clean_shell_noise(raw)
|
||||
|
||||
start = first + len(_OUTPUT_FENCE)
|
||||
last = raw.rfind(_OUTPUT_FENCE)
|
||||
|
||||
if last <= first:
|
||||
# Only start fence found (e.g. user command called `exit`)
|
||||
return _clean_shell_noise(raw[start:])
|
||||
|
||||
return raw[start:last]
|
||||
|
||||
|
||||
class LocalEnvironment(PersistentShellMixin, BaseEnvironment):
|
||||
class LocalEnvironment(BaseEnvironment):
|
||||
"""Run commands directly on the host machine.
|
||||
|
||||
Features:
|
||||
- Popen + polling for interrupt support (user can cancel mid-command)
|
||||
- Background stdout drain thread to prevent pipe buffer deadlocks
|
||||
- stdin_data support for piping content (bypasses ARG_MAX limits)
|
||||
- sudo -S transform via SUDO_PASSWORD env var
|
||||
- Uses interactive login shell so full user env is available
|
||||
- Optional persistent shell mode (cwd/env vars survive across calls)
|
||||
Spawn-per-call: every execute() spawns a fresh bash process.
|
||||
Session snapshot preserves env vars across calls.
|
||||
CWD persists via file-based read after each command.
|
||||
"""
|
||||
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None,
|
||||
persistent: bool = False):
|
||||
def __init__(self, cwd: str = "", timeout: int = 60, env: dict = None):
|
||||
super().__init__(cwd=cwd or os.getcwd(), timeout=timeout, env=env)
|
||||
self.persistent = persistent
|
||||
if self.persistent:
|
||||
self._init_persistent_shell()
|
||||
self.init_session()
|
||||
|
||||
@property
|
||||
def _temp_prefix(self) -> str:
|
||||
return f"/tmp/hermes-local-{self._session_id}"
|
||||
|
||||
def _spawn_shell_process(self) -> subprocess.Popen:
|
||||
user_shell = _find_bash()
|
||||
run_env = _make_run_env(self.env)
|
||||
return subprocess.Popen(
|
||||
[user_shell, "-l"],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
env=run_env,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
def _read_temp_files(self, *paths: str) -> list[str]:
|
||||
results = []
|
||||
for path in paths:
|
||||
if os.path.exists(path):
|
||||
with open(path) as f:
|
||||
results.append(f.read())
|
||||
else:
|
||||
results.append("")
|
||||
return results
|
||||
|
||||
def _kill_shell_children(self):
|
||||
if self._shell_pid is None:
|
||||
return
|
||||
try:
|
||||
subprocess.run(
|
||||
["pkill", "-P", str(self._shell_pid)],
|
||||
capture_output=True, timeout=5,
|
||||
)
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
pass
|
||||
|
||||
def _cleanup_temp_files(self):
|
||||
for f in glob.glob(f"{self._temp_prefix}-*"):
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
|
||||
def _execute_oneshot(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
stdin_data: str | None = None) -> dict:
|
||||
work_dir = cwd or self.cwd or os.getcwd()
|
||||
effective_timeout = timeout or self.timeout
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
if sudo_stdin is not None and stdin_data is not None:
|
||||
effective_stdin = sudo_stdin + stdin_data
|
||||
elif sudo_stdin is not None:
|
||||
effective_stdin = sudo_stdin
|
||||
else:
|
||||
effective_stdin = stdin_data
|
||||
|
||||
user_shell = _find_bash()
|
||||
# Newline-separated wrapper (not `cmd; __hermes_rc=...` on one line).
|
||||
# A trailing `; __hermes_rc` glued to `<<EOF` / a closing `EOF` line breaks
|
||||
# heredoc parsing: the delimiter must be alone on its line, otherwise the
|
||||
# rest of this script becomes heredoc body and leaks into stdout (e.g. gh
|
||||
# issue/PR flows that use here-documents for bodies).
|
||||
fenced_cmd = (
|
||||
f"printf '{_OUTPUT_FENCE}'\n"
|
||||
f"{exec_command}\n"
|
||||
f"__hermes_rc=$?\n"
|
||||
f"printf '{_OUTPUT_FENCE}'\n"
|
||||
f"exit $__hermes_rc\n"
|
||||
)
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None) -> subprocess.Popen:
|
||||
bash = _find_bash()
|
||||
args = [bash, "-l", "-c", cmd_string] if login else [bash, "-c", cmd_string]
|
||||
run_env = _make_run_env(self.env)
|
||||
|
||||
proc = subprocess.Popen(
|
||||
[user_shell, "-lic", fenced_cmd],
|
||||
args,
|
||||
text=True,
|
||||
cwd=work_dir,
|
||||
env=run_env,
|
||||
encoding="utf-8",
|
||||
errors="replace",
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
stdin=subprocess.PIPE if effective_stdin is not None else subprocess.DEVNULL,
|
||||
stdin=subprocess.PIPE if stdin_data is not None else subprocess.DEVNULL,
|
||||
preexec_fn=None if _IS_WINDOWS else os.setsid,
|
||||
)
|
||||
|
||||
if effective_stdin is not None:
|
||||
def _write_stdin():
|
||||
if stdin_data is not None:
|
||||
_pipe_stdin(proc, stdin_data)
|
||||
|
||||
return proc
|
||||
|
||||
def _kill_process(self, proc):
|
||||
"""Kill the entire process group (all children)."""
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
try:
|
||||
proc.stdin.write(effective_stdin)
|
||||
proc.stdin.close()
|
||||
except (BrokenPipeError, OSError):
|
||||
pass
|
||||
threading.Thread(target=_write_stdin, daemon=True).start()
|
||||
|
||||
_output_chunks: list[str] = []
|
||||
|
||||
def _drain_stdout():
|
||||
proc.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
_output_chunks.append(line)
|
||||
except ValueError:
|
||||
proc.kill()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
proc.stdout.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
reader = threading.Thread(target=_drain_stdout, daemon=True)
|
||||
reader.start()
|
||||
deadline = time.monotonic() + effective_timeout
|
||||
def _update_cwd(self, result: dict):
|
||||
"""Read CWD from temp file (local-only, no round-trip needed)."""
|
||||
try:
|
||||
cwd_path = open(self._cwd_file).read().strip()
|
||||
if cwd_path:
|
||||
self.cwd = cwd_path
|
||||
except (OSError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
while proc.poll() is None:
|
||||
if is_interrupted():
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
pgid = os.getpgid(proc.pid)
|
||||
os.killpg(pgid, signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=1.0)
|
||||
except subprocess.TimeoutExpired:
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
return {
|
||||
"output": "".join(_output_chunks) + "\n[Command interrupted — user sent a new message]",
|
||||
"returncode": 130,
|
||||
}
|
||||
if time.monotonic() > deadline:
|
||||
try:
|
||||
if _IS_WINDOWS:
|
||||
proc.terminate()
|
||||
else:
|
||||
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
|
||||
except (ProcessLookupError, PermissionError):
|
||||
proc.kill()
|
||||
reader.join(timeout=2)
|
||||
partial = "".join(_output_chunks)
|
||||
timeout_msg = f"\n[Command timed out after {effective_timeout}s]"
|
||||
return {
|
||||
"output": partial + timeout_msg if partial else timeout_msg.lstrip(),
|
||||
"returncode": 124,
|
||||
}
|
||||
time.sleep(0.2)
|
||||
# Still strip the marker from output so it's not visible
|
||||
self._extract_cwd_from_output(result)
|
||||
|
||||
reader.join(timeout=5)
|
||||
output = _extract_fenced_output("".join(_output_chunks))
|
||||
return {"output": output, "returncode": proc.returncode}
|
||||
def cleanup(self):
|
||||
"""Clean up temp files."""
|
||||
for f in (self._snapshot_path, self._cwd_file):
|
||||
try:
|
||||
os.unlink(f)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@@ -10,7 +10,7 @@ import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from tools.environments.modal_common import (
|
||||
from tools.environments.modal_utils import (
|
||||
BaseModalExecutionEnvironment,
|
||||
ModalExecStart,
|
||||
PreparedModalExec,
|
||||
@@ -185,7 +185,7 @@ class ManagedModalEnvironment(BaseModalExecutionEnvironment):
|
||||
"cwd": self.cwd,
|
||||
"cpu": cpu,
|
||||
"memoryMiB": memory,
|
||||
"timeoutMs": 3_600_000,
|
||||
"timeoutMs": int(self._sandbox_kwargs.get("timeout", 3600)) * 1000,
|
||||
"idleTimeoutMs": max(300_000, int(self.timeout * 1000)),
|
||||
"persistentFilesystem": self._persistent,
|
||||
"logicalKey": self._task_id,
|
||||
|
||||
+81
-153
@@ -5,19 +5,19 @@ wrapper, while preserving Hermes' persistent snapshot behavior across sessions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shlex
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from tools.environments.modal_common import (
|
||||
BaseModalExecutionEnvironment,
|
||||
ModalExecStart,
|
||||
PreparedModalExec,
|
||||
from tools.environments.base import (
|
||||
BaseEnvironment,
|
||||
_ThreadedProcessHandle,
|
||||
_file_mtime_key,
|
||||
_load_json_store,
|
||||
_save_json_store,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -26,20 +26,12 @@ _SNAPSHOT_STORE = get_hermes_home() / "modal_snapshots.json"
|
||||
_DIRECT_SNAPSHOT_NAMESPACE = "direct"
|
||||
|
||||
|
||||
def _load_snapshots() -> Dict[str, str]:
|
||||
"""Load snapshot ID mapping from disk."""
|
||||
if _SNAPSHOT_STORE.exists():
|
||||
try:
|
||||
return json.loads(_SNAPSHOT_STORE.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
def _load_snapshots() -> dict:
|
||||
return _load_json_store(_SNAPSHOT_STORE)
|
||||
|
||||
|
||||
def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
"""Persist snapshot ID mapping to disk."""
|
||||
_SNAPSHOT_STORE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_SNAPSHOT_STORE.write_text(json.dumps(data, indent=2))
|
||||
def _save_snapshots(data: dict) -> None:
|
||||
_save_json_store(_SNAPSHOT_STORE, data)
|
||||
|
||||
|
||||
def _direct_snapshot_key(task_id: str) -> str:
|
||||
@@ -47,23 +39,18 @@ def _direct_snapshot_key(task_id: str) -> str:
|
||||
|
||||
|
||||
def _get_snapshot_restore_candidate(task_id: str) -> tuple[str | None, bool]:
|
||||
"""Return a snapshot id and whether it came from the legacy key format."""
|
||||
snapshots = _load_snapshots()
|
||||
|
||||
namespaced_key = _direct_snapshot_key(task_id)
|
||||
snapshot_id = snapshots.get(namespaced_key)
|
||||
if isinstance(snapshot_id, str) and snapshot_id:
|
||||
return snapshot_id, False
|
||||
|
||||
legacy_snapshot_id = snapshots.get(task_id)
|
||||
if isinstance(legacy_snapshot_id, str) and legacy_snapshot_id:
|
||||
return legacy_snapshot_id, True
|
||||
|
||||
return None, False
|
||||
|
||||
|
||||
def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None:
|
||||
"""Persist the direct Modal snapshot id under the direct namespace."""
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[_direct_snapshot_key(task_id)] = snapshot_id
|
||||
snapshots.pop(task_id, None)
|
||||
@@ -71,10 +58,8 @@ def _store_direct_snapshot(task_id: str, snapshot_id: str) -> None:
|
||||
|
||||
|
||||
def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> None:
|
||||
"""Remove direct Modal snapshot entries for a task, including legacy keys."""
|
||||
snapshots = _load_snapshots()
|
||||
updated = False
|
||||
|
||||
for key in (_direct_snapshot_key(task_id), task_id):
|
||||
value = snapshots.get(key)
|
||||
if value is None:
|
||||
@@ -82,13 +67,15 @@ def _delete_direct_snapshot(task_id: str, snapshot_id: str | None = None) -> Non
|
||||
if snapshot_id is None or value == snapshot_id:
|
||||
snapshots.pop(key, None)
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
_save_snapshots(snapshots)
|
||||
|
||||
|
||||
def _resolve_modal_image(image_spec: Any) -> Any:
|
||||
"""Convert registry references or snapshot ids into Modal image objects."""
|
||||
"""Convert registry references or snapshot ids into Modal image objects.
|
||||
|
||||
Includes add_python support for ubuntu/debian images (absorbed from PR 4511).
|
||||
"""
|
||||
import modal as _modal
|
||||
|
||||
if not isinstance(image_spec, str):
|
||||
@@ -97,12 +84,22 @@ def _resolve_modal_image(image_spec: Any) -> Any:
|
||||
if image_spec.startswith("im-"):
|
||||
return _modal.Image.from_id(image_spec)
|
||||
|
||||
# PR 4511: add python to ubuntu/debian images that don't have it
|
||||
lower = image_spec.lower()
|
||||
add_python = any(base in lower for base in ("ubuntu", "debian"))
|
||||
|
||||
setup_commands = [
|
||||
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
||||
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
||||
]
|
||||
if add_python:
|
||||
setup_commands.insert(0,
|
||||
"RUN apt-get update -qq && apt-get install -y -qq python3 python3-venv > /dev/null 2>&1 || true"
|
||||
)
|
||||
|
||||
return _modal.Image.from_registry(
|
||||
image_spec,
|
||||
setup_dockerfile_commands=[
|
||||
"RUN rm -rf /usr/local/lib/python*/site-packages/pip* 2>/dev/null; "
|
||||
"python -m ensurepip --upgrade --default-pip 2>/dev/null || true",
|
||||
],
|
||||
setup_dockerfile_commands=setup_commands,
|
||||
)
|
||||
|
||||
|
||||
@@ -138,19 +135,15 @@ class _AsyncWorker:
|
||||
self._thread.join(timeout=10)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DirectModalExecHandle:
|
||||
thread: threading.Thread
|
||||
result_holder: Dict[str, Any]
|
||||
class ModalEnvironment(BaseEnvironment):
|
||||
"""Modal cloud execution via native Modal sandboxes.
|
||||
|
||||
|
||||
class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
"""Modal cloud execution via native Modal sandboxes."""
|
||||
Spawn-per-call via _ThreadedProcessHandle wrapping async SDK calls.
|
||||
cancel_fn wired to sandbox.terminate for interrupt support.
|
||||
"""
|
||||
|
||||
_stdin_mode = "heredoc"
|
||||
_poll_interval_seconds = 0.2
|
||||
_interrupt_output = "[Command interrupted - Modal sandbox terminated]"
|
||||
_unexpected_error_prefix = "Modal execution error"
|
||||
_snapshot_timeout = 60 # Modal cold starts can be slow
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -160,6 +153,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
modal_sandbox_kwargs: Optional[Dict[str, Any]] = None,
|
||||
persistent_filesystem: bool = True,
|
||||
task_id: str = "default",
|
||||
add_python: Optional[str] = None,
|
||||
):
|
||||
super().__init__(cwd=cwd, timeout=timeout)
|
||||
|
||||
@@ -170,6 +164,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
self._app = None
|
||||
self._worker = _AsyncWorker()
|
||||
self._synced_files: Dict[str, tuple] = {}
|
||||
self._last_sync_time: float = 0
|
||||
|
||||
sandbox_kwargs = dict(modal_sandbox_kwargs or {})
|
||||
|
||||
@@ -199,27 +194,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
remote_path=mount_entry["container_path"],
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Modal: mounting credential %s -> %s",
|
||||
mount_entry["host_path"],
|
||||
mount_entry["container_path"],
|
||||
)
|
||||
|
||||
# Mount individual skill files (symlinks filtered out).
|
||||
skills_files = iter_skills_files()
|
||||
for entry in skills_files:
|
||||
for entry in iter_skills_files():
|
||||
cred_mounts.append(
|
||||
_modal.Mount.from_local_file(
|
||||
entry["host_path"],
|
||||
remote_path=entry["container_path"],
|
||||
)
|
||||
)
|
||||
if skills_files:
|
||||
logger.info("Modal: mounting %d skill files", len(skills_files))
|
||||
|
||||
# Mount host-side cache files (documents, images, audio,
|
||||
# screenshots). New files arriving mid-session are picked up
|
||||
# by _sync_files() before each command execution.
|
||||
cache_files = iter_cache_files()
|
||||
for entry in cache_files:
|
||||
cred_mounts.append(
|
||||
@@ -228,8 +209,6 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
remote_path=entry["container_path"],
|
||||
)
|
||||
)
|
||||
if cache_files:
|
||||
logger.info("Modal: mounting %d cache files", len(cache_files))
|
||||
except Exception as e:
|
||||
logger.debug("Modal: could not load credential file mounts: %s", e)
|
||||
|
||||
@@ -243,8 +222,7 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
existing_mounts.extend(cred_mounts)
|
||||
create_kwargs["mounts"] = existing_mounts
|
||||
sandbox = await _modal.Sandbox.create.aio(
|
||||
"sleep",
|
||||
"infinity",
|
||||
"sleep", "infinity",
|
||||
image=image_spec,
|
||||
app=app,
|
||||
timeout=int(create_kwargs.pop("timeout", 3600)),
|
||||
@@ -255,57 +233,41 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
try:
|
||||
target_image_spec = restored_snapshot_id or image
|
||||
try:
|
||||
# _resolve_modal_image keeps the Modal bootstrap fix together:
|
||||
# it applies setup_dockerfile_commands with ensurepip before
|
||||
# Modal builds registry images, while snapshot ids restore via
|
||||
# modal.Image.from_id() without rebuilding.
|
||||
effective_image = _resolve_modal_image(target_image_spec)
|
||||
self._app, self._sandbox = self._worker.run_coroutine(
|
||||
_create_sandbox(effective_image),
|
||||
timeout=300,
|
||||
_create_sandbox(effective_image), timeout=300,
|
||||
)
|
||||
except Exception as exc:
|
||||
if not restored_snapshot_id:
|
||||
raise
|
||||
|
||||
logger.warning(
|
||||
"Modal: failed to restore snapshot %s, retrying with base image: %s",
|
||||
restored_snapshot_id[:20],
|
||||
exc,
|
||||
restored_snapshot_id[:20], exc,
|
||||
)
|
||||
_delete_direct_snapshot(self._task_id, restored_snapshot_id)
|
||||
base_image = _resolve_modal_image(image)
|
||||
self._app, self._sandbox = self._worker.run_coroutine(
|
||||
_create_sandbox(base_image),
|
||||
timeout=300,
|
||||
_create_sandbox(base_image), timeout=300,
|
||||
)
|
||||
else:
|
||||
if restored_snapshot_id and restored_from_legacy_key:
|
||||
_store_direct_snapshot(self._task_id, restored_snapshot_id)
|
||||
logger.info(
|
||||
"Modal: migrated legacy snapshot entry for task %s",
|
||||
self._task_id,
|
||||
)
|
||||
except Exception:
|
||||
self._worker.stop()
|
||||
raise
|
||||
|
||||
logger.info("Modal: sandbox created (task=%s)", self._task_id)
|
||||
self.init_session()
|
||||
|
||||
def _push_file_to_sandbox(self, host_path: str, container_path: str) -> bool:
|
||||
"""Push a single file into the sandbox if changed. Returns True if synced."""
|
||||
hp = Path(host_path)
|
||||
try:
|
||||
stat = hp.stat()
|
||||
file_key = (stat.st_mtime, stat.st_size)
|
||||
except OSError:
|
||||
"""Push a single file into the sandbox if changed."""
|
||||
file_key = _file_mtime_key(host_path)
|
||||
if file_key is None:
|
||||
return False
|
||||
|
||||
if self._synced_files.get(container_path) == file_key:
|
||||
return False
|
||||
|
||||
try:
|
||||
content = hp.read_bytes()
|
||||
content = Path(host_path).read_bytes()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -326,85 +288,55 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
return True
|
||||
|
||||
def _sync_files(self) -> None:
|
||||
"""Push credential, skill, and cache files into the running sandbox.
|
||||
|
||||
Runs before each command. Uses mtime+size caching so only changed
|
||||
files are pushed (~13μs overhead in the no-op case). Cache files
|
||||
are especially important here — new uploads/screenshots may appear
|
||||
mid-session after sandbox creation.
|
||||
"""
|
||||
"""Push credential, skill, and cache files into the running sandbox."""
|
||||
try:
|
||||
from tools.credential_files import (
|
||||
get_credential_file_mounts,
|
||||
iter_skills_files,
|
||||
iter_cache_files,
|
||||
)
|
||||
|
||||
for entry in get_credential_file_mounts():
|
||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Modal: synced credential %s", entry["container_path"])
|
||||
|
||||
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||
for entry in iter_skills_files():
|
||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Modal: synced skill file %s", entry["container_path"])
|
||||
|
||||
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||
for entry in iter_cache_files():
|
||||
if self._push_file_to_sandbox(entry["host_path"], entry["container_path"]):
|
||||
logger.debug("Modal: synced cache file %s", entry["container_path"])
|
||||
self._push_file_to_sandbox(entry["host_path"], entry["container_path"])
|
||||
except Exception as e:
|
||||
logger.debug("Modal: file sync failed: %s", e)
|
||||
|
||||
def _before_execute(self) -> None:
|
||||
self._sync_files()
|
||||
def _run_bash(self, cmd_string: str, *, login: bool = False,
|
||||
timeout: int = 120,
|
||||
stdin_data: str | None = None):
|
||||
"""Return a _ThreadedProcessHandle wrapping an async Modal sandbox exec."""
|
||||
sandbox = self._sandbox
|
||||
worker = self._worker
|
||||
|
||||
def _start_modal_exec(self, prepared: PreparedModalExec) -> ModalExecStart:
|
||||
full_command = f"cd {shlex.quote(prepared.cwd)} && {prepared.command}"
|
||||
result_holder = {"value": None, "error": None}
|
||||
def cancel():
|
||||
worker.run_coroutine(sandbox.terminate.aio(), timeout=15)
|
||||
|
||||
def _run():
|
||||
try:
|
||||
async def _do_execute():
|
||||
process = await self._sandbox.exec.aio(
|
||||
"bash",
|
||||
"-c",
|
||||
full_command,
|
||||
timeout=prepared.timeout,
|
||||
)
|
||||
stdout = await process.stdout.read.aio()
|
||||
stderr = await process.stderr.read.aio()
|
||||
exit_code = await process.wait.aio()
|
||||
if isinstance(stdout, bytes):
|
||||
stdout = stdout.decode("utf-8", errors="replace")
|
||||
if isinstance(stderr, bytes):
|
||||
stderr = stderr.decode("utf-8", errors="replace")
|
||||
output = stdout
|
||||
if stderr:
|
||||
output = f"{stdout}\n{stderr}" if stdout else stderr
|
||||
return self._result(output, exit_code)
|
||||
def exec_fn() -> tuple[str, int]:
|
||||
async def _do():
|
||||
args = ["bash"]
|
||||
if login:
|
||||
args.extend(["-l", "-c", cmd_string])
|
||||
else:
|
||||
args.extend(["-c", cmd_string])
|
||||
process = await sandbox.exec.aio(*args, timeout=timeout)
|
||||
stdout = await process.stdout.read.aio()
|
||||
stderr = await process.stderr.read.aio()
|
||||
exit_code = await process.wait.aio()
|
||||
if isinstance(stdout, bytes):
|
||||
stdout = stdout.decode("utf-8", errors="replace")
|
||||
if isinstance(stderr, bytes):
|
||||
stderr = stderr.decode("utf-8", errors="replace")
|
||||
output = stdout
|
||||
if stderr:
|
||||
output = f"{stdout}\n{stderr}" if stdout else stderr
|
||||
return output, exit_code
|
||||
|
||||
result_holder["value"] = self._worker.run_coroutine(
|
||||
_do_execute(),
|
||||
timeout=prepared.timeout + 30,
|
||||
)
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
return worker.run_coroutine(_do(), timeout=timeout + 30)
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
return ModalExecStart(handle=_DirectModalExecHandle(thread=t, result_holder=result_holder))
|
||||
|
||||
def _poll_modal_exec(self, handle: _DirectModalExecHandle) -> dict | None:
|
||||
if handle.thread.is_alive():
|
||||
return None
|
||||
if handle.result_holder["error"]:
|
||||
return self._error_result(f"Modal execution error: {handle.result_holder['error']}")
|
||||
return handle.result_holder["value"]
|
||||
|
||||
def _cancel_modal_exec(self, handle: _DirectModalExecHandle) -> None:
|
||||
self._worker.run_coroutine(
|
||||
self._sandbox.terminate.aio(),
|
||||
timeout=15,
|
||||
)
|
||||
return _ThreadedProcessHandle(exec_fn, cancel_fn=cancel)
|
||||
|
||||
def cleanup(self):
|
||||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||
@@ -426,17 +358,13 @@ class ModalEnvironment(BaseModalExecutionEnvironment):
|
||||
_store_direct_snapshot(self._task_id, snapshot_id)
|
||||
logger.info(
|
||||
"Modal: saved filesystem snapshot %s for task %s",
|
||||
snapshot_id[:20],
|
||||
self._task_id,
|
||||
snapshot_id[:20], self._task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
||||
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
self._sandbox.terminate.aio(),
|
||||
timeout=15,
|
||||
)
|
||||
self._worker.run_coroutine(self._sandbox.terminate.aio(), timeout=15)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
|
||||
@@ -56,7 +56,15 @@ def wrap_modal_sudo_pipe(command: str, sudo_stdin: str) -> str:
|
||||
|
||||
|
||||
class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||
"""Common execute() flow for direct and managed Modal transports."""
|
||||
"""Execution flow for the *managed* Modal transport (gateway-owned sandbox).
|
||||
|
||||
This deliberately overrides :meth:`BaseEnvironment.execute` because the
|
||||
tool-gateway handles command preparation, CWD tracking, and env-snapshot
|
||||
management on the server side. The base class's ``_wrap_command`` /
|
||||
``_wait_for_process`` / snapshot machinery does not apply here — the
|
||||
gateway owns that responsibility. See ``ManagedModalEnvironment`` for the
|
||||
concrete subclass.
|
||||
"""
|
||||
|
||||
_stdin_mode = "payload"
|
||||
_poll_interval_seconds = 0.25
|
||||
@@ -124,7 +132,7 @@ class BaseModalExecutionEnvironment(BaseEnvironment):
|
||||
|
||||
def _before_execute(self) -> None:
|
||||
"""Hook for backends that need pre-exec sync or validation."""
|
||||
return None
|
||||
pass
|
||||
|
||||
def _prepare_modal_exec(
|
||||
self,
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user