Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b54591ddda |
@@ -235,7 +235,6 @@ hermes_cli/skin_engine.py # SkinConfig dataclass, built-in skins, YAML loader
|
||||
| Spinner verbs | `spinner.thinking_verbs` | `display.py` |
|
||||
| Spinner wings (optional) | `spinner.wings` | `display.py` |
|
||||
| Tool output prefix | `tool_prefix` | `display.py` |
|
||||
| Per-tool emojis | `tool_emojis` | `display.py` → `get_tool_emoji()` |
|
||||
| Agent name | `branding.agent_name` | `banner.py`, `cli.py` |
|
||||
| Welcome message | `branding.welcome` | `cli.py` |
|
||||
| Response box label | `branding.response_label` | `cli.py` |
|
||||
|
||||
@@ -62,24 +62,6 @@ hermes doctor # Diagnose any issues
|
||||
|
||||
📖 **[Full documentation →](https://hermes-agent.nousresearch.com/docs/)**
|
||||
|
||||
## CLI vs Messaging Quick Reference
|
||||
|
||||
Hermes has two entry points: start the terminal UI with `hermes`, or run the gateway and talk to it from Telegram, Discord, Slack, WhatsApp, Signal, or Email. Once you're in a conversation, many slash commands are shared across both interfaces.
|
||||
|
||||
| Action | CLI | Messaging platforms |
|
||||
|---------|-----|---------------------|
|
||||
| Start chatting | `hermes` | Run `hermes gateway setup` + `hermes gateway start`, then send the bot a message |
|
||||
| Start fresh conversation | `/new` or `/reset` | `/new` or `/reset` |
|
||||
| Change model | `/model [provider:model]` | `/model [provider:model]` |
|
||||
| Set a personality | `/personality [name]` | `/personality [name]` |
|
||||
| Retry or undo the last turn | `/retry`, `/undo` | `/retry`, `/undo` |
|
||||
| Compress context / check usage | `/compress`, `/usage`, `/insights [--days N]` | `/compress`, `/usage`, `/insights [days]` |
|
||||
| Browse skills | `/skills` or `/<skill-name>` | `/skills` or `/<skill-name>` |
|
||||
| Interrupt current work | `Ctrl+C` or send a new message | `/stop` or send a new message |
|
||||
| Platform-specific status | `/platforms` | `/status`, `/sethome` |
|
||||
|
||||
For the full command lists, see the [CLI guide](https://hermes-agent.nousresearch.com/docs/user-guide/cli) and the [Messaging Gateway guide](https://hermes-agent.nousresearch.com/docs/user-guide/messaging).
|
||||
|
||||
---
|
||||
|
||||
## Documentation
|
||||
|
||||
+5
-151
@@ -42,7 +42,7 @@ from acp_adapter.events import (
|
||||
make_tool_progress_cb,
|
||||
)
|
||||
from acp_adapter.permissions import make_approval_callback
|
||||
from acp_adapter.session import SessionManager, SessionState
|
||||
from acp_adapter.session import SessionManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -226,19 +226,10 @@ class HermesACPAgent(acp.Agent):
|
||||
logger.error("prompt: session %s not found", session_id)
|
||||
return PromptResponse(stop_reason="refusal")
|
||||
|
||||
user_text = _extract_text(prompt).strip()
|
||||
if not user_text:
|
||||
user_text = _extract_text(prompt)
|
||||
if not user_text.strip():
|
||||
return PromptResponse(stop_reason="end_turn")
|
||||
|
||||
# Intercept slash commands — handle locally without calling the LLM
|
||||
if user_text.startswith("/"):
|
||||
response_text = self._handle_slash_command(user_text, state)
|
||||
if response_text is not None:
|
||||
if self._conn:
|
||||
update = acp.update_agent_message_text(response_text)
|
||||
await self._conn.session_update(session_id, update)
|
||||
return PromptResponse(stop_reason="end_turn")
|
||||
|
||||
logger.info("Prompt on session %s: %s", session_id, user_text[:100])
|
||||
|
||||
conn = self._conn
|
||||
@@ -324,149 +315,12 @@ class HermesACPAgent(acp.Agent):
|
||||
stop_reason = "cancelled" if state.cancel_event and state.cancel_event.is_set() else "end_turn"
|
||||
return PromptResponse(stop_reason=stop_reason, usage=usage)
|
||||
|
||||
# ---- Slash commands (headless) -------------------------------------------
|
||||
|
||||
_SLASH_COMMANDS = {
|
||||
"help": "Show available commands",
|
||||
"model": "Show or change current model",
|
||||
"tools": "List available tools",
|
||||
"context": "Show conversation context info",
|
||||
"reset": "Clear conversation history",
|
||||
"compact": "Compress conversation context",
|
||||
"version": "Show Hermes version",
|
||||
}
|
||||
|
||||
def _handle_slash_command(self, text: str, state: SessionState) -> str | None:
|
||||
"""Dispatch a slash command and return the response text.
|
||||
|
||||
Returns ``None`` for unrecognized commands so they fall through
|
||||
to the LLM (the user may have typed ``/something`` as prose).
|
||||
"""
|
||||
parts = text.split(maxsplit=1)
|
||||
cmd = parts[0].lstrip("/").lower()
|
||||
args = parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
handler = {
|
||||
"help": self._cmd_help,
|
||||
"model": self._cmd_model,
|
||||
"tools": self._cmd_tools,
|
||||
"context": self._cmd_context,
|
||||
"reset": self._cmd_reset,
|
||||
"compact": self._cmd_compact,
|
||||
"version": self._cmd_version,
|
||||
}.get(cmd)
|
||||
|
||||
if handler is None:
|
||||
return None # not a known command — let the LLM handle it
|
||||
|
||||
try:
|
||||
return handler(args, state)
|
||||
except Exception as e:
|
||||
logger.error("Slash command /%s error: %s", cmd, e, exc_info=True)
|
||||
return f"Error executing /{cmd}: {e}"
|
||||
|
||||
def _cmd_help(self, args: str, state: SessionState) -> str:
|
||||
lines = ["Available commands:", ""]
|
||||
for cmd, desc in self._SLASH_COMMANDS.items():
|
||||
lines.append(f" /{cmd:10s} {desc}")
|
||||
lines.append("")
|
||||
lines.append("Unrecognized /commands are sent to the model as normal messages.")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _cmd_model(self, args: str, state: SessionState) -> str:
|
||||
if not args:
|
||||
model = state.model or getattr(state.agent, "model", "unknown")
|
||||
provider = getattr(state.agent, "provider", None) or "auto"
|
||||
return f"Current model: {model}\nProvider: {provider}"
|
||||
|
||||
new_model = args.strip()
|
||||
target_provider = None
|
||||
|
||||
# Auto-detect provider for the requested model
|
||||
try:
|
||||
from hermes_cli.models import parse_model_input, detect_provider_for_model
|
||||
current_provider = getattr(state.agent, "provider", None) or "openrouter"
|
||||
target_provider, new_model = parse_model_input(new_model, current_provider)
|
||||
if target_provider == current_provider:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
except Exception:
|
||||
logger.debug("Provider detection failed, using model as-is", exc_info=True)
|
||||
|
||||
state.model = new_model
|
||||
state.agent = self.session_manager._make_agent(
|
||||
session_id=state.session_id,
|
||||
cwd=state.cwd,
|
||||
model=new_model,
|
||||
)
|
||||
provider_label = target_provider or getattr(state.agent, "provider", "auto")
|
||||
logger.info("Session %s: model switched to %s", state.session_id, new_model)
|
||||
return f"Model switched to: {new_model}\nProvider: {provider_label}"
|
||||
|
||||
def _cmd_tools(self, args: str, state: SessionState) -> str:
|
||||
try:
|
||||
from model_tools import get_tool_definitions
|
||||
toolsets = getattr(state.agent, "enabled_toolsets", None) or ["hermes-acp"]
|
||||
tools = get_tool_definitions(enabled_toolsets=toolsets, quiet_mode=True)
|
||||
if not tools:
|
||||
return "No tools available."
|
||||
lines = [f"Available tools ({len(tools)}):"]
|
||||
for t in tools:
|
||||
name = t.get("function", {}).get("name", "?")
|
||||
desc = t.get("function", {}).get("description", "")
|
||||
# Truncate long descriptions
|
||||
if len(desc) > 80:
|
||||
desc = desc[:77] + "..."
|
||||
lines.append(f" {name}: {desc}")
|
||||
return "\n".join(lines)
|
||||
except Exception as e:
|
||||
return f"Could not list tools: {e}"
|
||||
|
||||
def _cmd_context(self, args: str, state: SessionState) -> str:
|
||||
n_messages = len(state.history)
|
||||
if n_messages == 0:
|
||||
return "Conversation is empty (no messages yet)."
|
||||
# Count by role
|
||||
roles: dict[str, int] = {}
|
||||
for msg in state.history:
|
||||
role = msg.get("role", "unknown")
|
||||
roles[role] = roles.get(role, 0) + 1
|
||||
lines = [
|
||||
f"Conversation: {n_messages} messages",
|
||||
f" user: {roles.get('user', 0)}, assistant: {roles.get('assistant', 0)}, "
|
||||
f"tool: {roles.get('tool', 0)}, system: {roles.get('system', 0)}",
|
||||
]
|
||||
model = state.model or getattr(state.agent, "model", "")
|
||||
if model:
|
||||
lines.append(f"Model: {model}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def _cmd_reset(self, args: str, state: SessionState) -> str:
|
||||
state.history.clear()
|
||||
return "Conversation history cleared."
|
||||
|
||||
def _cmd_compact(self, args: str, state: SessionState) -> str:
|
||||
if not state.history:
|
||||
return "Nothing to compress — conversation is empty."
|
||||
try:
|
||||
agent = state.agent
|
||||
if hasattr(agent, "compress_context"):
|
||||
agent.compress_context(state.history)
|
||||
return f"Context compressed. Messages: {len(state.history)}"
|
||||
return "Context compression not available for this agent."
|
||||
except Exception as e:
|
||||
return f"Compression failed: {e}"
|
||||
|
||||
def _cmd_version(self, args: str, state: SessionState) -> str:
|
||||
return f"Hermes Agent v{HERMES_VERSION}"
|
||||
|
||||
# ---- Model switching (ACP protocol method) -------------------------------
|
||||
# ---- Model switching ----------------------------------------------------
|
||||
|
||||
async def set_session_model(
|
||||
self, model_id: str, session_id: str, **kwargs: Any
|
||||
):
|
||||
"""Switch the model for a session (called by ACP protocol)."""
|
||||
"""Switch the model for a session."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state:
|
||||
state.model = model_id
|
||||
|
||||
@@ -45,19 +45,14 @@ _COMMON_BETAS = [
|
||||
"fine-grained-tool-streaming-2025-05-14",
|
||||
]
|
||||
|
||||
# Additional beta headers required for OAuth/subscription auth.
|
||||
# Matches what Claude Code (and pi-ai / OpenCode) send.
|
||||
# Additional beta headers required for OAuth/subscription auth
|
||||
# Both clawdbot and OpenCode include claude-code-20250219 alongside oauth-2025-04-20.
|
||||
# Without claude-code-20250219, Anthropic's API rejects OAuth tokens with 401.
|
||||
_OAUTH_ONLY_BETAS = [
|
||||
"claude-code-20250219",
|
||||
"oauth-2025-04-20",
|
||||
]
|
||||
|
||||
# Claude Code identity — required for OAuth requests to be routed correctly.
|
||||
# Without these, Anthropic's infrastructure intermittently 500s OAuth traffic.
|
||||
_CLAUDE_CODE_VERSION = "2.1.2"
|
||||
_CLAUDE_CODE_SYSTEM_PREFIX = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
_MCP_TOOL_PREFIX = "mcp_"
|
||||
|
||||
|
||||
def _is_oauth_token(key: str) -> bool:
|
||||
"""Check if the key is an OAuth/setup token (not a regular Console API key).
|
||||
@@ -93,16 +88,10 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
kwargs["base_url"] = base_url
|
||||
|
||||
if _is_oauth_token(api_key):
|
||||
# OAuth access token / setup-token → Bearer auth + Claude Code identity.
|
||||
# Anthropic routes OAuth requests based on user-agent and headers;
|
||||
# without Claude Code's fingerprint, requests get intermittent 500s.
|
||||
# OAuth access token / setup-token → Bearer auth + beta headers
|
||||
all_betas = _COMMON_BETAS + _OAUTH_ONLY_BETAS
|
||||
kwargs["auth_token"] = api_key
|
||||
kwargs["default_headers"] = {
|
||||
"anthropic-beta": ",".join(all_betas),
|
||||
"user-agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
"x-app": "cli",
|
||||
}
|
||||
kwargs["default_headers"] = {"anthropic-beta": ",".join(all_betas)}
|
||||
else:
|
||||
# Regular API key → x-api-key header + common betas
|
||||
kwargs["api_key"] = api_key
|
||||
@@ -725,59 +714,14 @@ def build_anthropic_kwargs(
|
||||
max_tokens: Optional[int],
|
||||
reasoning_config: Optional[Dict[str, Any]],
|
||||
tool_choice: Optional[str] = None,
|
||||
is_oauth: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs for anthropic.messages.create().
|
||||
|
||||
When *is_oauth* is True, applies Claude Code compatibility transforms:
|
||||
system prompt prefix, tool name prefixing, and prompt sanitization.
|
||||
"""
|
||||
"""Build kwargs for anthropic.messages.create()."""
|
||||
system, anthropic_messages = convert_messages_to_anthropic(messages)
|
||||
anthropic_tools = convert_tools_to_anthropic(tools) if tools else []
|
||||
|
||||
model = normalize_model_name(model)
|
||||
effective_max_tokens = max_tokens or 16384
|
||||
|
||||
# ── OAuth: Claude Code identity ──────────────────────────────────
|
||||
if is_oauth:
|
||||
# 1. Prepend Claude Code system prompt identity
|
||||
cc_block = {"type": "text", "text": _CLAUDE_CODE_SYSTEM_PREFIX}
|
||||
if isinstance(system, list):
|
||||
system = [cc_block] + system
|
||||
elif isinstance(system, str) and system:
|
||||
system = [cc_block, {"type": "text", "text": system}]
|
||||
else:
|
||||
system = [cc_block]
|
||||
|
||||
# 2. Sanitize system prompt — replace product name references
|
||||
# to avoid Anthropic's server-side content filters.
|
||||
for block in system:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text = block.get("text", "")
|
||||
text = text.replace("Hermes Agent", "Claude Code")
|
||||
text = text.replace("Hermes agent", "Claude Code")
|
||||
text = text.replace("hermes-agent", "claude-code")
|
||||
text = text.replace("Nous Research", "Anthropic")
|
||||
block["text"] = text
|
||||
|
||||
# 3. Prefix tool names with mcp_ (Claude Code convention)
|
||||
if anthropic_tools:
|
||||
for tool in anthropic_tools:
|
||||
if "name" in tool:
|
||||
tool["name"] = _MCP_TOOL_PREFIX + tool["name"]
|
||||
|
||||
# 4. Prefix tool names in message history (tool_use and tool_result blocks)
|
||||
for msg in anthropic_messages:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "tool_use" and "name" in block:
|
||||
if not block["name"].startswith(_MCP_TOOL_PREFIX):
|
||||
block["name"] = _MCP_TOOL_PREFIX + block["name"]
|
||||
elif block.get("type") == "tool_result" and "tool_use_id" in block:
|
||||
pass # tool_result uses ID, not name
|
||||
|
||||
kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": anthropic_messages,
|
||||
@@ -824,15 +768,11 @@ def build_anthropic_kwargs(
|
||||
|
||||
def normalize_anthropic_response(
|
||||
response,
|
||||
strip_tool_prefix: bool = False,
|
||||
) -> Tuple[SimpleNamespace, str]:
|
||||
"""Normalize Anthropic response to match the shape expected by AIAgent.
|
||||
|
||||
Returns (assistant_message, finish_reason) where assistant_message has
|
||||
.content, .tool_calls, and .reasoning attributes.
|
||||
|
||||
When *strip_tool_prefix* is True, removes the ``mcp_`` prefix that was
|
||||
added to tool names for OAuth Claude Code compatibility.
|
||||
"""
|
||||
text_parts = []
|
||||
reasoning_parts = []
|
||||
@@ -844,15 +784,12 @@ def normalize_anthropic_response(
|
||||
elif block.type == "thinking":
|
||||
reasoning_parts.append(block.thinking)
|
||||
elif block.type == "tool_use":
|
||||
name = block.name
|
||||
if strip_tool_prefix and name.startswith(_MCP_TOOL_PREFIX):
|
||||
name = name[len(_MCP_TOOL_PREFIX):]
|
||||
tool_calls.append(
|
||||
SimpleNamespace(
|
||||
id=block.id,
|
||||
type="function",
|
||||
function=SimpleNamespace(
|
||||
name=name,
|
||||
name=block.name,
|
||||
arguments=json.dumps(block.input),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -59,32 +59,6 @@ def get_skin_tool_prefix() -> str:
|
||||
return "┊"
|
||||
|
||||
|
||||
def get_tool_emoji(tool_name: str, default: str = "⚡") -> str:
|
||||
"""Get the display emoji for a tool.
|
||||
|
||||
Resolution order:
|
||||
1. Active skin's ``tool_emojis`` overrides (if a skin is loaded)
|
||||
2. Tool registry's per-tool ``emoji`` field
|
||||
3. *default* fallback
|
||||
"""
|
||||
# 1. Skin override
|
||||
skin = _get_skin()
|
||||
if skin and skin.tool_emojis:
|
||||
override = skin.tool_emojis.get(tool_name)
|
||||
if override:
|
||||
return override
|
||||
# 2. Registry default
|
||||
try:
|
||||
from tools.registry import registry
|
||||
emoji = registry.get_emoji(tool_name, default="")
|
||||
if emoji:
|
||||
return emoji
|
||||
except Exception:
|
||||
pass
|
||||
# 3. Hardcoded fallback
|
||||
return default
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tool preview (one-line summary of a tool call's primary argument)
|
||||
# =========================================================================
|
||||
|
||||
+106
-7
@@ -20,16 +20,65 @@ import json
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.usage_pricing import DEFAULT_PRICING, estimate_cost_usd, format_duration_compact, get_pricing, has_known_pricing
|
||||
# =========================================================================
|
||||
# Model pricing (USD per million tokens) — approximate as of early 2026
|
||||
# =========================================================================
|
||||
MODEL_PRICING = {
|
||||
# OpenAI
|
||||
"gpt-4o": {"input": 2.50, "output": 10.00},
|
||||
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
||||
"gpt-4.1": {"input": 2.00, "output": 8.00},
|
||||
"gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
||||
"gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
||||
"gpt-4.5-preview": {"input": 75.00, "output": 150.00},
|
||||
"gpt-5": {"input": 10.00, "output": 30.00},
|
||||
"gpt-5.4": {"input": 10.00, "output": 30.00},
|
||||
"o3": {"input": 10.00, "output": 40.00},
|
||||
"o3-mini": {"input": 1.10, "output": 4.40},
|
||||
"o4-mini": {"input": 1.10, "output": 4.40},
|
||||
# Anthropic
|
||||
"claude-opus-4-20250514": {"input": 15.00, "output": 75.00},
|
||||
"claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00},
|
||||
"claude-3-opus-20240229": {"input": 15.00, "output": 75.00},
|
||||
"claude-3-haiku-20240307": {"input": 0.25, "output": 1.25},
|
||||
# DeepSeek
|
||||
"deepseek-chat": {"input": 0.14, "output": 0.28},
|
||||
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
|
||||
# Google
|
||||
"gemini-2.5-pro": {"input": 1.25, "output": 10.00},
|
||||
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
|
||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||
# Meta (via providers)
|
||||
"llama-4-maverick": {"input": 0.50, "output": 0.70},
|
||||
"llama-4-scout": {"input": 0.20, "output": 0.30},
|
||||
# Z.AI / GLM (direct provider — pricing not published externally, treat as local)
|
||||
"glm-5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.7": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
|
||||
# Kimi / Moonshot (direct provider — pricing not published externally, treat as local)
|
||||
"kimi-k2.5": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
|
||||
# MiniMax (direct provider — pricing not published externally, treat as local)
|
||||
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
|
||||
}
|
||||
|
||||
_DEFAULT_PRICING = DEFAULT_PRICING
|
||||
# Fallback: unknown/custom models get zero cost (we can't assume pricing
|
||||
# for self-hosted models, custom OAI endpoints, local inference, etc.)
|
||||
_DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
||||
|
||||
|
||||
def _has_known_pricing(model_name: str) -> bool:
|
||||
"""Check if a model has known pricing (vs unknown/custom endpoint)."""
|
||||
return has_known_pricing(model_name)
|
||||
return _get_pricing(model_name) is not _DEFAULT_PRICING
|
||||
|
||||
|
||||
def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
@@ -38,17 +87,67 @@ def _get_pricing(model_name: str) -> Dict[str, float]:
|
||||
Returns _DEFAULT_PRICING (zero cost) for unknown/custom models —
|
||||
we can't assume costs for self-hosted endpoints, local inference, etc.
|
||||
"""
|
||||
return get_pricing(model_name)
|
||||
if not model_name:
|
||||
return _DEFAULT_PRICING
|
||||
|
||||
# Strip provider prefix (e.g., "anthropic/claude-..." -> "claude-...")
|
||||
bare = model_name.split("/")[-1].lower()
|
||||
|
||||
# Exact match first
|
||||
if bare in MODEL_PRICING:
|
||||
return MODEL_PRICING[bare]
|
||||
|
||||
# Fuzzy prefix match — prefer the LONGEST matching key to avoid
|
||||
# e.g. "gpt-4o" matching before "gpt-4o-mini" for "gpt-4o-mini-2024-07-18"
|
||||
best_match = None
|
||||
best_len = 0
|
||||
for key, price in MODEL_PRICING.items():
|
||||
if bare.startswith(key) and len(key) > best_len:
|
||||
best_match = price
|
||||
best_len = len(key)
|
||||
if best_match:
|
||||
return best_match
|
||||
|
||||
# Keyword heuristics (checked in most-specific-first order)
|
||||
if "opus" in bare:
|
||||
return {"input": 15.00, "output": 75.00}
|
||||
if "sonnet" in bare:
|
||||
return {"input": 3.00, "output": 15.00}
|
||||
if "haiku" in bare:
|
||||
return {"input": 0.80, "output": 4.00}
|
||||
if "gpt-4o-mini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
if "gpt-4o" in bare:
|
||||
return {"input": 2.50, "output": 10.00}
|
||||
if "gpt-5" in bare:
|
||||
return {"input": 10.00, "output": 30.00}
|
||||
if "deepseek" in bare:
|
||||
return {"input": 0.14, "output": 0.28}
|
||||
if "gemini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
|
||||
return _DEFAULT_PRICING
|
||||
|
||||
|
||||
def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
"""Estimate the USD cost for a given model and token counts."""
|
||||
return estimate_cost_usd(model, input_tokens, output_tokens)
|
||||
pricing = _get_pricing(model)
|
||||
return (input_tokens * pricing["input"] + output_tokens * pricing["output"]) / 1_000_000
|
||||
|
||||
|
||||
def _format_duration(seconds: float) -> str:
|
||||
"""Format seconds into a human-readable duration string."""
|
||||
return format_duration_compact(seconds)
|
||||
if seconds < 60:
|
||||
return f"{seconds:.0f}s"
|
||||
minutes = seconds / 60
|
||||
if minutes < 60:
|
||||
return f"{minutes:.0f}m"
|
||||
hours = minutes / 60
|
||||
if hours < 24:
|
||||
remaining_min = int(minutes % 60)
|
||||
return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h"
|
||||
days = hours / 24
|
||||
return f"{days:.1f}d"
|
||||
|
||||
|
||||
def _bar_chart(values: List[int], max_width: int = 20) -> List[str]:
|
||||
|
||||
+5
-17
@@ -73,15 +73,9 @@ DEFAULT_AGENT_IDENTITY = (
|
||||
MEMORY_GUIDANCE = (
|
||||
"You have persistent memory across sessions. Save durable facts using the memory "
|
||||
"tool: user preferences, environment details, tool quirks, and stable conventions. "
|
||||
"Memory is injected into every turn, so keep it compact and focused on facts that "
|
||||
"will still matter later.\n"
|
||||
"Prioritize what reduces future user steering — the most valuable memory is one "
|
||||
"that prevents the user from having to correct or remind you again. "
|
||||
"User preferences and recurring corrections matter more than procedural task details.\n"
|
||||
"Do NOT save task progress, session outcomes, completed-work logs, or temporary TODO "
|
||||
"state to memory; use session_search to recall those from past transcripts. "
|
||||
"If you've discovered a new way to do something, solved a problem that could be "
|
||||
"necessary later, save it as a skill with the skill tool."
|
||||
"Memory is injected into every turn, so keep it compact. Do NOT save task progress, "
|
||||
"session outcomes, or completed-work logs to memory; use session_search to recall "
|
||||
"those from past transcripts."
|
||||
)
|
||||
|
||||
SESSION_SEARCH_GUIDANCE = (
|
||||
@@ -92,11 +86,8 @@ SESSION_SEARCH_GUIDANCE = (
|
||||
|
||||
SKILLS_GUIDANCE = (
|
||||
"After completing a complex task (5+ tool calls), fixing a tricky error, "
|
||||
"or discovering a non-trivial workflow, save the approach as a "
|
||||
"skill with skill_manage so you can reuse it next time.\n"
|
||||
"When using a skill and finding it outdated, incomplete, or wrong, "
|
||||
"patch it immediately with skill_manage(action='patch') — don't wait to be asked. "
|
||||
"Skills that aren't maintained become liabilities."
|
||||
"or discovering a non-trivial workflow, consider saving the approach as a "
|
||||
"skill with skill_manage so you can reuse it next time."
|
||||
)
|
||||
|
||||
PLATFORM_HINTS = {
|
||||
@@ -335,9 +326,6 @@ def build_skills_system_prompt(
|
||||
"Before replying, scan the skills below. If one clearly matches your task, "
|
||||
"load it with skill_view(name) and follow its instructions. "
|
||||
"If a skill has issues, fix it with skill_manage(action='patch').\n"
|
||||
"After difficult/iterative tasks, offer to save as a skill. "
|
||||
"If a skill you loaded was missing steps, had wrong commands, or needed "
|
||||
"pitfalls you discovered, update it before finishing.\n"
|
||||
"\n"
|
||||
"<available_skills>\n"
|
||||
+ "\n".join(index_lines) + "\n"
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
"""Helpers for optional cheap-vs-strong model routing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
_COMPLEX_KEYWORDS = {
|
||||
"debug",
|
||||
"debugging",
|
||||
"implement",
|
||||
"implementation",
|
||||
"refactor",
|
||||
"patch",
|
||||
"traceback",
|
||||
"stacktrace",
|
||||
"exception",
|
||||
"error",
|
||||
"analyze",
|
||||
"analysis",
|
||||
"investigate",
|
||||
"architecture",
|
||||
"design",
|
||||
"compare",
|
||||
"benchmark",
|
||||
"optimize",
|
||||
"optimise",
|
||||
"review",
|
||||
"terminal",
|
||||
"shell",
|
||||
"tool",
|
||||
"tools",
|
||||
"pytest",
|
||||
"test",
|
||||
"tests",
|
||||
"plan",
|
||||
"planning",
|
||||
"delegate",
|
||||
"subagent",
|
||||
"cron",
|
||||
"docker",
|
||||
"kubernetes",
|
||||
}
|
||||
|
||||
_URL_RE = re.compile(r"https?://|www\.", re.IGNORECASE)
|
||||
|
||||
|
||||
def _coerce_bool(value: Any, default: bool = False) -> bool:
|
||||
if value is None:
|
||||
return default
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(value)
|
||||
|
||||
|
||||
def _coerce_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
|
||||
def choose_cheap_model_route(user_message: str, routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Return the configured cheap-model route when a message looks simple.
|
||||
|
||||
Conservative by design: if the message has signs of code/tool/debugging/
|
||||
long-form work, keep the primary model.
|
||||
"""
|
||||
cfg = routing_config or {}
|
||||
if not _coerce_bool(cfg.get("enabled"), False):
|
||||
return None
|
||||
|
||||
cheap_model = cfg.get("cheap_model") or {}
|
||||
if not isinstance(cheap_model, dict):
|
||||
return None
|
||||
provider = str(cheap_model.get("provider") or "").strip().lower()
|
||||
model = str(cheap_model.get("model") or "").strip()
|
||||
if not provider or not model:
|
||||
return None
|
||||
|
||||
text = (user_message or "").strip()
|
||||
if not text:
|
||||
return None
|
||||
|
||||
max_chars = _coerce_int(cfg.get("max_simple_chars"), 160)
|
||||
max_words = _coerce_int(cfg.get("max_simple_words"), 28)
|
||||
|
||||
if len(text) > max_chars:
|
||||
return None
|
||||
if len(text.split()) > max_words:
|
||||
return None
|
||||
if text.count("\n") > 1:
|
||||
return None
|
||||
if "```" in text or "`" in text:
|
||||
return None
|
||||
if _URL_RE.search(text):
|
||||
return None
|
||||
|
||||
lowered = text.lower()
|
||||
words = {token.strip(".,:;!?()[]{}\"'`") for token in lowered.split()}
|
||||
if words & _COMPLEX_KEYWORDS:
|
||||
return None
|
||||
|
||||
route = dict(cheap_model)
|
||||
route["provider"] = provider
|
||||
route["model"] = model
|
||||
route["routing_reason"] = "simple_turn"
|
||||
return route
|
||||
|
||||
|
||||
def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any]], primary: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Resolve the effective model/runtime for one turn.
|
||||
|
||||
Returns a dict with model/runtime/signature/label fields.
|
||||
"""
|
||||
route = choose_cheap_model_route(user_message, routing_config)
|
||||
if not route:
|
||||
return {
|
||||
"model": primary.get("model"),
|
||||
"runtime": {
|
||||
"api_key": primary.get("api_key"),
|
||||
"base_url": primary.get("base_url"),
|
||||
"provider": primary.get("provider"),
|
||||
"api_mode": primary.get("api_mode"),
|
||||
},
|
||||
"label": None,
|
||||
"signature": (
|
||||
primary.get("model"),
|
||||
primary.get("provider"),
|
||||
primary.get("base_url"),
|
||||
primary.get("api_mode"),
|
||||
),
|
||||
}
|
||||
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
|
||||
explicit_api_key = None
|
||||
api_key_env = str(route.get("api_key_env") or "").strip()
|
||||
if api_key_env:
|
||||
explicit_api_key = os.getenv(api_key_env) or None
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(
|
||||
requested=route.get("provider"),
|
||||
explicit_api_key=explicit_api_key,
|
||||
explicit_base_url=route.get("base_url"),
|
||||
)
|
||||
except Exception:
|
||||
return {
|
||||
"model": primary.get("model"),
|
||||
"runtime": {
|
||||
"api_key": primary.get("api_key"),
|
||||
"base_url": primary.get("base_url"),
|
||||
"provider": primary.get("provider"),
|
||||
"api_mode": primary.get("api_mode"),
|
||||
},
|
||||
"label": None,
|
||||
"signature": (
|
||||
primary.get("model"),
|
||||
primary.get("provider"),
|
||||
primary.get("base_url"),
|
||||
primary.get("api_mode"),
|
||||
),
|
||||
}
|
||||
|
||||
return {
|
||||
"model": route.get("model"),
|
||||
"runtime": {
|
||||
"api_key": runtime.get("api_key"),
|
||||
"base_url": runtime.get("base_url"),
|
||||
"provider": runtime.get("provider"),
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
},
|
||||
"label": f"smart route → {route.get('model')} ({runtime.get('provider')})",
|
||||
"signature": (
|
||||
route.get("model"),
|
||||
runtime.get("provider"),
|
||||
runtime.get("base_url"),
|
||||
runtime.get("api_mode"),
|
||||
),
|
||||
}
|
||||
@@ -1,134 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Dict
|
||||
|
||||
|
||||
MODEL_PRICING = {
|
||||
"gpt-4o": {"input": 2.50, "output": 10.00},
|
||||
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
|
||||
"gpt-4.1": {"input": 2.00, "output": 8.00},
|
||||
"gpt-4.1-mini": {"input": 0.40, "output": 1.60},
|
||||
"gpt-4.1-nano": {"input": 0.10, "output": 0.40},
|
||||
"gpt-4.5-preview": {"input": 75.00, "output": 150.00},
|
||||
"gpt-5": {"input": 10.00, "output": 30.00},
|
||||
"gpt-5.4": {"input": 10.00, "output": 30.00},
|
||||
"o3": {"input": 10.00, "output": 40.00},
|
||||
"o3-mini": {"input": 1.10, "output": 4.40},
|
||||
"o4-mini": {"input": 1.10, "output": 4.40},
|
||||
"claude-opus-4-20250514": {"input": 15.00, "output": 75.00},
|
||||
"claude-sonnet-4-20250514": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-sonnet-20241022": {"input": 3.00, "output": 15.00},
|
||||
"claude-3-5-haiku-20241022": {"input": 0.80, "output": 4.00},
|
||||
"claude-3-opus-20240229": {"input": 15.00, "output": 75.00},
|
||||
"claude-3-haiku-20240307": {"input": 0.25, "output": 1.25},
|
||||
"deepseek-chat": {"input": 0.14, "output": 0.28},
|
||||
"deepseek-reasoner": {"input": 0.55, "output": 2.19},
|
||||
"gemini-2.5-pro": {"input": 1.25, "output": 10.00},
|
||||
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
|
||||
"gemini-2.0-flash": {"input": 0.10, "output": 0.40},
|
||||
"llama-4-maverick": {"input": 0.50, "output": 0.70},
|
||||
"llama-4-scout": {"input": 0.20, "output": 0.30},
|
||||
"glm-5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.7": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5": {"input": 0.0, "output": 0.0},
|
||||
"glm-4.5-flash": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2.5": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-thinking": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-turbo-preview": {"input": 0.0, "output": 0.0},
|
||||
"kimi-k2-0905-preview": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.5": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.5-highspeed": {"input": 0.0, "output": 0.0},
|
||||
"MiniMax-M2.1": {"input": 0.0, "output": 0.0},
|
||||
}
|
||||
|
||||
DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
||||
|
||||
|
||||
def get_pricing(model_name: str) -> Dict[str, float]:
|
||||
if not model_name:
|
||||
return DEFAULT_PRICING
|
||||
|
||||
bare = model_name.split("/")[-1].lower()
|
||||
if bare in MODEL_PRICING:
|
||||
return MODEL_PRICING[bare]
|
||||
|
||||
best_match = None
|
||||
best_len = 0
|
||||
for key, price in MODEL_PRICING.items():
|
||||
if bare.startswith(key) and len(key) > best_len:
|
||||
best_match = price
|
||||
best_len = len(key)
|
||||
if best_match:
|
||||
return best_match
|
||||
|
||||
if "opus" in bare:
|
||||
return {"input": 15.00, "output": 75.00}
|
||||
if "sonnet" in bare:
|
||||
return {"input": 3.00, "output": 15.00}
|
||||
if "haiku" in bare:
|
||||
return {"input": 0.80, "output": 4.00}
|
||||
if "gpt-4o-mini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
if "gpt-4o" in bare:
|
||||
return {"input": 2.50, "output": 10.00}
|
||||
if "gpt-5" in bare:
|
||||
return {"input": 10.00, "output": 30.00}
|
||||
if "deepseek" in bare:
|
||||
return {"input": 0.14, "output": 0.28}
|
||||
if "gemini" in bare:
|
||||
return {"input": 0.15, "output": 0.60}
|
||||
|
||||
return DEFAULT_PRICING
|
||||
|
||||
|
||||
def has_known_pricing(model_name: str) -> bool:
|
||||
pricing = get_pricing(model_name)
|
||||
return pricing is not DEFAULT_PRICING and any(
|
||||
float(value) > 0 for value in pricing.values()
|
||||
)
|
||||
|
||||
|
||||
def estimate_cost_usd(model: str, input_tokens: int, output_tokens: int) -> float:
|
||||
pricing = get_pricing(model)
|
||||
total = (
|
||||
Decimal(input_tokens) * Decimal(str(pricing["input"]))
|
||||
+ Decimal(output_tokens) * Decimal(str(pricing["output"]))
|
||||
) / Decimal("1000000")
|
||||
return float(total)
|
||||
|
||||
|
||||
def format_duration_compact(seconds: float) -> str:
|
||||
if seconds < 60:
|
||||
return f"{seconds:.0f}s"
|
||||
minutes = seconds / 60
|
||||
if minutes < 60:
|
||||
return f"{minutes:.0f}m"
|
||||
hours = minutes / 60
|
||||
if hours < 24:
|
||||
remaining_min = int(minutes % 60)
|
||||
return f"{int(hours)}h {remaining_min}m" if remaining_min else f"{int(hours)}h"
|
||||
days = hours / 24
|
||||
return f"{days:.1f}d"
|
||||
|
||||
|
||||
def format_token_count_compact(value: int) -> str:
|
||||
abs_value = abs(int(value))
|
||||
if abs_value < 1_000:
|
||||
return str(int(value))
|
||||
|
||||
sign = "-" if value < 0 else ""
|
||||
units = ((1_000_000_000, "B"), (1_000_000, "M"), (1_000, "K"))
|
||||
for threshold, suffix in units:
|
||||
if abs_value >= threshold:
|
||||
scaled = abs_value / threshold
|
||||
if scaled < 10:
|
||||
text = f"{scaled:.2f}"
|
||||
elif scaled < 100:
|
||||
text = f"{scaled:.1f}"
|
||||
else:
|
||||
text = f"{scaled:.0f}"
|
||||
text = text.rstrip("0").rstrip(".")
|
||||
return f"{sign}{text}{suffix}"
|
||||
|
||||
return f"{value:,}"
|
||||
+7
-53
@@ -51,20 +51,6 @@ model:
|
||||
# # Data policy: "allow" (default) or "deny" to exclude providers that may store data
|
||||
# # data_collection: "deny"
|
||||
|
||||
# =============================================================================
|
||||
# Smart Model Routing (optional)
|
||||
# =============================================================================
|
||||
# Use a cheaper model for short/simple turns while keeping your main model for
|
||||
# more complex requests. Disabled by default.
|
||||
#
|
||||
# smart_model_routing:
|
||||
# enabled: true
|
||||
# max_simple_chars: 160
|
||||
# max_simple_words: 28
|
||||
# cheap_model:
|
||||
# provider: openrouter
|
||||
# model: google/gemini-2.5-flash
|
||||
|
||||
# =============================================================================
|
||||
# Git Worktree Isolation
|
||||
# =============================================================================
|
||||
@@ -90,9 +76,8 @@ model:
|
||||
# - Messaging (Telegram/Discord): Uses MESSAGING_CWD from .env (default: home)
|
||||
terminal:
|
||||
backend: "local"
|
||||
cwd: "." # For local backend: "." = current directory. Ignored for remote backends unless a backend documents otherwise.
|
||||
cwd: "." # For local backend: "." = current directory. Ignored for remote backends.
|
||||
timeout: 180
|
||||
docker_mount_cwd_to_workspace: false # SECURITY: off by default. Opt in to mount the launch cwd into Docker /workspace.
|
||||
lifetime_seconds: 300
|
||||
# sudo_password: "" # Enable sudo commands (pipes via sudo -S) - SECURITY WARNING: plaintext!
|
||||
|
||||
@@ -122,7 +107,12 @@ terminal:
|
||||
# timeout: 180
|
||||
# lifetime_seconds: 300
|
||||
# docker_image: "nikolaik/python-nodejs:python3.11-nodejs20"
|
||||
# docker_mount_cwd_to_workspace: true # Explicit opt-in: mount your launch cwd into /workspace
|
||||
# # Optional: explicitly forward selected env vars into Docker.
|
||||
# # These values come from your current shell first, then ~/.hermes/.env.
|
||||
# # Warning: anything forwarded here is visible to commands run in the container.
|
||||
# docker_forward_env:
|
||||
# - "GITHUB_TOKEN"
|
||||
# - "NPM_TOKEN"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# OPTION 4: Singularity/Apptainer container
|
||||
@@ -349,25 +339,6 @@ session_reset:
|
||||
idle_minutes: 1440 # Inactivity timeout in minutes (default: 1440 = 24 hours)
|
||||
at_hour: 4 # Daily reset hour, 0-23 local time (default: 4 AM)
|
||||
|
||||
# When true, group/channel chats use one session per participant when the platform
|
||||
# provides a user ID. This is the secure default and prevents users in the same
|
||||
# room from sharing context, interrupts, and token costs. Set false only if you
|
||||
# explicitly want one shared "room brain" per group/channel.
|
||||
group_sessions_per_user: true
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Gateway Streaming
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Stream tokens to messaging platforms in real-time. The bot sends a message
|
||||
# on first token, then progressively edits it as more tokens arrive.
|
||||
# Disabled by default — enable to try the streaming UX on Telegram/Discord/Slack.
|
||||
streaming:
|
||||
enabled: false
|
||||
# transport: edit # "edit" = progressive editMessageText
|
||||
# edit_interval: 0.3 # seconds between message edits
|
||||
# buffer_threshold: 40 # chars before forcing an edit flush
|
||||
# cursor: " ▉" # cursor shown during streaming
|
||||
|
||||
# =============================================================================
|
||||
# Skills Configuration
|
||||
# =============================================================================
|
||||
@@ -729,12 +700,6 @@ display:
|
||||
# Toggle at runtime with /reasoning show or /reasoning hide.
|
||||
show_reasoning: false
|
||||
|
||||
# Stream tokens to the terminal as they arrive instead of waiting for the
|
||||
# full response. The response box opens on first token and text appears
|
||||
# line-by-line. Tool calls are still captured silently.
|
||||
# Disabled by default — enable to try the streaming UX.
|
||||
streaming: false
|
||||
|
||||
# ───────────────────────────────────────────────────────────────────────────
|
||||
# Skin / Theme
|
||||
# ───────────────────────────────────────────────────────────────────────────
|
||||
@@ -775,14 +740,3 @@ display:
|
||||
# tool_prefix: "╎" # Tool output line prefix (default: ┊)
|
||||
#
|
||||
skin: default
|
||||
|
||||
# =============================================================================
|
||||
# Privacy
|
||||
# =============================================================================
|
||||
# privacy:
|
||||
# # Redact PII from the LLM context prompt.
|
||||
# # When true, phone numbers are stripped and user/chat IDs are replaced
|
||||
# # with deterministic hashes before being sent to the model.
|
||||
# # Names and usernames are NOT affected (user-chosen, publicly visible).
|
||||
# # Routing/delivery still uses the original values internally.
|
||||
# redact_pii: false
|
||||
|
||||
+5
-19
@@ -315,7 +315,6 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
|
||||
# Provider routing
|
||||
pr = _cfg.get("provider_routing", {})
|
||||
smart_routing = _cfg.get("smart_model_routing", {}) or {}
|
||||
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
@@ -332,25 +331,12 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
message = format_runtime_provider_error(exc)
|
||||
raise RuntimeError(message) from exc
|
||||
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
turn_route = resolve_turn_route(
|
||||
prompt,
|
||||
smart_routing,
|
||||
{
|
||||
"model": model,
|
||||
"api_key": runtime.get("api_key"),
|
||||
"base_url": runtime.get("base_url"),
|
||||
"provider": runtime.get("provider"),
|
||||
"api_mode": runtime.get("api_mode"),
|
||||
},
|
||||
)
|
||||
|
||||
agent = AIAgent(
|
||||
model=turn_route["model"],
|
||||
api_key=turn_route["runtime"].get("api_key"),
|
||||
base_url=turn_route["runtime"].get("base_url"),
|
||||
provider=turn_route["runtime"].get("provider"),
|
||||
api_mode=turn_route["runtime"].get("api_mode"),
|
||||
model=model,
|
||||
api_key=runtime.get("api_key"),
|
||||
base_url=runtime.get("base_url"),
|
||||
provider=runtime.get("provider"),
|
||||
api_mode=runtime.get("api_mode"),
|
||||
max_iterations=max_iterations,
|
||||
reasoning_config=reasoning_config,
|
||||
prefill_messages=prefill_messages,
|
||||
|
||||
+2
-54
@@ -97,11 +97,10 @@ class SessionResetPolicy:
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SessionResetPolicy":
|
||||
# Handle both missing keys and explicit null values (YAML null → None)
|
||||
mode = data.get("mode")
|
||||
at_hour = data.get("at_hour")
|
||||
idle_minutes = data.get("idle_minutes")
|
||||
return cls(
|
||||
mode=mode if mode is not None else "both",
|
||||
mode=data.get("mode", "both"),
|
||||
at_hour=at_hour if at_hour is not None else 4,
|
||||
idle_minutes=idle_minutes if idle_minutes is not None else 1440,
|
||||
)
|
||||
@@ -146,37 +145,6 @@ class PlatformConfig:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamingConfig:
|
||||
"""Configuration for real-time token streaming to messaging platforms."""
|
||||
enabled: bool = False
|
||||
transport: str = "edit" # "edit" (progressive editMessageText) or "off"
|
||||
edit_interval: float = 0.3 # Seconds between message edits
|
||||
buffer_threshold: int = 40 # Chars before forcing an edit
|
||||
cursor: str = " ▉" # Cursor shown during streaming
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"transport": self.transport,
|
||||
"edit_interval": self.edit_interval,
|
||||
"buffer_threshold": self.buffer_threshold,
|
||||
"cursor": self.cursor,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StreamingConfig":
|
||||
if not data:
|
||||
return cls()
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
transport=data.get("transport", "edit"),
|
||||
edit_interval=float(data.get("edit_interval", 0.3)),
|
||||
buffer_threshold=int(data.get("buffer_threshold", 40)),
|
||||
cursor=data.get("cursor", " ▉"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GatewayConfig:
|
||||
"""
|
||||
@@ -206,13 +174,7 @@ class GatewayConfig:
|
||||
|
||||
# STT settings
|
||||
stt_enabled: bool = True # Whether to auto-transcribe inbound voice messages
|
||||
|
||||
# Session isolation in shared chats
|
||||
group_sessions_per_user: bool = True # Isolate group/channel sessions per participant when user IDs are available
|
||||
|
||||
# Streaming configuration
|
||||
streaming: StreamingConfig = field(default_factory=StreamingConfig)
|
||||
|
||||
|
||||
def get_connected_platforms(self) -> List[Platform]:
|
||||
"""Return list of platforms that are enabled and configured."""
|
||||
connected = []
|
||||
@@ -277,8 +239,6 @@ class GatewayConfig:
|
||||
"sessions_dir": str(self.sessions_dir),
|
||||
"always_log_local": self.always_log_local,
|
||||
"stt_enabled": self.stt_enabled,
|
||||
"group_sessions_per_user": self.group_sessions_per_user,
|
||||
"streaming": self.streaming.to_dict(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -319,8 +279,6 @@ class GatewayConfig:
|
||||
if stt_enabled is None:
|
||||
stt_enabled = data.get("stt", {}).get("enabled") if isinstance(data.get("stt"), dict) else None
|
||||
|
||||
group_sessions_per_user = data.get("group_sessions_per_user")
|
||||
|
||||
return cls(
|
||||
platforms=platforms,
|
||||
default_reset_policy=default_policy,
|
||||
@@ -331,8 +289,6 @@ class GatewayConfig:
|
||||
sessions_dir=sessions_dir,
|
||||
always_log_local=data.get("always_log_local", True),
|
||||
stt_enabled=_coerce_bool(stt_enabled, True),
|
||||
group_sessions_per_user=_coerce_bool(group_sessions_per_user, True),
|
||||
streaming=StreamingConfig.from_dict(data.get("streaming", {})),
|
||||
)
|
||||
|
||||
|
||||
@@ -388,14 +344,6 @@ def load_gateway_config() -> GatewayConfig:
|
||||
if isinstance(stt_cfg, dict) and "enabled" in stt_cfg:
|
||||
config.stt_enabled = _coerce_bool(stt_cfg.get("enabled"), True)
|
||||
|
||||
# Bridge group session isolation from config.yaml into gateway runtime.
|
||||
# Secure default is per-user isolation in shared chats.
|
||||
if "group_sessions_per_user" in yaml_cfg:
|
||||
config.group_sessions_per_user = _coerce_bool(
|
||||
yaml_cfg.get("group_sessions_per_user"),
|
||||
True,
|
||||
)
|
||||
|
||||
# Bridge discord settings from config.yaml to env vars
|
||||
# (env vars take precedence — only set if not already defined)
|
||||
discord_cfg = yaml_cfg.get("discord", {})
|
||||
|
||||
@@ -752,10 +752,7 @@ class BasePlatformAdapter(ABC):
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
session_key = build_session_key(
|
||||
event.source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
)
|
||||
session_key = build_session_key(event.source)
|
||||
|
||||
# Check if there's already an active handler for this session
|
||||
if session_key in self._active_sessions:
|
||||
|
||||
@@ -135,23 +135,14 @@ def _extract_email_address(raw: str) -> str:
|
||||
return raw.strip().lower()
|
||||
|
||||
|
||||
def _extract_attachments(
|
||||
msg: email_lib.message.Message,
|
||||
skip_attachments: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Extract attachment metadata and cache files locally.
|
||||
|
||||
When *skip_attachments* is True, all attachment/inline parts are ignored
|
||||
(useful for malware protection or bandwidth savings).
|
||||
"""
|
||||
def _extract_attachments(msg: email_lib.message.Message) -> List[Dict[str, Any]]:
|
||||
"""Extract attachment metadata and cache files locally."""
|
||||
attachments = []
|
||||
if not msg.is_multipart():
|
||||
return attachments
|
||||
|
||||
for part in msg.walk():
|
||||
disposition = str(part.get("Content-Disposition", ""))
|
||||
if skip_attachments and ("attachment" in disposition or "inline" in disposition):
|
||||
continue
|
||||
if "attachment" not in disposition and "inline" not in disposition:
|
||||
continue
|
||||
# Skip text/plain and text/html body parts
|
||||
@@ -205,13 +196,6 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
self._smtp_port = int(os.getenv("EMAIL_SMTP_PORT", "587"))
|
||||
self._poll_interval = int(os.getenv("EMAIL_POLL_INTERVAL", "15"))
|
||||
|
||||
# Skip attachments — configured via config.yaml:
|
||||
# platforms:
|
||||
# email:
|
||||
# skip_attachments: true
|
||||
extra = config.extra or {}
|
||||
self._skip_attachments = extra.get("skip_attachments", False)
|
||||
|
||||
# Track message IDs we've already processed to avoid duplicates
|
||||
self._seen_uids: set = set()
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
@@ -322,7 +306,7 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
message_id = msg.get("Message-ID", "")
|
||||
in_reply_to = msg.get("In-Reply-To", "")
|
||||
body = _extract_text_body(msg)
|
||||
attachments = _extract_attachments(msg, skip_attachments=self._skip_attachments)
|
||||
attachments = _extract_attachments(msg)
|
||||
|
||||
results.append({
|
||||
"uid": uid,
|
||||
|
||||
@@ -202,26 +202,8 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
self._handle_media_message
|
||||
))
|
||||
|
||||
# Start polling — retry initialize() for transient TLS resets
|
||||
try:
|
||||
from telegram.error import NetworkError, TimedOut
|
||||
except ImportError:
|
||||
NetworkError = TimedOut = OSError # type: ignore[misc,assignment]
|
||||
_max_connect = 3
|
||||
for _attempt in range(_max_connect):
|
||||
try:
|
||||
await self._app.initialize()
|
||||
break
|
||||
except (NetworkError, TimedOut, OSError) as init_err:
|
||||
if _attempt < _max_connect - 1:
|
||||
wait = 2 ** _attempt
|
||||
logger.warning(
|
||||
"[%s] Connect attempt %d/%d failed: %s — retrying in %ds",
|
||||
self.name, _attempt + 1, _max_connect, init_err, wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
raise
|
||||
# Start polling in background
|
||||
await self._app.initialize()
|
||||
await self._app.start()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
@@ -283,8 +265,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
release_scoped_lock("telegram-bot-token", self._token_lock_identity)
|
||||
except Exception:
|
||||
pass
|
||||
message = f"Telegram startup failed: {e}"
|
||||
self._set_fatal_error("telegram_connect_error", message, retryable=True)
|
||||
logger.error("[%s] Failed to connect to Telegram: %s", self.name, e, exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -342,59 +322,36 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
if len(chunks) > 1:
|
||||
# truncate_message appends a raw " (1/2)" suffix. Escape the
|
||||
# MarkdownV2-special parentheses so Telegram doesn't reject the
|
||||
# chunk and fall back to plain text.
|
||||
chunks = [
|
||||
re.sub(r" \((\d+)/(\d+)\)$", r" \\(\1/\2\\)", chunk)
|
||||
for chunk in chunks
|
||||
]
|
||||
|
||||
message_ids = []
|
||||
thread_id = metadata.get("thread_id") if metadata else None
|
||||
|
||||
try:
|
||||
from telegram.error import NetworkError as _NetErr
|
||||
except ImportError:
|
||||
_NetErr = OSError # type: ignore[misc,assignment]
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
msg = None
|
||||
for _send_attempt in range(3):
|
||||
try:
|
||||
# Try Markdown first, fall back to plain text if it fails
|
||||
try:
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=plain_chunk,
|
||||
parse_mode=None,
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
break # success
|
||||
except _NetErr as send_err:
|
||||
if _send_attempt < 2:
|
||||
wait = 2 ** _send_attempt
|
||||
logger.warning("[%s] Network error on send (attempt %d/3), retrying in %ds: %s",
|
||||
self.name, _send_attempt + 1, wait, send_err)
|
||||
await asyncio.sleep(wait)
|
||||
else:
|
||||
raise
|
||||
# Try Markdown first, fall back to plain text if it fails
|
||||
try:
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=chunk,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
if "parse" in str(md_error).lower() or "markdown" in str(md_error).lower():
|
||||
logger.warning("[%s] MarkdownV2 parse failed, falling back to plain text: %s", self.name, md_error)
|
||||
# Strip MDV2 escape backslashes so the user doesn't
|
||||
# see raw backslashes littered through the message.
|
||||
plain_chunk = _strip_mdv2(chunk)
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=plain_chunk,
|
||||
parse_mode=None, # Plain text
|
||||
reply_to_message_id=int(reply_to) if reply_to and i == 0 else None,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
else:
|
||||
raise # Re-raise if not a parse error
|
||||
message_ids.append(str(msg.message_id))
|
||||
|
||||
return SendResult(
|
||||
@@ -864,10 +821,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
def _photo_batch_key(self, event: MessageEvent, msg: Message) -> str:
|
||||
"""Return a batching key for Telegram photos/albums."""
|
||||
from gateway.session import build_session_key
|
||||
session_key = build_session_key(
|
||||
event.source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
)
|
||||
session_key = build_session_key(event.source)
|
||||
media_group_id = getattr(msg, "media_group_id", None)
|
||||
if media_group_id:
|
||||
return f"{session_key}:album:{media_group_id}"
|
||||
|
||||
+79
-296
@@ -29,49 +29,6 @@ from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Optional, Any, List
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SSL certificate auto-detection for NixOS and other non-standard systems.
|
||||
# Must run BEFORE any HTTP library (discord, aiohttp, etc.) is imported.
|
||||
# ---------------------------------------------------------------------------
|
||||
def _ensure_ssl_certs() -> None:
|
||||
"""Set SSL_CERT_FILE if the system doesn't expose CA certs to Python."""
|
||||
if "SSL_CERT_FILE" in os.environ:
|
||||
return # user already configured it
|
||||
|
||||
import ssl
|
||||
|
||||
# 1. Python's compiled-in defaults
|
||||
paths = ssl.get_default_verify_paths()
|
||||
for candidate in (paths.cafile, paths.openssl_cafile):
|
||||
if candidate and os.path.exists(candidate):
|
||||
os.environ["SSL_CERT_FILE"] = candidate
|
||||
return
|
||||
|
||||
# 2. certifi (ships its own Mozilla bundle)
|
||||
try:
|
||||
import certifi
|
||||
os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
return
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 3. Common distro / macOS locations
|
||||
for candidate in (
|
||||
"/etc/ssl/certs/ca-certificates.crt", # Debian/Ubuntu/Gentoo
|
||||
"/etc/pki/tls/certs/ca-bundle.crt", # RHEL/CentOS 7
|
||||
"/etc/pki/ca-trust/extracted/pem/tls-ca-bundle.pem", # RHEL/CentOS 8+
|
||||
"/etc/ssl/ca-bundle.pem", # SUSE/OpenSUSE
|
||||
"/etc/ssl/cert.pem", # Alpine / macOS
|
||||
"/etc/pki/tls/cert.pem", # Fedora
|
||||
"/usr/local/etc/openssl@1.1/cert.pem", # macOS Homebrew Intel
|
||||
"/opt/homebrew/etc/openssl@1.1/cert.pem", # macOS Homebrew ARM
|
||||
):
|
||||
if os.path.exists(candidate):
|
||||
os.environ["SSL_CERT_FILE"] = candidate
|
||||
return
|
||||
|
||||
_ensure_ssl_certs()
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
@@ -107,6 +64,7 @@ if _config_path.exists():
|
||||
"timeout": "TERMINAL_TIMEOUT",
|
||||
"lifetime_seconds": "TERMINAL_LIFETIME_SECONDS",
|
||||
"docker_image": "TERMINAL_DOCKER_IMAGE",
|
||||
"docker_forward_env": "TERMINAL_DOCKER_FORWARD_ENV",
|
||||
"singularity_image": "TERMINAL_SINGULARITY_IMAGE",
|
||||
"modal_image": "TERMINAL_MODAL_IMAGE",
|
||||
"daytona_image": "TERMINAL_DAYTONA_IMAGE",
|
||||
@@ -120,7 +78,6 @@ if _config_path.exists():
|
||||
"container_persistent": "TERMINAL_CONTAINER_PERSISTENT",
|
||||
"docker_volumes": "TERMINAL_DOCKER_VOLUMES",
|
||||
"sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
"persistent_shell": "TERMINAL_PERSISTENT_SHELL",
|
||||
}
|
||||
for _cfg_key, _env_var in _terminal_env_map.items():
|
||||
if _cfg_key in _terminal_cfg:
|
||||
@@ -157,12 +114,6 @@ if _config_path.exists():
|
||||
"base_url": "AUXILIARY_WEB_EXTRACT_BASE_URL",
|
||||
"api_key": "AUXILIARY_WEB_EXTRACT_API_KEY",
|
||||
},
|
||||
"approval": {
|
||||
"provider": "AUXILIARY_APPROVAL_PROVIDER",
|
||||
"model": "AUXILIARY_APPROVAL_MODEL",
|
||||
"base_url": "AUXILIARY_APPROVAL_BASE_URL",
|
||||
"api_key": "AUXILIARY_APPROVAL_API_KEY",
|
||||
},
|
||||
}
|
||||
for _task_key, _env_map in _aux_task_env.items():
|
||||
_task_cfg = _auxiliary_cfg.get(_task_key, {})
|
||||
@@ -324,7 +275,6 @@ class GatewayRunner:
|
||||
self._show_reasoning = self._load_show_reasoning()
|
||||
self._provider_routing = self._load_provider_routing()
|
||||
self._fallback_model = self._load_fallback_model()
|
||||
self._smart_model_routing = self._load_smart_model_routing()
|
||||
|
||||
# Wire process registry into session store for reset protection
|
||||
from tools.process_registry import process_registry
|
||||
@@ -356,7 +306,7 @@ class GatewayRunner:
|
||||
# Ensure tirith security scanner is available (downloads if needed)
|
||||
try:
|
||||
from tools.tirith_security import ensure_installed
|
||||
ensure_installed(log_failures=False)
|
||||
ensure_installed()
|
||||
except Exception:
|
||||
pass # Non-fatal — fail-open at scan time if unavailable
|
||||
|
||||
@@ -485,11 +435,7 @@ class GatewayRunner:
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def _flush_memories_for_session(
|
||||
self,
|
||||
old_session_id: str,
|
||||
honcho_session_key: Optional[str] = None,
|
||||
):
|
||||
def _flush_memories_for_session(self, old_session_id: str):
|
||||
"""Prompt the agent to save memories/skills before context is lost.
|
||||
|
||||
Synchronous worker — meant to be called via run_in_executor from
|
||||
@@ -517,7 +463,6 @@ class GatewayRunner:
|
||||
quiet_mode=True,
|
||||
enabled_toolsets=["memory", "skills"],
|
||||
session_id=old_session_id,
|
||||
honcho_session_key=honcho_session_key,
|
||||
)
|
||||
|
||||
# Build conversation history from transcript
|
||||
@@ -545,7 +490,6 @@ class GatewayRunner:
|
||||
tmp_agent.run_conversation(
|
||||
user_message=flush_prompt,
|
||||
conversation_history=msgs,
|
||||
sync_honcho=False,
|
||||
)
|
||||
logger.info("Pre-reset memory flush completed for session %s", old_session_id)
|
||||
# Flush any queued Honcho writes before the session is dropped
|
||||
@@ -557,19 +501,10 @@ class GatewayRunner:
|
||||
except Exception as e:
|
||||
logger.debug("Pre-reset memory flush failed for session %s: %s", old_session_id, e)
|
||||
|
||||
async def _async_flush_memories(
|
||||
self,
|
||||
old_session_id: str,
|
||||
honcho_session_key: Optional[str] = None,
|
||||
):
|
||||
async def _async_flush_memories(self, old_session_id: str):
|
||||
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self._flush_memories_for_session,
|
||||
old_session_id,
|
||||
honcho_session_key,
|
||||
)
|
||||
await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id)
|
||||
|
||||
@property
|
||||
def should_exit_cleanly(self) -> bool:
|
||||
@@ -579,33 +514,6 @@ class GatewayRunner:
|
||||
def exit_reason(self) -> Optional[str]:
|
||||
return self._exit_reason
|
||||
|
||||
def _session_key_for_source(self, source: SessionSource) -> str:
|
||||
"""Resolve the current session key for a source, honoring gateway config when available."""
|
||||
if hasattr(self, "session_store") and self.session_store is not None:
|
||||
try:
|
||||
session_key = self.session_store._generate_session_key(source)
|
||||
if isinstance(session_key, str) and session_key:
|
||||
return session_key
|
||||
except Exception:
|
||||
pass
|
||||
config = getattr(self, "config", None)
|
||||
return build_session_key(
|
||||
source,
|
||||
group_sessions_per_user=getattr(config, "group_sessions_per_user", True),
|
||||
)
|
||||
|
||||
def _resolve_turn_agent_config(self, user_message: str, model: str, runtime_kwargs: dict) -> dict:
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
primary = {
|
||||
"model": model,
|
||||
"api_key": runtime_kwargs.get("api_key"),
|
||||
"base_url": runtime_kwargs.get("base_url"),
|
||||
"provider": runtime_kwargs.get("provider"),
|
||||
"api_mode": runtime_kwargs.get("api_mode"),
|
||||
}
|
||||
return resolve_turn_route(user_message, getattr(self, "_smart_model_routing", {}), primary)
|
||||
|
||||
async def _handle_adapter_fatal_error(self, adapter: BasePlatformAdapter) -> None:
|
||||
"""React to a non-retryable adapter failure after startup."""
|
||||
logger.error(
|
||||
@@ -808,20 +716,6 @@ class GatewayRunner:
|
||||
pass
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _load_smart_model_routing() -> dict:
|
||||
"""Load optional smart cheap-vs-strong model routing config."""
|
||||
try:
|
||||
import yaml as _y
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
with open(cfg_path, encoding="utf-8") as _f:
|
||||
cfg = _y.safe_load(_f) or {}
|
||||
return cfg.get("smart_model_routing", {}) or {}
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
async def start(self) -> bool:
|
||||
"""
|
||||
Start the gateway and all configured platform adapters.
|
||||
@@ -864,15 +758,12 @@ class GatewayRunner:
|
||||
logger.warning("Process checkpoint recovery: %s", e)
|
||||
|
||||
connected_count = 0
|
||||
enabled_platform_count = 0
|
||||
startup_nonretryable_errors: list[str] = []
|
||||
startup_retryable_errors: list[str] = []
|
||||
|
||||
# Initialize and connect each configured platform
|
||||
for platform, platform_config in self.config.platforms.items():
|
||||
if not platform_config.enabled:
|
||||
continue
|
||||
enabled_platform_count += 1
|
||||
|
||||
adapter = self._create_adapter(platform, platform_config)
|
||||
if not adapter:
|
||||
@@ -894,22 +785,12 @@ class GatewayRunner:
|
||||
logger.info("✓ %s connected", platform.value)
|
||||
else:
|
||||
logger.warning("✗ %s failed to connect", platform.value)
|
||||
if adapter.has_fatal_error:
|
||||
target = (
|
||||
startup_retryable_errors
|
||||
if adapter.fatal_error_retryable
|
||||
else startup_nonretryable_errors
|
||||
)
|
||||
target.append(
|
||||
if adapter.has_fatal_error and not adapter.fatal_error_retryable:
|
||||
startup_nonretryable_errors.append(
|
||||
f"{platform.value}: {adapter.fatal_error_message}"
|
||||
)
|
||||
else:
|
||||
startup_retryable_errors.append(
|
||||
f"{platform.value}: failed to connect"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("✗ %s error: %s", platform.value, e)
|
||||
startup_retryable_errors.append(f"{platform.value}: {e}")
|
||||
|
||||
if connected_count == 0:
|
||||
if startup_nonretryable_errors:
|
||||
@@ -922,16 +803,7 @@ class GatewayRunner:
|
||||
pass
|
||||
self._request_clean_exit(reason)
|
||||
return True
|
||||
if enabled_platform_count > 0:
|
||||
reason = "; ".join(startup_retryable_errors) or "all configured messaging platforms failed to connect"
|
||||
logger.error("Gateway failed to connect any configured messaging platform: %s", reason)
|
||||
try:
|
||||
from gateway.status import write_runtime_status
|
||||
write_runtime_status(gateway_state="startup_failed", exit_reason=reason)
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
logger.warning("No messaging platforms enabled.")
|
||||
logger.warning("No messaging platforms connected.")
|
||||
logger.info("Gateway will continue running for cron job execution.")
|
||||
|
||||
# Update delivery router with adapters
|
||||
@@ -1008,7 +880,7 @@ class GatewayRunner:
|
||||
entry.session_id, key,
|
||||
)
|
||||
try:
|
||||
await self._async_flush_memories(entry.session_id, key)
|
||||
await self._async_flush_memories(entry.session_id)
|
||||
self._shutdown_gateway_honcho(key)
|
||||
self.session_store._pre_flushed_sessions.add(entry.session_id)
|
||||
except Exception as e:
|
||||
@@ -1070,12 +942,6 @@ class GatewayRunner:
|
||||
config: Any
|
||||
) -> Optional[BasePlatformAdapter]:
|
||||
"""Create the appropriate adapter for a platform."""
|
||||
if hasattr(config, "extra") and isinstance(config.extra, dict):
|
||||
config.extra.setdefault(
|
||||
"group_sessions_per_user",
|
||||
self.config.group_sessions_per_user,
|
||||
)
|
||||
|
||||
if platform == Platform.TELEGRAM:
|
||||
from gateway.platforms.telegram import TelegramAdapter, check_telegram_requirements
|
||||
if not check_telegram_requirements():
|
||||
@@ -1247,11 +1113,8 @@ class GatewayRunner:
|
||||
# Special case: Telegram/photo bursts often arrive as multiple near-
|
||||
# simultaneous updates. Do NOT interrupt for photo-only follow-ups here;
|
||||
# let the adapter-level batching/queueing logic absorb them.
|
||||
_quick_key = self._session_key_for_source(source)
|
||||
_quick_key = build_session_key(source)
|
||||
if _quick_key in self._running_agents:
|
||||
if event.get_command() == "status":
|
||||
return await self._handle_status_command(event)
|
||||
|
||||
if event.message_type == MessageType.PHOTO:
|
||||
logger.debug("PRIORITY photo follow-up for session %s — queueing without interrupt", _quick_key[:20])
|
||||
adapter = self.adapters.get(source.platform)
|
||||
@@ -1436,7 +1299,7 @@ class GatewayRunner:
|
||||
logger.debug("Skill command check failed (non-fatal): %s", e)
|
||||
|
||||
# Check for pending exec approval responses
|
||||
session_key_preview = self._session_key_for_source(source)
|
||||
session_key_preview = build_session_key(source)
|
||||
if session_key_preview in self._pending_approvals:
|
||||
user_text = event.text.strip().lower()
|
||||
if user_text in ("yes", "y", "approve", "ok", "go", "do it"):
|
||||
@@ -1485,17 +1348,8 @@ class GatewayRunner:
|
||||
# Set environment variables for tools
|
||||
self._set_session_env(context)
|
||||
|
||||
# Read privacy.redact_pii from config (re-read per message)
|
||||
_redact_pii = False
|
||||
try:
|
||||
with open(_config_path, encoding="utf-8") as _pf:
|
||||
_pcfg = yaml.safe_load(_pf) or {}
|
||||
_redact_pii = bool((_pcfg.get("privacy") or {}).get("redact_pii", False))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Build the context prompt to inject
|
||||
context_prompt = build_session_context_prompt(context, redact_pii=_redact_pii)
|
||||
context_prompt = build_session_context_prompt(context)
|
||||
|
||||
# If the previous session expired and was auto-reset, prepend a notice
|
||||
# so the agent knows this is a fresh conversation (not an intentional /reset).
|
||||
@@ -1864,17 +1718,9 @@ class GatewayRunner:
|
||||
session_key=session_key
|
||||
)
|
||||
|
||||
response = agent_result.get("final_response") or ""
|
||||
response = agent_result.get("final_response", "")
|
||||
agent_messages = agent_result.get("messages", [])
|
||||
|
||||
# Surface error details when the agent failed silently (final_response=None)
|
||||
if not response and agent_result.get("failed"):
|
||||
error_detail = agent_result.get("error", "unknown error")
|
||||
response = (
|
||||
f"The request failed: {str(error_detail)[:300]}\n"
|
||||
"Try again or use /reset to start a fresh session."
|
||||
)
|
||||
|
||||
# If the agent's session_id changed during compression, update
|
||||
# session_entry so transcript writes below go to the right session.
|
||||
if agent_result.get("session_id") and agent_result["session_id"] != session_entry.session_id:
|
||||
@@ -1977,8 +1823,6 @@ class GatewayRunner:
|
||||
# Update session with actual prompt token count and model from the agent
|
||||
self.session_store.update_session(
|
||||
session_entry.session_key,
|
||||
input_tokens=agent_result.get("input_tokens", 0),
|
||||
output_tokens=agent_result.get("output_tokens", 0),
|
||||
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
||||
model=agent_result.get("model"),
|
||||
)
|
||||
@@ -1987,29 +1831,13 @@ class GatewayRunner:
|
||||
if self._should_send_voice_reply(event, response, agent_messages):
|
||||
await self._send_voice_reply(event, response)
|
||||
|
||||
# If streaming already delivered the response, return None so
|
||||
# _process_message_background doesn't send it again.
|
||||
if agent_result.get("already_sent"):
|
||||
return None
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Agent error in session %s", session_key)
|
||||
error_type = type(e).__name__
|
||||
error_detail = str(e)[:300] if str(e) else "no details available"
|
||||
status_hint = ""
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code == 401:
|
||||
status_hint = " Check your API key or run `claude /login` to refresh OAuth credentials."
|
||||
elif status_code == 429:
|
||||
status_hint = " You are being rate-limited. Please wait a moment and try again."
|
||||
elif status_code == 529:
|
||||
status_hint = " The API is temporarily overloaded. Please try again shortly."
|
||||
return (
|
||||
f"Sorry, I encountered an error ({error_type}).\n"
|
||||
f"{error_detail}\n"
|
||||
f"{status_hint}"
|
||||
"Sorry, I encountered an unexpected error. "
|
||||
"The details have been logged for debugging. "
|
||||
"Try again or use /reset to start a fresh session."
|
||||
)
|
||||
finally:
|
||||
@@ -2021,16 +1849,14 @@ class GatewayRunner:
|
||||
source = event.source
|
||||
|
||||
# Get existing session key
|
||||
session_key = self._session_key_for_source(source)
|
||||
session_key = self.session_store._generate_session_key(source)
|
||||
|
||||
# Flush memories in the background (fire-and-forget) so the user
|
||||
# gets the "Session reset!" response immediately.
|
||||
try:
|
||||
old_entry = self.session_store._entries.get(session_key)
|
||||
if old_entry:
|
||||
asyncio.create_task(
|
||||
self._async_flush_memories(old_entry.session_id, session_key)
|
||||
)
|
||||
asyncio.create_task(self._async_flush_memories(old_entry.session_id))
|
||||
except Exception as e:
|
||||
logger.debug("Gateway memory flush on reset failed: %s", e)
|
||||
|
||||
@@ -2194,12 +2020,6 @@ class GatewayRunner:
|
||||
|
||||
# Parse provider:model syntax
|
||||
target_provider, new_model = parse_model_input(args, current_provider)
|
||||
# Auto-detect provider when no explicit provider:model syntax was used
|
||||
if target_provider == current_provider:
|
||||
from hermes_cli.models import detect_provider_for_model
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
provider_changed = target_provider != current_provider
|
||||
|
||||
# Resolve credentials for the target provider (for API probe)
|
||||
@@ -2982,12 +2802,11 @@ class GatewayRunner:
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._reasoning_config = reasoning_config
|
||||
turn_route = self._resolve_turn_agent_config(prompt, model, runtime_kwargs)
|
||||
|
||||
def run_sync():
|
||||
agent = AIAgent(
|
||||
model=turn_route["model"],
|
||||
**turn_route["runtime"],
|
||||
model=model,
|
||||
**runtime_kwargs,
|
||||
max_iterations=max_iterations,
|
||||
quiet_mode=True,
|
||||
verbose_logging=False,
|
||||
@@ -3260,7 +3079,7 @@ class GatewayRunner:
|
||||
return "Session database not available."
|
||||
|
||||
source = event.source
|
||||
session_key = self._session_key_for_source(source)
|
||||
session_key = build_session_key(source)
|
||||
name = event.get_command_args().strip()
|
||||
|
||||
if not name:
|
||||
@@ -3304,9 +3123,7 @@ class GatewayRunner:
|
||||
|
||||
# Flush memories for current session before switching
|
||||
try:
|
||||
asyncio.create_task(
|
||||
self._async_flush_memories(current_entry.session_id, session_key)
|
||||
)
|
||||
asyncio.create_task(self._async_flush_memories(current_entry.session_id))
|
||||
except Exception as e:
|
||||
logger.debug("Memory flush on resume failed: %s", e)
|
||||
|
||||
@@ -3334,7 +3151,7 @@ class GatewayRunner:
|
||||
async def _handle_usage_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /usage command -- show token usage for the session's last agent run."""
|
||||
source = event.source
|
||||
session_key = self._session_key_for_source(source)
|
||||
session_key = build_session_key(source)
|
||||
|
||||
agent = self._running_agents.get(session_key)
|
||||
if agent and hasattr(agent, "session_total_tokens") and agent.session_api_calls > 0:
|
||||
@@ -3709,9 +3526,13 @@ class GatewayRunner:
|
||||
1. Immediately understand what the user sent (no extra tool call).
|
||||
2. Re-examine the image with vision_analyze if it needs more detail.
|
||||
|
||||
Athabasca persistence should happen through Athabasca's own POST
|
||||
/api/uploads flow, using the returned asset.publicUrl rather than local
|
||||
cache paths.
|
||||
|
||||
Args:
|
||||
user_text: The user's original caption / message text.
|
||||
image_paths: List of local file paths to cached images.
|
||||
user_text: The user's original caption / message text.
|
||||
image_paths: List of local file paths to cached images.
|
||||
|
||||
Returns:
|
||||
The enriched message string with vision descriptions prepended.
|
||||
@@ -3736,10 +3557,16 @@ class GatewayRunner:
|
||||
result = _json.loads(result_json)
|
||||
if result.get("success"):
|
||||
description = result.get("analysis", "")
|
||||
athabasca_note = (
|
||||
"\n[If this image needs to persist in Athabasca state, upload the cached file "
|
||||
"through Athabasca POST /api/uploads and use the returned asset.publicUrl. "
|
||||
"Do not store the local cache path as the canonical imageUrl.]"
|
||||
)
|
||||
enriched_parts.append(
|
||||
f"[The user sent an image~ Here's what I can see:\n{description}]\n"
|
||||
f"[If you need a closer look, use vision_analyze with "
|
||||
f"image_url: {path} ~]"
|
||||
f"{athabasca_note}"
|
||||
)
|
||||
else:
|
||||
enriched_parts.append(
|
||||
@@ -3803,10 +3630,7 @@ class GatewayRunner:
|
||||
)
|
||||
else:
|
||||
error = result.get("error", "unknown error")
|
||||
if (
|
||||
"No STT provider" in error
|
||||
or error.startswith("Neither VOICE_TOOLS_OPENAI_KEY nor OPENAI_API_KEY is set")
|
||||
):
|
||||
if "No STT provider" in error or "not set" in error:
|
||||
enriched_parts.append(
|
||||
"[The user sent a voice message but I can't listen "
|
||||
"to it right now~ No STT provider is configured "
|
||||
@@ -3851,7 +3675,6 @@ class GatewayRunner:
|
||||
session_key = watcher.get("session_key", "")
|
||||
platform_name = watcher.get("platform", "")
|
||||
chat_id = watcher.get("chat_id", "")
|
||||
thread_id = watcher.get("thread_id", "")
|
||||
notify_mode = self._load_background_notifications_mode()
|
||||
|
||||
logger.debug("Process watcher started: %s (every %ss, notify=%s)",
|
||||
@@ -3899,8 +3722,7 @@ class GatewayRunner:
|
||||
break
|
||||
if adapter and chat_id:
|
||||
try:
|
||||
send_meta = {"thread_id": thread_id} if thread_id else None
|
||||
await adapter.send(chat_id, message_text, metadata=send_meta)
|
||||
await adapter.send(chat_id, message_text)
|
||||
except Exception as e:
|
||||
logger.error("Watcher delivery error: %s", e)
|
||||
break
|
||||
@@ -3919,8 +3741,7 @@ class GatewayRunner:
|
||||
break
|
||||
if adapter and chat_id:
|
||||
try:
|
||||
send_meta = {"thread_id": thread_id} if thread_id else None
|
||||
await adapter.send(chat_id, message_text, metadata=send_meta)
|
||||
await adapter.send(chat_id, message_text)
|
||||
except Exception as e:
|
||||
logger.error("Watcher delivery error: %s", e)
|
||||
|
||||
@@ -4031,8 +3852,45 @@ class GatewayRunner:
|
||||
last_tool[0] = tool_name
|
||||
|
||||
# Build progress message with primary argument preview
|
||||
from agent.display import get_tool_emoji
|
||||
emoji = get_tool_emoji(tool_name, default="⚙️")
|
||||
tool_emojis = {
|
||||
"terminal": "💻",
|
||||
"process": "⚙️",
|
||||
"web_search": "🔍",
|
||||
"web_extract": "📄",
|
||||
"read_file": "📖",
|
||||
"write_file": "✍️",
|
||||
"patch": "🔧",
|
||||
"search": "🔎",
|
||||
"search_files": "🔎",
|
||||
"list_directory": "📂",
|
||||
"image_generate": "🎨",
|
||||
"text_to_speech": "🔊",
|
||||
"browser_navigate": "🌐",
|
||||
"browser_click": "👆",
|
||||
"browser_type": "⌨️",
|
||||
"browser_snapshot": "📸",
|
||||
"browser_scroll": "📜",
|
||||
"browser_back": "◀️",
|
||||
"browser_press": "⌨️",
|
||||
"browser_close": "🚪",
|
||||
"browser_get_images": "🖼️",
|
||||
"browser_vision": "👁️",
|
||||
"moa_query": "🧠",
|
||||
"mixture_of_agents": "🧠",
|
||||
"vision_analyze": "👁️",
|
||||
"skill_view": "📚",
|
||||
"skills_list": "📋",
|
||||
"todo": "📋",
|
||||
"memory": "🧠",
|
||||
"session_search": "🔍",
|
||||
"send_message": "📨",
|
||||
"cronjob": "⏰",
|
||||
"execute_code": "🐍",
|
||||
"delegate_task": "🔀",
|
||||
"clarify": "❓",
|
||||
"skill_manage": "📝",
|
||||
}
|
||||
emoji = tool_emojis.get(tool_name, "⚙️")
|
||||
|
||||
# Verbose mode: show detailed arguments
|
||||
if progress_mode == "verbose" and args:
|
||||
@@ -4159,7 +4017,6 @@ class GatewayRunner:
|
||||
agent_holder = [None] # Mutable container for the agent instance
|
||||
result_holder = [None] # Mutable container for the result
|
||||
tools_holder = [None] # Mutable container for the tool definitions
|
||||
stream_consumer_holder = [None] # Mutable container for stream consumer
|
||||
|
||||
# Bridge sync step_callback → async hooks.emit for agent:step events
|
||||
_loop_for_step = asyncio.get_event_loop()
|
||||
@@ -4222,39 +4079,9 @@ class GatewayRunner:
|
||||
honcho_manager, honcho_config = self._get_or_create_gateway_honcho(session_key)
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._reasoning_config = reasoning_config
|
||||
# Set up streaming consumer if enabled
|
||||
_stream_consumer = None
|
||||
_stream_delta_cb = None
|
||||
_scfg = getattr(getattr(self, 'config', None), 'streaming', None)
|
||||
if _scfg is None:
|
||||
from gateway.config import StreamingConfig
|
||||
_scfg = StreamingConfig()
|
||||
|
||||
if _scfg.enabled and _scfg.transport != "off":
|
||||
try:
|
||||
from gateway.stream_consumer import GatewayStreamConsumer, StreamConsumerConfig
|
||||
_adapter = self.adapters.get(source.platform)
|
||||
if _adapter:
|
||||
_consumer_cfg = StreamConsumerConfig(
|
||||
edit_interval=_scfg.edit_interval,
|
||||
buffer_threshold=_scfg.buffer_threshold,
|
||||
cursor=_scfg.cursor,
|
||||
)
|
||||
_stream_consumer = GatewayStreamConsumer(
|
||||
adapter=_adapter,
|
||||
chat_id=source.chat_id,
|
||||
config=_consumer_cfg,
|
||||
metadata={"thread_id": source.thread_id} if source.thread_id else None,
|
||||
)
|
||||
_stream_delta_cb = _stream_consumer.on_delta
|
||||
stream_consumer_holder[0] = _stream_consumer
|
||||
except Exception as _sc_err:
|
||||
logger.debug("Could not set up stream consumer: %s", _sc_err)
|
||||
|
||||
turn_route = self._resolve_turn_agent_config(message, model, runtime_kwargs)
|
||||
agent = AIAgent(
|
||||
model=turn_route["model"],
|
||||
**turn_route["runtime"],
|
||||
model=model,
|
||||
**runtime_kwargs,
|
||||
max_iterations=max_iterations,
|
||||
quiet_mode=True,
|
||||
verbose_logging=False,
|
||||
@@ -4271,7 +4098,6 @@ class GatewayRunner:
|
||||
session_id=session_id,
|
||||
tool_progress_callback=progress_callback if tool_progress_enabled else None,
|
||||
step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None,
|
||||
stream_delta_callback=_stream_delta_cb,
|
||||
platform=platform_key,
|
||||
honcho_session_key=session_key,
|
||||
honcho_manager=honcho_manager,
|
||||
@@ -4342,23 +4168,15 @@ class GatewayRunner:
|
||||
|
||||
result = agent.run_conversation(message, conversation_history=agent_history, task_id=session_id)
|
||||
result_holder[0] = result
|
||||
|
||||
# Signal the stream consumer that the agent is done
|
||||
if _stream_consumer is not None:
|
||||
_stream_consumer.finish()
|
||||
|
||||
# Return final response, or a message if something went wrong
|
||||
final_response = result.get("final_response")
|
||||
|
||||
# Extract actual token counts from the agent instance used for this run
|
||||
# Extract last actual prompt token count from the agent's compressor
|
||||
_last_prompt_toks = 0
|
||||
_input_toks = 0
|
||||
_output_toks = 0
|
||||
_agent = agent_holder[0]
|
||||
if _agent and hasattr(_agent, "context_compressor"):
|
||||
_last_prompt_toks = getattr(_agent.context_compressor, "last_prompt_tokens", 0)
|
||||
_input_toks = getattr(_agent, "session_prompt_tokens", 0)
|
||||
_output_toks = getattr(_agent, "session_completion_tokens", 0)
|
||||
_resolved_model = getattr(_agent, "model", None) if _agent else None
|
||||
|
||||
if not final_response:
|
||||
@@ -4370,8 +4188,6 @@ class GatewayRunner:
|
||||
"tools": tools_holder[0] or [],
|
||||
"history_offset": len(agent_history),
|
||||
"last_prompt_tokens": _last_prompt_toks,
|
||||
"input_tokens": _input_toks,
|
||||
"output_tokens": _output_toks,
|
||||
"model": _resolved_model,
|
||||
}
|
||||
|
||||
@@ -4435,8 +4251,6 @@ class GatewayRunner:
|
||||
"tools": tools_holder[0] or [],
|
||||
"history_offset": len(agent_history),
|
||||
"last_prompt_tokens": _last_prompt_toks,
|
||||
"input_tokens": _input_toks,
|
||||
"output_tokens": _output_toks,
|
||||
"model": _resolved_model,
|
||||
"session_id": effective_session_id,
|
||||
}
|
||||
@@ -4445,20 +4259,6 @@ class GatewayRunner:
|
||||
progress_task = None
|
||||
if tool_progress_enabled:
|
||||
progress_task = asyncio.create_task(send_progress_messages())
|
||||
|
||||
# Start stream consumer task — polls for consumer creation since it
|
||||
# happens inside run_sync (thread pool) after the agent is constructed.
|
||||
stream_task = None
|
||||
|
||||
async def _start_stream_consumer():
|
||||
"""Wait for the stream consumer to be created, then run it."""
|
||||
for _ in range(200): # Up to 10s wait
|
||||
if stream_consumer_holder[0] is not None:
|
||||
await stream_consumer_holder[0].run()
|
||||
return
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
stream_task = asyncio.create_task(_start_stream_consumer())
|
||||
|
||||
# Track this agent as running for this session (for interrupt support)
|
||||
# We do this in a callback after the agent is created
|
||||
@@ -4541,17 +4341,6 @@ class GatewayRunner:
|
||||
if progress_task:
|
||||
progress_task.cancel()
|
||||
interrupt_monitor.cancel()
|
||||
|
||||
# Wait for stream consumer to finish its final edit
|
||||
if stream_task:
|
||||
try:
|
||||
await asyncio.wait_for(stream_task, timeout=5.0)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
stream_task.cancel()
|
||||
try:
|
||||
await stream_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Clean up tracking
|
||||
tracking_task.cancel()
|
||||
@@ -4565,12 +4354,6 @@ class GatewayRunner:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# If streaming already delivered the response, mark it so the
|
||||
# caller's send() is skipped (avoiding duplicate messages).
|
||||
_sc = stream_consumer_holder[0]
|
||||
if _sc and _sc.already_sent and isinstance(response, dict):
|
||||
response["already_sent"] = True
|
||||
|
||||
return response
|
||||
|
||||
|
||||
+12
-108
@@ -8,11 +8,9 @@ Handles:
|
||||
- Dynamic system prompt injection (agent knows its context)
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta
|
||||
@@ -21,41 +19,6 @@ from typing import Dict, List, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PII redaction helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_PHONE_RE = re.compile(r"^\+?\d[\d\-\s]{6,}$")
|
||||
|
||||
|
||||
def _hash_id(value: str) -> str:
|
||||
"""Deterministic 12-char hex hash of an identifier."""
|
||||
return hashlib.sha256(value.encode("utf-8")).hexdigest()[:12]
|
||||
|
||||
|
||||
def _hash_sender_id(value: str) -> str:
|
||||
"""Hash a sender ID to ``user_<12hex>``."""
|
||||
return f"user_{_hash_id(value)}"
|
||||
|
||||
|
||||
def _hash_chat_id(value: str) -> str:
|
||||
"""Hash the numeric portion of a chat ID, preserving platform prefix.
|
||||
|
||||
``telegram:12345`` → ``telegram:<hash>``
|
||||
``12345`` → ``<hash>``
|
||||
"""
|
||||
colon = value.find(":")
|
||||
if colon > 0:
|
||||
prefix = value[:colon]
|
||||
return f"{prefix}:{_hash_id(value[colon + 1:])}"
|
||||
return _hash_id(value)
|
||||
|
||||
|
||||
def _looks_like_phone(value: str) -> bool:
|
||||
"""Return True if *value* looks like a phone number (E.164 or similar)."""
|
||||
return bool(_PHONE_RE.match(value.strip()))
|
||||
|
||||
from .config import (
|
||||
Platform,
|
||||
GatewayConfig,
|
||||
@@ -183,21 +146,7 @@ class SessionContext:
|
||||
}
|
||||
|
||||
|
||||
_PII_SAFE_PLATFORMS = frozenset({
|
||||
Platform.WHATSAPP,
|
||||
Platform.SIGNAL,
|
||||
Platform.TELEGRAM,
|
||||
})
|
||||
"""Platforms where user IDs can be safely redacted (no in-message mention system
|
||||
that requires raw IDs). Discord is excluded because mentions use ``<@user_id>``
|
||||
and the LLM needs the real ID to tag users."""
|
||||
|
||||
|
||||
def build_session_context_prompt(
|
||||
context: SessionContext,
|
||||
*,
|
||||
redact_pii: bool = False,
|
||||
) -> str:
|
||||
def build_session_context_prompt(context: SessionContext) -> str:
|
||||
"""
|
||||
Build the dynamic system prompt section that tells the agent about its context.
|
||||
|
||||
@@ -205,15 +154,7 @@ def build_session_context_prompt(
|
||||
- Where messages are coming from
|
||||
- What platforms are connected
|
||||
- Where it can deliver scheduled task outputs
|
||||
|
||||
When *redact_pii* is True **and** the source platform is in
|
||||
``_PII_SAFE_PLATFORMS``, phone numbers are stripped and user/chat IDs
|
||||
are replaced with deterministic hashes before being sent to the LLM.
|
||||
Platforms like Discord are excluded because mentions need real IDs.
|
||||
Routing still uses the original values (they stay in SessionSource).
|
||||
"""
|
||||
# Only apply redaction on platforms where IDs aren't needed for mentions
|
||||
redact_pii = redact_pii and context.source.platform in _PII_SAFE_PLATFORMS
|
||||
lines = [
|
||||
"## Current Session Context",
|
||||
"",
|
||||
@@ -224,25 +165,7 @@ def build_session_context_prompt(
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append(f"**Source:** {platform_name} (the machine running this agent)")
|
||||
else:
|
||||
# Build a description that respects PII redaction
|
||||
src = context.source
|
||||
if redact_pii:
|
||||
# Build a safe description without raw IDs
|
||||
_uname = src.user_name or (
|
||||
_hash_sender_id(src.user_id) if src.user_id else "user"
|
||||
)
|
||||
_cname = src.chat_name or _hash_chat_id(src.chat_id)
|
||||
if src.chat_type == "dm":
|
||||
desc = f"DM with {_uname}"
|
||||
elif src.chat_type == "group":
|
||||
desc = f"group: {_cname}"
|
||||
elif src.chat_type == "channel":
|
||||
desc = f"channel: {_cname}"
|
||||
else:
|
||||
desc = _cname
|
||||
else:
|
||||
desc = src.description
|
||||
lines.append(f"**Source:** {platform_name} ({desc})")
|
||||
lines.append(f"**Source:** {platform_name} ({context.source.description})")
|
||||
|
||||
# Channel topic (if available - provides context about the channel's purpose)
|
||||
if context.source.chat_topic:
|
||||
@@ -252,10 +175,7 @@ def build_session_context_prompt(
|
||||
if context.source.user_name:
|
||||
lines.append(f"**User:** {context.source.user_name}")
|
||||
elif context.source.user_id:
|
||||
uid = context.source.user_id
|
||||
if redact_pii:
|
||||
uid = _hash_sender_id(uid)
|
||||
lines.append(f"**User ID:** {uid}")
|
||||
lines.append(f"**User ID:** {context.source.user_id}")
|
||||
|
||||
# Platform-specific behavioral notes
|
||||
if context.source.platform == Platform.SLACK:
|
||||
@@ -290,8 +210,7 @@ def build_session_context_prompt(
|
||||
lines.append("")
|
||||
lines.append("**Home Channels (default destinations):**")
|
||||
for platform, home in context.home_channels.items():
|
||||
hc_id = _hash_chat_id(home.chat_id) if redact_pii else home.chat_id
|
||||
lines.append(f" - {platform.value}: {home.name} (ID: {hc_id})")
|
||||
lines.append(f" - {platform.value}: {home.name} (ID: {home.chat_id})")
|
||||
|
||||
# Delivery options for scheduled tasks
|
||||
lines.append("")
|
||||
@@ -301,10 +220,7 @@ def build_session_context_prompt(
|
||||
if context.source.platform == Platform.LOCAL:
|
||||
lines.append("- `\"origin\"` → Local output (saved to files)")
|
||||
else:
|
||||
_origin_label = context.source.chat_name or (
|
||||
_hash_chat_id(context.source.chat_id) if redact_pii else context.source.chat_id
|
||||
)
|
||||
lines.append(f"- `\"origin\"` → Back to this chat ({_origin_label})")
|
||||
lines.append(f"- `\"origin\"` → Back to this chat ({context.source.chat_name or context.source.chat_id})")
|
||||
|
||||
# Local always available
|
||||
lines.append("- `\"local\"` → Save to local files only (~/.hermes/cron/output/)")
|
||||
@@ -399,7 +315,7 @@ class SessionEntry:
|
||||
)
|
||||
|
||||
|
||||
def build_session_key(source: SessionSource, group_sessions_per_user: bool = True) -> str:
|
||||
def build_session_key(source: SessionSource) -> str:
|
||||
"""Build a deterministic session key from a message source.
|
||||
|
||||
This is the single source of truth for session key construction.
|
||||
@@ -412,11 +328,7 @@ def build_session_key(source: SessionSource, group_sessions_per_user: bool = Tru
|
||||
|
||||
Group/channel rules:
|
||||
- chat_id identifies the parent group/channel.
|
||||
- user_id/user_id_alt isolates participants within that parent chat when available when
|
||||
``group_sessions_per_user`` is enabled.
|
||||
- thread_id differentiates threads within that parent chat.
|
||||
- Without participant identifiers, or when isolation is disabled, messages fall back to one
|
||||
shared session per chat.
|
||||
- Without identifiers, messages fall back to one session per platform/chat_type.
|
||||
"""
|
||||
platform = source.platform.value
|
||||
@@ -428,18 +340,13 @@ def build_session_key(source: SessionSource, group_sessions_per_user: bool = Tru
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:dm:{source.thread_id}"
|
||||
return f"agent:main:{platform}:dm"
|
||||
|
||||
participant_id = source.user_id_alt or source.user_id
|
||||
key_parts = ["agent:main", platform, source.chat_type]
|
||||
|
||||
if source.chat_id:
|
||||
key_parts.append(source.chat_id)
|
||||
if source.thread_id:
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.chat_id}"
|
||||
if source.thread_id:
|
||||
key_parts.append(source.thread_id)
|
||||
if group_sessions_per_user and participant_id:
|
||||
key_parts.append(str(participant_id))
|
||||
|
||||
return ":".join(key_parts)
|
||||
return f"agent:main:{platform}:{source.chat_type}:{source.thread_id}"
|
||||
return f"agent:main:{platform}:{source.chat_type}"
|
||||
|
||||
|
||||
class SessionStore:
|
||||
@@ -518,10 +425,7 @@ class SessionStore:
|
||||
|
||||
def _generate_session_key(self, source: SessionSource) -> str:
|
||||
"""Generate a session key from a source."""
|
||||
return build_session_key(
|
||||
source,
|
||||
group_sessions_per_user=getattr(self.config, "group_sessions_per_user", True),
|
||||
)
|
||||
return build_session_key(source)
|
||||
|
||||
def _is_session_expired(self, entry: SessionEntry) -> bool:
|
||||
"""Check if a session has expired based on its reset policy.
|
||||
|
||||
+4
-22
@@ -83,7 +83,8 @@ def _looks_like_gateway_process(pid: int) -> bool:
|
||||
"""Return True when the live PID still looks like the Hermes gateway."""
|
||||
cmdline = _read_process_cmdline(pid)
|
||||
if not cmdline:
|
||||
return False
|
||||
# If we cannot inspect the process, fall back to the liveness check.
|
||||
return True
|
||||
|
||||
patterns = (
|
||||
"hermes_cli.main gateway",
|
||||
@@ -93,24 +94,6 @@ def _looks_like_gateway_process(pid: int) -> bool:
|
||||
return any(pattern in cmdline for pattern in patterns)
|
||||
|
||||
|
||||
def _record_looks_like_gateway(record: dict[str, Any]) -> bool:
|
||||
"""Validate gateway identity from PID-file metadata when cmdline is unavailable."""
|
||||
if record.get("kind") != _GATEWAY_KIND:
|
||||
return False
|
||||
|
||||
argv = record.get("argv")
|
||||
if not isinstance(argv, list) or not argv:
|
||||
return False
|
||||
|
||||
cmdline = " ".join(str(part) for part in argv)
|
||||
patterns = (
|
||||
"hermes_cli.main gateway",
|
||||
"hermes gateway",
|
||||
"gateway/run.py",
|
||||
)
|
||||
return any(pattern in cmdline for pattern in patterns)
|
||||
|
||||
|
||||
def _build_pid_record() -> dict:
|
||||
return {
|
||||
"pid": os.getpid(),
|
||||
@@ -342,9 +325,8 @@ def get_running_pid() -> Optional[int]:
|
||||
return None
|
||||
|
||||
if not _looks_like_gateway_process(pid):
|
||||
if not _record_looks_like_gateway(record):
|
||||
remove_pid_file()
|
||||
return None
|
||||
remove_pid_file()
|
||||
return None
|
||||
|
||||
return pid
|
||||
|
||||
|
||||
@@ -1,177 +0,0 @@
|
||||
"""Gateway streaming consumer — bridges sync agent callbacks to async platform delivery.
|
||||
|
||||
The agent fires stream_delta_callback(text) synchronously from its worker thread.
|
||||
GatewayStreamConsumer:
|
||||
1. Receives deltas via on_delta() (thread-safe, sync)
|
||||
2. Queues them to an asyncio task via queue.Queue
|
||||
3. The async run() task buffers, rate-limits, and progressively edits
|
||||
a single message on the target platform
|
||||
|
||||
Design: Uses the edit transport (send initial message, then editMessageText).
|
||||
This is universally supported across Telegram, Discord, and Slack.
|
||||
|
||||
Credit: jobless0x (#774, #1312), OutThisLife (#798), clicksingh (#697).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger("gateway.stream_consumer")
|
||||
|
||||
# Sentinel to signal the stream is complete
|
||||
_DONE = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamConsumerConfig:
|
||||
"""Runtime config for a single stream consumer instance."""
|
||||
edit_interval: float = 0.3
|
||||
buffer_threshold: int = 40
|
||||
cursor: str = " ▉"
|
||||
|
||||
|
||||
class GatewayStreamConsumer:
|
||||
"""Async consumer that progressively edits a platform message with streamed tokens.
|
||||
|
||||
Usage::
|
||||
|
||||
consumer = GatewayStreamConsumer(adapter, chat_id, config, metadata=metadata)
|
||||
# Pass consumer.on_delta as stream_delta_callback to AIAgent
|
||||
agent = AIAgent(..., stream_delta_callback=consumer.on_delta)
|
||||
# Start the consumer as an asyncio task
|
||||
task = asyncio.create_task(consumer.run())
|
||||
# ... run agent in thread pool ...
|
||||
consumer.finish() # signal completion
|
||||
await task # wait for final edit
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: Any,
|
||||
chat_id: str,
|
||||
config: Optional[StreamConsumerConfig] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
self.adapter = adapter
|
||||
self.chat_id = chat_id
|
||||
self.cfg = config or StreamConsumerConfig()
|
||||
self.metadata = metadata
|
||||
self._queue: queue.Queue = queue.Queue()
|
||||
self._accumulated = ""
|
||||
self._message_id: Optional[str] = None
|
||||
self._already_sent = False
|
||||
self._edit_supported = True # Disabled on first edit failure (Signal/Email/HA)
|
||||
self._last_edit_time = 0.0
|
||||
|
||||
@property
|
||||
def already_sent(self) -> bool:
|
||||
"""True if at least one message was sent/edited — signals the base
|
||||
adapter to skip re-sending the final response."""
|
||||
return self._already_sent
|
||||
|
||||
def on_delta(self, text: str) -> None:
|
||||
"""Thread-safe callback — called from the agent's worker thread."""
|
||||
if text:
|
||||
self._queue.put(text)
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Signal that the stream is complete."""
|
||||
self._queue.put(_DONE)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Async task that drains the queue and edits the platform message."""
|
||||
try:
|
||||
while True:
|
||||
# Drain all available items from the queue
|
||||
got_done = False
|
||||
while True:
|
||||
try:
|
||||
item = self._queue.get_nowait()
|
||||
if item is _DONE:
|
||||
got_done = True
|
||||
break
|
||||
self._accumulated += item
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Decide whether to flush an edit
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_edit_time
|
||||
should_edit = (
|
||||
got_done
|
||||
or (elapsed >= self.cfg.edit_interval
|
||||
and len(self._accumulated) > 0)
|
||||
or len(self._accumulated) >= self.cfg.buffer_threshold
|
||||
)
|
||||
|
||||
if should_edit and self._accumulated:
|
||||
display_text = self._accumulated
|
||||
if not got_done:
|
||||
display_text += self.cfg.cursor
|
||||
|
||||
await self._send_or_edit(display_text)
|
||||
self._last_edit_time = time.monotonic()
|
||||
|
||||
if got_done:
|
||||
# Final edit without cursor
|
||||
if self._accumulated and self._message_id:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
return
|
||||
|
||||
await asyncio.sleep(0.05) # Small yield to not busy-loop
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Best-effort final edit on cancellation
|
||||
if self._accumulated and self._message_id:
|
||||
try:
|
||||
await self._send_or_edit(self._accumulated)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error("Stream consumer error: %s", e)
|
||||
|
||||
async def _send_or_edit(self, text: str) -> None:
|
||||
"""Send or edit the streaming message."""
|
||||
try:
|
||||
if self._message_id is not None:
|
||||
if self._edit_supported:
|
||||
# Edit existing message
|
||||
result = await self.adapter.edit_message(
|
||||
chat_id=self.chat_id,
|
||||
message_id=self._message_id,
|
||||
content=text,
|
||||
)
|
||||
if result.success:
|
||||
self._already_sent = True
|
||||
else:
|
||||
# Edit not supported by this adapter — stop streaming,
|
||||
# let the normal send path handle the final response.
|
||||
# Without this guard, adapters like Signal/Email would
|
||||
# flood the chat with a new message every edit_interval.
|
||||
logger.debug("Edit failed, disabling streaming for this adapter")
|
||||
self._edit_supported = False
|
||||
else:
|
||||
# Editing not supported — skip intermediate updates.
|
||||
# The final response will be sent by the normal path.
|
||||
pass
|
||||
else:
|
||||
# First message — send new
|
||||
result = await self.adapter.send(
|
||||
chat_id=self.chat_id,
|
||||
content=text,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
if result.success and result.message_id:
|
||||
self._message_id = result.message_id
|
||||
self._already_sent = True
|
||||
else:
|
||||
# Initial send failed — disable streaming for this session
|
||||
self._edit_supported = False
|
||||
except Exception as e:
|
||||
logger.error("Stream send/edit error: %s", e)
|
||||
@@ -147,14 +147,6 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
api_key_env_vars=("MINIMAX_CN_API_KEY",),
|
||||
base_url_env_var="MINIMAX_CN_BASE_URL",
|
||||
),
|
||||
"deepseek": ProviderConfig(
|
||||
id="deepseek",
|
||||
name="DeepSeek",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://api.deepseek.com/v1",
|
||||
api_key_env_vars=("DEEPSEEK_API_KEY",),
|
||||
base_url_env_var="DEEPSEEK_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -7,9 +7,7 @@ interactive CLI.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Callable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
@@ -28,7 +26,6 @@ COMMANDS_BY_CATEGORY = {
|
||||
"/title": "Set a title for the current session (usage: /title My Session Name)",
|
||||
"/compress": "Manually compress conversation context (flush memories + summarize)",
|
||||
"/rollback": "List or restore filesystem checkpoints (usage: /rollback [number])",
|
||||
"/stop": "Kill all running background processes",
|
||||
"/background": "Run a prompt in the background (usage: /background <prompt>)",
|
||||
},
|
||||
"Configuration": {
|
||||
@@ -48,8 +45,6 @@ COMMANDS_BY_CATEGORY = {
|
||||
"/skills": "Search, install, inspect, or manage skills from online registries",
|
||||
"/cron": "Manage scheduled tasks (list, add/create, edit, pause, resume, run, remove)",
|
||||
"/reload-mcp": "Reload MCP servers from config.yaml",
|
||||
"/browser": "Connect browser tools to your live Chrome (usage: /browser connect|disconnect|status)",
|
||||
"/plugins": "List installed plugins and their status",
|
||||
},
|
||||
"Info": {
|
||||
"/help": "Show this help message",
|
||||
@@ -97,88 +92,9 @@ class SlashCommandCompleter(Completer):
|
||||
"""
|
||||
return f"{cmd_name} " if cmd_name == word else cmd_name
|
||||
|
||||
@staticmethod
|
||||
def _extract_path_word(text: str) -> str | None:
|
||||
"""Extract the current word if it looks like a file path.
|
||||
|
||||
Returns the path-like token under the cursor, or None if the
|
||||
current word doesn't look like a path. A word is path-like when
|
||||
it starts with ``./``, ``../``, ``~/``, ``/``, or contains a
|
||||
``/`` separator (e.g. ``src/main.py``).
|
||||
"""
|
||||
if not text:
|
||||
return None
|
||||
# Walk backwards to find the start of the current "word".
|
||||
# Words are delimited by spaces, but paths can contain almost anything.
|
||||
i = len(text) - 1
|
||||
while i >= 0 and text[i] != " ":
|
||||
i -= 1
|
||||
word = text[i + 1:]
|
||||
if not word:
|
||||
return None
|
||||
# Only trigger path completion for path-like tokens
|
||||
if word.startswith(("./", "../", "~/", "/")) or "/" in word:
|
||||
return word
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _path_completions(word: str, limit: int = 30):
|
||||
"""Yield Completion objects for file paths matching *word*."""
|
||||
expanded = os.path.expanduser(word)
|
||||
# Split into directory part and prefix to match inside it
|
||||
if expanded.endswith("/"):
|
||||
search_dir = expanded
|
||||
prefix = ""
|
||||
else:
|
||||
search_dir = os.path.dirname(expanded) or "."
|
||||
prefix = os.path.basename(expanded)
|
||||
|
||||
try:
|
||||
entries = os.listdir(search_dir)
|
||||
except OSError:
|
||||
return
|
||||
|
||||
count = 0
|
||||
prefix_lower = prefix.lower()
|
||||
for entry in sorted(entries):
|
||||
if prefix and not entry.lower().startswith(prefix_lower):
|
||||
continue
|
||||
if count >= limit:
|
||||
break
|
||||
|
||||
full_path = os.path.join(search_dir, entry)
|
||||
is_dir = os.path.isdir(full_path)
|
||||
|
||||
# Build the completion text (what replaces the typed word)
|
||||
if word.startswith("~"):
|
||||
display_path = "~/" + os.path.relpath(full_path, os.path.expanduser("~"))
|
||||
elif os.path.isabs(word):
|
||||
display_path = full_path
|
||||
else:
|
||||
# Keep relative
|
||||
display_path = os.path.relpath(full_path)
|
||||
|
||||
if is_dir:
|
||||
display_path += "/"
|
||||
|
||||
suffix = "/" if is_dir else ""
|
||||
meta = "dir" if is_dir else _file_size_label(full_path)
|
||||
|
||||
yield Completion(
|
||||
display_path,
|
||||
start_position=-len(word),
|
||||
display=entry + suffix,
|
||||
display_meta=meta,
|
||||
)
|
||||
count += 1
|
||||
|
||||
def get_completions(self, document, complete_event):
|
||||
text = document.text_before_cursor
|
||||
if not text.startswith("/"):
|
||||
# Try file path completion for non-slash input
|
||||
path_word = self._extract_path_word(text)
|
||||
if path_word is not None:
|
||||
yield from self._path_completions(path_word)
|
||||
return
|
||||
|
||||
word = text[1:]
|
||||
@@ -204,18 +120,3 @@ class SlashCommandCompleter(Completer):
|
||||
display=cmd,
|
||||
display_meta=f"⚡ {short_desc}",
|
||||
)
|
||||
|
||||
|
||||
def _file_size_label(path: str) -> str:
|
||||
"""Return a compact human-readable file size, or '' on error."""
|
||||
try:
|
||||
size = os.path.getsize(path)
|
||||
except OSError:
|
||||
return ""
|
||||
if size < 1024:
|
||||
return f"{size}B"
|
||||
if size < 1024 * 1024:
|
||||
return f"{size / 1024:.0f}K"
|
||||
if size < 1024 * 1024 * 1024:
|
||||
return f"{size / (1024 * 1024):.1f}M"
|
||||
return f"{size / (1024 * 1024 * 1024):.1f}G"
|
||||
|
||||
+3
-79
@@ -106,6 +106,7 @@ DEFAULT_CONFIG = {
|
||||
"cwd": ".", # Use current directory
|
||||
"timeout": 180,
|
||||
"docker_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"docker_forward_env": [],
|
||||
"singularity_image": "docker://nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"modal_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
"daytona_image": "nikolaik/python-nodejs:python3.11-nodejs20",
|
||||
@@ -118,14 +119,6 @@ DEFAULT_CONFIG = {
|
||||
# Each entry is "host_path:container_path" (standard Docker -v syntax).
|
||||
# Example: ["/home/user/projects:/workspace/projects", "/data:/data"]
|
||||
"docker_volumes": [],
|
||||
# Explicit opt-in: mount the host cwd into /workspace for Docker sessions.
|
||||
# Default off because passing host directories into a sandbox weakens isolation.
|
||||
"docker_mount_cwd_to_workspace": False,
|
||||
# Persistent shell — keep a long-lived bash shell across execute() calls
|
||||
# so cwd/env vars/shell variables survive between commands.
|
||||
# Enabled by default for non-local backends (SSH); local is always opt-in
|
||||
# via TERMINAL_LOCAL_PERSISTENT env var.
|
||||
"persistent_shell": True,
|
||||
},
|
||||
|
||||
"browser": {
|
||||
@@ -137,7 +130,7 @@ DEFAULT_CONFIG = {
|
||||
# When enabled, the agent takes a snapshot of the working directory once per
|
||||
# conversation turn (on first write_file/patch call). Use /rollback to restore.
|
||||
"checkpoints": {
|
||||
"enabled": True,
|
||||
"enabled": False,
|
||||
"max_snapshots": 50, # Max checkpoints to keep per directory
|
||||
},
|
||||
|
||||
@@ -147,12 +140,6 @@ DEFAULT_CONFIG = {
|
||||
"summary_model": "google/gemini-3-flash-preview",
|
||||
"summary_provider": "auto",
|
||||
},
|
||||
"smart_model_routing": {
|
||||
"enabled": False,
|
||||
"max_simple_chars": 160,
|
||||
"max_simple_words": 28,
|
||||
"cheap_model": {},
|
||||
},
|
||||
|
||||
# Auxiliary model config — provider:model for each side task.
|
||||
# Format: provider is the provider name, model is the model slug.
|
||||
@@ -191,12 +178,6 @@ DEFAULT_CONFIG = {
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
},
|
||||
"approval": {
|
||||
"provider": "auto",
|
||||
"model": "", # fast/cheap model recommended (e.g. gemini-flash, haiku)
|
||||
"base_url": "",
|
||||
"api_key": "",
|
||||
},
|
||||
"mcp": {
|
||||
"provider": "auto",
|
||||
"model": "",
|
||||
@@ -217,15 +198,8 @@ DEFAULT_CONFIG = {
|
||||
"resume_display": "full",
|
||||
"bell_on_complete": False,
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
"show_cost": False, # Show $ cost in the status bar (off by default)
|
||||
"skin": "default",
|
||||
},
|
||||
|
||||
# Privacy settings
|
||||
"privacy": {
|
||||
"redact_pii": False, # When True, hash user IDs and strip phone numbers from LLM context
|
||||
},
|
||||
|
||||
# Text-to-speech configuration
|
||||
"tts": {
|
||||
@@ -310,14 +284,6 @@ DEFAULT_CONFIG = {
|
||||
"auto_thread": True, # Auto-create threads on @mention in channels (like Slack)
|
||||
},
|
||||
|
||||
# Approval mode for dangerous commands:
|
||||
# manual — always prompt the user (default)
|
||||
# smart — use auxiliary LLM to auto-approve low-risk commands, prompt for high-risk
|
||||
# off — skip all approval prompts (equivalent to --yolo)
|
||||
"approvals": {
|
||||
"mode": "manual",
|
||||
},
|
||||
|
||||
# Permanently allowed dangerous command patterns (added via "always" approval)
|
||||
"command_allowlist": [],
|
||||
# User-defined quick commands that bypass the agent loop (type: exec only)
|
||||
@@ -337,7 +303,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
|
||||
# Config schema version - bump this when adding new required fields
|
||||
"_config_version": 8,
|
||||
"_config_version": 9,
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
@@ -459,20 +425,6 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"DEEPSEEK_API_KEY": {
|
||||
"description": "DeepSeek API key for direct DeepSeek access",
|
||||
"prompt": "DeepSeek API Key",
|
||||
"url": "https://platform.deepseek.com/api_keys",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
},
|
||||
"DEEPSEEK_BASE_URL": {
|
||||
"description": "Custom DeepSeek API base URL (advanced)",
|
||||
"prompt": "DeepSeek Base URL",
|
||||
"url": "",
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
},
|
||||
|
||||
# ── Tool API keys ──
|
||||
"FIRECRAWL_API_KEY": {
|
||||
@@ -1017,19 +969,6 @@ _FALLBACK_COMMENT = """
|
||||
# fallback_model:
|
||||
# provider: openrouter
|
||||
# model: anthropic/claude-sonnet-4
|
||||
#
|
||||
# ── Smart Model Routing ────────────────────────────────────────────────
|
||||
# Optional cheap-vs-strong routing for simple turns.
|
||||
# Keeps the primary model for complex work, but can route short/simple
|
||||
# messages to a cheaper model across providers.
|
||||
#
|
||||
# smart_model_routing:
|
||||
# enabled: true
|
||||
# max_simple_chars: 160
|
||||
# max_simple_words: 28
|
||||
# cheap_model:
|
||||
# provider: openrouter
|
||||
# model: google/gemini-2.5-flash
|
||||
"""
|
||||
|
||||
|
||||
@@ -1060,19 +999,6 @@ _COMMENTED_SECTIONS = """
|
||||
# fallback_model:
|
||||
# provider: openrouter
|
||||
# model: anthropic/claude-sonnet-4
|
||||
#
|
||||
# ── Smart Model Routing ────────────────────────────────────────────────
|
||||
# Optional cheap-vs-strong routing for simple turns.
|
||||
# Keeps the primary model for complex work, but can route short/simple
|
||||
# messages to a cheaper model across providers.
|
||||
#
|
||||
# smart_model_routing:
|
||||
# enabled: true
|
||||
# max_simple_chars: 160
|
||||
# max_simple_words: 28
|
||||
# cheap_model:
|
||||
# provider: openrouter
|
||||
# model: google/gemini-2.5-flash
|
||||
"""
|
||||
|
||||
|
||||
@@ -1463,11 +1389,9 @@ def set_config_value(key: str, value: str):
|
||||
"terminal.singularity_image": "TERMINAL_SINGULARITY_IMAGE",
|
||||
"terminal.modal_image": "TERMINAL_MODAL_IMAGE",
|
||||
"terminal.daytona_image": "TERMINAL_DAYTONA_IMAGE",
|
||||
"terminal.docker_mount_cwd_to_workspace": "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE",
|
||||
"terminal.cwd": "TERMINAL_CWD",
|
||||
"terminal.timeout": "TERMINAL_TIMEOUT",
|
||||
"terminal.sandbox_dir": "TERMINAL_SANDBOX_DIR",
|
||||
"terminal.persistent_shell": "TERMINAL_PERSISTENT_SHELL",
|
||||
}
|
||||
if key in _config_to_env_sync:
|
||||
save_env_value(_config_to_env_sync[key], str(value))
|
||||
|
||||
+22
-76
@@ -119,35 +119,14 @@ def is_windows() -> bool:
|
||||
# Service Configuration
|
||||
# =============================================================================
|
||||
|
||||
_SERVICE_BASE = "hermes-gateway"
|
||||
SERVICE_NAME = "hermes-gateway"
|
||||
SERVICE_DESCRIPTION = "Hermes Agent Gateway - Messaging Platform Integration"
|
||||
|
||||
|
||||
def get_service_name() -> str:
|
||||
"""Derive a systemd service name scoped to this HERMES_HOME.
|
||||
|
||||
Default ``~/.hermes`` returns ``hermes-gateway`` (backward compatible).
|
||||
Any other HERMES_HOME appends a short hash so multiple installations
|
||||
can each have their own systemd service without conflicting.
|
||||
"""
|
||||
import hashlib
|
||||
from pathlib import Path as _Path # local import to avoid monkeypatch interference
|
||||
home = _Path(os.getenv("HERMES_HOME", _Path.home() / ".hermes")).resolve()
|
||||
default = (_Path.home() / ".hermes").resolve()
|
||||
if home == default:
|
||||
return _SERVICE_BASE
|
||||
suffix = hashlib.sha256(str(home).encode()).hexdigest()[:8]
|
||||
return f"{_SERVICE_BASE}-{suffix}"
|
||||
|
||||
|
||||
SERVICE_NAME = _SERVICE_BASE # backward-compat for external importers; prefer get_service_name()
|
||||
|
||||
|
||||
def get_systemd_unit_path(system: bool = False) -> Path:
|
||||
name = get_service_name()
|
||||
if system:
|
||||
return Path("/etc/systemd/system") / f"{name}.service"
|
||||
return Path.home() / ".config" / "systemd" / "user" / f"{name}.service"
|
||||
return Path("/etc/systemd/system") / f"{SERVICE_NAME}.service"
|
||||
return Path.home() / ".config" / "systemd" / "user" / f"{SERVICE_NAME}.service"
|
||||
|
||||
|
||||
def _systemctl_cmd(system: bool = False) -> list[str]:
|
||||
@@ -371,6 +350,8 @@ def get_hermes_cli_path() -> str:
|
||||
# =============================================================================
|
||||
|
||||
def generate_systemd_unit(system: bool = False, run_as_user: str | None = None) -> str:
|
||||
import shutil
|
||||
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
venv_dir = str(PROJECT_ROOT / "venv")
|
||||
@@ -379,8 +360,7 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
|
||||
|
||||
# Build a PATH that includes the venv, node_modules, and standard system dirs
|
||||
sane_path = f"{venv_bin}:{node_bin}:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"
|
||||
|
||||
hermes_home = str(Path(os.getenv("HERMES_HOME", Path.home() / ".hermes")).resolve())
|
||||
hermes_cli = shutil.which("hermes") or f"{python_path} -m hermes_cli.main"
|
||||
|
||||
if system:
|
||||
username, group_name, home_dir = _system_service_identity(run_as_user)
|
||||
@@ -400,12 +380,11 @@ Environment="USER={username}"
|
||||
Environment="LOGNAME={username}"
|
||||
Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Environment="HERMES_HOME={hermes_home}"
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=60
|
||||
TimeoutStopSec=15
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
@@ -420,15 +399,15 @@ After=network.target
|
||||
[Service]
|
||||
Type=simple
|
||||
ExecStart={python_path} -m hermes_cli.main gateway run --replace
|
||||
ExecStop={hermes_cli} gateway stop
|
||||
WorkingDirectory={working_dir}
|
||||
Environment="PATH={sane_path}"
|
||||
Environment="VIRTUAL_ENV={venv_dir}"
|
||||
Environment="HERMES_HOME={hermes_home}"
|
||||
Restart=on-failure
|
||||
RestartSec=10
|
||||
KillMode=mixed
|
||||
KillSignal=SIGTERM
|
||||
TimeoutStopSec=60
|
||||
TimeoutStopSec=15
|
||||
StandardOutput=journal
|
||||
StandardError=journal
|
||||
|
||||
@@ -476,7 +455,7 @@ def _print_linger_enable_warning(username: str, detail: str | None = None) -> No
|
||||
print(f" sudo loginctl enable-linger {username}")
|
||||
print()
|
||||
print(" Then restart the gateway:")
|
||||
print(f" systemctl --user restart {get_service_name()}.service")
|
||||
print(f" systemctl --user restart {SERVICE_NAME}.service")
|
||||
print()
|
||||
|
||||
|
||||
@@ -547,7 +526,7 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
||||
unit_path.write_text(generate_systemd_unit(system=system, run_as_user=run_as_user), encoding="utf-8")
|
||||
|
||||
subprocess.run(_systemctl_cmd(system) + ["daemon-reload"], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["enable", get_service_name()], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["enable", SERVICE_NAME], check=True)
|
||||
|
||||
print()
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service installed and enabled!")
|
||||
@@ -555,7 +534,7 @@ def systemd_install(force: bool = False, system: bool = False, run_as_user: str
|
||||
print("Next steps:")
|
||||
print(f" {'sudo ' if system else ''}hermes gateway start{scope_flag} # Start the service")
|
||||
print(f" {'sudo ' if system else ''}hermes gateway status{scope_flag} # Check status")
|
||||
print(f" {'journalctl' if system else 'journalctl --user'} -u {get_service_name()} -f # View logs")
|
||||
print(f" {'journalctl' if system else 'journalctl --user'} -u {SERVICE_NAME} -f # View logs")
|
||||
print()
|
||||
|
||||
if system:
|
||||
@@ -573,8 +552,8 @@ def systemd_uninstall(system: bool = False):
|
||||
if system:
|
||||
_require_root_for_system_service("uninstall")
|
||||
|
||||
subprocess.run(_systemctl_cmd(system) + ["stop", get_service_name()], check=False)
|
||||
subprocess.run(_systemctl_cmd(system) + ["disable", get_service_name()], check=False)
|
||||
subprocess.run(_systemctl_cmd(system) + ["stop", SERVICE_NAME], check=False)
|
||||
subprocess.run(_systemctl_cmd(system) + ["disable", SERVICE_NAME], check=False)
|
||||
|
||||
unit_path = get_systemd_unit_path(system=system)
|
||||
if unit_path.exists():
|
||||
@@ -590,7 +569,7 @@ def systemd_start(system: bool = False):
|
||||
if system:
|
||||
_require_root_for_system_service("start")
|
||||
refresh_systemd_unit_if_needed(system=system)
|
||||
subprocess.run(_systemctl_cmd(system) + ["start", get_service_name()], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["start", SERVICE_NAME], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service started")
|
||||
|
||||
|
||||
@@ -599,7 +578,7 @@ def systemd_stop(system: bool = False):
|
||||
system = _select_systemd_scope(system)
|
||||
if system:
|
||||
_require_root_for_system_service("stop")
|
||||
subprocess.run(_systemctl_cmd(system) + ["stop", get_service_name()], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["stop", SERVICE_NAME], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service stopped")
|
||||
|
||||
|
||||
@@ -609,7 +588,7 @@ def systemd_restart(system: bool = False):
|
||||
if system:
|
||||
_require_root_for_system_service("restart")
|
||||
refresh_systemd_unit_if_needed(system=system)
|
||||
subprocess.run(_systemctl_cmd(system) + ["restart", get_service_name()], check=True)
|
||||
subprocess.run(_systemctl_cmd(system) + ["restart", SERVICE_NAME], check=True)
|
||||
print(f"✓ {_service_scope_label(system).capitalize()} service restarted")
|
||||
|
||||
|
||||
@@ -634,12 +613,12 @@ def systemd_status(deep: bool = False, system: bool = False):
|
||||
print()
|
||||
|
||||
subprocess.run(
|
||||
_systemctl_cmd(system) + ["status", get_service_name(), "--no-pager"],
|
||||
_systemctl_cmd(system) + ["status", SERVICE_NAME, "--no-pager"],
|
||||
capture_output=False,
|
||||
)
|
||||
|
||||
result = subprocess.run(
|
||||
_systemctl_cmd(system) + ["is-active", get_service_name()],
|
||||
_systemctl_cmd(system) + ["is-active", SERVICE_NAME],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
@@ -678,7 +657,7 @@ def systemd_status(deep: bool = False, system: bool = False):
|
||||
if deep:
|
||||
print()
|
||||
print("Recent logs:")
|
||||
subprocess.run(_journalctl_cmd(system) + ["-u", get_service_name(), "-n", "20", "--no-pager"])
|
||||
subprocess.run(_journalctl_cmd(system) + ["-u", SERVICE_NAME, "-n", "20", "--no-pager"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -705,7 +684,6 @@ def generate_launchd_plist() -> str:
|
||||
<string>hermes_cli.main</string>
|
||||
<string>gateway</string>
|
||||
<string>run</string>
|
||||
<string>--replace</string>
|
||||
</array>
|
||||
|
||||
<key>WorkingDirectory</key>
|
||||
@@ -729,36 +707,6 @@ def generate_launchd_plist() -> str:
|
||||
</plist>
|
||||
"""
|
||||
|
||||
def launchd_plist_is_current() -> bool:
|
||||
"""Check if the installed launchd plist matches the currently generated one."""
|
||||
plist_path = get_launchd_plist_path()
|
||||
if not plist_path.exists():
|
||||
return False
|
||||
|
||||
installed = plist_path.read_text(encoding="utf-8")
|
||||
expected = generate_launchd_plist()
|
||||
return _normalize_service_definition(installed) == _normalize_service_definition(expected)
|
||||
|
||||
|
||||
def refresh_launchd_plist_if_needed() -> bool:
|
||||
"""Rewrite the installed launchd plist when the generated definition has changed.
|
||||
|
||||
Unlike systemd, launchd picks up plist changes on the next ``launchctl stop``/
|
||||
``launchctl start`` cycle — no daemon-reload is needed. We still unload/reload
|
||||
to make launchd re-read the updated plist immediately.
|
||||
"""
|
||||
plist_path = get_launchd_plist_path()
|
||||
if not plist_path.exists() or launchd_plist_is_current():
|
||||
return False
|
||||
|
||||
plist_path.write_text(generate_launchd_plist(), encoding="utf-8")
|
||||
# Unload/reload so launchd picks up the new definition
|
||||
subprocess.run(["launchctl", "unload", str(plist_path)], check=False)
|
||||
subprocess.run(["launchctl", "load", str(plist_path)], check=False)
|
||||
print("↻ Updated gateway launchd service definition to match the current Hermes install")
|
||||
return True
|
||||
|
||||
|
||||
def launchd_install(force: bool = False):
|
||||
plist_path = get_launchd_plist_path()
|
||||
|
||||
@@ -791,7 +739,6 @@ def launchd_uninstall():
|
||||
print("✓ Service uninstalled")
|
||||
|
||||
def launchd_start():
|
||||
refresh_launchd_plist_if_needed()
|
||||
subprocess.run(["launchctl", "start", "ai.hermes.gateway"], check=True)
|
||||
print("✓ Service started")
|
||||
|
||||
@@ -800,7 +747,6 @@ def launchd_stop():
|
||||
print("✓ Service stopped")
|
||||
|
||||
def launchd_restart():
|
||||
refresh_launchd_plist_if_needed()
|
||||
launchd_stop()
|
||||
launchd_start()
|
||||
|
||||
@@ -1172,7 +1118,7 @@ def _is_service_running() -> bool:
|
||||
|
||||
if user_unit_exists:
|
||||
result = subprocess.run(
|
||||
_systemctl_cmd(False) + ["is-active", get_service_name()],
|
||||
_systemctl_cmd(False) + ["is-active", SERVICE_NAME],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.stdout.strip() == "active":
|
||||
@@ -1180,7 +1126,7 @@ def _is_service_running() -> bool:
|
||||
|
||||
if system_unit_exists:
|
||||
result = subprocess.run(
|
||||
_systemctl_cmd(True) + ["is-active", get_service_name()],
|
||||
_systemctl_cmd(True) + ["is-active", SERVICE_NAME],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
if result.stdout.strip() == "active":
|
||||
|
||||
+18
-122
@@ -1112,32 +1112,8 @@ def _model_flow_custom(config):
|
||||
|
||||
effective_key = api_key or current_key
|
||||
|
||||
from hermes_cli.models import probe_api_models
|
||||
|
||||
probe = probe_api_models(effective_key, effective_url)
|
||||
if probe.get("used_fallback") and probe.get("resolved_base_url"):
|
||||
print(
|
||||
f"Warning: endpoint verification worked at {probe['resolved_base_url']}/models, "
|
||||
f"not the exact URL you entered. Saving the working base URL instead."
|
||||
)
|
||||
effective_url = probe["resolved_base_url"]
|
||||
if base_url:
|
||||
base_url = effective_url
|
||||
elif probe.get("models") is not None:
|
||||
print(
|
||||
f"Verified endpoint via {probe.get('probed_url')} "
|
||||
f"({len(probe.get('models') or [])} model(s) visible)"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: could not verify this endpoint via {probe.get('probed_url')}. "
|
||||
f"Hermes will still save it."
|
||||
)
|
||||
if probe.get("suggested_base_url"):
|
||||
print(f" If this server expects /v1, try base URL: {probe['suggested_base_url']}")
|
||||
|
||||
if base_url:
|
||||
save_env_value("OPENAI_BASE_URL", effective_url)
|
||||
save_env_value("OPENAI_BASE_URL", base_url)
|
||||
if api_key:
|
||||
save_env_value("OPENAI_API_KEY", api_key)
|
||||
|
||||
@@ -2301,106 +2277,26 @@ def cmd_update(args):
|
||||
print()
|
||||
print("✓ Update complete!")
|
||||
|
||||
# Auto-restart gateway if it's running.
|
||||
# Uses the PID file (scoped to HERMES_HOME) to find this
|
||||
# installation's gateway — safe with multiple installations.
|
||||
# Auto-restart gateway if it's running as a systemd service
|
||||
try:
|
||||
from gateway.status import get_running_pid, remove_pid_file
|
||||
from hermes_cli.gateway import (
|
||||
get_service_name, get_launchd_plist_path, is_macos,
|
||||
refresh_launchd_plist_if_needed,
|
||||
check = subprocess.run(
|
||||
["systemctl", "--user", "is-active", "hermes-gateway"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
import signal as _signal
|
||||
|
||||
_gw_service_name = get_service_name()
|
||||
existing_pid = get_running_pid()
|
||||
has_systemd_service = False
|
||||
has_launchd_service = False
|
||||
|
||||
try:
|
||||
check = subprocess.run(
|
||||
["systemctl", "--user", "is-active", _gw_service_name],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
has_systemd_service = check.stdout.strip() == "active"
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
|
||||
# Check for macOS launchd service
|
||||
if is_macos():
|
||||
try:
|
||||
plist_path = get_launchd_plist_path()
|
||||
if plist_path.exists():
|
||||
check = subprocess.run(
|
||||
["launchctl", "list", "ai.hermes.gateway"],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
has_launchd_service = check.returncode == 0
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
pass
|
||||
|
||||
if existing_pid or has_systemd_service or has_launchd_service:
|
||||
if check.stdout.strip() == "active":
|
||||
print()
|
||||
|
||||
# When a service manager is handling the gateway, let it
|
||||
# manage the lifecycle — don't manually SIGTERM the PID
|
||||
# (launchd KeepAlive would respawn immediately, causing races).
|
||||
if has_systemd_service:
|
||||
import time as _time
|
||||
if existing_pid:
|
||||
try:
|
||||
os.kill(existing_pid, _signal.SIGTERM)
|
||||
print(f"→ Stopped gateway process (PID {existing_pid})")
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
except PermissionError:
|
||||
print(f"⚠ Permission denied killing gateway PID {existing_pid}")
|
||||
remove_pid_file()
|
||||
_time.sleep(1) # Brief pause for port/socket release
|
||||
print("→ Restarting gateway service...")
|
||||
restart = subprocess.run(
|
||||
["systemctl", "--user", "restart", _gw_service_name],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if restart.returncode == 0:
|
||||
print("✓ Gateway restarted.")
|
||||
else:
|
||||
print(f"⚠ Gateway restart failed: {restart.stderr.strip()}")
|
||||
print(" Try manually: hermes gateway restart")
|
||||
elif has_launchd_service:
|
||||
# Refresh the plist first (picks up --replace and other
|
||||
# changes from the update we just pulled).
|
||||
refresh_launchd_plist_if_needed()
|
||||
# Explicit stop+start — don't rely on KeepAlive respawn
|
||||
# after a manual SIGTERM, which would race with the
|
||||
# PID file cleanup.
|
||||
print("→ Restarting gateway service...")
|
||||
stop = subprocess.run(
|
||||
["launchctl", "stop", "ai.hermes.gateway"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
start = subprocess.run(
|
||||
["launchctl", "start", "ai.hermes.gateway"],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
if start.returncode == 0:
|
||||
print("✓ Gateway restarted via launchd.")
|
||||
else:
|
||||
print(f"⚠ Gateway restart failed: {start.stderr.strip()}")
|
||||
print(" Try manually: hermes gateway restart")
|
||||
elif existing_pid:
|
||||
try:
|
||||
os.kill(existing_pid, _signal.SIGTERM)
|
||||
print(f"→ Stopped gateway process (PID {existing_pid})")
|
||||
except ProcessLookupError:
|
||||
pass # Already gone
|
||||
except PermissionError:
|
||||
print(f"⚠ Permission denied killing gateway PID {existing_pid}")
|
||||
remove_pid_file()
|
||||
print(" ℹ️ Gateway was running manually (not as a service).")
|
||||
print(" Restart it with: hermes gateway run")
|
||||
except Exception as e:
|
||||
logger.debug("Gateway restart during update failed: %s", e)
|
||||
print("→ Gateway service is running — restarting to pick up changes...")
|
||||
restart = subprocess.run(
|
||||
["systemctl", "--user", "restart", "hermes-gateway"],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
)
|
||||
if restart.returncode == 0:
|
||||
print("✓ Gateway restarted.")
|
||||
else:
|
||||
print(f"⚠ Gateway restart failed: {restart.stderr.strip()}")
|
||||
print(" Try manually: hermes gateway restart")
|
||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||
pass # No systemd (macOS, WSL1, etc.) — skip silently
|
||||
|
||||
print()
|
||||
print("Tip: You can now select a provider and model:")
|
||||
|
||||
+19
-211
@@ -78,10 +78,6 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-haiku-4-5-20251001",
|
||||
],
|
||||
"deepseek": [
|
||||
"deepseek-chat",
|
||||
"deepseek-reasoner",
|
||||
],
|
||||
}
|
||||
|
||||
_PROVIDER_LABELS = {
|
||||
@@ -93,7 +89,6 @@ _PROVIDER_LABELS = {
|
||||
"minimax": "MiniMax",
|
||||
"minimax-cn": "MiniMax (China)",
|
||||
"anthropic": "Anthropic",
|
||||
"deepseek": "DeepSeek",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
|
||||
@@ -108,7 +103,6 @@ _PROVIDER_ALIASES = {
|
||||
"minimax_cn": "minimax-cn",
|
||||
"claude": "anthropic",
|
||||
"claude-code": "anthropic",
|
||||
"deep-seek": "deepseek",
|
||||
}
|
||||
|
||||
|
||||
@@ -142,7 +136,7 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
# Canonical providers in display order
|
||||
_PROVIDER_ORDER = [
|
||||
"openrouter", "nous", "openai-codex",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn", "anthropic", "deepseek",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn", "anthropic",
|
||||
]
|
||||
# Build reverse alias map
|
||||
aliases_for: dict[str, list[str]] = {}
|
||||
@@ -218,111 +212,6 @@ def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]
|
||||
return [(m, "") for m in models]
|
||||
|
||||
|
||||
def detect_provider_for_model(
|
||||
model_name: str,
|
||||
current_provider: str,
|
||||
) -> Optional[tuple[str, str]]:
|
||||
"""Auto-detect the best provider for a model name.
|
||||
|
||||
Returns ``(provider_id, model_name)`` — the model name may be remapped
|
||||
(e.g. bare ``deepseek-chat`` → ``deepseek/deepseek-chat`` for OpenRouter).
|
||||
Returns ``None`` when no confident match is found.
|
||||
|
||||
Priority:
|
||||
1. Direct provider with credentials (highest)
|
||||
2. Direct provider without credentials → remap to OpenRouter slug
|
||||
3. OpenRouter catalog match
|
||||
"""
|
||||
name = (model_name or "").strip()
|
||||
if not name:
|
||||
return None
|
||||
|
||||
name_lower = name.lower()
|
||||
|
||||
# Aggregators list other providers' models — never auto-switch TO them
|
||||
_AGGREGATORS = {"nous", "openrouter"}
|
||||
|
||||
# If the model belongs to the current provider's catalog, don't suggest switching
|
||||
current_models = _PROVIDER_MODELS.get(current_provider, [])
|
||||
if any(name_lower == m.lower() for m in current_models):
|
||||
return None
|
||||
|
||||
# --- Step 1: check static provider catalogs for a direct match ---
|
||||
direct_match: Optional[str] = None
|
||||
for pid, models in _PROVIDER_MODELS.items():
|
||||
if pid == current_provider or pid in _AGGREGATORS:
|
||||
continue
|
||||
if any(name_lower == m.lower() for m in models):
|
||||
direct_match = pid
|
||||
break
|
||||
|
||||
if direct_match:
|
||||
# Check if we have credentials for this provider
|
||||
has_creds = False
|
||||
try:
|
||||
from hermes_cli.auth import PROVIDER_REGISTRY
|
||||
pconfig = PROVIDER_REGISTRY.get(direct_match)
|
||||
if pconfig:
|
||||
import os
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
if os.getenv(env_var, "").strip():
|
||||
has_creds = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_creds:
|
||||
return (direct_match, name)
|
||||
|
||||
# No direct creds — try to find this model on OpenRouter instead
|
||||
or_slug = _find_openrouter_slug(name)
|
||||
if or_slug:
|
||||
return ("openrouter", or_slug)
|
||||
# Still return the direct provider — credential resolution will
|
||||
# give a clear error rather than silently using the wrong provider
|
||||
return (direct_match, name)
|
||||
|
||||
# --- Step 2: check OpenRouter catalog ---
|
||||
# First try exact match (handles provider/model format)
|
||||
or_slug = _find_openrouter_slug(name)
|
||||
if or_slug:
|
||||
if current_provider != "openrouter":
|
||||
return ("openrouter", or_slug)
|
||||
# Already on openrouter, just return the resolved slug
|
||||
if or_slug != name:
|
||||
return ("openrouter", or_slug)
|
||||
return None # already on openrouter with matching name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _find_openrouter_slug(model_name: str) -> Optional[str]:
|
||||
"""Find the full OpenRouter model slug for a bare or partial model name.
|
||||
|
||||
Handles:
|
||||
- Exact match: ``anthropic/claude-opus-4.6`` → as-is
|
||||
- Bare name: ``deepseek-chat`` → ``deepseek/deepseek-chat``
|
||||
- Bare name: ``claude-opus-4.6`` → ``anthropic/claude-opus-4.6``
|
||||
"""
|
||||
name_lower = model_name.strip().lower()
|
||||
if not name_lower:
|
||||
return None
|
||||
|
||||
# Exact match (already has provider/ prefix)
|
||||
for mid, _ in OPENROUTER_MODELS:
|
||||
if name_lower == mid.lower():
|
||||
return mid
|
||||
|
||||
# Try matching just the model part (after the /)
|
||||
for mid, _ in OPENROUTER_MODELS:
|
||||
if "/" in mid:
|
||||
_, model_part = mid.split("/", 1)
|
||||
if name_lower == model_part.lower():
|
||||
return mid
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def normalize_provider(provider: Optional[str]) -> str:
|
||||
"""Normalize provider aliases to Hermes' canonical provider ids.
|
||||
|
||||
@@ -419,62 +308,6 @@ def _fetch_anthropic_models(timeout: float = 5.0) -> Optional[list[str]]:
|
||||
return None
|
||||
|
||||
|
||||
def probe_api_models(
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
timeout: float = 5.0,
|
||||
) -> dict[str, Any]:
|
||||
"""Probe an OpenAI-compatible ``/models`` endpoint with light URL heuristics."""
|
||||
normalized = (base_url or "").strip().rstrip("/")
|
||||
if not normalized:
|
||||
return {
|
||||
"models": None,
|
||||
"probed_url": None,
|
||||
"resolved_base_url": "",
|
||||
"suggested_base_url": None,
|
||||
"used_fallback": False,
|
||||
}
|
||||
|
||||
if normalized.endswith("/v1"):
|
||||
alternate_base = normalized[:-3].rstrip("/")
|
||||
else:
|
||||
alternate_base = normalized + "/v1"
|
||||
|
||||
candidates: list[tuple[str, bool]] = [(normalized, False)]
|
||||
if alternate_base and alternate_base != normalized:
|
||||
candidates.append((alternate_base, True))
|
||||
|
||||
tried: list[str] = []
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
for candidate_base, is_fallback in candidates:
|
||||
url = candidate_base.rstrip("/") + "/models"
|
||||
tried.append(url)
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
return {
|
||||
"models": [m.get("id", "") for m in data.get("data", [])],
|
||||
"probed_url": url,
|
||||
"resolved_base_url": candidate_base.rstrip("/"),
|
||||
"suggested_base_url": alternate_base if alternate_base != candidate_base else normalized,
|
||||
"used_fallback": is_fallback,
|
||||
}
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {
|
||||
"models": None,
|
||||
"probed_url": tried[-1] if tried else normalized.rstrip("/") + "/models",
|
||||
"resolved_base_url": normalized,
|
||||
"suggested_base_url": alternate_base if alternate_base != normalized else None,
|
||||
"used_fallback": False,
|
||||
}
|
||||
|
||||
|
||||
def fetch_api_models(
|
||||
api_key: Optional[str],
|
||||
base_url: Optional[str],
|
||||
@@ -485,7 +318,22 @@ def fetch_api_models(
|
||||
Returns a list of model ID strings, or ``None`` if the endpoint could not
|
||||
be reached (network error, timeout, auth failure, etc.).
|
||||
"""
|
||||
return probe_api_models(api_key, base_url, timeout=timeout).get("models")
|
||||
if not base_url:
|
||||
return None
|
||||
|
||||
url = base_url.rstrip("/") + "/models"
|
||||
headers: dict[str, str] = {}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
req = urllib.request.Request(url, headers=headers)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
data = json.loads(resp.read().decode())
|
||||
# Standard OpenAI format: {"data": [{"id": "model-name", ...}, ...]}
|
||||
return [m.get("id", "") for m in data.get("data", [])]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def validate_requested_model(
|
||||
@@ -528,53 +376,13 @@ def validate_requested_model(
|
||||
"message": "Model names cannot contain spaces.",
|
||||
}
|
||||
|
||||
# Custom endpoints can serve any model — skip validation
|
||||
if normalized == "custom":
|
||||
probe = probe_api_models(api_key, base_url)
|
||||
api_models = probe.get("models")
|
||||
if api_models is not None:
|
||||
if requested in set(api_models):
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": True,
|
||||
"message": None,
|
||||
}
|
||||
|
||||
suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5)
|
||||
suggestion_text = ""
|
||||
if suggestions:
|
||||
suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions)
|
||||
|
||||
message = (
|
||||
f"Note: `{requested}` was not found in this custom endpoint's model listing "
|
||||
f"({probe.get('probed_url')}). It may still work if the server supports hidden or aliased models."
|
||||
f"{suggestion_text}"
|
||||
)
|
||||
if probe.get("used_fallback"):
|
||||
message += (
|
||||
f"\n Endpoint verification succeeded after trying `{probe.get('resolved_base_url')}`. "
|
||||
f"Consider saving that as your base URL."
|
||||
)
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": False,
|
||||
"message": message,
|
||||
}
|
||||
|
||||
message = (
|
||||
f"Note: could not reach this custom endpoint's model listing at `{probe.get('probed_url')}`. "
|
||||
f"Hermes will still save `{requested}`, but the endpoint should expose `/models` for verification."
|
||||
)
|
||||
if probe.get("suggested_base_url"):
|
||||
message += f"\n If this server expects `/v1`, try base URL: `{probe.get('suggested_base_url')}`"
|
||||
|
||||
return {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": False,
|
||||
"message": message,
|
||||
"message": None,
|
||||
}
|
||||
|
||||
# Probe the live API to check if the model actually exists
|
||||
|
||||
@@ -1,449 +0,0 @@
|
||||
"""
|
||||
Hermes Plugin System
|
||||
====================
|
||||
|
||||
Discovers, loads, and manages plugins from three sources:
|
||||
|
||||
1. **User plugins** – ``~/.hermes/plugins/<name>/``
|
||||
2. **Project plugins** – ``./.hermes/plugins/<name>/``
|
||||
3. **Pip plugins** – packages that expose the ``hermes_agent.plugins``
|
||||
entry-point group.
|
||||
|
||||
Each directory plugin must contain a ``plugin.yaml`` manifest **and** an
|
||||
``__init__.py`` with a ``register(ctx)`` function.
|
||||
|
||||
Lifecycle hooks
|
||||
---------------
|
||||
Plugins may register callbacks for any of the hooks in ``VALID_HOOKS``.
|
||||
The agent core calls ``invoke_hook(name, **kwargs)`` at the appropriate
|
||||
points.
|
||||
|
||||
Tool registration
|
||||
-----------------
|
||||
``PluginContext.register_tool()`` delegates to ``tools.registry.register()``
|
||||
so plugin-defined tools appear alongside the built-in tools.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
try:
|
||||
import yaml
|
||||
except ImportError: # pragma: no cover – yaml is optional at import time
|
||||
yaml = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_HOOKS: Set[str] = {
|
||||
"pre_tool_call",
|
||||
"post_tool_call",
|
||||
"pre_llm_call",
|
||||
"post_llm_call",
|
||||
"on_session_start",
|
||||
"on_session_end",
|
||||
}
|
||||
|
||||
ENTRY_POINTS_GROUP = "hermes_agent.plugins"
|
||||
|
||||
_NS_PARENT = "hermes_plugins"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class PluginManifest:
|
||||
"""Parsed representation of a plugin.yaml manifest."""
|
||||
|
||||
name: str
|
||||
version: str = ""
|
||||
description: str = ""
|
||||
author: str = ""
|
||||
requires_env: List[str] = field(default_factory=list)
|
||||
provides_tools: List[str] = field(default_factory=list)
|
||||
provides_hooks: List[str] = field(default_factory=list)
|
||||
source: str = "" # "user", "project", or "entrypoint"
|
||||
path: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedPlugin:
|
||||
"""Runtime state for a single loaded plugin."""
|
||||
|
||||
manifest: PluginManifest
|
||||
module: Optional[types.ModuleType] = None
|
||||
tools_registered: List[str] = field(default_factory=list)
|
||||
hooks_registered: List[str] = field(default_factory=list)
|
||||
enabled: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PluginContext – handed to each plugin's ``register()`` function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PluginContext:
|
||||
"""Facade given to plugins so they can register tools and hooks."""
|
||||
|
||||
def __init__(self, manifest: PluginManifest, manager: "PluginManager"):
|
||||
self.manifest = manifest
|
||||
self._manager = manager
|
||||
|
||||
# -- tool registration --------------------------------------------------
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
name: str,
|
||||
toolset: str,
|
||||
schema: dict,
|
||||
handler: Callable,
|
||||
check_fn: Callable | None = None,
|
||||
requires_env: list | None = None,
|
||||
is_async: bool = False,
|
||||
description: str = "",
|
||||
emoji: str = "",
|
||||
) -> None:
|
||||
"""Register a tool in the global registry **and** track it as plugin-provided."""
|
||||
from tools.registry import registry
|
||||
|
||||
registry.register(
|
||||
name=name,
|
||||
toolset=toolset,
|
||||
schema=schema,
|
||||
handler=handler,
|
||||
check_fn=check_fn,
|
||||
requires_env=requires_env,
|
||||
is_async=is_async,
|
||||
description=description,
|
||||
emoji=emoji,
|
||||
)
|
||||
self._manager._plugin_tool_names.add(name)
|
||||
logger.debug("Plugin %s registered tool: %s", self.manifest.name, name)
|
||||
|
||||
# -- hook registration --------------------------------------------------
|
||||
|
||||
def register_hook(self, hook_name: str, callback: Callable) -> None:
|
||||
"""Register a lifecycle hook callback.
|
||||
|
||||
Unknown hook names produce a warning but are still stored so
|
||||
forward-compatible plugins don't break.
|
||||
"""
|
||||
if hook_name not in VALID_HOOKS:
|
||||
logger.warning(
|
||||
"Plugin '%s' registered unknown hook '%s' "
|
||||
"(valid: %s)",
|
||||
self.manifest.name,
|
||||
hook_name,
|
||||
", ".join(sorted(VALID_HOOKS)),
|
||||
)
|
||||
self._manager._hooks.setdefault(hook_name, []).append(callback)
|
||||
logger.debug("Plugin %s registered hook: %s", self.manifest.name, hook_name)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PluginManager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class PluginManager:
|
||||
"""Central manager that discovers, loads, and invokes plugins."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._plugins: Dict[str, LoadedPlugin] = {}
|
||||
self._hooks: Dict[str, List[Callable]] = {}
|
||||
self._plugin_tool_names: Set[str] = set()
|
||||
self._discovered: bool = False
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def discover_and_load(self) -> None:
|
||||
"""Scan all plugin sources and load each plugin found."""
|
||||
if self._discovered:
|
||||
return
|
||||
self._discovered = True
|
||||
|
||||
manifests: List[PluginManifest] = []
|
||||
|
||||
# 1. User plugins (~/.hermes/plugins/)
|
||||
hermes_home = os.environ.get("HERMES_HOME", os.path.expanduser("~/.hermes"))
|
||||
user_dir = Path(hermes_home) / "plugins"
|
||||
manifests.extend(self._scan_directory(user_dir, source="user"))
|
||||
|
||||
# 2. Project plugins (./.hermes/plugins/)
|
||||
project_dir = Path.cwd() / ".hermes" / "plugins"
|
||||
manifests.extend(self._scan_directory(project_dir, source="project"))
|
||||
|
||||
# 3. Pip / entry-point plugins
|
||||
manifests.extend(self._scan_entry_points())
|
||||
|
||||
# Load each manifest
|
||||
for manifest in manifests:
|
||||
self._load_plugin(manifest)
|
||||
|
||||
if manifests:
|
||||
logger.info(
|
||||
"Plugin discovery complete: %d found, %d enabled",
|
||||
len(self._plugins),
|
||||
sum(1 for p in self._plugins.values() if p.enabled),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Directory scanning
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _scan_directory(self, path: Path, source: str) -> List[PluginManifest]:
|
||||
"""Read ``plugin.yaml`` manifests from subdirectories of *path*."""
|
||||
manifests: List[PluginManifest] = []
|
||||
if not path.is_dir():
|
||||
return manifests
|
||||
|
||||
for child in sorted(path.iterdir()):
|
||||
if not child.is_dir():
|
||||
continue
|
||||
manifest_file = child / "plugin.yaml"
|
||||
if not manifest_file.exists():
|
||||
manifest_file = child / "plugin.yml"
|
||||
if not manifest_file.exists():
|
||||
logger.debug("Skipping %s (no plugin.yaml)", child)
|
||||
continue
|
||||
|
||||
try:
|
||||
if yaml is None:
|
||||
logger.warning("PyYAML not installed – cannot load %s", manifest_file)
|
||||
continue
|
||||
data = yaml.safe_load(manifest_file.read_text()) or {}
|
||||
manifest = PluginManifest(
|
||||
name=data.get("name", child.name),
|
||||
version=str(data.get("version", "")),
|
||||
description=data.get("description", ""),
|
||||
author=data.get("author", ""),
|
||||
requires_env=data.get("requires_env", []),
|
||||
provides_tools=data.get("provides_tools", []),
|
||||
provides_hooks=data.get("provides_hooks", []),
|
||||
source=source,
|
||||
path=str(child),
|
||||
)
|
||||
manifests.append(manifest)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to parse %s: %s", manifest_file, exc)
|
||||
|
||||
return manifests
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Entry-point scanning
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _scan_entry_points(self) -> List[PluginManifest]:
|
||||
"""Check ``importlib.metadata`` for pip-installed plugins."""
|
||||
manifests: List[PluginManifest] = []
|
||||
try:
|
||||
eps = importlib.metadata.entry_points()
|
||||
# Python 3.12+ returns a SelectableGroups; earlier returns dict
|
||||
if hasattr(eps, "select"):
|
||||
group_eps = eps.select(group=ENTRY_POINTS_GROUP)
|
||||
elif isinstance(eps, dict):
|
||||
group_eps = eps.get(ENTRY_POINTS_GROUP, [])
|
||||
else:
|
||||
group_eps = [ep for ep in eps if ep.group == ENTRY_POINTS_GROUP]
|
||||
|
||||
for ep in group_eps:
|
||||
manifest = PluginManifest(
|
||||
name=ep.name,
|
||||
source="entrypoint",
|
||||
path=ep.value,
|
||||
)
|
||||
manifests.append(manifest)
|
||||
except Exception as exc:
|
||||
logger.debug("Entry-point scan failed: %s", exc)
|
||||
|
||||
return manifests
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Loading
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _load_plugin(self, manifest: PluginManifest) -> None:
|
||||
"""Import a plugin module and call its ``register(ctx)`` function."""
|
||||
loaded = LoadedPlugin(manifest=manifest)
|
||||
|
||||
try:
|
||||
if manifest.source in ("user", "project"):
|
||||
module = self._load_directory_module(manifest)
|
||||
else:
|
||||
module = self._load_entrypoint_module(manifest)
|
||||
|
||||
loaded.module = module
|
||||
|
||||
# Call register()
|
||||
register_fn = getattr(module, "register", None)
|
||||
if register_fn is None:
|
||||
loaded.error = "no register() function"
|
||||
logger.warning("Plugin '%s' has no register() function", manifest.name)
|
||||
else:
|
||||
ctx = PluginContext(manifest, self)
|
||||
register_fn(ctx)
|
||||
loaded.tools_registered = [
|
||||
t for t in self._plugin_tool_names
|
||||
if t not in {
|
||||
n
|
||||
for name, p in self._plugins.items()
|
||||
for n in p.tools_registered
|
||||
}
|
||||
]
|
||||
loaded.hooks_registered = list(
|
||||
{
|
||||
h
|
||||
for h, cbs in self._hooks.items()
|
||||
if cbs # non-empty
|
||||
}
|
||||
- {
|
||||
h
|
||||
for name, p in self._plugins.items()
|
||||
for h in p.hooks_registered
|
||||
}
|
||||
)
|
||||
loaded.enabled = True
|
||||
|
||||
except Exception as exc:
|
||||
loaded.error = str(exc)
|
||||
logger.warning("Failed to load plugin '%s': %s", manifest.name, exc)
|
||||
|
||||
self._plugins[manifest.name] = loaded
|
||||
|
||||
def _load_directory_module(self, manifest: PluginManifest) -> types.ModuleType:
|
||||
"""Import a directory-based plugin as ``hermes_plugins.<name>``."""
|
||||
plugin_dir = Path(manifest.path) # type: ignore[arg-type]
|
||||
init_file = plugin_dir / "__init__.py"
|
||||
if not init_file.exists():
|
||||
raise FileNotFoundError(f"No __init__.py in {plugin_dir}")
|
||||
|
||||
# Ensure the namespace parent package exists
|
||||
if _NS_PARENT not in sys.modules:
|
||||
ns_pkg = types.ModuleType(_NS_PARENT)
|
||||
ns_pkg.__path__ = [] # type: ignore[attr-defined]
|
||||
ns_pkg.__package__ = _NS_PARENT
|
||||
sys.modules[_NS_PARENT] = ns_pkg
|
||||
|
||||
module_name = f"{_NS_PARENT}.{manifest.name.replace('-', '_')}"
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name,
|
||||
init_file,
|
||||
submodule_search_locations=[str(plugin_dir)],
|
||||
)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot create module spec for {init_file}")
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
module.__package__ = module_name
|
||||
module.__path__ = [str(plugin_dir)] # type: ignore[attr-defined]
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
def _load_entrypoint_module(self, manifest: PluginManifest) -> types.ModuleType:
|
||||
"""Load a pip-installed plugin via its entry-point reference."""
|
||||
eps = importlib.metadata.entry_points()
|
||||
if hasattr(eps, "select"):
|
||||
group_eps = eps.select(group=ENTRY_POINTS_GROUP)
|
||||
elif isinstance(eps, dict):
|
||||
group_eps = eps.get(ENTRY_POINTS_GROUP, [])
|
||||
else:
|
||||
group_eps = [ep for ep in eps if ep.group == ENTRY_POINTS_GROUP]
|
||||
|
||||
for ep in group_eps:
|
||||
if ep.name == manifest.name:
|
||||
return ep.load()
|
||||
|
||||
raise ImportError(
|
||||
f"Entry point '{manifest.name}' not found in group '{ENTRY_POINTS_GROUP}'"
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Hook invocation
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def invoke_hook(self, hook_name: str, **kwargs: Any) -> None:
|
||||
"""Call all registered callbacks for *hook_name*.
|
||||
|
||||
Each callback is wrapped in its own try/except so a misbehaving
|
||||
plugin cannot break the core agent loop.
|
||||
"""
|
||||
callbacks = self._hooks.get(hook_name, [])
|
||||
for cb in callbacks:
|
||||
try:
|
||||
cb(**kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Hook '%s' callback %s raised: %s",
|
||||
hook_name,
|
||||
getattr(cb, "__name__", repr(cb)),
|
||||
exc,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Introspection
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def list_plugins(self) -> List[Dict[str, Any]]:
|
||||
"""Return a list of info dicts for all discovered plugins."""
|
||||
result: List[Dict[str, Any]] = []
|
||||
for name, loaded in sorted(self._plugins.items()):
|
||||
result.append(
|
||||
{
|
||||
"name": name,
|
||||
"version": loaded.manifest.version,
|
||||
"description": loaded.manifest.description,
|
||||
"source": loaded.manifest.source,
|
||||
"enabled": loaded.enabled,
|
||||
"tools": len(loaded.tools_registered),
|
||||
"hooks": len(loaded.hooks_registered),
|
||||
"error": loaded.error,
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level singleton & convenience functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_plugin_manager: Optional[PluginManager] = None
|
||||
|
||||
|
||||
def get_plugin_manager() -> PluginManager:
|
||||
"""Return (and lazily create) the global PluginManager singleton."""
|
||||
global _plugin_manager
|
||||
if _plugin_manager is None:
|
||||
_plugin_manager = PluginManager()
|
||||
return _plugin_manager
|
||||
|
||||
|
||||
def discover_plugins() -> None:
|
||||
"""Discover and load all plugins (idempotent)."""
|
||||
get_plugin_manager().discover_and_load()
|
||||
|
||||
|
||||
def invoke_hook(hook_name: str, **kwargs: Any) -> None:
|
||||
"""Invoke a lifecycle hook on all loaded plugins."""
|
||||
get_plugin_manager().invoke_hook(hook_name, **kwargs)
|
||||
|
||||
|
||||
def get_plugin_tool_names() -> Set[str]:
|
||||
"""Return the set of tool names registered by plugins."""
|
||||
return get_plugin_manager()._plugin_tool_names
|
||||
+121
-115
@@ -227,86 +227,54 @@ def prompt(question: str, default: str = None, password: bool = False) -> str:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _curses_prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Single-select menu using curses to avoid simple_term_menu rendering bugs."""
|
||||
try:
|
||||
import curses
|
||||
result_holder = [default]
|
||||
|
||||
def _curses_menu(stdscr):
|
||||
curses.curs_set(0)
|
||||
if curses.has_colors():
|
||||
curses.start_color()
|
||||
curses.use_default_colors()
|
||||
curses.init_pair(1, curses.COLOR_GREEN, -1)
|
||||
curses.init_pair(2, curses.COLOR_YELLOW, -1)
|
||||
cursor = default
|
||||
|
||||
while True:
|
||||
stdscr.clear()
|
||||
max_y, max_x = stdscr.getmaxyx()
|
||||
try:
|
||||
stdscr.addnstr(
|
||||
0,
|
||||
0,
|
||||
question,
|
||||
max_x - 1,
|
||||
curses.A_BOLD | (curses.color_pair(2) if curses.has_colors() else 0),
|
||||
)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
for i, choice in enumerate(choices):
|
||||
y = i + 2
|
||||
if y >= max_y - 1:
|
||||
break
|
||||
arrow = "→" if i == cursor else " "
|
||||
line = f" {arrow} {choice}"
|
||||
attr = curses.A_NORMAL
|
||||
if i == cursor:
|
||||
attr = curses.A_BOLD
|
||||
if curses.has_colors():
|
||||
attr |= curses.color_pair(1)
|
||||
try:
|
||||
stdscr.addnstr(y, 0, line, max_x - 1, attr)
|
||||
except curses.error:
|
||||
pass
|
||||
|
||||
stdscr.refresh()
|
||||
key = stdscr.getch()
|
||||
if key in (curses.KEY_UP, ord("k")):
|
||||
cursor = (cursor - 1) % len(choices)
|
||||
elif key in (curses.KEY_DOWN, ord("j")):
|
||||
cursor = (cursor + 1) % len(choices)
|
||||
elif key in (curses.KEY_ENTER, 10, 13):
|
||||
result_holder[0] = cursor
|
||||
return
|
||||
elif key in (27, ord("q")):
|
||||
return
|
||||
|
||||
curses.wrapper(_curses_menu)
|
||||
return result_holder[0]
|
||||
except Exception:
|
||||
return -1
|
||||
|
||||
|
||||
|
||||
def prompt_choice(question: str, choices: list, default: int = 0) -> int:
|
||||
"""Prompt for a choice from a list with arrow key navigation.
|
||||
|
||||
Escape keeps the current default (skips the question).
|
||||
Ctrl+C exits the wizard.
|
||||
"""
|
||||
idx = _curses_prompt_choice(question, choices, default)
|
||||
if idx >= 0:
|
||||
if idx == default:
|
||||
print_info(" Skipped (keeping current)")
|
||||
print(color(question, Colors.YELLOW))
|
||||
|
||||
# Try to use interactive menu if available
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
import re
|
||||
|
||||
# Strip emoji characters — simple_term_menu miscalculates visual
|
||||
# width of emojis, causing duplicated/garbled lines on redraw.
|
||||
_emoji_re = re.compile(
|
||||
"[\U0001f300-\U0001f9ff\U00002600-\U000027bf\U0000fe00-\U0000fe0f"
|
||||
"\U0001fa00-\U0001fa6f\U0001fa70-\U0001faff\u200d]+",
|
||||
flags=re.UNICODE,
|
||||
)
|
||||
menu_choices = [f" {_emoji_re.sub('', choice).strip()}" for choice in choices]
|
||||
|
||||
print_info(" ↑/↓ Navigate Enter Select Esc Skip Ctrl+C Exit")
|
||||
|
||||
terminal_menu = TerminalMenu(
|
||||
menu_choices,
|
||||
cursor_index=default,
|
||||
menu_cursor="→ ",
|
||||
menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True,
|
||||
clear_screen=False,
|
||||
)
|
||||
|
||||
idx = terminal_menu.show()
|
||||
if idx is None: # User pressed Escape — keep current value
|
||||
print_info(f" Skipped (keeping current)")
|
||||
print()
|
||||
return default
|
||||
print()
|
||||
print() # Add newline after selection
|
||||
return idx
|
||||
|
||||
print(color(question, Colors.YELLOW))
|
||||
except (ImportError, NotImplementedError):
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f" (Interactive menu unavailable: {e})")
|
||||
|
||||
# Fallback to number-based selection (simple_term_menu doesn't support Windows)
|
||||
for i, choice in enumerate(choices):
|
||||
marker = "●" if i == default else "○"
|
||||
if i == default:
|
||||
@@ -376,15 +344,84 @@ def prompt_checklist(title: str, items: list, pre_selected: list = None) -> list
|
||||
if pre_selected is None:
|
||||
pre_selected = []
|
||||
|
||||
from hermes_cli.curses_ui import curses_checklist
|
||||
print(color(title, Colors.YELLOW))
|
||||
print_info(" SPACE Toggle ENTER Confirm ESC Skip Ctrl+C Exit")
|
||||
print()
|
||||
|
||||
chosen = curses_checklist(
|
||||
title,
|
||||
items,
|
||||
set(pre_selected),
|
||||
cancel_returns=set(pre_selected),
|
||||
)
|
||||
return sorted(chosen)
|
||||
try:
|
||||
from simple_term_menu import TerminalMenu
|
||||
import re
|
||||
|
||||
# Strip emoji characters from menu labels — simple_term_menu miscalculates
|
||||
# visual width of emojis on macOS, causing duplicated/garbled lines.
|
||||
_emoji_re = re.compile(
|
||||
"[\U0001f300-\U0001f9ff\U00002600-\U000027bf\U0000fe00-\U0000fe0f"
|
||||
"\U0001fa00-\U0001fa6f\U0001fa70-\U0001faff\u200d]+",
|
||||
flags=re.UNICODE,
|
||||
)
|
||||
menu_items = [f" {_emoji_re.sub('', item).strip()}" for item in items]
|
||||
|
||||
# Map pre-selected indices to the actual menu entry strings
|
||||
preselected = [menu_items[i] for i in pre_selected if i < len(menu_items)]
|
||||
|
||||
terminal_menu = TerminalMenu(
|
||||
menu_items,
|
||||
multi_select=True,
|
||||
show_multi_select_hint=False,
|
||||
multi_select_cursor="[✓] ",
|
||||
multi_select_select_on_accept=False,
|
||||
multi_select_empty_ok=True,
|
||||
preselected_entries=preselected if preselected else None,
|
||||
menu_cursor="→ ",
|
||||
menu_cursor_style=("fg_green", "bold"),
|
||||
menu_highlight_style=("fg_green",),
|
||||
cycle_cursor=True,
|
||||
clear_screen=False,
|
||||
)
|
||||
|
||||
terminal_menu.show()
|
||||
|
||||
if terminal_menu.chosen_menu_entries is None:
|
||||
print_info(" Skipped (keeping current)")
|
||||
return list(pre_selected)
|
||||
|
||||
selected = list(terminal_menu.chosen_menu_indices or [])
|
||||
return selected
|
||||
|
||||
except (ImportError, NotImplementedError):
|
||||
# Fallback: numbered toggle interface (simple_term_menu doesn't support Windows)
|
||||
selected = set(pre_selected)
|
||||
|
||||
while True:
|
||||
for i, item in enumerate(items):
|
||||
marker = color("[✓]", Colors.GREEN) if i in selected else "[ ]"
|
||||
print(f" {marker} {i + 1}. {item}")
|
||||
print()
|
||||
|
||||
try:
|
||||
value = input(
|
||||
color(" Toggle # (or Enter to confirm): ", Colors.DIM)
|
||||
).strip()
|
||||
if not value:
|
||||
break
|
||||
idx = int(value) - 1
|
||||
if 0 <= idx < len(items):
|
||||
if idx in selected:
|
||||
selected.discard(idx)
|
||||
else:
|
||||
selected.add(idx)
|
||||
else:
|
||||
print_error(f"Enter a number between 1 and {len(items)}")
|
||||
except ValueError:
|
||||
print_error("Enter a number")
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return []
|
||||
|
||||
# Clear and redraw (simple approach)
|
||||
print()
|
||||
|
||||
return sorted(selected)
|
||||
|
||||
|
||||
def _prompt_api_key(var: dict):
|
||||
@@ -743,7 +780,6 @@ def setup_model_provider(config: dict):
|
||||
selected_provider = (
|
||||
None # "nous", "openai-codex", "openrouter", "custom", or None (keep)
|
||||
)
|
||||
selected_base_url = None # deferred until after model selection
|
||||
nous_models = [] # populated if Nous login succeeds
|
||||
|
||||
if provider_idx == 0: # Nous Portal (OAuth)
|
||||
@@ -897,35 +933,11 @@ def setup_model_provider(config: dict):
|
||||
|
||||
base_url = prompt(
|
||||
" API base URL (e.g., https://api.example.com/v1)", current_url
|
||||
).strip()
|
||||
)
|
||||
api_key = prompt(" API key", password=True)
|
||||
model_name = prompt(" Model name (e.g., gpt-4, claude-3-opus)", current_model)
|
||||
|
||||
if base_url:
|
||||
from hermes_cli.models import probe_api_models
|
||||
|
||||
probe = probe_api_models(api_key, base_url)
|
||||
if probe.get("used_fallback") and probe.get("resolved_base_url"):
|
||||
print_warning(
|
||||
f"Endpoint verification worked at {probe['resolved_base_url']}/models, "
|
||||
f"not the exact URL you entered. Saving the working base URL instead."
|
||||
)
|
||||
base_url = probe["resolved_base_url"]
|
||||
elif probe.get("models") is not None:
|
||||
print_success(
|
||||
f"Verified endpoint via {probe.get('probed_url')} "
|
||||
f"({len(probe.get('models') or [])} model(s) visible)"
|
||||
)
|
||||
else:
|
||||
print_warning(
|
||||
f"Could not verify this endpoint via {probe.get('probed_url')}. "
|
||||
f"Hermes will still save it."
|
||||
)
|
||||
if probe.get("suggested_base_url"):
|
||||
print_info(
|
||||
f" If this server expects /v1, try base URL: {probe['suggested_base_url']}"
|
||||
)
|
||||
|
||||
save_env_value("OPENAI_BASE_URL", base_url)
|
||||
if api_key:
|
||||
save_env_value("OPENAI_API_KEY", api_key)
|
||||
@@ -1026,8 +1038,8 @@ def setup_model_provider(config: dict):
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_update_config_for_provider("zai", zai_base_url, default_model="glm-5")
|
||||
_set_model_provider(config, "zai", zai_base_url)
|
||||
selected_base_url = zai_base_url
|
||||
|
||||
elif provider_idx == 5: # Kimi / Moonshot
|
||||
selected_provider = "kimi-coding"
|
||||
@@ -1059,8 +1071,8 @@ def setup_model_provider(config: dict):
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_update_config_for_provider("kimi-coding", pconfig.inference_base_url, default_model="kimi-k2.5")
|
||||
_set_model_provider(config, "kimi-coding", pconfig.inference_base_url)
|
||||
selected_base_url = pconfig.inference_base_url
|
||||
|
||||
elif provider_idx == 6: # MiniMax
|
||||
selected_provider = "minimax"
|
||||
@@ -1092,8 +1104,8 @@ def setup_model_provider(config: dict):
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_update_config_for_provider("minimax", pconfig.inference_base_url, default_model="MiniMax-M2.5")
|
||||
_set_model_provider(config, "minimax", pconfig.inference_base_url)
|
||||
selected_base_url = pconfig.inference_base_url
|
||||
|
||||
elif provider_idx == 7: # MiniMax China
|
||||
selected_provider = "minimax-cn"
|
||||
@@ -1125,8 +1137,8 @@ def setup_model_provider(config: dict):
|
||||
if existing_custom:
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_update_config_for_provider("minimax-cn", pconfig.inference_base_url, default_model="MiniMax-M2.5")
|
||||
_set_model_provider(config, "minimax-cn", pconfig.inference_base_url)
|
||||
selected_base_url = pconfig.inference_base_url
|
||||
|
||||
elif provider_idx == 8: # Anthropic
|
||||
selected_provider = "anthropic"
|
||||
@@ -1229,8 +1241,8 @@ def setup_model_provider(config: dict):
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
# Don't save base_url for Anthropic — resolve_runtime_provider()
|
||||
# always hardcodes it. Stale base_urls contaminate other providers.
|
||||
_update_config_for_provider("anthropic", "", default_model="claude-opus-4-6")
|
||||
_set_model_provider(config, "anthropic")
|
||||
selected_base_url = ""
|
||||
|
||||
# else: provider_idx == 9 (Keep current) — only shown when a provider already exists
|
||||
# Normalize "keep current" to an explicit provider so downstream logic
|
||||
@@ -1460,12 +1472,6 @@ def setup_model_provider(config: dict):
|
||||
)
|
||||
print_success(f"Model set to: {_display}")
|
||||
|
||||
# Write provider+base_url to config.yaml only after model selection is complete.
|
||||
# This prevents a race condition where the gateway picks up a new provider
|
||||
# before the model name has been updated to match.
|
||||
if selected_provider in ("zai", "kimi-coding", "minimax", "minimax-cn", "anthropic") and selected_base_url is not None:
|
||||
_update_config_for_provider(selected_provider, selected_base_url)
|
||||
|
||||
save_config(config)
|
||||
|
||||
|
||||
|
||||
@@ -60,12 +60,6 @@ All fields are optional. Missing values inherit from the ``default`` skin.
|
||||
# Tool prefix: character for tool output lines (default: ┊)
|
||||
tool_prefix: "┊"
|
||||
|
||||
# Tool emojis: override the default emoji for any tool (used in spinners & progress)
|
||||
tool_emojis:
|
||||
terminal: "⚔" # Override terminal tool emoji
|
||||
web_search: "🔮" # Override web_search tool emoji
|
||||
# Any tool not listed here uses its registry default
|
||||
|
||||
USAGE
|
||||
=====
|
||||
|
||||
@@ -117,7 +111,6 @@ class SkinConfig:
|
||||
spinner: Dict[str, Any] = field(default_factory=dict)
|
||||
branding: Dict[str, str] = field(default_factory=dict)
|
||||
tool_prefix: str = "┊"
|
||||
tool_emojis: Dict[str, str] = field(default_factory=dict) # per-tool emoji overrides
|
||||
banner_logo: str = "" # Rich-markup ASCII art logo (replaces HERMES_AGENT_LOGO)
|
||||
banner_hero: str = "" # Rich-markup hero art (replaces HERMES_CADUCEUS)
|
||||
|
||||
@@ -548,7 +541,6 @@ def _build_skin_config(data: Dict[str, Any]) -> SkinConfig:
|
||||
spinner=spinner,
|
||||
branding=branding,
|
||||
tool_prefix=data.get("tool_prefix", default.get("tool_prefix", "┊")),
|
||||
tool_emojis=data.get("tool_emojis", {}),
|
||||
banner_logo=data.get("banner_logo", ""),
|
||||
banner_hero=data.get("banner_hero", ""),
|
||||
)
|
||||
|
||||
@@ -275,13 +275,8 @@ def show_status(args):
|
||||
print(color("◆ Gateway Service", Colors.CYAN, Colors.BOLD))
|
||||
|
||||
if sys.platform.startswith('linux'):
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
_gw_svc = get_service_name()
|
||||
except Exception:
|
||||
_gw_svc = "hermes-gateway"
|
||||
result = subprocess.run(
|
||||
["systemctl", "--user", "is-active", _gw_svc],
|
||||
["systemctl", "--user", "is-active", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
@@ -133,13 +133,7 @@ def uninstall_gateway_service():
|
||||
if platform.system() != "Linux":
|
||||
return False
|
||||
|
||||
try:
|
||||
from hermes_cli.gateway import get_service_name
|
||||
svc_name = get_service_name()
|
||||
except Exception:
|
||||
svc_name = "hermes-gateway"
|
||||
|
||||
service_file = Path.home() / ".config" / "systemd" / "user" / f"{svc_name}.service"
|
||||
service_file = Path.home() / ".config" / "systemd" / "user" / "hermes-gateway.service"
|
||||
|
||||
if not service_file.exists():
|
||||
return False
|
||||
@@ -147,14 +141,14 @@ def uninstall_gateway_service():
|
||||
try:
|
||||
# Stop the service
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "stop", svc_name],
|
||||
["systemctl", "--user", "stop", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
# Disable the service
|
||||
subprocess.run(
|
||||
["systemctl", "--user", "disable", svc_name],
|
||||
["systemctl", "--user", "disable", "hermes-gateway"],
|
||||
capture_output=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
@@ -927,11 +927,6 @@ class HonchoSessionManager:
|
||||
return False
|
||||
|
||||
assistant_peer = self._get_or_create_peer(session.assistant_peer_id)
|
||||
honcho_session = self._sessions_cache.get(session.honcho_session_id)
|
||||
if not honcho_session:
|
||||
logger.warning("No Honcho session cached for '%s', skipping AI seed", session_key)
|
||||
return False
|
||||
|
||||
try:
|
||||
wrapped = (
|
||||
f"<ai_identity_seed>\n"
|
||||
@@ -940,7 +935,7 @@ class HonchoSessionManager:
|
||||
f"{content.strip()}\n"
|
||||
f"</ai_identity_seed>"
|
||||
)
|
||||
honcho_session.add_messages([assistant_peer.message(wrapped)])
|
||||
assistant_peer.add_message("assistant", wrapped)
|
||||
logger.info("Seeded AI identity from '%s' into %s", source, session_key)
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
+6
-43
@@ -113,13 +113,6 @@ try:
|
||||
except Exception as e:
|
||||
logger.debug("MCP tool discovery failed: %s", e)
|
||||
|
||||
# Plugin tool discovery (user/project/pip plugins)
|
||||
try:
|
||||
from hermes_cli.plugins import discover_plugins
|
||||
discover_plugins()
|
||||
except Exception as e:
|
||||
logger.debug("Plugin discovery failed: %s", e)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Backward-compat constants (built once after discovery)
|
||||
@@ -229,16 +222,6 @@ def get_tool_definitions(
|
||||
for ts_name in get_all_toolsets():
|
||||
tools_to_include.update(resolve_toolset(ts_name))
|
||||
|
||||
# Always include plugin-registered tools — they bypass the toolset filter
|
||||
# because their toolsets are dynamic (created at plugin load time).
|
||||
try:
|
||||
from hermes_cli.plugins import get_plugin_tool_names
|
||||
plugin_tools = get_plugin_tool_names()
|
||||
if plugin_tools:
|
||||
tools_to_include.update(plugin_tools)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Ask the registry for schemas (only returns tools whose check_fn passes)
|
||||
filtered_tools = registry.get_definitions(tools_to_include, quiet=quiet_mode)
|
||||
|
||||
@@ -284,8 +267,6 @@ def handle_function_call(
|
||||
task_id: Optional[str] = None,
|
||||
user_task: Optional[str] = None,
|
||||
enabled_tools: Optional[List[str]] = None,
|
||||
honcho_manager: Optional[Any] = None,
|
||||
honcho_session_key: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Main function call dispatcher that routes calls to the tool registry.
|
||||
@@ -317,39 +298,21 @@ def handle_function_call(
|
||||
if function_name in _AGENT_LOOP_TOOLS:
|
||||
return json.dumps({"error": f"{function_name} must be handled by the agent loop"})
|
||||
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook
|
||||
invoke_hook("pre_tool_call", tool_name=function_name, args=function_args, task_id=task_id or "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if function_name == "execute_code":
|
||||
# Prefer the caller-provided list so subagents can't overwrite
|
||||
# the parent's tool set via the process-global.
|
||||
sandbox_enabled = enabled_tools if enabled_tools is not None else _last_resolved_tool_names
|
||||
result = registry.dispatch(
|
||||
return registry.dispatch(
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
enabled_tools=sandbox_enabled,
|
||||
honcho_manager=honcho_manager,
|
||||
honcho_session_key=honcho_session_key,
|
||||
)
|
||||
else:
|
||||
result = registry.dispatch(
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
user_task=user_task,
|
||||
honcho_manager=honcho_manager,
|
||||
honcho_session_key=honcho_session_key,
|
||||
)
|
||||
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook
|
||||
invoke_hook("post_tool_call", tool_name=function_name, args=function_args, result=result, task_id=task_id or "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
return registry.dispatch(
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
user_task=user_task,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing {function_name}: {str(e)}"
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
---
|
||||
name: blender-mcp
|
||||
description: Control Blender directly from Hermes via socket connection to the blender-mcp addon. Create 3D objects, materials, animations, and run arbitrary Blender Python (bpy) code. Use when user wants to create or modify anything in Blender.
|
||||
version: 1.0.0
|
||||
requires: Blender 4.3+ (desktop instance required, headless not supported)
|
||||
author: alireza78a
|
||||
tags: [blender, 3d, animation, modeling, bpy, mcp]
|
||||
---
|
||||
|
||||
# Blender MCP
|
||||
|
||||
Control a running Blender instance from Hermes via socket on TCP port 9876.
|
||||
|
||||
## Setup (one-time)
|
||||
|
||||
### 1. Install the Blender addon
|
||||
|
||||
curl -sL https://raw.githubusercontent.com/ahujasid/blender-mcp/main/addon.py -o ~/Desktop/blender_mcp_addon.py
|
||||
|
||||
In Blender:
|
||||
Edit > Preferences > Add-ons > Install > select blender_mcp_addon.py
|
||||
Enable "Interface: Blender MCP"
|
||||
|
||||
### 2. Start the socket server in Blender
|
||||
|
||||
Press N in Blender viewport to open sidebar.
|
||||
Find "BlenderMCP" tab and click "Start Server".
|
||||
|
||||
### 3. Verify connection
|
||||
|
||||
nc -z -w2 localhost 9876 && echo "OPEN" || echo "CLOSED"
|
||||
|
||||
## Protocol
|
||||
|
||||
Plain UTF-8 JSON over TCP -- no length prefix.
|
||||
|
||||
Send: {"type": "<command>", "params": {<kwargs>}}
|
||||
Receive: {"status": "success", "result": <value>}
|
||||
{"status": "error", "message": "<reason>"}
|
||||
|
||||
## Available Commands
|
||||
|
||||
| type | params | description |
|
||||
|-------------------------|-------------------|---------------------------------|
|
||||
| execute_code | code (str) | Run arbitrary bpy Python code |
|
||||
| get_scene_info | (none) | List all objects in scene |
|
||||
| get_object_info | object_name (str) | Details on a specific object |
|
||||
| get_viewport_screenshot | (none) | Screenshot of current viewport |
|
||||
|
||||
## Python Helper
|
||||
|
||||
Use this inside execute_code tool calls:
|
||||
|
||||
import socket, json
|
||||
|
||||
def blender_exec(code: str, host="localhost", port=9876, timeout=15):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.connect((host, port))
|
||||
s.settimeout(timeout)
|
||||
payload = json.dumps({"type": "execute_code", "params": {"code": code}})
|
||||
s.sendall(payload.encode("utf-8"))
|
||||
buf = b""
|
||||
while True:
|
||||
try:
|
||||
chunk = s.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
buf += chunk
|
||||
try:
|
||||
json.loads(buf.decode("utf-8"))
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except socket.timeout:
|
||||
break
|
||||
s.close()
|
||||
return json.loads(buf.decode("utf-8"))
|
||||
|
||||
## Common bpy Patterns
|
||||
|
||||
### Clear scene
|
||||
bpy.ops.object.select_all(action='SELECT')
|
||||
bpy.ops.object.delete()
|
||||
|
||||
### Add mesh objects
|
||||
bpy.ops.mesh.primitive_uv_sphere_add(radius=1, location=(0, 0, 0))
|
||||
bpy.ops.mesh.primitive_cube_add(size=2, location=(3, 0, 0))
|
||||
bpy.ops.mesh.primitive_cylinder_add(radius=0.5, depth=2, location=(-3, 0, 0))
|
||||
|
||||
### Create and assign material
|
||||
mat = bpy.data.materials.new(name="MyMat")
|
||||
mat.use_nodes = True
|
||||
bsdf = mat.node_tree.nodes.get("Principled BSDF")
|
||||
bsdf.inputs["Base Color"].default_value = (R, G, B, 1.0)
|
||||
bsdf.inputs["Roughness"].default_value = 0.3
|
||||
bsdf.inputs["Metallic"].default_value = 0.0
|
||||
obj.data.materials.append(mat)
|
||||
|
||||
### Keyframe animation
|
||||
obj.location = (0, 0, 0)
|
||||
obj.keyframe_insert(data_path="location", frame=1)
|
||||
obj.location = (0, 0, 3)
|
||||
obj.keyframe_insert(data_path="location", frame=60)
|
||||
|
||||
### Render to file
|
||||
bpy.context.scene.render.filepath = "/tmp/render.png"
|
||||
bpy.context.scene.render.engine = 'CYCLES'
|
||||
bpy.ops.render.render(write_still=True)
|
||||
|
||||
## Pitfalls
|
||||
|
||||
- Must check socket is open before running (nc -z localhost 9876)
|
||||
- Addon server must be started inside Blender each session (N-panel > BlenderMCP > Connect)
|
||||
- Break complex scenes into multiple smaller execute_code calls to avoid timeouts
|
||||
- Render output path must be absolute (/tmp/...) not relative
|
||||
- shade_smooth() requires object to be selected and in object mode
|
||||
@@ -1,422 +0,0 @@
|
||||
---
|
||||
name: oss-forensics
|
||||
description: |
|
||||
Supply chain investigation, evidence recovery, and forensic analysis for GitHub repositories.
|
||||
Covers deleted commit recovery, force-push detection, IOC extraction, multi-source evidence
|
||||
collection, hypothesis formation/validation, and structured forensic reporting.
|
||||
Inspired by RAPTOR's 1800+ line OSS Forensics system.
|
||||
category: security
|
||||
triggers:
|
||||
- "investigate this repository"
|
||||
- "investigate [owner/repo]"
|
||||
- "check for supply chain compromise"
|
||||
- "recover deleted commits"
|
||||
- "forensic analysis of [owner/repo]"
|
||||
- "was this repo compromised"
|
||||
- "supply chain attack"
|
||||
- "suspicious commit"
|
||||
- "force push detected"
|
||||
- "IOC extraction"
|
||||
toolsets:
|
||||
- terminal
|
||||
- web
|
||||
- file
|
||||
- delegation
|
||||
---
|
||||
|
||||
# OSS Security Forensics Skill
|
||||
|
||||
A 7-phase multi-agent investigation framework for researching open-source supply chain attacks.
|
||||
Adapted from RAPTOR's forensics system. Covers GitHub Archive, Wayback Machine, GitHub API,
|
||||
local git analysis, IOC extraction, evidence-backed hypothesis formation and validation,
|
||||
and final forensic report generation.
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ Anti-Hallucination Guardrails
|
||||
|
||||
Read these before every investigation step. Violating them invalidates the report.
|
||||
|
||||
1. **Evidence-First Rule**: Every claim in any report, hypothesis, or summary MUST cite at least one evidence ID (`EV-XXXX`). Assertions without citations are forbidden.
|
||||
2. **STAY IN YOUR LANE**: Each sub-agent (investigator) has a single data source. Do NOT mix sources. The GH Archive investigator does not query the GitHub API, and vice versa. Role boundaries are hard.
|
||||
3. **Fact vs. Hypothesis Separation**: Mark all unverified inferences with `[HYPOTHESIS]`. Only statements verified against original sources may be stated as facts.
|
||||
4. **No Evidence Fabrication**: The hypothesis validator MUST mechanically check that every cited evidence ID actually exists in the evidence store before accepting a hypothesis.
|
||||
5. **Proof-Required Disproval**: A hypothesis cannot be dismissed without a specific, evidence-backed counter-argument. "No evidence found" is not sufficient to disprove—it only makes a hypothesis inconclusive.
|
||||
6. **SHA/URL Double-Verification**: Any commit SHA, URL, or external identifier cited as evidence must be independently confirmed from at least two sources before being marked as verified.
|
||||
7. **Suspicious Code Rule**: Never run code found inside the investigated repository locally. Analyze statically only, or use `execute_code` in a sandboxed environment.
|
||||
8. **Secret Redaction**: Any API keys, tokens, or credentials discovered during investigation must be redacted in the final report. Log them internally only.
|
||||
|
||||
---
|
||||
|
||||
## Example Scenarios
|
||||
|
||||
- **Scenario A: Dependency Confusion**: A malicious package `internal-lib-v2` is uploaded to NPM with a higher version than the internal one. The investigator must track when this package was first seen and if any PushEvents in the target repo updated `package.json` to this version.
|
||||
- **Scenario B: Maintainer Takeover**: A long-term contributor's account is used to push a backdoored `.github/workflows/build.yml`. The investigator looks for PushEvents from this user after a long period of inactivity or from a new IP/location (if detectable via BigQuery).
|
||||
- **Scenario C: Force-Push Hide**: A developer accidentally commits a production secret, then force-pushes to "fix" it. The investigator uses `git fsck` and GH Archive to recover the original commit SHA and verify what was leaked.
|
||||
|
||||
---
|
||||
|
||||
> **Path convention**: Throughout this skill, `SKILL_DIR` refers to the root of this skill's
|
||||
> installation directory (the folder containing this `SKILL.md`). When the skill is loaded,
|
||||
> resolve `SKILL_DIR` to the actual path — e.g. `~/.hermes/skills/security/oss-forensics/`
|
||||
> or the `optional-skills/` equivalent. All script and template references are relative to it.
|
||||
|
||||
## Phase 0: Initialization
|
||||
|
||||
1. Create investigation working directory:
|
||||
```bash
|
||||
mkdir investigation_$(echo "REPO_NAME" | tr '/' '_')
|
||||
cd investigation_$(echo "REPO_NAME" | tr '/' '_')
|
||||
```
|
||||
2. Initialize the evidence store:
|
||||
```bash
|
||||
python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json list
|
||||
```
|
||||
3. Copy the forensic report template:
|
||||
```bash
|
||||
cp SKILL_DIR/templates/forensic-report.md ./investigation-report.md
|
||||
```
|
||||
4. Create an `iocs.md` file to track Indicators of Compromise as they are discovered.
|
||||
5. Record the investigation start time, target repository, and stated investigation goal.
|
||||
|
||||
---
|
||||
|
||||
## Phase 1: Prompt Parsing and IOC Extraction
|
||||
|
||||
**Goal**: Extract all structured investigative targets from the user's request.
|
||||
|
||||
**Actions**:
|
||||
- Parse the user prompt and extract:
|
||||
- Target repository (`owner/repo`)
|
||||
- Target actors (GitHub handles, email addresses)
|
||||
- Time window of interest (commit date ranges, PR timestamps)
|
||||
- Provided Indicators of Compromise: commit SHAs, file paths, package names, IP addresses, domains, API keys/tokens, malicious URLs
|
||||
- Any linked vendor security reports or blog posts
|
||||
|
||||
**Tools**: Reasoning only, or `execute_code` for regex extraction from large text blocks.
|
||||
|
||||
**Output**: Populate `iocs.md` with extracted IOCs. Each IOC must have:
|
||||
- Type (from: COMMIT_SHA, FILE_PATH, API_KEY, SECRET, IP_ADDRESS, DOMAIN, PACKAGE_NAME, ACTOR_USERNAME, MALICIOUS_URL, OTHER)
|
||||
- Value
|
||||
- Source (user-provided, inferred)
|
||||
|
||||
**Reference**: See [evidence-types.md](./references/evidence-types.md) for IOC taxonomy.
|
||||
|
||||
---
|
||||
|
||||
## Phase 2: Parallel Evidence Collection
|
||||
|
||||
Spawn up to 5 specialist investigator sub-agents using `delegate_task` (batch mode, max 3 concurrent). Each investigator has a **single data source** and must not mix sources.
|
||||
|
||||
> **Orchestrator note**: Pass the IOC list from Phase 1 and the investigation time window in the `context` field of each delegated task.
|
||||
|
||||
---
|
||||
|
||||
### Investigator 1: Local Git Investigator
|
||||
|
||||
**ROLE BOUNDARY**: You query the LOCAL GIT REPOSITORY ONLY. Do not call any external APIs.
|
||||
|
||||
**Actions**:
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/OWNER/REPO.git target_repo && cd target_repo
|
||||
|
||||
# Full commit log with stats
|
||||
git log --all --full-history --stat --format="%H|%ae|%an|%ai|%s" > ../git_log.txt
|
||||
|
||||
# Detect force-push evidence (orphaned/dangling commits)
|
||||
git fsck --lost-found --unreachable 2>&1 | grep commit > ../dangling_commits.txt
|
||||
|
||||
# Check reflog for rewritten history
|
||||
git reflog --all > ../reflog.txt
|
||||
|
||||
# List ALL branches including deleted remote refs
|
||||
git branch -a -v > ../branches.txt
|
||||
|
||||
# Find suspicious large binary additions
|
||||
git log --all --diff-filter=A --name-only --format="%H %ai" -- "*.so" "*.dll" "*.exe" "*.bin" > ../binary_additions.txt
|
||||
|
||||
# Check for GPG signature anomalies
|
||||
git log --show-signature --format="%H %ai %aN" > ../signature_check.txt 2>&1
|
||||
```
|
||||
|
||||
**Evidence to collect** (add via `python3 SKILL_DIR/scripts/evidence-store.py add`):
|
||||
- Each dangling commit SHA → type: `git`
|
||||
- Force-push evidence (reflog showing history rewrite) → type: `git`
|
||||
- Unsigned commits from verified contributors → type: `git`
|
||||
- Suspicious binary file additions → type: `git`
|
||||
|
||||
**Reference**: See [recovery-techniques.md](./references/recovery-techniques.md) for accessing force-pushed commits.
|
||||
|
||||
---
|
||||
|
||||
### Investigator 2: GitHub API Investigator
|
||||
|
||||
**ROLE BOUNDARY**: You query the GITHUB REST API ONLY. Do not run git commands locally.
|
||||
|
||||
**Actions**:
|
||||
```bash
|
||||
# Commits (paginated)
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/commits?per_page=100" > api_commits.json
|
||||
|
||||
# Pull Requests including closed/deleted
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/pulls?state=all&per_page=100" > api_prs.json
|
||||
|
||||
# Issues
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/issues?state=all&per_page=100" > api_issues.json
|
||||
|
||||
# Contributors and collaborator changes
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/contributors" > api_contributors.json
|
||||
|
||||
# Repository events (last 300)
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/events?per_page=100" > api_events.json
|
||||
|
||||
# Check specific suspicious commit SHA details
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/git/commits/SHA" > commit_detail.json
|
||||
|
||||
# Releases
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/releases?per_page=100" > api_releases.json
|
||||
|
||||
# Check if a specific commit exists (force-pushed commits may 404 on commits/ but succeed on git/commits/)
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/commits/SHA" | jq .sha
|
||||
```
|
||||
|
||||
**Cross-reference targets** (flag discrepancies as evidence):
|
||||
- PR exists in archive but missing from API → evidence of deletion
|
||||
- Contributor in archive events but not in contributors list → evidence of permission revocation
|
||||
- Commit in archive PushEvents but not in API commit list → evidence of force-push/deletion
|
||||
|
||||
**Reference**: See [evidence-types.md](./references/evidence-types.md) for GH event types.
|
||||
|
||||
---
|
||||
|
||||
### Investigator 3: Wayback Machine Investigator
|
||||
|
||||
**ROLE BOUNDARY**: You query the WAYBACK MACHINE CDX API ONLY. Do not use the GitHub API.
|
||||
|
||||
**Goal**: Recover deleted GitHub pages (READMEs, issues, PRs, releases, wiki pages).
|
||||
|
||||
**Actions**:
|
||||
```bash
|
||||
# Search for archived snapshots of the repo main page
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO&output=json&limit=100&from=YYYYMMDD&to=YYYYMMDD" > wayback_main.json
|
||||
|
||||
# Search for a specific deleted issue
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO/issues/NUM&output=json&limit=50" > wayback_issue_NUM.json
|
||||
|
||||
# Search for a specific deleted PR
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO/pull/NUM&output=json&limit=50" > wayback_pr_NUM.json
|
||||
|
||||
# Fetch the best snapshot of a page
|
||||
# Use the Wayback Machine URL: https://web.archive.org/web/TIMESTAMP/ORIGINAL_URL
|
||||
# Example: https://web.archive.org/web/20240101000000*/github.com/OWNER/REPO
|
||||
|
||||
# Advanced: Search for deleted releases/tags
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO/releases/tag/*&output=json" > wayback_tags.json
|
||||
|
||||
# Advanced: Search for historical wiki changes
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO/wiki/*&output=json" > wayback_wiki.json
|
||||
```
|
||||
|
||||
**Evidence to collect**:
|
||||
- Archived snapshots of deleted issues/PRs with their content
|
||||
- Historical README versions showing changes
|
||||
- Evidence of content present in archive but missing from current GitHub state
|
||||
|
||||
**Reference**: See [github-archive-guide.md](./references/github-archive-guide.md) for CDX API parameters.
|
||||
|
||||
---
|
||||
|
||||
### Investigator 4: GH Archive / BigQuery Investigator
|
||||
|
||||
**ROLE BOUNDARY**: You query GITHUB ARCHIVE via BIGQUERY ONLY. This is a tamper-proof record of all public GitHub events.
|
||||
|
||||
> **Prerequisites**: Requires Google Cloud credentials with BigQuery access (`gcloud auth application-default login`). If unavailable, skip this investigator and note it in the report.
|
||||
|
||||
**Cost Optimization Rules** (MANDATORY):
|
||||
1. ALWAYS run a `--dry_run` before every query to estimate cost.
|
||||
2. Use `_TABLE_SUFFIX` to filter by date range and minimize scanned data.
|
||||
3. Only SELECT the columns you need.
|
||||
4. Add a LIMIT unless aggregating.
|
||||
|
||||
```bash
|
||||
# Template: safe BigQuery query for PushEvents to OWNER/REPO
|
||||
bq query --use_legacy_sql=false --dry_run "
|
||||
SELECT created_at, actor.login, payload.commits, payload.before, payload.head,
|
||||
payload.size, payload.distinct_size
|
||||
FROM \`githubarchive.month.*\`
|
||||
WHERE _TABLE_SUFFIX BETWEEN 'YYYYMM' AND 'YYYYMM'
|
||||
AND type = 'PushEvent'
|
||||
AND repo.name = 'OWNER/REPO'
|
||||
LIMIT 1000
|
||||
"
|
||||
# If cost is acceptable, re-run without --dry_run
|
||||
|
||||
# Detect force-pushes: zero-distinct_size PushEvents mean commits were force-erased
|
||||
# payload.distinct_size = 0 AND payload.size > 0 → force push indicator
|
||||
|
||||
# Check for deleted branch events
|
||||
bq query --use_legacy_sql=false "
|
||||
SELECT created_at, actor.login, payload.ref, payload.ref_type
|
||||
FROM \`githubarchive.month.*\`
|
||||
WHERE _TABLE_SUFFIX BETWEEN 'YYYYMM' AND 'YYYYMM'
|
||||
AND type = 'DeleteEvent'
|
||||
AND repo.name = 'OWNER/REPO'
|
||||
LIMIT 200
|
||||
"
|
||||
```
|
||||
|
||||
**Evidence to collect**:
|
||||
- Force-push events (payload.size > 0, payload.distinct_size = 0)
|
||||
- DeleteEvents for branches/tags
|
||||
- WorkflowRunEvents for suspicious CI/CD automation
|
||||
- PushEvents that precede a "gap" in the git log (evidence of rewrite)
|
||||
|
||||
**Reference**: See [github-archive-guide.md](./references/github-archive-guide.md) for all 12 event types and query patterns.
|
||||
|
||||
---
|
||||
|
||||
### Investigator 5: IOC Enrichment Investigator
|
||||
|
||||
**ROLE BOUNDARY**: You enrich EXISTING IOCs from Phase 1 using passive public sources ONLY. Do not execute any code from the target repository.
|
||||
|
||||
**Actions**:
|
||||
- For each commit SHA: attempt recovery via direct GitHub URL (`github.com/OWNER/REPO/commit/SHA.patch`)
|
||||
- For each domain/IP: check passive DNS, WHOIS records (via `web_extract` on public WHOIS services)
|
||||
- For each package name: check npm/PyPI for matching malicious package reports
|
||||
- For each actor username: check GitHub profile, contribution history, account age
|
||||
- Recover force-pushed commits using 3 methods (see [recovery-techniques.md](./references/recovery-techniques.md))
|
||||
|
||||
---
|
||||
|
||||
## Phase 3: Evidence Consolidation
|
||||
|
||||
After all investigators complete:
|
||||
|
||||
1. Run `python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json list` to see all collected evidence.
|
||||
2. For each piece of evidence, verify the `content_sha256` hash matches the original source.
|
||||
3. Group evidence by:
|
||||
- **Timeline**: Sort all timestamped evidence chronologically
|
||||
- **Actor**: Group by GitHub handle or email
|
||||
- **IOC**: Link evidence to the IOC it relates to
|
||||
4. Identify **discrepancies**: items present in one source but absent in another (key deletion indicators).
|
||||
5. Flag evidence as `[VERIFIED]` (confirmed from 2+ independent sources) or `[UNVERIFIED]` (single source only).
|
||||
|
||||
---
|
||||
|
||||
## Phase 4: Hypothesis Formation
|
||||
|
||||
A hypothesis must:
|
||||
- State a specific claim (e.g., "Actor X force-pushed to BRANCH on DATE to erase commit SHA")
|
||||
- Cite at least 2 evidence IDs that support it (`EV-XXXX`, `EV-YYYY`)
|
||||
- Identify what evidence would disprove it
|
||||
- Be labeled `[HYPOTHESIS]` until validated
|
||||
|
||||
**Common hypothesis templates** (see [investigation-templates.md](./references/investigation-templates.md)):
|
||||
- Maintainer Compromise: legitimate account used post-takeover to inject malicious code
|
||||
- Dependency Confusion: package name squatting to intercept installs
|
||||
- CI/CD Injection: malicious workflow changes to run code during builds
|
||||
- Typosquatting: near-identical package name targeting misspellers
|
||||
- Credential Leak: token/key accidentally committed then force-pushed to erase
|
||||
|
||||
For each hypothesis, spawn a `delegate_task` sub-agent to attempt to find disconfirming evidence before confirming.
|
||||
|
||||
---
|
||||
|
||||
## Phase 5: Hypothesis Validation
|
||||
|
||||
The validator sub-agent MUST mechanically check:
|
||||
|
||||
1. For each hypothesis, extract all cited evidence IDs.
|
||||
2. Verify each ID exists in `evidence.json` (hard failure if any ID is missing → hypothesis rejected as potentially fabricated).
|
||||
3. Verify each `[VERIFIED]` piece of evidence was confirmed from 2+ sources.
|
||||
4. Check logical consistency: does the timeline depicted by the evidence support the hypothesis?
|
||||
5. Check for alternative explanations: could the same evidence pattern arise from a benign cause?
|
||||
|
||||
**Output**:
|
||||
- `VALIDATED`: All evidence cited, verified, logically consistent, no plausible alternative explanation.
|
||||
- `INCONCLUSIVE`: Evidence supports hypothesis but alternative explanations exist or evidence is insufficient.
|
||||
- `REJECTED`: Missing evidence IDs, unverified evidence cited as fact, logical inconsistency detected.
|
||||
|
||||
Rejected hypotheses feed back into Phase 4 for refinement (max 3 iterations).
|
||||
|
||||
---
|
||||
|
||||
## Phase 6: Final Report Generation
|
||||
|
||||
Populate `investigation-report.md` using the template in [forensic-report.md](./templates/forensic-report.md).
|
||||
|
||||
**Mandatory sections**:
|
||||
- Executive Summary: one-paragraph verdict (Compromised / Clean / Inconclusive) with confidence level
|
||||
- Timeline: chronological reconstruction of all significant events with evidence citations
|
||||
- Validated Hypotheses: each with status and supporting evidence IDs
|
||||
- Evidence Registry: table of all `EV-XXXX` entries with source, type, and verification status
|
||||
- IOC List: all extracted and enriched Indicators of Compromise
|
||||
- Chain of Custody: how evidence was collected, from what sources, at what timestamps
|
||||
- Recommendations: immediate mitigations if compromise detected; monitoring recommendations
|
||||
|
||||
**Report rules**:
|
||||
- Every factual claim must have at least one `[EV-XXXX]` citation
|
||||
- Executive Summary must state confidence level (High / Medium / Low)
|
||||
- All secrets/credentials must be redacted to `[REDACTED]`
|
||||
|
||||
---
|
||||
|
||||
## Phase 7: Completion
|
||||
|
||||
1. Run final evidence count: `python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json list`
|
||||
2. Archive the full investigation directory.
|
||||
3. If compromise is confirmed:
|
||||
- List immediate mitigations (rotate credentials, pin dependency hashes, notify affected users)
|
||||
- Identify affected versions/packages
|
||||
- Note disclosure obligations (if a public package: coordinate with the package registry)
|
||||
4. Present the final `investigation-report.md` to the user.
|
||||
|
||||
---
|
||||
|
||||
## Ethical Use Guidelines
|
||||
|
||||
This skill is designed for **defensive security investigation** — protecting open-source software from supply chain attacks. It must not be used for:
|
||||
|
||||
- **Harassment or stalking** of contributors or maintainers
|
||||
- **Doxing** — correlating GitHub activity to real identities for malicious purposes
|
||||
- **Competitive intelligence** — investigating proprietary or internal repositories without authorization
|
||||
- **False accusations** — publishing investigation results without validated evidence (see anti-hallucination guardrails)
|
||||
|
||||
Investigations should be conducted with the principle of **minimal intrusion**: collect only the evidence necessary to validate or refute the hypothesis. When publishing results, follow responsible disclosure practices and coordinate with affected maintainers before public disclosure.
|
||||
|
||||
If the investigation reveals a genuine compromise, follow the coordinated vulnerability disclosure process:
|
||||
1. Notify the repository maintainers privately first
|
||||
2. Allow reasonable time for remediation (typically 90 days)
|
||||
3. Coordinate with package registries (npm, PyPI, etc.) if published packages are affected
|
||||
4. File a CVE if appropriate
|
||||
|
||||
---
|
||||
|
||||
## API Rate Limiting
|
||||
|
||||
GitHub REST API enforces rate limits that will interrupt large investigations if not managed.
|
||||
|
||||
**Authenticated requests**: 5,000/hour (requires `GITHUB_TOKEN` env var or `gh` CLI auth)
|
||||
**Unauthenticated requests**: 60/hour (unusable for investigations)
|
||||
|
||||
**Best practices**:
|
||||
- Always authenticate: `export GITHUB_TOKEN=ghp_...` or use `gh` CLI (auto-authenticates)
|
||||
- Use conditional requests (`If-None-Match` / `If-Modified-Since` headers) to avoid consuming quota on unchanged data
|
||||
- For paginated endpoints, fetch all pages in sequence — don't parallelize against the same endpoint
|
||||
- Check `X-RateLimit-Remaining` header; if below 100, pause for `X-RateLimit-Reset` timestamp
|
||||
- BigQuery has its own quotas (10 TiB/day free tier) — always dry-run first
|
||||
- Wayback Machine CDX API: no formal rate limit, but be courteous (1-2 req/sec max)
|
||||
|
||||
If rate-limited mid-investigation, record the partial results in the evidence store and note the limitation in the report.
|
||||
|
||||
---
|
||||
|
||||
## Reference Materials
|
||||
|
||||
- [github-archive-guide.md](./references/github-archive-guide.md) — BigQuery queries, CDX API, 12 event types
|
||||
- [evidence-types.md](./references/evidence-types.md) — IOC taxonomy, evidence source types, observation types
|
||||
- [recovery-techniques.md](./references/recovery-techniques.md) — Recovering deleted commits, PRs, issues
|
||||
- [investigation-templates.md](./references/investigation-templates.md) — Pre-built hypothesis templates per attack type
|
||||
- [evidence-store.py](./scripts/evidence-store.py) — CLI tool for managing the evidence JSON store
|
||||
- [forensic-report.md](./templates/forensic-report.md) — Structured report template
|
||||
@@ -1,89 +0,0 @@
|
||||
# Evidence Types Reference
|
||||
|
||||
Taxonomy of all evidence types, IOC types, GitHub event types, and observation types
|
||||
used in OSS forensic investigations.
|
||||
|
||||
---
|
||||
|
||||
## Evidence Source Types
|
||||
|
||||
| Type | Description | Example Sources |
|
||||
|------|-------------|-----------------|
|
||||
| `git` | Data from local git repository analysis | `git log`, `git fsck`, `git reflog`, `git blame` |
|
||||
| `gh_api` | Data from GitHub REST API responses | `/repos/.../commits`, `/repos/.../pulls`, `/repos/.../events` |
|
||||
| `gh_archive` | Data from GitHub Archive (BigQuery) | `githubarchive.month.*` BigQuery tables |
|
||||
| `web_archive` | Archived web pages from Wayback Machine | CDX API results, `web.archive.org/web/...` snapshots |
|
||||
| `ioc` | Indicator of Compromise from any source | Extracted from vendor reports, git history, network traces |
|
||||
| `analysis` | Derived insight from cross-source correlation | "SHA present in archive but absent from API" |
|
||||
| `vendor_report` | External security vendor or researcher report | CVE advisories, blog posts, NVD records |
|
||||
| `manual` | Manually recorded observation by investigator | Notes on behavioral patterns, timeline gaps |
|
||||
|
||||
---
|
||||
|
||||
## IOC Types
|
||||
|
||||
| Type | Description | Example |
|
||||
|------|-------------|---------|
|
||||
| `COMMIT_SHA` | A git commit hash linked to malicious activity | `abc123def456...` |
|
||||
| `FILE_PATH` | A suspicious file inside the repository | `src/utils/crypto.js`, `dist/index.min.js` |
|
||||
| `API_KEY` | An API key accidentally committed | `AKIA...` (AWS), `ghp_...` (GitHub PAT) |
|
||||
| `SECRET` | A generic secret / credential | Database password, private key blob |
|
||||
| `IP_ADDRESS` | A C2 server or attacker IP | `192.0.2.1` |
|
||||
| `DOMAIN` | A malicious or suspicious domain | `evil-cdn.io`, typosquatted package registry domain |
|
||||
| `PACKAGE_NAME` | A malicious or squatted package name | `colo-rs` (typosquatting `color`), `lodash-utils` |
|
||||
| `ACTOR_USERNAME` | A GitHub handle linked to the attack | `malicious-bot-account` |
|
||||
| `MALICIOUS_URL` | A URL to a malicious resource | `https://evil.example.com/payload.sh` |
|
||||
| `WORKFLOW_FILE` | A suspicious CI/CD workflow file | `.github/workflows/release.yml` |
|
||||
| `BRANCH_NAME` | A suspicious branch | `refs/heads/temp-fix-do-not-merge` |
|
||||
| `TAG_NAME` | A suspicious git tag | `v1.0.0-security-patch` |
|
||||
| `RELEASE_NAME` | A suspicious release | Release with no associated tag or changelog |
|
||||
| `OTHER` | Catch-all for unclassified IOCs | — |
|
||||
|
||||
---
|
||||
|
||||
## GitHub Archive Event Types (12 Types)
|
||||
|
||||
| Event Type | Forensic Relevance |
|
||||
|------------|-------------------|
|
||||
| `PushEvent` | Core: `payload.distinct_size=0` with `payload.size>0` → force push. `payload.before`/`payload.head` shows rewritten history. |
|
||||
| `PullRequestEvent` | Detects deleted PRs, rapid open→close patterns, PRs from new accounts |
|
||||
| `IssueEvent` | Detects deleted issues, coordinated labeling, rapid closure of vulnerability reports |
|
||||
| `IssueCommentEvent` | Deleted comments, rapid activity bursts |
|
||||
| `WatchEvent` | Star-farming campaigns (coordinated starring from new accounts) |
|
||||
| `ForkEvent` | Unusual fork patterns before malicious commit |
|
||||
| `CreateEvent` | Branch/tag creation: signals new release or code injection point |
|
||||
| `DeleteEvent` | Branch/tag deletion: critical — often used to hide traces |
|
||||
| `ReleaseEvent` | Unauthorized releases, release artifacts modified post-publish |
|
||||
| `MemberEvent` | Collaborator added/removed: maintainer compromise indicator |
|
||||
| `PublicEvent` | Repository made public (sometimes to drop malicious code briefly) |
|
||||
| `WorkflowRunEvent` | CI/CD pipeline executions: workflow injection, secret exfiltration |
|
||||
|
||||
---
|
||||
|
||||
## Evidence Verification States
|
||||
|
||||
| State | Meaning |
|
||||
|-------|---------|
|
||||
| `unverified` | Collected from a single source, not cross-referenced |
|
||||
| `single_source` | The primary source has been confirmed directly (e.g., SHA resolves on GitHub), but no second source |
|
||||
| `multi_source_verified` | Confirmed from 2+ independent sources (e.g., GH Archive AND GitHub API both show the same event) |
|
||||
|
||||
Only `multi_source_verified` evidence may be cited as fact in validated hypotheses.
|
||||
`unverified` and `single_source` evidence must be labeled `[UNVERIFIED]` or `[SINGLE-SOURCE]`.
|
||||
|
||||
---
|
||||
|
||||
## Observation Types (Patterned after RAPTOR)
|
||||
|
||||
| Type | Description |
|
||||
|------|-------------|
|
||||
| `CommitObservation` | Specific commit SHA with metadata (author, date, files changed) |
|
||||
| `ForceWashObservation` | Evidence that commits were force-erased from a branch |
|
||||
| `DanglingCommitObservation` | SHA present in git object store but unreachable from any ref |
|
||||
| `IssueObservation` | A GitHub issue (current or archived) with title, body, timestamp |
|
||||
| `PRObservation` | A GitHub PR (current or archived) with diff summary, reviewers |
|
||||
| `IOC` | A single Indicator of Compromise with context |
|
||||
| `TimelineGap` | A period with unusual absence of expected activity |
|
||||
| `ActorAnomalyObservation` | Behavioral anomaly for a specific GitHub actor |
|
||||
| `WorkflowAnomalyObservation` | Suspicious CI/CD workflow change or unexpected run |
|
||||
| `CrossSourceDiscrepancy` | Item present in one source but absent in another (strong deletion indicator) |
|
||||
@@ -1,184 +0,0 @@
|
||||
# GitHub Archive Query Guide (BigQuery)
|
||||
|
||||
GitHub Archive records every public event on GitHub as immutable JSON records. This data is accessible via Google BigQuery and is the most reliable source for forensic investigation — events cannot be deleted or modified after recording.
|
||||
|
||||
## Public Dataset
|
||||
|
||||
- **Project**: `githubarchive`
|
||||
- **Tables**: `day.YYYYMMDD`, `month.YYYYMM`, `year.YYYY`
|
||||
- **Cost**: $6.25 per TiB scanned. Always run dry runs first.
|
||||
- **Access**: Requires a Google Cloud account with BigQuery enabled. Free tier includes 1 TiB/month of queries.
|
||||
|
||||
---
|
||||
|
||||
## The 12 GitHub Event Types
|
||||
|
||||
| Event Type | What It Records | Forensic Value |
|
||||
|------------|-----------------|----------------|
|
||||
| `PushEvent` | Commits pushed to a branch | Force-push detection, commit timeline, author attribution |
|
||||
| `PullRequestEvent` | PR opened, closed, merged, reopened | Deleted PR recovery, review timeline |
|
||||
| `IssuesEvent` | Issue opened, closed, reopened, labeled | Deleted issue recovery, social engineering traces |
|
||||
| `IssueCommentEvent` | Comments on issues and PRs | Deleted comment recovery, communication patterns |
|
||||
| `CreateEvent` | Branch, tag, or repository creation | Suspicious branch creation, tag timing |
|
||||
| `DeleteEvent` | Branch or tag deletion | Evidence of cleanup after compromise |
|
||||
| `MemberEvent` | Collaborator added or removed | Permission changes, access escalation |
|
||||
| `PublicEvent` | Repository made public | Accidental exposure of private repos |
|
||||
| `WatchEvent` | User stars a repository | Actor reconnaissance patterns |
|
||||
| `ForkEvent` | Repository forked | Exfiltration of code before cleanup |
|
||||
| `ReleaseEvent` | Release published, edited, deleted | Malicious release injection, deleted release recovery |
|
||||
| `WorkflowRunEvent` | GitHub Actions workflow triggered | CI/CD abuse, unauthorized workflow runs |
|
||||
|
||||
---
|
||||
|
||||
## Query Templates
|
||||
|
||||
### Basic: All Events for a Repository
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
created_at,
|
||||
type,
|
||||
actor.login,
|
||||
repo.name,
|
||||
payload
|
||||
FROM
|
||||
`githubarchive.day.20240101` -- Adjust date
|
||||
WHERE
|
||||
repo.name = 'owner/repo'
|
||||
AND type IN ('PushEvent', 'DeleteEvent', 'MemberEvent')
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
```
|
||||
|
||||
### Force-Push Detection
|
||||
|
||||
Force-pushes produce PushEvents where commits are overwritten. Key indicators:
|
||||
- `payload.distinct_size = 0` with `payload.size > 0` → commits were erased
|
||||
- `payload.before` contains the SHA before the rewrite (recoverable)
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
created_at,
|
||||
actor.login,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.before') AS before_sha,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.head') AS after_sha,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.size') AS total_commits,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.distinct_size') AS distinct_commits,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.ref') AS branch_ref
|
||||
FROM
|
||||
`githubarchive.month.*`
|
||||
WHERE
|
||||
_TABLE_SUFFIX BETWEEN '202401' AND '202403'
|
||||
AND type = 'PushEvent'
|
||||
AND repo.name = 'owner/repo'
|
||||
AND CAST(JSON_EXTRACT_SCALAR(payload, '$.distinct_size') AS INT64) = 0
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
```
|
||||
|
||||
### Deleted Branch/Tag Detection
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
created_at,
|
||||
actor.login,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.ref') AS deleted_ref,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.ref_type') AS ref_type
|
||||
FROM
|
||||
`githubarchive.month.*`
|
||||
WHERE
|
||||
_TABLE_SUFFIX BETWEEN '202401' AND '202403'
|
||||
AND type = 'DeleteEvent'
|
||||
AND repo.name = 'owner/repo'
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
```
|
||||
|
||||
### Collaborator Permission Changes
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
created_at,
|
||||
actor.login,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.action') AS action,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.member.login') AS member
|
||||
FROM
|
||||
`githubarchive.month.*`
|
||||
WHERE
|
||||
_TABLE_SUFFIX BETWEEN '202401' AND '202403'
|
||||
AND type = 'MemberEvent'
|
||||
AND repo.name = 'owner/repo'
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
```
|
||||
|
||||
### CI/CD Workflow Activity
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
created_at,
|
||||
actor.login,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.action') AS action,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.workflow_run.name') AS workflow_name,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.workflow_run.conclusion') AS conclusion,
|
||||
JSON_EXTRACT_SCALAR(payload, '$.workflow_run.head_sha') AS head_sha
|
||||
FROM
|
||||
`githubarchive.month.*`
|
||||
WHERE
|
||||
_TABLE_SUFFIX BETWEEN '202401' AND '202403'
|
||||
AND type = 'WorkflowRunEvent'
|
||||
AND repo.name = 'owner/repo'
|
||||
ORDER BY
|
||||
created_at ASC
|
||||
```
|
||||
|
||||
### Actor Activity Profiling
|
||||
|
||||
```sql
|
||||
SELECT
|
||||
type,
|
||||
COUNT(*) AS event_count,
|
||||
MIN(created_at) AS first_event,
|
||||
MAX(created_at) AS last_event
|
||||
FROM
|
||||
`githubarchive.month.*`
|
||||
WHERE
|
||||
_TABLE_SUFFIX BETWEEN '202301' AND '202412'
|
||||
AND actor.login = 'suspicious-username'
|
||||
GROUP BY type
|
||||
ORDER BY event_count DESC
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cost Optimization (MANDATORY)
|
||||
|
||||
1. **Always dry run first**: Add `--dry_run` flag to `bq query` to see estimated bytes scanned before executing.
|
||||
2. **Use `_TABLE_SUFFIX`**: Narrow the date range as much as possible. `day.*` tables are cheapest for narrow windows; `month.*` for broader sweeps.
|
||||
3. **Select only needed columns**: Avoid `SELECT *`. The `payload` column is large — only select specific JSON paths.
|
||||
4. **Add LIMIT**: Use `LIMIT 1000` during exploration. Remove only for final exhaustive queries.
|
||||
5. **Column filtering in WHERE**: Filter on indexed columns (`type`, `repo.name`, `actor.login`) before payload extraction.
|
||||
|
||||
**Cost estimation**: A single month of GH Archive data is ~1-2 TiB uncompressed. Querying a specific repo + event type with `_TABLE_SUFFIX` typically scans 1-10 GiB ($0.006-$0.06).
|
||||
|
||||
---
|
||||
|
||||
## Accessing via Hermes
|
||||
|
||||
**Option A: BigQuery CLI** (if `gcloud` is installed)
|
||||
```bash
|
||||
bq query --use_legacy_sql=false --format=json "YOUR QUERY"
|
||||
```
|
||||
|
||||
**Option B: Python** (via `execute_code`)
|
||||
```python
|
||||
from google.cloud import bigquery
|
||||
client = bigquery.Client()
|
||||
query = "YOUR QUERY"
|
||||
results = client.query(query).result()
|
||||
for row in results:
|
||||
print(dict(row))
|
||||
```
|
||||
|
||||
**Option C: No GCP credentials available**
|
||||
If BigQuery is unavailable, document this limitation in the report. Use the other 4 investigators (Git, GitHub API, Wayback Machine, IOC Enrichment) — they cover most investigation needs without BigQuery.
|
||||
@@ -1,131 +0,0 @@
|
||||
# Investigation Templates
|
||||
|
||||
Pre-built hypothesis and investigation templates for common supply chain attack scenarios.
|
||||
Each template includes: attack pattern, key evidence to collect, and hypothesis starters.
|
||||
|
||||
---
|
||||
|
||||
## Template 1: Maintainer Account Compromise
|
||||
|
||||
**Pattern**: Attacker gains access to a legitimate maintainer account (phishing, credential stuffing)
|
||||
and uses it to push malicious code, create backdoored releases, or exfiltrate CI secrets.
|
||||
|
||||
**Real-world examples**: XZ Utils (2024), Codecov (2021), event-stream (2018)
|
||||
|
||||
**Key Evidence to Collect**:
|
||||
- [ ] Push events from maintainer account outside normal working hours/timezone
|
||||
- [ ] Commits adding new dependencies, obfuscated code, or modified build scripts
|
||||
- [ ] Release creation immediately after suspicious push (to maximize package distribution)
|
||||
- [ ] MemberEvent adding unknown collaborators (attacker adding backup access)
|
||||
- [ ] WorkflowRunEvent with unexpected secret access or exfiltration-like behavior
|
||||
- [ ] Account login location changes (check social media, conference talks for corroboration)
|
||||
|
||||
**Hypothesis Starters**:
|
||||
```
|
||||
[HYPOTHESIS] Actor <HANDLE>'s account was compromised on or around <DATE>,
|
||||
based on anomalous commit timing [EV-XXXX] and geographic access patterns [EV-YYYY].
|
||||
```
|
||||
```
|
||||
[HYPOTHESIS] Release <VERSION> was published by the compromised account to push
|
||||
malicious code to downstream users, evidenced by the malicious commit [EV-XXXX]
|
||||
being added <N> hours before the release [EV-YYYY].
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Template 2: Malicious Dependency Injection
|
||||
|
||||
**Pattern**: A trusted package is modified to include malicious code in a dependency,
|
||||
or a new malicious dependency is injected into an existing package.
|
||||
|
||||
**Key Evidence to Collect**:
|
||||
- [ ] Diff of `package.json`/`requirements.txt`/`go.mod` before and after suspicious commit
|
||||
- [ ] The new dependency's publication timestamp vs. the injection commit timestamp
|
||||
- [ ] Whether the new dependency exists on npm/PyPI and who owns it
|
||||
- [ ] Any obfuscation patterns in the injected dependency code
|
||||
- [ ] Install-time scripts (`postinstall`, `setup.py`, etc.) that execute code on install
|
||||
|
||||
**Hypothesis Starters**:
|
||||
```
|
||||
[HYPOTHESIS] Commit <SHA> [EV-XXXX] introduced dependency <PACKAGE@VERSION>
|
||||
which appears to be a malicious package published by actor <HANDLE> [EV-YYYY],
|
||||
designed to execute <BEHAVIOR> during installation.
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Template 3: CI/CD Pipeline Injection
|
||||
|
||||
**Pattern**: Attacker modifies GitHub Actions workflows to steal secrets, exfiltrate code,
|
||||
or inject malicious artifacts into the build output.
|
||||
|
||||
**Key Evidence to Collect**:
|
||||
- [ ] Diff of all `.github/workflows/*.yml` files before/after suspicious period
|
||||
- [ ] WorkflowRunEvents triggered by the modified workflows
|
||||
- [ ] Any `curl`, `wget`, or network calls added to workflow steps
|
||||
- [ ] New or modified `env:` sections referencing `secrets.*`
|
||||
- [ ] Artifacts produced by modified workflow runs
|
||||
|
||||
**Hypothesis Starters**:
|
||||
```
|
||||
[HYPOTHESIS] Workflow file <FILE> was modified in commit <SHA> [EV-XXXX] to
|
||||
exfiltrate repository secrets via <METHOD>, as evidenced by the added network
|
||||
call pattern [EV-YYYY].
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Template 4: Typosquatting / Dependency Confusion
|
||||
|
||||
**Pattern**: Attacker registers a package with a name similar to a popular package
|
||||
(or an internal package name) to intercept installs from users who mistype.
|
||||
|
||||
**Key Evidence to Collect**:
|
||||
- [ ] Registration timestamp of the suspicious package on the registry
|
||||
- [ ] Package content: does it contain malicious code or is it a stub?
|
||||
- [ ] Download statistics for the suspicious package
|
||||
- [ ] Names of internal packages that could be targeted (if private repo scope)
|
||||
- [ ] Any references to the legitimate package in the malicious one's metadata
|
||||
|
||||
**Hypothesis Starters**:
|
||||
```
|
||||
[HYPOTHESIS] Package <MALICIOUS_NAME> was registered on <DATE> [EV-XXXX] to
|
||||
typosquat on <LEGITIMATE_NAME>, targeting users who misspell the package name.
|
||||
The package contains <BEHAVIOR> [EV-YYYY].
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Template 5: Force-Push History Rewrite (Evidence Erasure)
|
||||
|
||||
**Pattern**: After a malicious commit is detected (or before wider notice), the attacker
|
||||
force-pushes to remove the malicious commit from branch history.
|
||||
|
||||
**Detection is key** — this template focuses on proving the erasure happened.
|
||||
|
||||
**Key Evidence to Collect**:
|
||||
- [ ] GH Archive PushEvent with `distinct_size=0` (force push indicator) [EV-XXXX]
|
||||
- [ ] The SHA of the commit BEFORE the force push (from GH Archive `payload.before`)
|
||||
- [ ] Recovery of the erased commit via direct URL or `git fetch origin SHA`
|
||||
- [ ] Wayback Machine snapshot of the commit page before erasure
|
||||
- [ ] Timeline gap in git log (N commits visible in archive but M < N in current repo)
|
||||
|
||||
**Hypothesis Starters**:
|
||||
```
|
||||
[HYPOTHESIS] Actor <HANDLE> force-pushed branch <BRANCH> on <DATE> [EV-XXXX]
|
||||
to erase commit <SHA> [EV-YYYY], which contained <MALICIOUS_CONTENT>.
|
||||
The erased commit was recovered via <METHOD> [EV-ZZZZ].
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cross-Cutting Investigation Checklist
|
||||
|
||||
Apply to every investigation regardless of template:
|
||||
|
||||
- [ ] Check all contributors for newly created accounts (< 30 days old at time of malicious activity)
|
||||
- [ ] Check if any maintainer account changed email in the period (sign of account takeover)
|
||||
- [ ] Verify GPG signatures on suspicious commits match known maintainer keys
|
||||
- [ ] Check if the repository changed ownership or transferred orgs near the incident
|
||||
- [ ] Look for "cleanup" commits immediately after the malicious commit (cover-up pattern)
|
||||
- [ ] Check related packages/repos by the same author for similar patterns
|
||||
@@ -1,164 +0,0 @@
|
||||
# Deleted Content Recovery Techniques
|
||||
|
||||
## Key Insight: GitHub Never Fully Deletes Force-Pushed Commits
|
||||
|
||||
Force-pushed commits are removed from the branch history but REMAIN on GitHub's servers until garbage collection runs (which can take weeks to months). This is the foundation of deleted commit recovery.
|
||||
|
||||
---
|
||||
|
||||
## Method 1: Direct GitHub URL (Fastest — No Auth Required)
|
||||
|
||||
If you have a commit SHA, access it directly even if it was force-pushed off a branch:
|
||||
|
||||
```bash
|
||||
# View commit metadata
|
||||
curl -s "https://github.com/OWNER/REPO/commit/SHA"
|
||||
|
||||
# Download as patch (includes full diff)
|
||||
curl -s "https://github.com/OWNER/REPO/commit/SHA.patch" > recovered_commit.patch
|
||||
|
||||
# Download as diff
|
||||
curl -s "https://github.com/OWNER/REPO/commit/SHA.diff" > recovered_commit.diff
|
||||
|
||||
# Example (Istio credential leak - real incident):
|
||||
curl -s "https://github.com/istio/istio/commit/FORCE_PUSHED_SHA.patch"
|
||||
```
|
||||
|
||||
**When this works**: SHA is known (from GH Archive, Wayback Machine, or `git fsck`)
|
||||
**When this fails**: GitHub has already garbage-collected the object (rare, typically 30–90 days post-force-push)
|
||||
|
||||
---
|
||||
|
||||
## Method 2: GitHub REST API
|
||||
|
||||
```bash
|
||||
# Works for commits force-pushed off branches but still on server
|
||||
# Note: /commits/SHA may 404, but /git/commits/SHA often succeeds for orphaned commits
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/git/commits/SHA" | jq .
|
||||
|
||||
# Get the tree (file listing) of a force-pushed commit
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/git/trees/SHA?recursive=1" | jq .
|
||||
|
||||
# Get a specific file from a force-pushed commit
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/contents/PATH?ref=SHA" | jq .content | base64 -d
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Method 3: Git Fetch by SHA (Local — Requires Clone)
|
||||
|
||||
```bash
|
||||
# Fetch an orphaned commit directly by SHA into local repo
|
||||
cd target_repo
|
||||
git fetch origin SHA
|
||||
git log FETCH_HEAD -1 # view the commit
|
||||
git diff FETCH_HEAD~1 FETCH_HEAD # view the diff
|
||||
|
||||
# If the SHA was recently force-pushed it will still be fetchable
|
||||
# This stops working once GitHub GC runs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Method 4: Dangling Commits via git fsck
|
||||
|
||||
```bash
|
||||
cd target_repo
|
||||
|
||||
# Find all unreachable objects (includes force-pushed commits)
|
||||
git fsck --unreachable --no-reflogs 2>&1 | grep "unreachable commit" | awk '{print $3}' > dangling_shas.txt
|
||||
|
||||
# For each dangling commit, get its metadata
|
||||
while read sha; do
|
||||
echo "=== $sha ===" >> dangling_details.txt
|
||||
git show --stat "$sha" >> dangling_details.txt 2>&1
|
||||
done < dangling_shas.txt
|
||||
|
||||
# Note: dangling objects only exist in LOCAL clone — not the same as GitHub's copies
|
||||
# GitHub's copies are accessible via Methods 1-3 until GC runs
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Recovering Deleted GitHub Issues and PRs
|
||||
|
||||
### Via Wayback Machine CDX API
|
||||
|
||||
```bash
|
||||
# Find all archived snapshots of a specific issue
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO/issues/NUMBER&output=json&limit=50&fl=timestamp,statuscode,original" | python3 -m json.tool
|
||||
|
||||
# Fetch the best snapshot
|
||||
# Use the timestamp from the CDX result:
|
||||
# https://web.archive.org/web/TIMESTAMP/https://github.com/OWNER/REPO/issues/NUMBER
|
||||
curl -s "https://web.archive.org/web/TIMESTAMP/https://github.com/OWNER/REPO/issues/NUMBER" > issue_NUMBER_archived.html
|
||||
|
||||
# Find all snapshots of the repo in a date range
|
||||
curl -s "https://web.archive.org/cdx/search/cdx?url=github.com/OWNER/REPO*&output=json&from=20240101&to=20240201&limit=200&fl=timestamp,urlkey,statuscode" | python3 -m json.tool
|
||||
```
|
||||
|
||||
### Via GitHub API (Limited — Only Non-Deleted Content)
|
||||
|
||||
```bash
|
||||
# Closed issues (not deleted) are retrievable
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/issues?state=closed&per_page=100" | jq '.[].number'
|
||||
|
||||
# Note: DELETED issues/PRs do NOT appear in the API. Use Wayback Machine or GH Archive for those.
|
||||
```
|
||||
|
||||
### Via GitHub Archive (For Event History — Not Content)
|
||||
|
||||
```sql
|
||||
-- Find all IssueEvents for a repo in a date range
|
||||
SELECT created_at, actor.login, payload.action, payload.issue.number, payload.issue.title
|
||||
FROM `githubarchive.day.*`
|
||||
WHERE _TABLE_SUFFIX BETWEEN '20240101' AND '20240201'
|
||||
AND type = 'IssuesEvent'
|
||||
AND repo.name = 'OWNER/REPO'
|
||||
ORDER BY created_at
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Recovering Deleted Files from a Known Commit
|
||||
|
||||
```bash
|
||||
# If you have the commit SHA (even force-pushed):
|
||||
git show SHA:path/to/file.py > recovered_file.py
|
||||
|
||||
# Or via API (base64 encoded content):
|
||||
curl -s "https://api.github.com/repos/OWNER/REPO/contents/path/to/file.py?ref=SHA" | python3 -c "
|
||||
import sys, json, base64
|
||||
d = json.load(sys.stdin)
|
||||
print(base64.b64decode(d['content']).decode())
|
||||
"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Evidence Recording
|
||||
|
||||
After recovering any deleted content, immediately record it:
|
||||
|
||||
```bash
|
||||
python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json add \
|
||||
--source "git fetch origin FORCE_PUSHED_SHA" \
|
||||
--content "Recovered commit: FORCE_PUSHED_SHA | Author: attacker@example.com | Date: 2024-01-15 | Added file: malicious.sh" \
|
||||
--type git \
|
||||
--actor "attacker-handle" \
|
||||
--url "https://github.com/OWNER/REPO/commit/FORCE_PUSHED_SHA.patch" \
|
||||
--timestamp "2024-01-15T00:00:00Z" \
|
||||
--verification single_source \
|
||||
--notes "Commit force-pushed off main branch on 2024-01-16. Recovered via direct fetch."
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Recovery Failure Modes
|
||||
|
||||
| Failure | Cause | Workaround |
|
||||
|---------|-------|------------|
|
||||
| `git fetch origin SHA` returns "not our ref" | GitHub GC already ran | Try Method 1/2, search Wayback Machine |
|
||||
| `github.com/OWNER/REPO/commit/SHA` returns 404 | GC ran or SHA is wrong | Verify SHA via GH Archive; try partial SHA search |
|
||||
| Wayback Machine has no snapshots | Page was never crawled by IA | Check `commoncrawl.org`, check Google Cache |
|
||||
| BigQuery shows event but no content | GH Archive stores event metadata, not file contents | Recovery only reveals the event occurred, not the content |
|
||||
@@ -1,313 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
OSS Forensics Evidence Store Manager
|
||||
Manages a JSON-based evidence store for forensic investigations.
|
||||
|
||||
Commands:
|
||||
add - Add a piece of evidence
|
||||
list - List all evidence (optionally filter by type or actor)
|
||||
verify - Re-check SHA-256 hashes for integrity
|
||||
query - Search evidence by keyword
|
||||
export - Export evidence as a Markdown table
|
||||
summary - Print investigation statistics
|
||||
|
||||
Usage example:
|
||||
python3 evidence-store.py --store evidence.json add \
|
||||
--source "git fsck output" --content "dangling commit abc123" \
|
||||
--type git --actor "malicious-user" --url "https://github.com/owner/repo/commit/abc123"
|
||||
|
||||
python3 evidence-store.py --store evidence.json list --type git
|
||||
python3 evidence-store.py --store evidence.json verify
|
||||
python3 evidence-store.py --store evidence.json export > evidence-table.md
|
||||
"""
|
||||
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
import datetime
|
||||
import hashlib
|
||||
import sys
|
||||
|
||||
EVIDENCE_TYPES = [
|
||||
"git", # Local git repository data (commits, reflog, fsck)
|
||||
"gh_api", # GitHub REST API responses
|
||||
"gh_archive", # GitHub Archive / BigQuery query results
|
||||
"web_archive", # Wayback Machine snapshots
|
||||
"ioc", # Indicator of Compromise (SHA, domain, IP, package name, etc.)
|
||||
"analysis", # Derived analysis / cross-source correlation result
|
||||
"manual", # Manually noted observation
|
||||
"vendor_report", # External security vendor report excerpt
|
||||
]
|
||||
|
||||
VERIFICATION_STATES = ["unverified", "single_source", "multi_source_verified"]
|
||||
|
||||
IOC_TYPES = [
|
||||
"COMMIT_SHA", "FILE_PATH", "API_KEY", "SECRET", "IP_ADDRESS",
|
||||
"DOMAIN", "PACKAGE_NAME", "ACTOR_USERNAME", "MALICIOUS_URL",
|
||||
"WORKFLOW_FILE", "BRANCH_NAME", "TAG_NAME", "RELEASE_NAME", "OTHER",
|
||||
]
|
||||
|
||||
|
||||
def _now_iso():
|
||||
return datetime.datetime.now(datetime.timezone.utc).isoformat(timespec="seconds") + "Z"
|
||||
|
||||
|
||||
def _sha256(content: str) -> str:
|
||||
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
class EvidenceStore:
|
||||
def __init__(self, filepath: str):
|
||||
self.filepath = filepath
|
||||
self.data = {
|
||||
"metadata": {
|
||||
"version": "2.0",
|
||||
"created_at": _now_iso(),
|
||||
"last_updated": _now_iso(),
|
||||
"investigation": "",
|
||||
"target_repo": "",
|
||||
},
|
||||
"evidence": [],
|
||||
"chain_of_custody": [],
|
||||
}
|
||||
if os.path.exists(filepath):
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
self.data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
print(f"Error loading evidence store '{filepath}': {e}", file=sys.stderr)
|
||||
print("Hint: The file might be corrupted. Check for manual edits or syntax errors.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
def _save(self):
|
||||
self.data["metadata"]["last_updated"] = _now_iso()
|
||||
with open(self.filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(self.data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def _next_id(self) -> str:
|
||||
return f"EV-{len(self.data['evidence']) + 1:04d}"
|
||||
|
||||
def add(
|
||||
self,
|
||||
source: str,
|
||||
content: str,
|
||||
evidence_type: str,
|
||||
actor: str = None,
|
||||
url: str = None,
|
||||
timestamp: str = None,
|
||||
ioc_type: str = None,
|
||||
verification: str = "unverified",
|
||||
notes: str = None,
|
||||
) -> str:
|
||||
evidence_id = self._next_id()
|
||||
entry = {
|
||||
"id": evidence_id,
|
||||
"type": evidence_type,
|
||||
"source": source,
|
||||
"content": content,
|
||||
"content_sha256": _sha256(content),
|
||||
"actor": actor,
|
||||
"url": url,
|
||||
"event_timestamp": timestamp,
|
||||
"collected_at": _now_iso(),
|
||||
"ioc_type": ioc_type,
|
||||
"verification": verification,
|
||||
"notes": notes,
|
||||
}
|
||||
self.data["evidence"].append(entry)
|
||||
self.data["chain_of_custody"].append({
|
||||
"action": "add",
|
||||
"evidence_id": evidence_id,
|
||||
"timestamp": _now_iso(),
|
||||
"source": source,
|
||||
})
|
||||
self._save()
|
||||
return evidence_id
|
||||
|
||||
def list_evidence(self, filter_type: str = None, filter_actor: str = None):
|
||||
results = self.data["evidence"]
|
||||
if filter_type:
|
||||
results = [e for e in results if e.get("type") == filter_type]
|
||||
if filter_actor:
|
||||
results = [e for e in results if e.get("actor") == filter_actor]
|
||||
return results
|
||||
|
||||
def verify_integrity(self):
|
||||
"""Re-compute SHA-256 for all entries and report mismatches."""
|
||||
issues = []
|
||||
for entry in self.data["evidence"]:
|
||||
expected = _sha256(entry["content"])
|
||||
stored = entry.get("content_sha256", "")
|
||||
if expected != stored:
|
||||
issues.append({
|
||||
"id": entry["id"],
|
||||
"stored_sha256": stored,
|
||||
"computed_sha256": expected,
|
||||
})
|
||||
return issues
|
||||
|
||||
def query(self, keyword: str):
|
||||
"""Search for keyword in content, source, actor, or url."""
|
||||
keyword_lower = keyword.lower()
|
||||
return [
|
||||
e for e in self.data["evidence"]
|
||||
if keyword_lower in (e.get("content", "") or "").lower()
|
||||
or keyword_lower in (e.get("source", "") or "").lower()
|
||||
or keyword_lower in (e.get("actor", "") or "").lower()
|
||||
or keyword_lower in (e.get("url", "") or "").lower()
|
||||
]
|
||||
|
||||
def export_markdown(self) -> str:
|
||||
lines = [
|
||||
"# Evidence Registry",
|
||||
"",
|
||||
f"**Store**: `{self.filepath}`",
|
||||
f"**Last Updated**: {self.data['metadata'].get('last_updated', 'N/A')}",
|
||||
f"**Total Evidence Items**: {len(self.data['evidence'])}",
|
||||
"",
|
||||
"| ID | Type | Source | Actor | Verification | Event Timestamp | URL |",
|
||||
"|----|------|--------|-------|--------------|-----------------|-----|",
|
||||
]
|
||||
for e in self.data["evidence"]:
|
||||
url = e.get("url") or ""
|
||||
url_display = f"[link]({url})" if url else ""
|
||||
lines.append(
|
||||
f"| {e['id']} | {e.get('type','')} | {e.get('source','')} "
|
||||
f"| {e.get('actor') or ''} | {e.get('verification','')} "
|
||||
f"| {e.get('event_timestamp') or ''} | {url_display} |"
|
||||
)
|
||||
lines.append("")
|
||||
lines.append("## Chain of Custody")
|
||||
lines.append("")
|
||||
lines.append("| Evidence ID | Action | Timestamp | Source |")
|
||||
lines.append("|-------------|--------|-----------|--------|")
|
||||
for c in self.data["chain_of_custody"]:
|
||||
lines.append(
|
||||
f"| {c.get('evidence_id','')} | {c.get('action','')} "
|
||||
f"| {c.get('timestamp','')} | {c.get('source','')} |"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
def summary(self) -> dict:
|
||||
by_type = {}
|
||||
by_verification = {}
|
||||
actors = set()
|
||||
for e in self.data["evidence"]:
|
||||
t = e.get("type", "unknown")
|
||||
by_type[t] = by_type.get(t, 0) + 1
|
||||
v = e.get("verification", "unverified")
|
||||
by_verification[v] = by_verification.get(v, 0) + 1
|
||||
if e.get("actor"):
|
||||
actors.add(e["actor"])
|
||||
return {
|
||||
"total": len(self.data["evidence"]),
|
||||
"by_type": by_type,
|
||||
"by_verification": by_verification,
|
||||
"unique_actors": sorted(actors),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="OSS Forensics Evidence Store Manager v2.0",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--store", default="evidence.json", help="Path to evidence JSON file (default: evidence.json)")
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", metavar="COMMAND")
|
||||
|
||||
# --- add ---
|
||||
add_p = subparsers.add_parser("add", help="Add a new evidence entry")
|
||||
add_p.add_argument("--source", required=True, help="Where this evidence came from (e.g. 'git fsck', 'GH API /commits')")
|
||||
add_p.add_argument("--content", required=True, help="The evidence content (commit SHA, API response excerpt, etc.)")
|
||||
add_p.add_argument("--type", required=True, choices=EVIDENCE_TYPES, dest="evidence_type", help="Evidence type")
|
||||
add_p.add_argument("--actor", help="GitHub handle or email of associated actor")
|
||||
add_p.add_argument("--url", help="URL to original source")
|
||||
add_p.add_argument("--timestamp", help="When the event occurred (ISO 8601)")
|
||||
add_p.add_argument("--ioc-type", choices=IOC_TYPES, help="IOC subtype (for --type ioc)")
|
||||
add_p.add_argument("--verification", choices=VERIFICATION_STATES, default="unverified")
|
||||
add_p.add_argument("--notes", help="Additional investigator notes")
|
||||
add_p.add_argument("--quiet", action="store_true", help="Suppress success message")
|
||||
|
||||
# --- list ---
|
||||
list_p = subparsers.add_parser("list", help="List all evidence entries")
|
||||
list_p.add_argument("--type", dest="filter_type", choices=EVIDENCE_TYPES, help="Filter by type")
|
||||
list_p.add_argument("--actor", dest="filter_actor", help="Filter by actor")
|
||||
|
||||
# --- verify ---
|
||||
subparsers.add_parser("verify", help="Verify SHA-256 integrity of all evidence content")
|
||||
|
||||
# --- query ---
|
||||
query_p = subparsers.add_parser("query", help="Search evidence by keyword")
|
||||
query_p.add_argument("keyword", help="Keyword to search for")
|
||||
|
||||
# --- export ---
|
||||
subparsers.add_parser("export", help="Export evidence as a Markdown table (stdout)")
|
||||
|
||||
# --- summary ---
|
||||
subparsers.add_parser("summary", help="Print investigation statistics")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.command:
|
||||
parser.print_help()
|
||||
sys.exit(0)
|
||||
|
||||
store = EvidenceStore(args.store)
|
||||
|
||||
if args.command == "add":
|
||||
eid = store.add(
|
||||
source=args.source,
|
||||
content=args.content,
|
||||
evidence_type=args.evidence_type,
|
||||
actor=args.actor,
|
||||
url=args.url,
|
||||
timestamp=args.timestamp,
|
||||
ioc_type=args.ioc_type,
|
||||
verification=args.verification,
|
||||
notes=args.notes,
|
||||
)
|
||||
if not getattr(args, "quiet", False):
|
||||
print(f"✓ Added evidence: {eid}")
|
||||
|
||||
elif args.command == "list":
|
||||
items = store.list_evidence(
|
||||
filter_type=getattr(args, "filter_type", None),
|
||||
filter_actor=getattr(args, "filter_actor", None),
|
||||
)
|
||||
if not items:
|
||||
print("No evidence found.")
|
||||
for e in items:
|
||||
actor_str = f" | actor: {e['actor']}" if e.get("actor") else ""
|
||||
url_str = f" | {e['url']}" if e.get("url") else ""
|
||||
print(f"[{e['id']}] {e['type']:12s} | {e['verification']:20s} | {e['source']}{actor_str}{url_str}")
|
||||
|
||||
elif args.command == "verify":
|
||||
issues = store.verify_integrity()
|
||||
if not issues:
|
||||
print(f"✓ All {len(store.data['evidence'])} evidence entries passed SHA-256 integrity check.")
|
||||
else:
|
||||
print(f"✗ {len(issues)} integrity issue(s) detected:")
|
||||
for i in issues:
|
||||
print(f" [{i['id']}] stored={i['stored_sha256'][:16]}... computed={i['computed_sha256'][:16]}...")
|
||||
sys.exit(1)
|
||||
|
||||
elif args.command == "query":
|
||||
results = store.query(args.keyword)
|
||||
print(f"Found {len(results)} result(s) for '{args.keyword}':")
|
||||
for e in results:
|
||||
print(f" [{e['id']}] {e['type']} | {e['source']} | {e['content'][:80]}")
|
||||
|
||||
elif args.command == "export":
|
||||
print(store.export_markdown())
|
||||
|
||||
elif args.command == "summary":
|
||||
s = store.summary()
|
||||
print(f"Total evidence items : {s['total']}")
|
||||
print(f"By type : {json.dumps(s['by_type'], indent=2)}")
|
||||
print(f"By verification : {json.dumps(s['by_verification'], indent=2)}")
|
||||
print(f"Unique actors : {s['unique_actors']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,151 +0,0 @@
|
||||
# Forensic Investigation Report
|
||||
|
||||
> **Instructions**: Fill in all sections. Every factual claim must cite at least one `[EV-XXXX]` evidence ID.
|
||||
> Remove placeholder text and instruction notes before finalizing. Redact all secrets to `[REDACTED]`.
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
**Target Repository**: `OWNER/REPO`
|
||||
**Investigation Period**: YYYY-MM-DD to YYYY-MM-DD
|
||||
**Verdict**: <!-- Compromised / Clean / Inconclusive -->
|
||||
**Confidence Level**: <!-- High / Medium / Low -->
|
||||
**Report Date**: YYYY-MM-DD
|
||||
**Investigator**: <!-- Agent session ID or analyst name -->
|
||||
|
||||
<!-- One paragraph: what was investigated, what was found, what is recommended. -->
|
||||
|
||||
---
|
||||
|
||||
## Timeline of Events
|
||||
|
||||
> All timestamps in UTC. Each event must cite at least one evidence ID.
|
||||
|
||||
| Timestamp (UTC) | Event | Evidence IDs | Source |
|
||||
|-----------------|-------|--------------|--------|
|
||||
| YYYY-MM-DDTHH:MM:SSZ | _Describe event_ | [EV-XXXX] | git / gh_api / gh_archive / web_archive |
|
||||
| | | | |
|
||||
|
||||
---
|
||||
|
||||
## Validated Hypotheses
|
||||
|
||||
### Hypothesis 1: <!-- Short title -->
|
||||
|
||||
**Status**: <!-- VALIDATED / INCONCLUSIVE / REJECTED -->
|
||||
|
||||
**Claim**: _Full statement of the hypothesis._
|
||||
|
||||
**Supporting Evidence**:
|
||||
- [EV-XXXX]: _What this evidence shows_
|
||||
- [EV-YYYY]: _What this evidence shows_
|
||||
|
||||
**Counter-Evidence Considered**: _What might disprove this, and why it was ruled out or not._
|
||||
|
||||
**Confidence**: <!-- High / Medium / Low, and why -->
|
||||
|
||||
---
|
||||
|
||||
## Indicators of Compromise (IOC List)
|
||||
|
||||
| Type | Value | Status | Evidence |
|
||||
|------|-------|--------|----------|
|
||||
| COMMIT_SHA | `abc123...` | Confirmed malicious | [EV-XXXX] |
|
||||
| ACTOR_USERNAME | `handle` | Suspected compromised | [EV-YYYY] |
|
||||
| FILE_PATH | `src/evil.js` | Confirmed malicious | [EV-ZZZZ] |
|
||||
| DOMAIN | `evil-cdn.io` | Confirmed C2 | [EV-WWWW] |
|
||||
|
||||
---
|
||||
|
||||
## Affected Versions
|
||||
|
||||
| Version / Tag | Published | Contains Malicious Code | Evidence |
|
||||
|---------------|-----------|------------------------|----------|
|
||||
| `v1.2.3` | YYYY-MM-DD | Yes / No / Unknown | [EV-XXXX] |
|
||||
|
||||
---
|
||||
|
||||
## Evidence Registry
|
||||
|
||||
> Generated by: `python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json export`
|
||||
|
||||
<!-- Paste the Markdown table output from the evidence-store.py export command here -->
|
||||
|
||||
| ID | Type | Source | Actor | Verification | Event Timestamp | URL |
|
||||
|----|------|--------|-------|--------------|-----------------|-----|
|
||||
| EV-0001 | | | | | | |
|
||||
|
||||
---
|
||||
|
||||
## Chain of Custody
|
||||
|
||||
> Generated by: `python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json export`
|
||||
|
||||
<!-- Paste the chain of custody section from the export output here -->
|
||||
|
||||
| Evidence ID | Action | Timestamp | Source |
|
||||
|-------------|--------|-----------|--------|
|
||||
| EV-0001 | add | | |
|
||||
|
||||
---
|
||||
|
||||
## Technical Findings
|
||||
|
||||
### Git History Analysis
|
||||
|
||||
_Summarize findings from local git analysis: dangling commits, reflog anomalies, unsigned commits, binary additions, etc._
|
||||
|
||||
### GitHub API Analysis
|
||||
|
||||
_Summarize findings from GitHub REST API: deleted PRs/issues, contributor changes, release anomalies, etc._
|
||||
|
||||
### GitHub Archive Analysis
|
||||
|
||||
_Summarize findings from BigQuery: force-push events, delete events, workflow anomalies, member changes, etc._
|
||||
_Note: If BigQuery was unavailable, state this explicitly._
|
||||
|
||||
### Wayback Machine Analysis
|
||||
|
||||
_Summarize findings from archive.org: recovered deleted pages, historical content differences, etc._
|
||||
|
||||
### IOC Enrichment
|
||||
|
||||
_Summarize enrichment results: WHOIS data for domains, recovered commit content, actor account analysis, etc._
|
||||
|
||||
---
|
||||
|
||||
## Recommendations
|
||||
|
||||
### Immediate Actions (If Compromise Confirmed)
|
||||
|
||||
- [ ] Rotate all GitHub tokens, API keys, and credentials that may have been exposed
|
||||
- [ ] Pin dependency versions to hashes in all affected packages
|
||||
- [ ] Publish a security advisory / CVE if applicable
|
||||
- [ ] Notify downstream users/package registries (npm, PyPI, etc.)
|
||||
- [ ] Revoke access for the compromised account and re-secure with hardware 2FA
|
||||
- [ ] Audit all CI/CD workflow files for unauthorized modifications
|
||||
- [ ] Review all releases published during the compromise window
|
||||
|
||||
### Monitoring Recommendations
|
||||
|
||||
- [ ] Enable branch protection on `main`/`master` (require code review, disallow force-push)
|
||||
- [ ] Enable required commit signing (GPG/SSH)
|
||||
- [ ] Set up GitHub audit log streaming for future monitoring
|
||||
- [ ] Pin critical dependencies to known-good SHAs in lock files
|
||||
|
||||
---
|
||||
|
||||
## Limitations and Caveats
|
||||
|
||||
- _List any data sources that were unavailable (e.g., no BigQuery access)_
|
||||
- _Note any evidence that is single-source only (not independently verified)_
|
||||
- _Note any hypotheses that could not be confirmed or denied_
|
||||
|
||||
---
|
||||
|
||||
## References
|
||||
|
||||
- Evidence store: `evidence.json` (SHA-256 integrity: run `python3 SKILL_DIR/scripts/evidence-store.py --store evidence.json verify`)
|
||||
- Related issues: <!-- Link to GitHub issues, CVEs, security advisories -->
|
||||
- RAPTOR framework: https://github.com/gadievron/raptor
|
||||
@@ -1,43 +0,0 @@
|
||||
# Malicious Package Investigation Report
|
||||
|
||||
---
|
||||
|
||||
## 📦 Package Metadata
|
||||
- **Package Name**:
|
||||
- **Registry**: [NPM / PyPI / RubyGems / etc.]
|
||||
- **Affected Versions**:
|
||||
- **Malicious Version(s)**:
|
||||
- **Downloads at Time of Detection**:
|
||||
- **Package URL**:
|
||||
|
||||
---
|
||||
|
||||
## 🚩 Indicators of Compromise (IOCs)
|
||||
- **Malicious URL(s)**:
|
||||
- **Exfiltrated Data Types**: [Environment variables, ~/.ssh/id_rsa, /etc/shadow, etc.]
|
||||
- **Exfiltration Method**: [DNS tunneling, HTTP POST to C2, etc.]
|
||||
- **C2 IP/Domain**:
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ Analysis Summary
|
||||
- **Primary Mechanism**: [Typosquatting / Dependency Confusion / Maintainer Takeover]
|
||||
- **Behavior Description**:
|
||||
- [Example: Installs a postinstall script that exfiltrates environment variables.]
|
||||
- [Example: Patches `setup.py` to download a secondary payload.]
|
||||
|
||||
---
|
||||
|
||||
## 🔍 Evidence Registry
|
||||
| Evidence ID | Type | Source | Description |
|
||||
|-------------|------|--------|-------------|
|
||||
| EV-XXXX | ioc | NPM | Package install script snapshot |
|
||||
| EV-YYYY | web | Wayback| Historical version comparison |
|
||||
|
||||
---
|
||||
|
||||
## 🛡️ Recommended Mitigations
|
||||
1. [ ] Unpublish/Report the package to the registry.
|
||||
2. [ ] Audit `package-lock.json` or `requirements.txt` across all projects.
|
||||
3. [ ] Rotate secrets exfiltrated via environment variables.
|
||||
4. [ ] Pin specific hashes (SHASUM) for mission-critical dependencies.
|
||||
+184
-447
@@ -90,7 +90,6 @@ from agent.display import (
|
||||
KawaiiSpinner, build_tool_preview as _build_tool_preview,
|
||||
get_cute_tool_message as _get_cute_tool_message_impl,
|
||||
_detect_tool_failure,
|
||||
get_tool_emoji as _get_tool_emoji,
|
||||
)
|
||||
from agent.trajectory import (
|
||||
convert_scratchpad_to_think, has_incomplete_scratchpad,
|
||||
@@ -205,33 +204,6 @@ _NEVER_PARALLEL_TOOLS = frozenset({"clarify"})
|
||||
# Maximum number of concurrent worker threads for parallel tool execution.
|
||||
_MAX_TOOL_WORKERS = 8
|
||||
|
||||
# Patterns that indicate a terminal command may modify/delete files.
|
||||
_DESTRUCTIVE_PATTERNS = re.compile(
|
||||
r"""(?:^|\s|&&|\|\||;|`)(?:
|
||||
rm\s|rmdir\s|
|
||||
mv\s|
|
||||
sed\s+-i|
|
||||
truncate\s|
|
||||
dd\s|
|
||||
shred\s|
|
||||
git\s+(?:reset|clean|checkout)\s
|
||||
)""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
# Output redirects that overwrite files (> but not >>)
|
||||
_REDIRECT_OVERWRITE = re.compile(r'[^>]>[^>]|^>[^>]')
|
||||
|
||||
|
||||
def _is_destructive_command(cmd: str) -> bool:
|
||||
"""Heuristic: does this terminal command look like it modifies/deletes files?"""
|
||||
if not cmd:
|
||||
return False
|
||||
if _DESTRUCTIVE_PATTERNS.search(cmd):
|
||||
return True
|
||||
if _REDIRECT_OVERWRITE.search(cmd):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _inject_honcho_turn_context(content, turn_context: str):
|
||||
"""Append Honcho recall to the current-turn user message without mutating history.
|
||||
@@ -296,7 +268,6 @@ class AIAgent:
|
||||
reasoning_callback: callable = None,
|
||||
clarify_callback: callable = None,
|
||||
step_callback: callable = None,
|
||||
stream_delta_callback: callable = None,
|
||||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
@@ -396,7 +367,6 @@ class AIAgent:
|
||||
self.reasoning_callback = reasoning_callback
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
self._last_reported_tool = None # Track for "new tool" mode
|
||||
|
||||
# Interrupt mechanism for breaking out of tool loops
|
||||
@@ -546,8 +516,6 @@ class AIAgent:
|
||||
effective_key = api_key or resolve_anthropic_token() or ""
|
||||
self._anthropic_api_key = effective_key
|
||||
self._anthropic_base_url = base_url
|
||||
from agent.anthropic_adapter import _is_oauth_token as _is_oat
|
||||
self._is_anthropic_oauth = _is_oat(effective_key)
|
||||
self._anthropic_client = build_anthropic_client(effective_key, base_url)
|
||||
# No OpenAI client needed for Anthropic mode
|
||||
self.client = None
|
||||
@@ -816,7 +784,7 @@ class AIAgent:
|
||||
logger.debug("peer %s memory_mode=honcho: local USER.md writes disabled", _hcfg.peer_name or "user")
|
||||
|
||||
# Skills config: nudge interval for skill creation reminders
|
||||
self._skill_nudge_interval = 10
|
||||
self._skill_nudge_interval = 15
|
||||
try:
|
||||
from hermes_cli.config import load_config as _load_skills_config
|
||||
skills_config = _load_skills_config().get("skills", {})
|
||||
@@ -860,9 +828,9 @@ class AIAgent:
|
||||
"""Verbose print — suppressed when streaming TTS is active.
|
||||
|
||||
Pass ``force=True`` for error/warning messages that should always be
|
||||
shown even during streaming playback (TTS or display).
|
||||
shown even during streaming TTS playback.
|
||||
"""
|
||||
if not force and self._has_stream_consumers():
|
||||
if not force and getattr(self, "_stream_callback", None) is not None:
|
||||
return
|
||||
print(*args, **kwargs)
|
||||
|
||||
@@ -2606,39 +2574,15 @@ class AIAgent:
|
||||
def _close_request_openai_client(self, client: Any, *, reason: str) -> None:
|
||||
self._close_openai_client(client, reason=reason, shared=False)
|
||||
|
||||
def _run_codex_stream(self, api_kwargs: dict, client: Any = None, on_first_delta: callable = None):
|
||||
def _run_codex_stream(self, api_kwargs: dict, client: Any = None):
|
||||
"""Execute one streaming Responses API request and return the final response."""
|
||||
active_client = client or self._ensure_primary_openai_client(reason="codex_stream_direct")
|
||||
max_stream_retries = 1
|
||||
has_tool_calls = False
|
||||
first_delta_fired = False
|
||||
for attempt in range(max_stream_retries + 1):
|
||||
try:
|
||||
with active_client.responses.stream(**api_kwargs) as stream:
|
||||
for event in stream:
|
||||
if self._interrupt_requested:
|
||||
break
|
||||
event_type = getattr(event, "type", "")
|
||||
# Fire callbacks on text content deltas (suppress during tool calls)
|
||||
if "output_text.delta" in event_type or event_type == "response.output_text.delta":
|
||||
delta_text = getattr(event, "delta", "")
|
||||
if delta_text and not has_tool_calls:
|
||||
if not first_delta_fired:
|
||||
first_delta_fired = True
|
||||
if on_first_delta:
|
||||
try:
|
||||
on_first_delta()
|
||||
except Exception:
|
||||
pass
|
||||
self._fire_stream_delta(delta_text)
|
||||
# Track tool calls to suppress text streaming
|
||||
elif "function_call" in event_type:
|
||||
has_tool_calls = True
|
||||
# Fire reasoning callbacks
|
||||
elif "reasoning" in event_type and "delta" in event_type:
|
||||
reasoning_text = getattr(event, "delta", "")
|
||||
if reasoning_text:
|
||||
self._fire_reasoning_delta(reasoning_text)
|
||||
for _ in stream:
|
||||
pass
|
||||
return stream.get_final_response()
|
||||
except RuntimeError as exc:
|
||||
err_text = str(exc)
|
||||
@@ -2819,7 +2763,6 @@ class AIAgent:
|
||||
result["response"] = self._run_codex_stream(
|
||||
api_kwargs,
|
||||
client=request_client_holder["client"],
|
||||
on_first_delta=getattr(self, "_codex_on_first_delta", None),
|
||||
)
|
||||
elif self.api_mode == "anthropic_messages":
|
||||
result["response"] = self._anthropic_messages_create(api_kwargs)
|
||||
@@ -2861,246 +2804,116 @@ class AIAgent:
|
||||
raise result["error"]
|
||||
return result["response"]
|
||||
|
||||
# ── Unified streaming API call ─────────────────────────────────────────
|
||||
def _streaming_api_call(self, api_kwargs: dict, stream_callback):
|
||||
"""Streaming variant of _interruptible_api_call for voice TTS pipeline.
|
||||
|
||||
def _fire_stream_delta(self, text: str) -> None:
|
||||
"""Fire all registered stream delta callbacks (display + TTS)."""
|
||||
for cb in (self.stream_delta_callback, self._stream_callback):
|
||||
if cb is not None:
|
||||
try:
|
||||
cb(text)
|
||||
except Exception:
|
||||
pass
|
||||
Uses ``stream=True`` and forwards content deltas to *stream_callback*
|
||||
in real-time. Returns a ``SimpleNamespace`` that mimics a normal
|
||||
``ChatCompletion`` so the rest of the agent loop works unchanged.
|
||||
|
||||
def _fire_reasoning_delta(self, text: str) -> None:
|
||||
"""Fire reasoning callback if registered."""
|
||||
cb = self.reasoning_callback
|
||||
if cb is not None:
|
||||
try:
|
||||
cb(text)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _has_stream_consumers(self) -> bool:
|
||||
"""Return True if any streaming consumer is registered."""
|
||||
return (
|
||||
self.stream_delta_callback is not None
|
||||
or getattr(self, "_stream_callback", None) is not None
|
||||
)
|
||||
|
||||
def _interruptible_streaming_api_call(
|
||||
self, api_kwargs: dict, *, on_first_delta: callable = None
|
||||
):
|
||||
"""Streaming variant of _interruptible_api_call for real-time token delivery.
|
||||
|
||||
Handles all three api_modes:
|
||||
- chat_completions: stream=True on OpenAI-compatible endpoints
|
||||
- anthropic_messages: client.messages.stream() via Anthropic SDK
|
||||
- codex_responses: delegates to _run_codex_stream (already streaming)
|
||||
|
||||
Fires stream_delta_callback and _stream_callback for each text token.
|
||||
Tool-call turns suppress the callback — only text-only final responses
|
||||
stream to the consumer. Returns a SimpleNamespace that mimics the
|
||||
non-streaming response shape so the rest of the agent loop is unchanged.
|
||||
|
||||
Falls back to _interruptible_api_call on provider errors indicating
|
||||
streaming is not supported.
|
||||
This method is separate from ``_interruptible_api_call`` to keep the
|
||||
core agent loop untouched for non-voice users.
|
||||
"""
|
||||
if self.api_mode == "codex_responses":
|
||||
# Codex streams internally via _run_codex_stream. The main dispatch
|
||||
# in _interruptible_api_call already calls it; we just need to
|
||||
# ensure on_first_delta reaches it. Store it on the instance
|
||||
# temporarily so _run_codex_stream can pick it up.
|
||||
self._codex_on_first_delta = on_first_delta
|
||||
try:
|
||||
return self._interruptible_api_call(api_kwargs)
|
||||
finally:
|
||||
self._codex_on_first_delta = None
|
||||
|
||||
result = {"response": None, "error": None}
|
||||
request_client_holder = {"client": None}
|
||||
first_delta_fired = {"done": False}
|
||||
deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback)
|
||||
|
||||
def _fire_first_delta():
|
||||
if not first_delta_fired["done"] and on_first_delta:
|
||||
first_delta_fired["done"] = True
|
||||
try:
|
||||
on_first_delta()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _call_chat_completions():
|
||||
"""Stream a chat completions response."""
|
||||
stream_kwargs = {**api_kwargs, "stream": True, "stream_options": {"include_usage": True}}
|
||||
request_client_holder["client"] = self._create_request_openai_client(
|
||||
reason="chat_completion_stream_request"
|
||||
)
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts: list = []
|
||||
tool_calls_acc: dict = {}
|
||||
finish_reason = None
|
||||
model_name = None
|
||||
role = "assistant"
|
||||
reasoning_parts: list = []
|
||||
usage_obj = None
|
||||
|
||||
for chunk in stream:
|
||||
if self._interrupt_requested:
|
||||
break
|
||||
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
# Usage comes in the final chunk with empty choices
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_obj = chunk.usage
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
|
||||
# Accumulate reasoning content
|
||||
reasoning_text = getattr(delta, "reasoning_content", None) or getattr(delta, "reasoning", None)
|
||||
if reasoning_text:
|
||||
reasoning_parts.append(reasoning_text)
|
||||
self._fire_reasoning_delta(reasoning_text)
|
||||
|
||||
# Accumulate text content — fire callback only when no tool calls
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
if not tool_calls_acc:
|
||||
_fire_first_delta()
|
||||
self._fire_stream_delta(delta.content)
|
||||
deltas_were_sent["yes"] = True
|
||||
|
||||
# Accumulate tool call deltas (silently, no callback)
|
||||
if delta and delta.tool_calls:
|
||||
for tc_delta in delta.tool_calls:
|
||||
idx = tc_delta.index if tc_delta.index is not None else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc_delta.id or "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
entry = tool_calls_acc[idx]
|
||||
if tc_delta.id:
|
||||
entry["id"] = tc_delta.id
|
||||
if tc_delta.function:
|
||||
if tc_delta.function.name:
|
||||
entry["function"]["name"] += tc_delta.function.name
|
||||
if tc_delta.function.arguments:
|
||||
entry["function"]["arguments"] += tc_delta.function.arguments
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
# Usage in the final chunk
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_obj = chunk.usage
|
||||
|
||||
# Build mock response matching non-streaming shape
|
||||
full_content = "".join(content_parts) or None
|
||||
mock_tool_calls = None
|
||||
if tool_calls_acc:
|
||||
mock_tool_calls = []
|
||||
for idx in sorted(tool_calls_acc):
|
||||
tc = tool_calls_acc[idx]
|
||||
mock_tool_calls.append(SimpleNamespace(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=SimpleNamespace(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
))
|
||||
|
||||
full_reasoning = "".join(reasoning_parts) or None
|
||||
mock_message = SimpleNamespace(
|
||||
role=role,
|
||||
content=full_content,
|
||||
tool_calls=mock_tool_calls,
|
||||
reasoning_content=full_reasoning,
|
||||
)
|
||||
mock_choice = SimpleNamespace(
|
||||
index=0,
|
||||
message=mock_message,
|
||||
finish_reason=finish_reason or "stop",
|
||||
)
|
||||
return SimpleNamespace(
|
||||
id="stream-" + str(uuid.uuid4()),
|
||||
model=model_name,
|
||||
choices=[mock_choice],
|
||||
usage=usage_obj,
|
||||
)
|
||||
|
||||
def _call_anthropic():
|
||||
"""Stream an Anthropic Messages API response.
|
||||
|
||||
Fires delta callbacks for real-time token delivery, but returns
|
||||
the native Anthropic Message object from get_final_message() so
|
||||
the rest of the agent loop (validation, tool extraction, etc.)
|
||||
works unchanged.
|
||||
"""
|
||||
has_tool_use = False
|
||||
|
||||
# Use the Anthropic SDK's streaming context manager
|
||||
with self._anthropic_client.messages.stream(**api_kwargs) as stream:
|
||||
for event in stream:
|
||||
if self._interrupt_requested:
|
||||
break
|
||||
|
||||
event_type = getattr(event, "type", None)
|
||||
|
||||
if event_type == "content_block_start":
|
||||
block = getattr(event, "content_block", None)
|
||||
if block and getattr(block, "type", None) == "tool_use":
|
||||
has_tool_use = True
|
||||
|
||||
elif event_type == "content_block_delta":
|
||||
delta = getattr(event, "delta", None)
|
||||
if delta:
|
||||
delta_type = getattr(delta, "type", None)
|
||||
if delta_type == "text_delta":
|
||||
text = getattr(delta, "text", "")
|
||||
if text and not has_tool_use:
|
||||
_fire_first_delta()
|
||||
self._fire_stream_delta(text)
|
||||
elif delta_type == "thinking_delta":
|
||||
thinking_text = getattr(delta, "thinking", "")
|
||||
if thinking_text:
|
||||
self._fire_reasoning_delta(thinking_text)
|
||||
|
||||
# Return the native Anthropic Message for downstream processing
|
||||
return stream.get_final_message()
|
||||
|
||||
def _call():
|
||||
try:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
self._try_refresh_anthropic_client_credentials()
|
||||
result["response"] = _call_anthropic()
|
||||
else:
|
||||
result["response"] = _call_chat_completions()
|
||||
stream_kwargs = {**api_kwargs, "stream": True}
|
||||
request_client_holder["client"] = self._create_request_openai_client(
|
||||
reason="chat_completion_stream_request"
|
||||
)
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts: list[str] = []
|
||||
tool_calls_acc: dict[int, dict] = {}
|
||||
finish_reason = None
|
||||
model_name = None
|
||||
role = "assistant"
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
continue
|
||||
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(chunk, "model") and chunk.model:
|
||||
model_name = chunk.model
|
||||
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
try:
|
||||
stream_callback(delta.content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if delta and delta.tool_calls:
|
||||
for tc_delta in delta.tool_calls:
|
||||
idx = tc_delta.index if tc_delta.index is not None else 0
|
||||
if idx in tool_calls_acc and tc_delta.id and tc_delta.id != tool_calls_acc[idx]["id"]:
|
||||
matched = False
|
||||
for eidx, eentry in tool_calls_acc.items():
|
||||
if eentry["id"] == tc_delta.id:
|
||||
idx = eidx
|
||||
matched = True
|
||||
break
|
||||
if not matched:
|
||||
idx = (max(k for k in tool_calls_acc if isinstance(k, int)) + 1) if tool_calls_acc else 0
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc_delta.id or "",
|
||||
"type": "function",
|
||||
"function": {"name": "", "arguments": ""},
|
||||
}
|
||||
entry = tool_calls_acc[idx]
|
||||
if tc_delta.id:
|
||||
entry["id"] = tc_delta.id
|
||||
if tc_delta.function:
|
||||
if tc_delta.function.name:
|
||||
entry["function"]["name"] += tc_delta.function.name
|
||||
if tc_delta.function.arguments:
|
||||
entry["function"]["arguments"] += tc_delta.function.arguments
|
||||
|
||||
if chunk.choices[0].finish_reason:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
full_content = "".join(content_parts) or None
|
||||
mock_tool_calls = None
|
||||
if tool_calls_acc:
|
||||
mock_tool_calls = []
|
||||
for idx in sorted(tool_calls_acc):
|
||||
tc = tool_calls_acc[idx]
|
||||
mock_tool_calls.append(SimpleNamespace(
|
||||
id=tc["id"],
|
||||
type=tc["type"],
|
||||
function=SimpleNamespace(
|
||||
name=tc["function"]["name"],
|
||||
arguments=tc["function"]["arguments"],
|
||||
),
|
||||
))
|
||||
|
||||
mock_message = SimpleNamespace(
|
||||
role=role,
|
||||
content=full_content,
|
||||
tool_calls=mock_tool_calls,
|
||||
reasoning_content=None,
|
||||
)
|
||||
mock_choice = SimpleNamespace(
|
||||
index=0,
|
||||
message=mock_message,
|
||||
finish_reason=finish_reason or "stop",
|
||||
)
|
||||
mock_response = SimpleNamespace(
|
||||
id="stream-" + str(uuid.uuid4()),
|
||||
model=model_name,
|
||||
choices=[mock_choice],
|
||||
usage=None,
|
||||
)
|
||||
result["response"] = mock_response
|
||||
|
||||
except Exception as e:
|
||||
if deltas_were_sent["yes"]:
|
||||
# Streaming failed AFTER some tokens were already delivered
|
||||
# to consumers. Don't fall back — that would cause
|
||||
# double-delivery (partial streamed + full non-streamed).
|
||||
# Let the error propagate; the partial content already
|
||||
# reached the user via the stream.
|
||||
logger.warning("Streaming failed after partial delivery, not falling back: %s", e)
|
||||
result["error"] = e
|
||||
else:
|
||||
# Streaming failed before any tokens reached consumers.
|
||||
# Safe to fall back to the standard non-streaming path.
|
||||
logger.info("Streaming failed before delivery, falling back to non-streaming: %s", e)
|
||||
try:
|
||||
result["response"] = self._interruptible_api_call(api_kwargs)
|
||||
except Exception as fallback_err:
|
||||
result["error"] = fallback_err
|
||||
result["error"] = e
|
||||
finally:
|
||||
request_client = request_client_holder.get("client")
|
||||
if request_client is not None:
|
||||
@@ -3126,7 +2939,7 @@ class AIAgent:
|
||||
self._close_request_openai_client(request_client, reason="stream_interrupt_abort")
|
||||
except Exception:
|
||||
pass
|
||||
raise InterruptedError("Agent interrupted during streaming API call")
|
||||
raise InterruptedError("Agent interrupted during API call")
|
||||
if result["error"] is not None:
|
||||
raise result["error"]
|
||||
return result["response"]
|
||||
@@ -3374,7 +3187,6 @@ class AIAgent:
|
||||
tools=self.tools,
|
||||
max_tokens=self.max_tokens,
|
||||
reasoning_config=self.reasoning_config,
|
||||
is_oauth=getattr(self, "_is_anthropic_oauth", False),
|
||||
)
|
||||
|
||||
if self.api_mode == "codex_responses":
|
||||
@@ -3489,7 +3301,8 @@ class AIAgent:
|
||||
extra_body["provider"] = provider_preferences
|
||||
_is_nous = "nousresearch" in self.base_url.lower()
|
||||
|
||||
if self._supports_reasoning_extra_body():
|
||||
_is_mistral = "api.mistral.ai" in self.base_url.lower()
|
||||
if (_is_openrouter or _is_nous) and not _is_mistral:
|
||||
if self.reasoning_config is not None:
|
||||
rc = dict(self.reasoning_config)
|
||||
# Nous Portal requires reasoning enabled — don't send
|
||||
@@ -3513,32 +3326,6 @@ class AIAgent:
|
||||
|
||||
return api_kwargs
|
||||
|
||||
def _supports_reasoning_extra_body(self) -> bool:
|
||||
"""Return True when reasoning extra_body is safe to send for this route/model.
|
||||
|
||||
OpenRouter forwards unknown extra_body fields to upstream providers.
|
||||
Some providers/routes reject `reasoning` with 400s, so gate it to
|
||||
known reasoning-capable model families and direct Nous Portal.
|
||||
"""
|
||||
base_url = (self.base_url or "").lower()
|
||||
if "nousresearch" in base_url:
|
||||
return True
|
||||
if "openrouter" not in base_url:
|
||||
return False
|
||||
if "api.mistral.ai" in base_url:
|
||||
return False
|
||||
|
||||
model = (self.model or "").lower()
|
||||
reasoning_model_prefixes = (
|
||||
"deepseek/",
|
||||
"anthropic/",
|
||||
"openai/",
|
||||
"x-ai/",
|
||||
"google/gemini-2",
|
||||
"qwen/qwen3",
|
||||
)
|
||||
return any(model.startswith(prefix) for prefix in reasoning_model_prefixes)
|
||||
|
||||
def _build_assistant_message(self, assistant_message, finish_reason: str) -> dict:
|
||||
"""Build a normalized assistant message dict from an API response message.
|
||||
|
||||
@@ -3558,7 +3345,8 @@ class AIAgent:
|
||||
reasoning_text = combined or None
|
||||
|
||||
if reasoning_text and self.verbose_logging:
|
||||
logging.debug(f"Captured reasoning ({len(reasoning_text)} chars): {reasoning_text}")
|
||||
preview = reasoning_text[:100] + "..." if len(reasoning_text) > 100 else reasoning_text
|
||||
logging.debug(f"Captured reasoning ({len(reasoning_text)} chars): {preview}")
|
||||
|
||||
if reasoning_text and self.reasoning_callback:
|
||||
try:
|
||||
@@ -3702,8 +3490,7 @@ class AIAgent:
|
||||
|
||||
flush_content = (
|
||||
"[System: The session is being compressed. "
|
||||
"Save anything worth remembering — prioritize user preferences, "
|
||||
"corrections, and recurring patterns over task-specific details.]"
|
||||
"Please save anything worth remembering to your memories.]"
|
||||
)
|
||||
_sentinel = f"__flush_{id(self)}_{time.monotonic()}"
|
||||
flush_msg = {"role": "user", "content": flush_content, "_flush_sentinel": _sentinel}
|
||||
@@ -3792,7 +3579,7 @@ class AIAgent:
|
||||
tool_calls = assistant_msg.tool_calls
|
||||
elif self.api_mode == "anthropic_messages" and not _aux_available:
|
||||
from agent.anthropic_adapter import normalize_anthropic_response as _nar_flush
|
||||
_flush_msg, _ = _nar_flush(response, strip_tool_prefix=getattr(self, '_is_anthropic_oauth', False))
|
||||
_flush_msg, _ = _nar_flush(response)
|
||||
if _flush_msg and _flush_msg.tool_calls:
|
||||
tool_calls = _flush_msg.tool_calls
|
||||
elif hasattr(response, "choices") and response.choices:
|
||||
@@ -3978,8 +3765,6 @@ class AIAgent:
|
||||
return handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
|
||||
honcho_manager=self._honcho,
|
||||
honcho_session_key=self._honcho_session_key,
|
||||
)
|
||||
|
||||
def _execute_tool_calls_concurrent(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
@@ -4030,18 +3815,6 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Checkpoint before destructive terminal commands
|
||||
if function_name == "terminal" and self._checkpoint_mgr.enabled:
|
||||
try:
|
||||
cmd = function_args.get("command", "")
|
||||
if _is_destructive_command(cmd):
|
||||
cwd = function_args.get("workdir") or os.getenv("TERMINAL_CWD", os.getcwd())
|
||||
self._checkpoint_mgr.ensure_checkpoint(
|
||||
cwd, f"before terminal: {cmd[:60]}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
parsed_calls.append((tool_call, function_name, function_args))
|
||||
|
||||
# ── Logging / callbacks ──────────────────────────────────────────
|
||||
@@ -4050,12 +3823,8 @@ class AIAgent:
|
||||
print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}")
|
||||
for i, (tc, name, args) in enumerate(parsed_calls, 1):
|
||||
args_str = json.dumps(args, ensure_ascii=False)
|
||||
if self.verbose_logging:
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())})")
|
||||
print(f" Args: {args_str}")
|
||||
else:
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
|
||||
|
||||
for _, name, args in parsed_calls:
|
||||
if self.tool_progress_callback:
|
||||
@@ -4120,20 +3889,17 @@ class AIAgent:
|
||||
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
||||
|
||||
if self.verbose_logging:
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result ({len(function_result)} chars): {function_result}")
|
||||
logging.debug(f"Tool result preview: {result_preview}...")
|
||||
|
||||
# Print cute message per tool
|
||||
if self.quiet_mode:
|
||||
cute_msg = _get_cute_tool_message_impl(name, args, tool_duration, result=function_result)
|
||||
print(f" {cute_msg}")
|
||||
elif not self.quiet_mode:
|
||||
if self.verbose_logging:
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s")
|
||||
print(f" Result: {function_result}")
|
||||
else:
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
|
||||
# Truncate oversized results
|
||||
MAX_TOOL_RESULT_CHARS = 100_000
|
||||
@@ -4209,12 +3975,8 @@ class AIAgent:
|
||||
|
||||
if not self.quiet_mode:
|
||||
args_str = json.dumps(function_args, ensure_ascii=False)
|
||||
if self.verbose_logging:
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())})")
|
||||
print(f" Args: {args_str}")
|
||||
else:
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
||||
|
||||
if self.tool_progress_callback:
|
||||
try:
|
||||
@@ -4235,18 +3997,6 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass # never block tool execution
|
||||
|
||||
# Checkpoint before destructive terminal commands
|
||||
if function_name == "terminal" and self._checkpoint_mgr.enabled:
|
||||
try:
|
||||
cmd = function_args.get("command", "")
|
||||
if _is_destructive_command(cmd):
|
||||
cwd = function_args.get("workdir") or os.getenv("TERMINAL_CWD", os.getcwd())
|
||||
self._checkpoint_mgr.ensure_checkpoint(
|
||||
cwd, f"before terminal: {cmd[:60]}"
|
||||
)
|
||||
except Exception:
|
||||
pass # never block tool execution
|
||||
|
||||
tool_start_time = time.time()
|
||||
|
||||
if function_name == "todo":
|
||||
@@ -4333,9 +4083,25 @@ class AIAgent:
|
||||
spinner.stop(cute_msg)
|
||||
elif self.quiet_mode:
|
||||
self._vprint(f" {cute_msg}")
|
||||
elif self.quiet_mode and not self._has_stream_consumers():
|
||||
elif self.quiet_mode and self._stream_callback is None:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
emoji = _get_tool_emoji(function_name)
|
||||
tool_emoji_map = {
|
||||
'web_search': '🔍', 'web_extract': '📄', 'web_crawl': '🕸️',
|
||||
'terminal': '💻', 'process': '⚙️',
|
||||
'read_file': '📖', 'write_file': '✍️', 'patch': '🔧', 'search_files': '🔎',
|
||||
'browser_navigate': '🌐', 'browser_snapshot': '📸',
|
||||
'browser_click': '👆', 'browser_type': '⌨️',
|
||||
'browser_scroll': '📜', 'browser_back': '◀️',
|
||||
'browser_press': '⌨️', 'browser_close': '🚪',
|
||||
'browser_get_images': '🖼️', 'browser_vision': '👁️',
|
||||
'image_generate': '🎨', 'text_to_speech': '🔊',
|
||||
'vision_analyze': '👁️', 'mixture_of_agents': '🧠',
|
||||
'skills_list': '📚', 'skill_view': '📚',
|
||||
'cronjob': '⏰',
|
||||
'send_message': '📨', 'todo': '📋', 'memory': '🧠', 'session_search': '🔍',
|
||||
'clarify': '❓', 'execute_code': '🐍', 'delegate_task': '🔀',
|
||||
}
|
||||
emoji = tool_emoji_map.get(function_name, '⚡')
|
||||
preview = _build_tool_preview(function_name, function_args) or function_name
|
||||
if len(preview) > 30:
|
||||
preview = preview[:27] + "..."
|
||||
@@ -4346,8 +4112,6 @@ class AIAgent:
|
||||
function_result = handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
|
||||
honcho_manager=self._honcho,
|
||||
honcho_session_key=self._honcho_session_key,
|
||||
)
|
||||
_spinner_result = function_result
|
||||
except Exception as tool_error:
|
||||
@@ -4362,17 +4126,13 @@ class AIAgent:
|
||||
function_result = handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
|
||||
honcho_manager=self._honcho,
|
||||
honcho_session_key=self._honcho_session_key,
|
||||
)
|
||||
except Exception as tool_error:
|
||||
function_result = f"Error executing tool '{function_name}': {tool_error}"
|
||||
logger.error("handle_function_call raised for %s: %s", function_name, tool_error, exc_info=True)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
|
||||
result_preview = function_result if self.verbose_logging else (
|
||||
function_result[:200] if len(function_result) > 200 else function_result
|
||||
)
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
|
||||
# Log tool errors to the persistent error log so [error] tags
|
||||
# in the UI always have a corresponding detailed entry on disk.
|
||||
@@ -4382,7 +4142,7 @@ class AIAgent:
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||
logging.debug(f"Tool result ({len(function_result)} chars): {function_result}")
|
||||
logging.debug(f"Tool result preview: {result_preview}...")
|
||||
|
||||
# Guard against tools returning absurdly large content that would
|
||||
# blow up the context window. 100K chars ≈ 25K tokens — generous
|
||||
@@ -4405,12 +4165,8 @@ class AIAgent:
|
||||
messages.append(tool_msg)
|
||||
|
||||
if not self.quiet_mode:
|
||||
if self.verbose_logging:
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s")
|
||||
print(f" Result: {function_result}")
|
||||
else:
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
|
||||
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
|
||||
|
||||
if self._interrupt_requested and i < len(assistant_message.tool_calls):
|
||||
remaining = len(assistant_message.tool_calls) - i
|
||||
@@ -4508,8 +4264,9 @@ class AIAgent:
|
||||
api_messages.insert(sys_offset + idx, pfm.copy())
|
||||
|
||||
summary_extra_body = {}
|
||||
_is_openrouter = "openrouter" in self.base_url.lower()
|
||||
_is_nous = "nousresearch" in self.base_url.lower()
|
||||
if self._supports_reasoning_extra_body():
|
||||
if _is_openrouter or _is_nous:
|
||||
if self.reasoning_config is not None:
|
||||
summary_extra_body["reasoning"] = self.reasoning_config
|
||||
else:
|
||||
@@ -4553,10 +4310,9 @@ class AIAgent:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs as _bak, normalize_anthropic_response as _nar
|
||||
_ant_kw = _bak(model=self.model, messages=api_messages, tools=None,
|
||||
max_tokens=self.max_tokens, reasoning_config=self.reasoning_config,
|
||||
is_oauth=getattr(self, '_is_anthropic_oauth', False))
|
||||
max_tokens=self.max_tokens, reasoning_config=self.reasoning_config)
|
||||
summary_response = self._anthropic_messages_create(_ant_kw)
|
||||
_msg, _ = _nar(summary_response, strip_tool_prefix=getattr(self, '_is_anthropic_oauth', False))
|
||||
_msg, _ = _nar(summary_response)
|
||||
final_response = (_msg.content or "").strip()
|
||||
else:
|
||||
summary_response = self._ensure_primary_openai_client(reason="iteration_limit_summary").chat.completions.create(**summary_kwargs)
|
||||
@@ -4584,10 +4340,9 @@ class AIAgent:
|
||||
elif self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs as _bak2, normalize_anthropic_response as _nar2
|
||||
_ant_kw2 = _bak2(model=self.model, messages=api_messages, tools=None,
|
||||
is_oauth=getattr(self, '_is_anthropic_oauth', False),
|
||||
max_tokens=self.max_tokens, reasoning_config=self.reasoning_config)
|
||||
retry_response = self._anthropic_messages_create(_ant_kw2)
|
||||
_retry_msg, _ = _nar2(retry_response, strip_tool_prefix=getattr(self, '_is_anthropic_oauth', False))
|
||||
_retry_msg, _ = _nar2(retry_response)
|
||||
final_response = (_retry_msg.content or "").strip()
|
||||
else:
|
||||
summary_kwargs = {
|
||||
@@ -4630,7 +4385,6 @@ class AIAgent:
|
||||
task_id: str = None,
|
||||
stream_callback: Optional[callable] = None,
|
||||
persist_user_message: Optional[str] = None,
|
||||
sync_honcho: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a complete conversation with tool calling until completion.
|
||||
@@ -4646,8 +4400,6 @@ class AIAgent:
|
||||
persist_user_message: Optional clean user message to store in
|
||||
transcripts/history when user_message contains API-only
|
||||
synthetic prefixes.
|
||||
sync_honcho: When False, skip writing the final synthetic turn back
|
||||
to Honcho or queuing follow-up prefetch work.
|
||||
|
||||
Returns:
|
||||
Dict: Complete conversation result with final response and message history
|
||||
@@ -4704,9 +4456,8 @@ class AIAgent:
|
||||
self._turns_since_memory += 1
|
||||
if self._turns_since_memory >= self._memory_nudge_interval:
|
||||
user_message += (
|
||||
"\n\n[System: You've had several exchanges. Consider: "
|
||||
"has the user shared preferences, corrected you, or revealed "
|
||||
"something about their workflow worth remembering for future sessions?]"
|
||||
"\n\n[System: You've had several exchanges in this session. "
|
||||
"Consider whether there's anything worth saving to your memories.]"
|
||||
)
|
||||
self._turns_since_memory = 0
|
||||
|
||||
@@ -4716,9 +4467,8 @@ class AIAgent:
|
||||
and self._iters_since_skill >= self._skill_nudge_interval
|
||||
and "skill_manage" in self.valid_tool_names):
|
||||
user_message += (
|
||||
"\n\n[System: The previous task involved many tool calls. "
|
||||
"Save the approach as a skill if it's reusable, or update "
|
||||
"any existing skill you used if it was wrong or incomplete.]"
|
||||
"\n\n[System: The previous task involved many steps. "
|
||||
"If you discovered a reusable workflow, consider saving it as a skill.]"
|
||||
)
|
||||
self._iters_since_skill = 0
|
||||
|
||||
@@ -4972,8 +4722,8 @@ class AIAgent:
|
||||
self._vprint(f"\n{self.log_prefix}🔄 Making API call #{api_call_count}/{self.max_iterations}...")
|
||||
self._vprint(f"{self.log_prefix} 📊 Request size: {len(api_messages)} messages, ~{approx_tokens:,} tokens (~{total_chars:,} chars)")
|
||||
self._vprint(f"{self.log_prefix} 🔧 Available tools: {len(self.tools) if self.tools else 0}")
|
||||
elif not self._has_stream_consumers():
|
||||
# Animated thinking spinner in quiet mode (skip during streaming)
|
||||
elif self._stream_callback is None:
|
||||
# Animated thinking spinner in quiet mode (skip during streaming TTS)
|
||||
face = random.choice(KawaiiSpinner.KAWAII_THINKING)
|
||||
verb = random.choice(KawaiiSpinner.THINKING_VERBS)
|
||||
if self.thinking_callback:
|
||||
@@ -5013,22 +4763,33 @@ class AIAgent:
|
||||
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
|
||||
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
||||
|
||||
if self._has_stream_consumers():
|
||||
# Streaming path: fire delta callbacks for real-time
|
||||
# token delivery to CLI display, gateway, or TTS.
|
||||
def _stop_spinner():
|
||||
nonlocal thinking_spinner
|
||||
if thinking_spinner:
|
||||
thinking_spinner.stop("")
|
||||
thinking_spinner = None
|
||||
if self.thinking_callback:
|
||||
self.thinking_callback("")
|
||||
|
||||
response = self._interruptible_streaming_api_call(
|
||||
api_kwargs, on_first_delta=_stop_spinner
|
||||
)
|
||||
cb = getattr(self, "_stream_callback", None)
|
||||
if cb is not None and self.api_mode == "chat_completions":
|
||||
response = self._streaming_api_call(api_kwargs, cb)
|
||||
else:
|
||||
response = self._interruptible_api_call(api_kwargs)
|
||||
# Forward full response to TTS callback for non-streaming providers
|
||||
# (e.g. Anthropic) so voice TTS still works via batch delivery.
|
||||
if cb is not None and response:
|
||||
try:
|
||||
content = None
|
||||
# Try choices first — _interruptible_api_call converts all
|
||||
# providers (including Anthropic) to this format.
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
except (AttributeError, IndexError):
|
||||
pass
|
||||
# Fallback: Anthropic native content blocks
|
||||
if not content and self.api_mode == "anthropic_messages":
|
||||
text_parts = [
|
||||
block.text for block in getattr(response, "content", [])
|
||||
if getattr(block, "type", None) == "text" and getattr(block, "text", None)
|
||||
]
|
||||
content = " ".join(text_parts) if text_parts else None
|
||||
if content:
|
||||
cb(content)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
@@ -5283,22 +5044,6 @@ class AIAgent:
|
||||
self.session_completion_tokens += completion_tokens
|
||||
self.session_total_tokens += total_tokens
|
||||
self.session_api_calls += 1
|
||||
|
||||
# Persist token counts to session DB for /insights.
|
||||
# Gateway sessions persist via session_store.update_session()
|
||||
# after run_conversation returns, so only persist here for
|
||||
# CLI (and other non-gateway) platforms to avoid double-counting.
|
||||
if (self._session_db and self.session_id
|
||||
and getattr(self, 'platform', None) == 'cli'):
|
||||
try:
|
||||
self._session_db.update_token_counts(
|
||||
self.session_id,
|
||||
input_tokens=prompt_tokens,
|
||||
output_tokens=completion_tokens,
|
||||
model=self.model,
|
||||
)
|
||||
except Exception:
|
||||
pass # never block the agent loop
|
||||
|
||||
if self.verbose_logging:
|
||||
logging.debug(f"Token usage: prompt={usage_dict['prompt_tokens']:,}, completion={usage_dict['completion_tokens']:,}, total={usage_dict['total_tokens']:,}")
|
||||
@@ -5547,13 +5292,10 @@ class AIAgent:
|
||||
# These indicate a problem with the request itself (bad model ID,
|
||||
# invalid API key, forbidden, etc.) and will never succeed on retry.
|
||||
# Note: 413 and context-length errors are excluded — handled above.
|
||||
# 429 (rate limit) is transient and MUST be retried with backoff.
|
||||
# 529 (Anthropic overloaded) is also transient.
|
||||
# Also catch local validation errors (ValueError, TypeError) — these
|
||||
# are programming bugs, not transient failures.
|
||||
_RETRYABLE_STATUS_CODES = {413, 429, 529}
|
||||
is_local_validation_error = isinstance(api_error, (ValueError, TypeError))
|
||||
is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code not in _RETRYABLE_STATUS_CODES
|
||||
is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code != 413
|
||||
is_client_error = (is_local_validation_error or is_client_status_error or any(phrase in error_msg for phrase in [
|
||||
'error code: 401', 'error code: 403',
|
||||
'error code: 404', 'error code: 422',
|
||||
@@ -5649,9 +5391,7 @@ class AIAgent:
|
||||
assistant_message, finish_reason = self._normalize_codex_response(response)
|
||||
elif self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import normalize_anthropic_response
|
||||
assistant_message, finish_reason = normalize_anthropic_response(
|
||||
response, strip_tool_prefix=getattr(self, "_is_anthropic_oauth", False)
|
||||
)
|
||||
assistant_message, finish_reason = normalize_anthropic_response(response)
|
||||
else:
|
||||
assistant_message = response.choices[0].message
|
||||
|
||||
@@ -5678,10 +5418,7 @@ class AIAgent:
|
||||
|
||||
# Handle assistant response
|
||||
if assistant_message.content and not self.quiet_mode:
|
||||
if self.verbose_logging:
|
||||
self._vprint(f"{self.log_prefix}🤖 Assistant: {assistant_message.content}")
|
||||
else:
|
||||
self._vprint(f"{self.log_prefix}🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}")
|
||||
self._vprint(f"{self.log_prefix}🤖 Assistant: {assistant_message.content[:100]}{'...' if len(assistant_message.content) > 100 else ''}")
|
||||
|
||||
# Notify progress callback of model's thinking (used by subagent
|
||||
# delegation to relay the child's reasoning to the parent display).
|
||||
@@ -6152,7 +5889,7 @@ class AIAgent:
|
||||
self._persist_session(messages, conversation_history)
|
||||
|
||||
# Sync conversation to Honcho for user modeling
|
||||
if final_response and not interrupted and sync_honcho:
|
||||
if final_response and not interrupted:
|
||||
self._honcho_sync(original_user_message, final_response)
|
||||
self._queue_honcho_prefetch(original_user_message)
|
||||
|
||||
|
||||
@@ -114,7 +114,6 @@ curl -s "https://export.arxiv.org/api/query?id_list=2402.03300,2401.12345,2403.0
|
||||
|
||||
After fetching metadata for a paper, generate a BibTeX entry:
|
||||
|
||||
{% raw %}
|
||||
```bash
|
||||
curl -s "https://export.arxiv.org/api/query?id_list=1706.03762" | python3 -c "
|
||||
import sys, xml.etree.ElementTree as ET
|
||||
@@ -140,7 +139,6 @@ print(f' url = {{https://arxiv.org/abs/{raw_id}}}')
|
||||
print('}')
|
||||
"
|
||||
```
|
||||
{% endraw %}
|
||||
|
||||
## Reading Paper Content
|
||||
|
||||
|
||||
@@ -215,7 +215,6 @@ def generate_citation_key(bibtex: str) -> str:
|
||||
|
||||
### Complete Citation Manager Class
|
||||
|
||||
{% raw %}
|
||||
```python
|
||||
"""
|
||||
Citation Manager - Verified citation workflow for ML papers.
|
||||
@@ -378,7 +377,6 @@ if __name__ == "__main__":
|
||||
if bibtex:
|
||||
print(bibtex)
|
||||
```
|
||||
{% endraw %}
|
||||
|
||||
### Quick Functions
|
||||
|
||||
|
||||
@@ -295,97 +295,3 @@ class TestOnConnect:
|
||||
mock_conn = MagicMock(spec=acp.Client)
|
||||
agent.on_connect(mock_conn)
|
||||
assert agent._conn is mock_conn
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slash commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSlashCommands:
|
||||
"""Test slash command dispatch in the ACP adapter."""
|
||||
|
||||
def _make_state(self, mock_manager):
|
||||
state = mock_manager.create_session(cwd="/tmp")
|
||||
state.agent.model = "test-model"
|
||||
state.agent.provider = "openrouter"
|
||||
state.model = "test-model"
|
||||
return state
|
||||
|
||||
def test_help_lists_commands(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/help", state)
|
||||
assert result is not None
|
||||
assert "/help" in result
|
||||
assert "/model" in result
|
||||
assert "/tools" in result
|
||||
assert "/reset" in result
|
||||
|
||||
def test_model_shows_current(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/model", state)
|
||||
assert "test-model" in result
|
||||
|
||||
def test_context_empty(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
state.history = []
|
||||
result = agent._handle_slash_command("/context", state)
|
||||
assert "empty" in result.lower()
|
||||
|
||||
def test_context_with_messages(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
state.history = [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "hi"},
|
||||
]
|
||||
result = agent._handle_slash_command("/context", state)
|
||||
assert "2 messages" in result
|
||||
assert "user: 1" in result
|
||||
|
||||
def test_reset_clears_history(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
state.history = [{"role": "user", "content": "hello"}]
|
||||
result = agent._handle_slash_command("/reset", state)
|
||||
assert "cleared" in result.lower()
|
||||
assert len(state.history) == 0
|
||||
|
||||
def test_version(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/version", state)
|
||||
assert HERMES_VERSION in result
|
||||
|
||||
def test_unknown_command_returns_none(self, agent, mock_manager):
|
||||
state = self._make_state(mock_manager)
|
||||
result = agent._handle_slash_command("/nonexistent", state)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slash_command_intercepted_in_prompt(self, agent, mock_manager):
|
||||
"""Slash commands should be handled without calling the LLM."""
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
mock_conn = AsyncMock(spec=acp.Client)
|
||||
agent._conn = mock_conn
|
||||
|
||||
prompt = [TextContentBlock(type="text", text="/help")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert resp.stop_reason == "end_turn"
|
||||
mock_conn.session_update.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_slash_falls_through_to_llm(self, agent, mock_manager):
|
||||
"""Unknown /commands should be sent to the LLM, not intercepted."""
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
mock_conn = AsyncMock(spec=acp.Client)
|
||||
agent._conn = mock_conn
|
||||
|
||||
# Mock run_in_executor to avoid actually running the agent
|
||||
with patch("asyncio.get_running_loop") as mock_loop:
|
||||
mock_loop.return_value.run_in_executor = AsyncMock(return_value={
|
||||
"final_response": "I processed /foo",
|
||||
"messages": [],
|
||||
})
|
||||
prompt = [TextContentBlock(type="text", text="/foo bar")]
|
||||
resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id)
|
||||
|
||||
assert resp.stop_reason == "end_turn"
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
"""Tests for get_tool_emoji in agent/display.py — skin + registry integration."""
|
||||
|
||||
from unittest.mock import patch as mock_patch, MagicMock
|
||||
|
||||
from agent.display import get_tool_emoji
|
||||
|
||||
|
||||
class TestGetToolEmoji:
|
||||
"""Verify the skin → registry → fallback resolution chain."""
|
||||
|
||||
def test_returns_registry_emoji_when_no_skin(self):
|
||||
"""Registry-registered emoji is used when no skin is active."""
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_emoji.return_value = "🎨"
|
||||
with mock_patch("agent.display._get_skin", return_value=None), \
|
||||
mock_patch("agent.display.registry", mock_registry, create=True):
|
||||
# Need to patch the import inside get_tool_emoji
|
||||
pass
|
||||
# Direct test: patch the lazy import path
|
||||
with mock_patch("agent.display._get_skin", return_value=None):
|
||||
# get_tool_emoji will try to import registry — mock that
|
||||
mock_reg = MagicMock()
|
||||
mock_reg.get_emoji.return_value = "📖"
|
||||
with mock_patch.dict("sys.modules", {}):
|
||||
import sys
|
||||
# Patch tools.registry module
|
||||
mock_module = MagicMock()
|
||||
mock_module.registry = mock_reg
|
||||
with mock_patch.dict(sys.modules, {"tools.registry": mock_module}):
|
||||
result = get_tool_emoji("read_file")
|
||||
assert result == "📖"
|
||||
|
||||
def test_skin_override_takes_precedence(self):
|
||||
"""Skin tool_emojis override registry defaults."""
|
||||
skin = MagicMock()
|
||||
skin.tool_emojis = {"terminal": "⚔"}
|
||||
with mock_patch("agent.display._get_skin", return_value=skin):
|
||||
result = get_tool_emoji("terminal")
|
||||
assert result == "⚔"
|
||||
|
||||
def test_skin_empty_dict_falls_through(self):
|
||||
"""Empty skin tool_emojis falls through to registry."""
|
||||
skin = MagicMock()
|
||||
skin.tool_emojis = {}
|
||||
mock_reg = MagicMock()
|
||||
mock_reg.get_emoji.return_value = "💻"
|
||||
import sys
|
||||
mock_module = MagicMock()
|
||||
mock_module.registry = mock_reg
|
||||
with mock_patch("agent.display._get_skin", return_value=skin), \
|
||||
mock_patch.dict(sys.modules, {"tools.registry": mock_module}):
|
||||
result = get_tool_emoji("terminal")
|
||||
assert result == "💻"
|
||||
|
||||
def test_fallback_default(self):
|
||||
"""When neither skin nor registry has an emoji, use the default."""
|
||||
skin = MagicMock()
|
||||
skin.tool_emojis = {}
|
||||
mock_reg = MagicMock()
|
||||
mock_reg.get_emoji.return_value = ""
|
||||
import sys
|
||||
mock_module = MagicMock()
|
||||
mock_module.registry = mock_reg
|
||||
with mock_patch("agent.display._get_skin", return_value=skin), \
|
||||
mock_patch.dict(sys.modules, {"tools.registry": mock_module}):
|
||||
result = get_tool_emoji("unknown_tool")
|
||||
assert result == "⚡"
|
||||
|
||||
def test_custom_default(self):
|
||||
"""Custom default is returned when nothing matches."""
|
||||
with mock_patch("agent.display._get_skin", return_value=None):
|
||||
mock_reg = MagicMock()
|
||||
mock_reg.get_emoji.return_value = ""
|
||||
import sys
|
||||
mock_module = MagicMock()
|
||||
mock_module.registry = mock_reg
|
||||
with mock_patch.dict(sys.modules, {"tools.registry": mock_module}):
|
||||
result = get_tool_emoji("x", default="⚙️")
|
||||
assert result == "⚙️"
|
||||
|
||||
def test_skin_override_only_for_matching_tool(self):
|
||||
"""Skin override for one tool doesn't affect others."""
|
||||
skin = MagicMock()
|
||||
skin.tool_emojis = {"terminal": "⚔"}
|
||||
mock_reg = MagicMock()
|
||||
mock_reg.get_emoji.return_value = "🔍"
|
||||
import sys
|
||||
mock_module = MagicMock()
|
||||
mock_module.registry = mock_reg
|
||||
with mock_patch("agent.display._get_skin", return_value=skin), \
|
||||
mock_patch.dict(sys.modules, {"tools.registry": mock_module}):
|
||||
assert get_tool_emoji("terminal") == "⚔" # skin override
|
||||
assert get_tool_emoji("web_search") == "🔍" # registry fallback
|
||||
|
||||
|
||||
class TestSkinConfigToolEmojis:
|
||||
"""Verify SkinConfig handles tool_emojis field correctly."""
|
||||
|
||||
def test_skin_config_has_tool_emojis_field(self):
|
||||
from hermes_cli.skin_engine import SkinConfig
|
||||
skin = SkinConfig(name="test")
|
||||
assert skin.tool_emojis == {}
|
||||
|
||||
def test_skin_config_accepts_tool_emojis(self):
|
||||
from hermes_cli.skin_engine import SkinConfig
|
||||
emojis = {"terminal": "⚔", "web_search": "🔮"}
|
||||
skin = SkinConfig(name="test", tool_emojis=emojis)
|
||||
assert skin.tool_emojis == emojis
|
||||
|
||||
def test_build_skin_config_includes_tool_emojis(self):
|
||||
from hermes_cli.skin_engine import _build_skin_config
|
||||
data = {
|
||||
"name": "custom",
|
||||
"tool_emojis": {"terminal": "🗡️", "patch": "⚒️"},
|
||||
}
|
||||
skin = _build_skin_config(data)
|
||||
assert skin.tool_emojis == {"terminal": "🗡️", "patch": "⚒️"}
|
||||
|
||||
def test_build_skin_config_empty_tool_emojis_default(self):
|
||||
from hermes_cli.skin_engine import _build_skin_config
|
||||
data = {"name": "minimal"}
|
||||
skin = _build_skin_config(data)
|
||||
assert skin.tool_emojis == {}
|
||||
@@ -1,61 +0,0 @@
|
||||
from agent.smart_model_routing import choose_cheap_model_route
|
||||
|
||||
|
||||
_BASE_CONFIG = {
|
||||
"enabled": True,
|
||||
"cheap_model": {
|
||||
"provider": "openrouter",
|
||||
"model": "google/gemini-2.5-flash",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_returns_none_when_disabled():
|
||||
cfg = {**_BASE_CONFIG, "enabled": False}
|
||||
assert choose_cheap_model_route("what time is it in tokyo?", cfg) is None
|
||||
|
||||
|
||||
def test_routes_short_simple_prompt():
|
||||
result = choose_cheap_model_route("what time is it in tokyo?", _BASE_CONFIG)
|
||||
assert result is not None
|
||||
assert result["provider"] == "openrouter"
|
||||
assert result["model"] == "google/gemini-2.5-flash"
|
||||
assert result["routing_reason"] == "simple_turn"
|
||||
|
||||
|
||||
def test_skips_long_prompt():
|
||||
prompt = "please summarize this carefully " * 20
|
||||
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
|
||||
|
||||
|
||||
def test_skips_code_like_prompt():
|
||||
prompt = "debug this traceback: ```python\nraise ValueError('bad')\n```"
|
||||
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
|
||||
|
||||
|
||||
def test_skips_tool_heavy_prompt_keywords():
|
||||
prompt = "implement a patch for this docker error"
|
||||
assert choose_cheap_model_route(prompt, _BASE_CONFIG) is None
|
||||
|
||||
|
||||
def test_resolve_turn_route_falls_back_to_primary_when_route_runtime_cannot_be_resolved(monkeypatch):
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||
lambda **kwargs: (_ for _ in ()).throw(RuntimeError("bad route")),
|
||||
)
|
||||
result = resolve_turn_route(
|
||||
"what time is it in tokyo?",
|
||||
_BASE_CONFIG,
|
||||
{
|
||||
"model": "anthropic/claude-sonnet-4",
|
||||
"provider": "openrouter",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_mode": "chat_completions",
|
||||
"api_key": "sk-primary",
|
||||
},
|
||||
)
|
||||
assert result["model"] == "anthropic/claude-sonnet-4"
|
||||
assert result["runtime"]["provider"] == "openrouter"
|
||||
assert result["label"] is None
|
||||
@@ -26,12 +26,6 @@ def _isolate_hermes_home(tmp_path, monkeypatch):
|
||||
(fake_home / "memories").mkdir()
|
||||
(fake_home / "skills").mkdir()
|
||||
monkeypatch.setenv("HERMES_HOME", str(fake_home))
|
||||
# Reset plugin singleton so tests don't leak plugins from ~/.hermes/plugins/
|
||||
try:
|
||||
import hermes_cli.plugins as _plugins_mod
|
||||
monkeypatch.setattr(_plugins_mod, "_plugin_manager", None)
|
||||
except Exception:
|
||||
pass
|
||||
# Tests should not inherit the agent's current gateway/messaging surface.
|
||||
# Individual tests that need gateway behavior set these explicitly.
|
||||
monkeypatch.delenv("HERMES_SESSION_PLATFORM", raising=False)
|
||||
|
||||
@@ -83,14 +83,6 @@ class TestSessionResetPolicy:
|
||||
assert policy.at_hour == 4
|
||||
assert policy.idle_minutes == 1440
|
||||
|
||||
def test_from_dict_treats_null_values_as_defaults(self):
|
||||
restored = SessionResetPolicy.from_dict(
|
||||
{"mode": None, "at_hour": None, "idle_minutes": None}
|
||||
)
|
||||
assert restored.mode == "both"
|
||||
assert restored.at_hour == 4
|
||||
assert restored.idle_minutes == 1440
|
||||
|
||||
|
||||
class TestGatewayConfigRoundtrip:
|
||||
def test_full_roundtrip(self):
|
||||
@@ -104,7 +96,6 @@ class TestGatewayConfigRoundtrip:
|
||||
},
|
||||
reset_triggers=["/new"],
|
||||
quick_commands={"limits": {"type": "exec", "command": "echo ok"}},
|
||||
group_sessions_per_user=False,
|
||||
)
|
||||
d = config.to_dict()
|
||||
restored = GatewayConfig.from_dict(d)
|
||||
@@ -113,7 +104,6 @@ class TestGatewayConfigRoundtrip:
|
||||
assert restored.platforms[Platform.TELEGRAM].token == "tok_123"
|
||||
assert restored.reset_triggers == ["/new"]
|
||||
assert restored.quick_commands == {"limits": {"type": "exec", "command": "echo ok"}}
|
||||
assert restored.group_sessions_per_user is False
|
||||
|
||||
|
||||
class TestLoadGatewayConfig:
|
||||
@@ -135,18 +125,6 @@ class TestLoadGatewayConfig:
|
||||
|
||||
assert config.quick_commands == {"limits": {"type": "exec", "command": "echo ok"}}
|
||||
|
||||
def test_bridges_group_sessions_per_user_from_config_yaml(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
config_path = hermes_home / "config.yaml"
|
||||
config_path.write_text("group_sessions_per_user: false\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
|
||||
|
||||
config = load_gateway_config()
|
||||
|
||||
assert config.group_sessions_per_user is False
|
||||
|
||||
def test_invalid_quick_commands_in_config_yaml_are_ignored(self, tmp_path, monkeypatch):
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
@@ -90,7 +90,6 @@ class TestGatewayHonchoLifecycle:
|
||||
runner = _make_runner()
|
||||
event = _make_event()
|
||||
runner._shutdown_gateway_honcho = MagicMock()
|
||||
runner._async_flush_memories = AsyncMock()
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store._generate_session_key.return_value = "gateway-key"
|
||||
runner.session_store._entries = {
|
||||
@@ -101,31 +100,4 @@ class TestGatewayHonchoLifecycle:
|
||||
result = await runner._handle_reset_command(event)
|
||||
|
||||
runner._shutdown_gateway_honcho.assert_called_once_with("gateway-key")
|
||||
runner._async_flush_memories.assert_called_once_with("old-session", "gateway-key")
|
||||
assert "Session reset" in result
|
||||
|
||||
def test_flush_memories_reuses_gateway_session_key_and_skips_honcho_sync(self):
|
||||
runner = _make_runner()
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.load_transcript.return_value = [
|
||||
{"role": "user", "content": "a"},
|
||||
{"role": "assistant", "content": "b"},
|
||||
{"role": "user", "content": "c"},
|
||||
{"role": "assistant", "content": "d"},
|
||||
]
|
||||
tmp_agent = MagicMock()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "test-key"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="model-name"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent) as mock_agent_cls,
|
||||
):
|
||||
runner._flush_memories_for_session("old-session", "gateway-key")
|
||||
|
||||
mock_agent_cls.assert_called_once()
|
||||
_, kwargs = mock_agent_cls.call_args
|
||||
assert kwargs["session_id"] == "old-session"
|
||||
assert kwargs["honcho_session_key"] == "gateway-key"
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
_, run_kwargs = tmp_agent.run_conversation.call_args
|
||||
assert run_kwargs["sync_honcho"] is False
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_enrichment_uses_athabasca_upload_guidance_without_stale_r2_warning():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
|
||||
with patch(
|
||||
"tools.vision_tools.vision_analyze_tool",
|
||||
return_value='{"success": true, "analysis": "A painted serpent warrior."}',
|
||||
):
|
||||
enriched = await runner._enrich_message_with_vision(
|
||||
"caption",
|
||||
["/tmp/test.jpg"],
|
||||
)
|
||||
|
||||
assert "R2 not configured" not in enriched
|
||||
assert "Gateway media URL available for reference" not in enriched
|
||||
assert "POST /api/uploads" in enriched
|
||||
assert "Do not store the local cache path" in enriched
|
||||
assert "caption" in enriched
|
||||
@@ -1,156 +0,0 @@
|
||||
"""Tests for PII redaction in gateway session context prompts."""
|
||||
|
||||
from gateway.session import (
|
||||
SessionContext,
|
||||
SessionSource,
|
||||
build_session_context_prompt,
|
||||
_hash_id,
|
||||
_hash_sender_id,
|
||||
_hash_chat_id,
|
||||
_looks_like_phone,
|
||||
)
|
||||
from gateway.config import Platform, HomeChannel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHashHelpers:
|
||||
def test_hash_id_deterministic(self):
|
||||
assert _hash_id("12345") == _hash_id("12345")
|
||||
|
||||
def test_hash_id_12_hex_chars(self):
|
||||
h = _hash_id("user-abc")
|
||||
assert len(h) == 12
|
||||
assert all(c in "0123456789abcdef" for c in h)
|
||||
|
||||
def test_hash_sender_id_prefix(self):
|
||||
assert _hash_sender_id("12345").startswith("user_")
|
||||
assert len(_hash_sender_id("12345")) == 17 # "user_" + 12
|
||||
|
||||
def test_hash_chat_id_preserves_prefix(self):
|
||||
result = _hash_chat_id("telegram:12345")
|
||||
assert result.startswith("telegram:")
|
||||
assert "12345" not in result
|
||||
|
||||
def test_hash_chat_id_no_prefix(self):
|
||||
result = _hash_chat_id("12345")
|
||||
assert len(result) == 12
|
||||
assert "12345" not in result
|
||||
|
||||
def test_looks_like_phone(self):
|
||||
assert _looks_like_phone("+15551234567")
|
||||
assert _looks_like_phone("15551234567")
|
||||
assert _looks_like_phone("+1-555-123-4567")
|
||||
assert not _looks_like_phone("alice")
|
||||
assert not _looks_like_phone("user-123")
|
||||
assert not _looks_like_phone("")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: build_session_context_prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_context(
|
||||
user_id="user-123",
|
||||
user_name=None,
|
||||
chat_id="telegram:99999",
|
||||
platform=Platform.TELEGRAM,
|
||||
home_channels=None,
|
||||
):
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
chat_id=chat_id,
|
||||
chat_type="dm",
|
||||
user_id=user_id,
|
||||
user_name=user_name,
|
||||
)
|
||||
return SessionContext(
|
||||
source=source,
|
||||
connected_platforms=[platform],
|
||||
home_channels=home_channels or {},
|
||||
)
|
||||
|
||||
|
||||
class TestBuildSessionContextPromptRedaction:
|
||||
def test_no_redaction_by_default(self):
|
||||
ctx = _make_context(user_id="user-123")
|
||||
prompt = build_session_context_prompt(ctx)
|
||||
assert "user-123" in prompt
|
||||
|
||||
def test_user_id_hashed_when_redact_pii(self):
|
||||
ctx = _make_context(user_id="user-123")
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "user-123" not in prompt
|
||||
assert "user_" in prompt # hashed ID present
|
||||
|
||||
def test_user_name_not_redacted(self):
|
||||
ctx = _make_context(user_id="user-123", user_name="Alice")
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "Alice" in prompt
|
||||
# user_id should not appear when user_name is present (name takes priority)
|
||||
assert "user-123" not in prompt
|
||||
|
||||
def test_home_channel_id_hashed(self):
|
||||
hc = {
|
||||
Platform.TELEGRAM: HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="telegram:99999",
|
||||
name="Home Chat",
|
||||
)
|
||||
}
|
||||
ctx = _make_context(home_channels=hc)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "99999" not in prompt
|
||||
assert "telegram:" in prompt # prefix preserved
|
||||
assert "Home Chat" in prompt # name not redacted
|
||||
|
||||
def test_home_channel_id_preserved_without_redaction(self):
|
||||
hc = {
|
||||
Platform.TELEGRAM: HomeChannel(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="telegram:99999",
|
||||
name="Home Chat",
|
||||
)
|
||||
}
|
||||
ctx = _make_context(home_channels=hc)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=False)
|
||||
assert "99999" in prompt
|
||||
|
||||
def test_redaction_is_deterministic(self):
|
||||
ctx = _make_context(user_id="+15551234567")
|
||||
prompt1 = build_session_context_prompt(ctx, redact_pii=True)
|
||||
prompt2 = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert prompt1 == prompt2
|
||||
|
||||
def test_different_ids_produce_different_hashes(self):
|
||||
ctx1 = _make_context(user_id="user-A")
|
||||
ctx2 = _make_context(user_id="user-B")
|
||||
p1 = build_session_context_prompt(ctx1, redact_pii=True)
|
||||
p2 = build_session_context_prompt(ctx2, redact_pii=True)
|
||||
assert p1 != p2
|
||||
|
||||
def test_discord_ids_not_redacted_even_with_flag(self):
|
||||
"""Discord needs real IDs for <@user_id> mentions."""
|
||||
ctx = _make_context(user_id="123456789", platform=Platform.DISCORD)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "123456789" in prompt
|
||||
|
||||
def test_whatsapp_ids_redacted(self):
|
||||
ctx = _make_context(user_id="+15551234567", platform=Platform.WHATSAPP)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "+15551234567" not in prompt
|
||||
assert "user_" in prompt
|
||||
|
||||
def test_signal_ids_redacted(self):
|
||||
ctx = _make_context(user_id="+15551234567", platform=Platform.SIGNAL)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "+15551234567" not in prompt
|
||||
assert "user_" in prompt
|
||||
|
||||
def test_slack_ids_not_redacted(self):
|
||||
"""Slack may need IDs for mentions too."""
|
||||
ctx = _make_context(user_id="U12345ABC", platform=Platform.SLACK)
|
||||
prompt = build_session_context_prompt(ctx, redact_pii=True)
|
||||
assert "U12345ABC" in prompt
|
||||
@@ -199,28 +199,3 @@ class TestHandleResumeCommand:
|
||||
|
||||
assert real_key not in runner._running_agents
|
||||
db.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_flushes_memories_with_gateway_session_key(self, tmp_path):
|
||||
"""Resume should preserve the gateway session key for Honcho flushes."""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("old_session", "telegram")
|
||||
db.set_session_title("old_session", "Old Work")
|
||||
db.create_session("current_session_001", "telegram")
|
||||
|
||||
event = _make_event(text="/resume Old Work")
|
||||
runner = _make_runner(
|
||||
session_db=db,
|
||||
current_session_id="current_session_001",
|
||||
event=event,
|
||||
)
|
||||
|
||||
await runner._handle_resume_command(event)
|
||||
|
||||
runner._async_flush_memories.assert_called_once_with(
|
||||
"current_session_001",
|
||||
_session_key_for_event(event),
|
||||
)
|
||||
db.close()
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import BasePlatformAdapter
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.status import read_runtime_status
|
||||
|
||||
|
||||
class _RetryableFailureAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
self._set_fatal_error(
|
||||
"telegram_connect_error",
|
||||
"Telegram startup failed: temporary DNS resolution failure.",
|
||||
retryable=True,
|
||||
)
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
class _DisabledAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=False, token="***"), Platform.TELEGRAM)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
raise AssertionError("connect should not be called for disabled platforms")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self._mark_disconnected()
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"id": chat_id}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_failure_for_retryable_startup_errors(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
monkeypatch.setattr(runner, "_create_adapter", lambda platform, platform_config: _RetryableFailureAdapter())
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is False
|
||||
assert runner.should_exit_cleanly is False
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "startup_failed"
|
||||
assert "temporary DNS resolution failure" in state["exit_reason"]
|
||||
assert state["platforms"]["telegram"]["state"] == "fatal"
|
||||
assert state["platforms"]["telegram"]["error_code"] == "telegram_connect_error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_allows_cron_only_mode_when_no_platforms_are_enabled(monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=False, token="***")
|
||||
},
|
||||
sessions_dir=tmp_path / "sessions",
|
||||
)
|
||||
runner = GatewayRunner(config)
|
||||
|
||||
ok = await runner.start()
|
||||
|
||||
assert ok is True
|
||||
assert runner.should_exit_cleanly is False
|
||||
assert runner.adapters == {}
|
||||
state = read_runtime_status()
|
||||
assert state["gateway_state"] == "running"
|
||||
@@ -369,54 +369,6 @@ class TestWhatsAppDMSessionKeyConsistency:
|
||||
)
|
||||
assert store._generate_session_key(source) == build_session_key(source)
|
||||
|
||||
def test_store_creates_distinct_group_sessions_per_user(self, store):
|
||||
first = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="alice",
|
||||
user_name="Alice",
|
||||
)
|
||||
second = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="bob",
|
||||
user_name="Bob",
|
||||
)
|
||||
|
||||
first_entry = store.get_or_create_session(first)
|
||||
second_entry = store.get_or_create_session(second)
|
||||
|
||||
assert first_entry.session_key == "agent:main:discord:group:guild-123:alice"
|
||||
assert second_entry.session_key == "agent:main:discord:group:guild-123:bob"
|
||||
assert first_entry.session_id != second_entry.session_id
|
||||
|
||||
def test_store_shares_group_sessions_when_disabled_in_config(self, store):
|
||||
store.config.group_sessions_per_user = False
|
||||
|
||||
first = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="alice",
|
||||
user_name="Alice",
|
||||
)
|
||||
second = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="bob",
|
||||
user_name="Bob",
|
||||
)
|
||||
|
||||
first_entry = store.get_or_create_session(first)
|
||||
second_entry = store.get_or_create_session(second)
|
||||
|
||||
assert first_entry.session_key == "agent:main:discord:group:guild-123"
|
||||
assert second_entry.session_key == "agent:main:discord:group:guild-123"
|
||||
assert first_entry.session_id == second_entry.session_id
|
||||
|
||||
def test_telegram_dm_includes_chat_id(self):
|
||||
"""Non-WhatsApp DMs should also include chat_id to separate users."""
|
||||
source = SessionSource(
|
||||
@@ -446,41 +398,6 @@ class TestWhatsAppDMSessionKeyConsistency:
|
||||
key = build_session_key(source)
|
||||
assert key == "agent:main:discord:group:guild-123"
|
||||
|
||||
def test_group_sessions_are_isolated_per_user_when_user_id_present(self):
|
||||
first = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="alice",
|
||||
)
|
||||
second = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="bob",
|
||||
)
|
||||
|
||||
assert build_session_key(first) == "agent:main:discord:group:guild-123:alice"
|
||||
assert build_session_key(second) == "agent:main:discord:group:guild-123:bob"
|
||||
assert build_session_key(first) != build_session_key(second)
|
||||
|
||||
def test_group_sessions_can_be_shared_when_isolation_disabled(self):
|
||||
first = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="alice",
|
||||
)
|
||||
second = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id="guild-123",
|
||||
chat_type="group",
|
||||
user_id="bob",
|
||||
)
|
||||
|
||||
assert build_session_key(first, group_sessions_per_user=False) == "agent:main:discord:group:guild-123"
|
||||
assert build_session_key(second, group_sessions_per_user=False) == "agent:main:discord:group:guild-123"
|
||||
|
||||
def test_group_thread_includes_thread_id(self):
|
||||
"""Forum-style threads need a distinct session key within one group."""
|
||||
source = SessionSource(
|
||||
@@ -492,17 +409,6 @@ class TestWhatsAppDMSessionKeyConsistency:
|
||||
key = build_session_key(source)
|
||||
assert key == "agent:main:telegram:group:-1002285219667:17585"
|
||||
|
||||
def test_group_thread_sessions_are_isolated_per_user(self):
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="-1002285219667",
|
||||
chat_type="group",
|
||||
thread_id="17585",
|
||||
user_id="42",
|
||||
)
|
||||
key = build_session_key(source)
|
||||
assert key == "agent:main:telegram:group:-1002285219667:17585:42"
|
||||
|
||||
|
||||
class TestSessionStoreEntriesAttribute:
|
||||
"""Regression: /reset must access _entries, not _sessions."""
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
"""Tests for SSL certificate auto-detection in gateway/run.py."""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
def _load_ensure_ssl():
|
||||
"""Import _ensure_ssl_certs fresh (gateway/run.py has heavy deps, so we
|
||||
extract just the function source to avoid importing the whole gateway)."""
|
||||
# We can test via the actual module since conftest isolates HERMES_HOME,
|
||||
# but we need to be careful about side effects. Instead, replicate the
|
||||
# logic in a controlled way.
|
||||
from types import ModuleType
|
||||
import textwrap, ssl as _ssl # noqa: F401
|
||||
|
||||
code = textwrap.dedent("""\
|
||||
import os, ssl
|
||||
|
||||
def _ensure_ssl_certs():
|
||||
if "SSL_CERT_FILE" in os.environ:
|
||||
return
|
||||
paths = ssl.get_default_verify_paths()
|
||||
for candidate in (paths.cafile, paths.openssl_cafile):
|
||||
if candidate and os.path.exists(candidate):
|
||||
os.environ["SSL_CERT_FILE"] = candidate
|
||||
return
|
||||
try:
|
||||
import certifi
|
||||
os.environ["SSL_CERT_FILE"] = certifi.where()
|
||||
return
|
||||
except ImportError:
|
||||
pass
|
||||
for candidate in (
|
||||
"/etc/ssl/certs/ca-certificates.crt",
|
||||
"/etc/ssl/cert.pem",
|
||||
):
|
||||
if os.path.exists(candidate):
|
||||
os.environ["SSL_CERT_FILE"] = candidate
|
||||
return
|
||||
""")
|
||||
mod = ModuleType("_ssl_helper")
|
||||
exec(code, mod.__dict__)
|
||||
return mod._ensure_ssl_certs
|
||||
|
||||
|
||||
class TestEnsureSslCerts:
|
||||
def test_respects_existing_env_var(self):
|
||||
fn = _load_ensure_ssl()
|
||||
with patch.dict(os.environ, {"SSL_CERT_FILE": "/custom/ca.pem"}):
|
||||
fn()
|
||||
assert os.environ["SSL_CERT_FILE"] == "/custom/ca.pem"
|
||||
|
||||
def test_sets_from_ssl_default_paths(self, tmp_path):
|
||||
fn = _load_ensure_ssl()
|
||||
cert = tmp_path / "ca.crt"
|
||||
cert.write_text("FAKE CERT")
|
||||
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.cafile = str(cert)
|
||||
mock_paths.openssl_cafile = None
|
||||
|
||||
env = {k: v for k, v in os.environ.items() if k != "SSL_CERT_FILE"}
|
||||
with patch.dict(os.environ, env, clear=True), \
|
||||
patch("ssl.get_default_verify_paths", return_value=mock_paths):
|
||||
fn()
|
||||
assert os.environ.get("SSL_CERT_FILE") == str(cert)
|
||||
|
||||
def test_no_op_when_nothing_found(self):
|
||||
fn = _load_ensure_ssl()
|
||||
mock_paths = MagicMock()
|
||||
mock_paths.cafile = None
|
||||
mock_paths.openssl_cafile = None
|
||||
|
||||
env = {k: v for k, v in os.environ.items() if k != "SSL_CERT_FILE"}
|
||||
with patch.dict(os.environ, env, clear=True), \
|
||||
patch("ssl.get_default_verify_paths", return_value=mock_paths), \
|
||||
patch("os.path.exists", return_value=False), \
|
||||
patch.dict("sys.modules", {"certifi": None}):
|
||||
fn()
|
||||
assert "SSL_CERT_FILE" not in os.environ
|
||||
@@ -26,22 +26,6 @@ class TestGatewayPidState:
|
||||
assert status.get_running_pid() is None
|
||||
assert not pid_path.exists()
|
||||
|
||||
def test_get_running_pid_accepts_gateway_metadata_when_cmdline_unavailable(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
pid_path = tmp_path / "gateway.pid"
|
||||
pid_path.write_text(json.dumps({
|
||||
"pid": os.getpid(),
|
||||
"kind": "hermes-gateway",
|
||||
"argv": ["python", "-m", "hermes_cli.main", "gateway"],
|
||||
"start_time": 123,
|
||||
}))
|
||||
|
||||
monkeypatch.setattr(status.os, "kill", lambda pid, sig: None)
|
||||
monkeypatch.setattr(status, "_get_process_start_time", lambda pid: 123)
|
||||
monkeypatch.setattr(status, "_read_process_cmdline", lambda pid: None)
|
||||
|
||||
assert status.get_running_pid() == os.getpid()
|
||||
|
||||
|
||||
class TestGatewayRuntimeStatus:
|
||||
def test_write_runtime_status_records_platform_failure(self, tmp_path, monkeypatch):
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
"""Tests for gateway /status behavior and token persistence."""
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from gateway.session import SessionEntry, SessionSource, build_session_key
|
||||
|
||||
|
||||
def _make_source() -> SessionSource:
|
||||
return SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
user_name="tester",
|
||||
chat_type="dm",
|
||||
)
|
||||
|
||||
|
||||
def _make_event(text: str) -> MessageEvent:
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
source=_make_source(),
|
||||
message_id="m1",
|
||||
)
|
||||
|
||||
|
||||
def _make_runner(session_entry: SessionEntry):
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(
|
||||
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")}
|
||||
)
|
||||
adapter = MagicMock()
|
||||
adapter.send = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
|
||||
runner.session_store = MagicMock()
|
||||
runner.session_store.get_or_create_session.return_value = session_entry
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner.session_store.has_any_sessions.return_value = True
|
||||
runner.session_store.append_to_transcript = MagicMock()
|
||||
runner.session_store.rewrite_transcript = MagicMock()
|
||||
runner.session_store.update_session = MagicMock()
|
||||
runner._running_agents = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._show_reasoning = False
|
||||
runner._is_user_authorized = lambda _source: True
|
||||
runner._set_session_env = lambda _context: None
|
||||
runner._should_send_voice_reply = lambda *_args, **_kwargs: False
|
||||
runner._send_voice_reply = AsyncMock()
|
||||
runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None
|
||||
runner._emit_gateway_run_progress = AsyncMock()
|
||||
return runner
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_status_command_reports_running_agent_without_interrupt(monkeypatch):
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
total_tokens=321,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
running_agent = MagicMock()
|
||||
runner._running_agents[build_session_key(_make_source())] = running_agent
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
assert "**Tokens:** 321" in result
|
||||
assert "**Agent Running:** Yes ⚡" in result
|
||||
running_agent.interrupt.assert_not_called()
|
||||
assert runner._pending_messages == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_persists_agent_token_counts(monkeypatch):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
session_entry = SessionEntry(
|
||||
session_key=build_session_key(_make_source()),
|
||||
session_id="sess-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_type="dm",
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner.session_store.load_transcript.return_value = [{"role": "user", "content": "earlier"}]
|
||||
runner._run_agent = AsyncMock(
|
||||
return_value={
|
||||
"final_response": "ok",
|
||||
"messages": [],
|
||||
"tools": [],
|
||||
"history_offset": 0,
|
||||
"last_prompt_tokens": 80,
|
||||
"input_tokens": 120,
|
||||
"output_tokens": 45,
|
||||
"model": "openai/test-model",
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
monkeypatch.setattr(
|
||||
"agent.model_metadata.get_model_context_length",
|
||||
lambda *_args, **_kwargs: 100000,
|
||||
)
|
||||
|
||||
result = await runner._handle_message(_make_event("hello"))
|
||||
|
||||
assert result == "ok"
|
||||
runner.session_store.update_session.assert_called_once_with(
|
||||
session_entry.session_key,
|
||||
input_tokens=120,
|
||||
output_tokens=45,
|
||||
last_prompt_tokens=80,
|
||||
model="openai/test-model",
|
||||
)
|
||||
@@ -51,27 +51,3 @@ async def test_enrich_message_with_transcription_skips_when_stt_disabled():
|
||||
|
||||
assert "transcription is disabled" in result.lower()
|
||||
assert "caption" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_message_with_transcription_avoids_bogus_no_provider_message_for_backend_key_errors():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner.config = GatewayConfig(stt_enabled=True)
|
||||
|
||||
with patch(
|
||||
"tools.transcription_tools.transcribe_audio",
|
||||
return_value={"success": False, "error": "VOICE_TOOLS_OPENAI_KEY not set"},
|
||||
), patch(
|
||||
"tools.transcription_tools.get_stt_model_from_config",
|
||||
return_value=None,
|
||||
):
|
||||
result = await runner._enrich_message_with_transcription(
|
||||
"caption",
|
||||
["/tmp/voice.ogg"],
|
||||
)
|
||||
|
||||
assert "No STT provider is configured" not in result
|
||||
assert "trouble transcribing" in result
|
||||
assert "caption" in result
|
||||
|
||||
@@ -100,39 +100,6 @@ async def test_polling_conflict_stops_polling_and_notifies_handler(monkeypatch):
|
||||
fatal_handler.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_marks_retryable_fatal_error_for_startup_network_failure(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.acquire_scoped_lock",
|
||||
lambda scope, identity, metadata=None: (True, None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"gateway.status.release_scoped_lock",
|
||||
lambda scope, identity: None,
|
||||
)
|
||||
|
||||
builder = MagicMock()
|
||||
builder.token.return_value = builder
|
||||
app = SimpleNamespace(
|
||||
bot=SimpleNamespace(),
|
||||
updater=SimpleNamespace(),
|
||||
add_handler=MagicMock(),
|
||||
initialize=AsyncMock(side_effect=RuntimeError("Temporary failure in name resolution")),
|
||||
start=AsyncMock(),
|
||||
)
|
||||
builder.build.return_value = app
|
||||
monkeypatch.setattr("gateway.platforms.telegram.Application", SimpleNamespace(builder=MagicMock(return_value=builder)))
|
||||
|
||||
ok = await adapter.connect()
|
||||
|
||||
assert ok is False
|
||||
assert adapter.fatal_error_code == "telegram_connect_error"
|
||||
assert adapter.fatal_error_retryable is True
|
||||
assert "Temporary failure in name resolution" in adapter.fatal_error_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_skips_inactive_updater_and_app(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
@@ -7,7 +7,7 @@ or corrupt user-visible content.
|
||||
|
||||
import re
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -392,27 +392,3 @@ class TestStripMdv2:
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _strip_mdv2("") == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_escapes_chunk_indicator_for_markdownv2(adapter):
|
||||
adapter.MAX_MESSAGE_LENGTH = 80
|
||||
adapter._bot = MagicMock()
|
||||
|
||||
sent_texts = []
|
||||
|
||||
async def _fake_send_message(**kwargs):
|
||||
sent_texts.append(kwargs["text"])
|
||||
msg = MagicMock()
|
||||
msg.message_id = len(sent_texts)
|
||||
return msg
|
||||
|
||||
adapter._bot.send_message = AsyncMock(side_effect=_fake_send_message)
|
||||
|
||||
content = ("**bold** chunk content " * 12).strip()
|
||||
result = await adapter.send("123", content)
|
||||
|
||||
assert result.success is True
|
||||
assert len(sent_texts) > 1
|
||||
assert re.search(r" \\\([0-9]+/[0-9]+\\\)$", sent_texts[0])
|
||||
assert re.search(r" \\\([0-9]+/[0-9]+\\\)$", sent_texts[-1])
|
||||
|
||||
@@ -12,8 +12,7 @@ EXPECTED_COMMANDS = {
|
||||
"/personality", "/clear", "/history", "/new", "/reset", "/retry",
|
||||
"/undo", "/save", "/config", "/cron", "/skills", "/platforms",
|
||||
"/verbose", "/reasoning", "/compress", "/title", "/usage", "/insights", "/paste",
|
||||
"/reload-mcp", "/rollback", "/stop", "/background", "/skin", "/voice", "/browser", "/quit",
|
||||
"/plugins",
|
||||
"/reload-mcp", "/rollback", "/background", "/skin", "/voice", "/quit",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def test_systemd_status_warns_when_linger_disabled(monkeypatch, tmp_path, capsys
|
||||
monkeypatch.setattr(gateway, "get_systemd_linger_status", lambda: (False, ""))
|
||||
|
||||
def fake_run(cmd, capture_output=False, text=False, check=False):
|
||||
if cmd[:4] == ["systemctl", "--user", "status", gateway.get_service_name()]:
|
||||
if cmd[:4] == ["systemctl", "--user", "status", gateway.SERVICE_NAME]:
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
if cmd[:3] == ["systemctl", "--user", "is-active"]:
|
||||
return SimpleNamespace(returncode=0, stdout="active\n", stderr="")
|
||||
@@ -76,7 +76,7 @@ def test_systemd_install_checks_linger_status(monkeypatch, tmp_path, capsys):
|
||||
assert unit_path.exists()
|
||||
assert [cmd for cmd, _ in calls] == [
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "enable", gateway.get_service_name()],
|
||||
["systemctl", "--user", "enable", gateway.SERVICE_NAME],
|
||||
]
|
||||
assert helper_calls == [True]
|
||||
assert "User service installed and enabled" in out
|
||||
@@ -110,7 +110,7 @@ def test_systemd_install_system_scope_skips_linger_and_uses_systemctl(monkeypatc
|
||||
assert unit_path.read_text(encoding="utf-8") == "scope=True user=alice\n"
|
||||
assert [cmd for cmd, _ in calls] == [
|
||||
["systemctl", "daemon-reload"],
|
||||
["systemctl", "enable", gateway.get_service_name()],
|
||||
["systemctl", "enable", gateway.SERVICE_NAME],
|
||||
]
|
||||
assert helper_calls == []
|
||||
assert "Configured to run as: alice" not in out # generated test unit has no User= line
|
||||
|
||||
@@ -114,7 +114,7 @@ def test_systemd_install_calls_linger_helper(monkeypatch, tmp_path, capsys):
|
||||
assert unit_path.exists()
|
||||
assert [cmd for cmd, _ in calls] == [
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "enable", gateway.get_service_name()],
|
||||
["systemctl", "--user", "enable", gateway.SERVICE_NAME],
|
||||
]
|
||||
assert helper_calls == [True]
|
||||
assert "User service installed and enabled" in out
|
||||
|
||||
@@ -26,7 +26,7 @@ class TestSystemdServiceRefresh:
|
||||
assert unit_path.read_text(encoding="utf-8") == "new unit\n"
|
||||
assert calls[:2] == [
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "start", gateway_cli.get_service_name()],
|
||||
["systemctl", "--user", "start", gateway_cli.SERVICE_NAME],
|
||||
]
|
||||
|
||||
def test_systemd_restart_refreshes_outdated_unit(self, tmp_path, monkeypatch):
|
||||
@@ -49,27 +49,10 @@ class TestSystemdServiceRefresh:
|
||||
assert unit_path.read_text(encoding="utf-8") == "new unit\n"
|
||||
assert calls[:2] == [
|
||||
["systemctl", "--user", "daemon-reload"],
|
||||
["systemctl", "--user", "restart", gateway_cli.get_service_name()],
|
||||
["systemctl", "--user", "restart", gateway_cli.SERVICE_NAME],
|
||||
]
|
||||
|
||||
|
||||
class TestGeneratedSystemdUnits:
|
||||
def test_user_unit_avoids_recursive_execstop_and_uses_extended_stop_timeout(self):
|
||||
unit = gateway_cli.generate_systemd_unit(system=False)
|
||||
|
||||
assert "ExecStart=" in unit
|
||||
assert "ExecStop=" not in unit
|
||||
assert "TimeoutStopSec=60" in unit
|
||||
|
||||
def test_system_unit_avoids_recursive_execstop_and_uses_extended_stop_timeout(self):
|
||||
unit = gateway_cli.generate_systemd_unit(system=True)
|
||||
|
||||
assert "ExecStart=" in unit
|
||||
assert "ExecStop=" not in unit
|
||||
assert "TimeoutStopSec=60" in unit
|
||||
assert "WantedBy=multi-user.target" in unit
|
||||
|
||||
|
||||
class TestGatewayStopCleanup:
|
||||
def test_stop_sweeps_manual_gateway_processes_after_service_stop(self, tmp_path, monkeypatch):
|
||||
unit_path = tmp_path / "hermes-gateway.service"
|
||||
@@ -109,9 +92,9 @@ class TestGatewayServiceDetection:
|
||||
)
|
||||
|
||||
def fake_run(cmd, capture_output=True, text=True, **kwargs):
|
||||
if cmd == ["systemctl", "--user", "is-active", gateway_cli.get_service_name()]:
|
||||
if cmd == ["systemctl", "--user", "is-active", gateway_cli.SERVICE_NAME]:
|
||||
return SimpleNamespace(returncode=0, stdout="inactive\n", stderr="")
|
||||
if cmd == ["systemctl", "is-active", gateway_cli.get_service_name()]:
|
||||
if cmd == ["systemctl", "is-active", gateway_cli.SERVICE_NAME]:
|
||||
return SimpleNamespace(returncode=0, stdout="active\n", stderr="")
|
||||
raise AssertionError(f"Unexpected command: {cmd}")
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from hermes_cli.models import (
|
||||
fetch_api_models,
|
||||
normalize_provider,
|
||||
parse_model_input,
|
||||
probe_api_models,
|
||||
provider_label,
|
||||
provider_model_ids,
|
||||
validate_requested_model,
|
||||
@@ -27,15 +26,7 @@ FAKE_API_MODELS = [
|
||||
|
||||
def _validate(model, provider="openrouter", api_models=FAKE_API_MODELS, **kw):
|
||||
"""Shortcut: call validate_requested_model with mocked API."""
|
||||
probe_payload = {
|
||||
"models": api_models,
|
||||
"probed_url": "http://localhost:11434/v1/models",
|
||||
"resolved_base_url": kw.get("base_url", "") or "http://localhost:11434/v1",
|
||||
"suggested_base_url": None,
|
||||
"used_fallback": False,
|
||||
}
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=api_models), \
|
||||
patch("hermes_cli.models.probe_api_models", return_value=probe_payload):
|
||||
with patch("hermes_cli.models.fetch_api_models", return_value=api_models):
|
||||
return validate_requested_model(model, provider, **kw)
|
||||
|
||||
|
||||
@@ -156,33 +147,6 @@ class TestFetchApiModels:
|
||||
with patch("hermes_cli.models.urllib.request.urlopen", side_effect=Exception("timeout")):
|
||||
assert fetch_api_models("key", "https://example.com/v1") is None
|
||||
|
||||
def test_probe_api_models_tries_v1_fallback(self):
|
||||
class _Resp:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def read(self):
|
||||
return b'{"data": [{"id": "local-model"}]}'
|
||||
|
||||
calls = []
|
||||
|
||||
def _fake_urlopen(req, timeout=5.0):
|
||||
calls.append(req.full_url)
|
||||
if req.full_url.endswith("/v1/models"):
|
||||
return _Resp()
|
||||
raise Exception("404")
|
||||
|
||||
with patch("hermes_cli.models.urllib.request.urlopen", side_effect=_fake_urlopen):
|
||||
probe = probe_api_models("key", "http://localhost:8000")
|
||||
|
||||
assert calls == ["http://localhost:8000/models", "http://localhost:8000/v1/models"]
|
||||
assert probe["models"] == ["local-model"]
|
||||
assert probe["resolved_base_url"] == "http://localhost:8000/v1"
|
||||
assert probe["used_fallback"] is True
|
||||
|
||||
|
||||
# -- validate — format checks -----------------------------------------------
|
||||
|
||||
@@ -227,7 +191,6 @@ class TestValidateApiFound:
|
||||
)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert result["recognized"] is True
|
||||
|
||||
|
||||
# -- validate — API not found ------------------------------------------------
|
||||
@@ -269,26 +232,3 @@ class TestValidateApiFallback:
|
||||
result = _validate("some-model", provider="totally-unknown", api_models=None)
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
|
||||
def test_custom_endpoint_warns_with_probed_url_and_v1_hint(self):
|
||||
with patch(
|
||||
"hermes_cli.models.probe_api_models",
|
||||
return_value={
|
||||
"models": None,
|
||||
"probed_url": "http://localhost:8000/v1/models",
|
||||
"resolved_base_url": "http://localhost:8000",
|
||||
"suggested_base_url": "http://localhost:8000/v1",
|
||||
"used_fallback": False,
|
||||
},
|
||||
):
|
||||
result = validate_requested_model(
|
||||
"qwen3",
|
||||
"custom",
|
||||
api_key="local-key",
|
||||
base_url="http://localhost:8000",
|
||||
)
|
||||
|
||||
assert result["accepted"] is True
|
||||
assert result["persist"] is True
|
||||
assert "http://localhost:8000/v1/models" in result["message"]
|
||||
assert "http://localhost:8000/v1" in result["message"]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Tests for the hermes_cli models module."""
|
||||
|
||||
from hermes_cli.models import OPENROUTER_MODELS, menu_labels, model_ids, detect_provider_for_model
|
||||
from hermes_cli.models import OPENROUTER_MODELS, menu_labels, model_ids
|
||||
|
||||
|
||||
class TestModelIds:
|
||||
@@ -54,66 +54,3 @@ class TestOpenRouterModels:
|
||||
def test_at_least_5_models(self):
|
||||
"""Sanity check that the models list hasn't been accidentally truncated."""
|
||||
assert len(OPENROUTER_MODELS) >= 5
|
||||
|
||||
|
||||
class TestFindOpenrouterSlug:
|
||||
def test_exact_match(self):
|
||||
from hermes_cli.models import _find_openrouter_slug
|
||||
assert _find_openrouter_slug("anthropic/claude-opus-4.6") == "anthropic/claude-opus-4.6"
|
||||
|
||||
def test_bare_name_match(self):
|
||||
from hermes_cli.models import _find_openrouter_slug
|
||||
result = _find_openrouter_slug("claude-opus-4.6")
|
||||
assert result == "anthropic/claude-opus-4.6"
|
||||
|
||||
def test_case_insensitive(self):
|
||||
from hermes_cli.models import _find_openrouter_slug
|
||||
result = _find_openrouter_slug("Anthropic/Claude-Opus-4.6")
|
||||
assert result is not None
|
||||
|
||||
def test_unknown_returns_none(self):
|
||||
from hermes_cli.models import _find_openrouter_slug
|
||||
assert _find_openrouter_slug("totally-fake-model-xyz") is None
|
||||
|
||||
|
||||
class TestDetectProviderForModel:
|
||||
def test_anthropic_model_detected(self):
|
||||
"""claude-opus-4-6 should resolve to anthropic provider."""
|
||||
result = detect_provider_for_model("claude-opus-4-6", "openai-codex")
|
||||
assert result is not None
|
||||
assert result[0] == "anthropic"
|
||||
|
||||
def test_deepseek_model_detected(self):
|
||||
"""deepseek-chat should resolve to deepseek provider."""
|
||||
result = detect_provider_for_model("deepseek-chat", "openai-codex")
|
||||
assert result is not None
|
||||
# Provider is deepseek (direct) or openrouter (fallback) depending on creds
|
||||
assert result[0] in ("deepseek", "openrouter")
|
||||
|
||||
def test_current_provider_model_returns_none(self):
|
||||
"""Models belonging to the current provider should not trigger a switch."""
|
||||
assert detect_provider_for_model("gpt-5.3-codex", "openai-codex") is None
|
||||
|
||||
def test_openrouter_slug_match(self):
|
||||
"""Models in the OpenRouter catalog should be found."""
|
||||
result = detect_provider_for_model("anthropic/claude-opus-4.6", "openai-codex")
|
||||
assert result is not None
|
||||
assert result[0] == "openrouter"
|
||||
assert result[1] == "anthropic/claude-opus-4.6"
|
||||
|
||||
def test_bare_name_gets_openrouter_slug(self):
|
||||
"""Bare model names should get mapped to full OpenRouter slugs."""
|
||||
result = detect_provider_for_model("claude-opus-4.6", "openai-codex")
|
||||
assert result is not None
|
||||
# Should find it on OpenRouter with full slug
|
||||
assert result[1] == "anthropic/claude-opus-4.6"
|
||||
|
||||
def test_unknown_model_returns_none(self):
|
||||
"""Completely unknown model names should return None."""
|
||||
assert detect_provider_for_model("nonexistent-model-xyz", "openai-codex") is None
|
||||
|
||||
def test_aggregator_not_suggested(self):
|
||||
"""nous/openrouter should never be auto-suggested as target provider."""
|
||||
result = detect_provider_for_model("claude-opus-4-6", "openai-codex")
|
||||
assert result is not None
|
||||
assert result[0] not in ("nous",) # nous has claude models but shouldn't be suggested
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
"""Tests for file path autocomplete in the CLI completer."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from prompt_toolkit.document import Document
|
||||
from prompt_toolkit.formatted_text import to_plain_text
|
||||
|
||||
from hermes_cli.commands import SlashCommandCompleter, _file_size_label
|
||||
|
||||
|
||||
def _display_names(completions):
|
||||
"""Extract plain-text display names from a list of Completion objects."""
|
||||
return [to_plain_text(c.display) for c in completions]
|
||||
|
||||
|
||||
def _display_metas(completions):
|
||||
"""Extract plain-text display_meta from a list of Completion objects."""
|
||||
return [to_plain_text(c.display_meta) if c.display_meta else "" for c in completions]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def completer():
|
||||
return SlashCommandCompleter()
|
||||
|
||||
|
||||
class TestExtractPathWord:
|
||||
def test_relative_path(self):
|
||||
assert SlashCommandCompleter._extract_path_word("look at ./src/main.py") == "./src/main.py"
|
||||
|
||||
def test_home_path(self):
|
||||
assert SlashCommandCompleter._extract_path_word("edit ~/docs/") == "~/docs/"
|
||||
|
||||
def test_absolute_path(self):
|
||||
assert SlashCommandCompleter._extract_path_word("read /etc/hosts") == "/etc/hosts"
|
||||
|
||||
def test_parent_path(self):
|
||||
assert SlashCommandCompleter._extract_path_word("check ../config.yaml") == "../config.yaml"
|
||||
|
||||
def test_path_with_slash_in_middle(self):
|
||||
assert SlashCommandCompleter._extract_path_word("open src/utils/helpers.py") == "src/utils/helpers.py"
|
||||
|
||||
def test_plain_word_not_path(self):
|
||||
assert SlashCommandCompleter._extract_path_word("hello world") is None
|
||||
|
||||
def test_empty_string(self):
|
||||
assert SlashCommandCompleter._extract_path_word("") is None
|
||||
|
||||
def test_single_word_no_slash(self):
|
||||
assert SlashCommandCompleter._extract_path_word("README.md") is None
|
||||
|
||||
def test_word_after_space(self):
|
||||
assert SlashCommandCompleter._extract_path_word("fix the bug in ./tools/") == "./tools/"
|
||||
|
||||
def test_just_dot_slash(self):
|
||||
assert SlashCommandCompleter._extract_path_word("./") == "./"
|
||||
|
||||
def test_just_tilde_slash(self):
|
||||
assert SlashCommandCompleter._extract_path_word("~/") == "~/"
|
||||
|
||||
|
||||
class TestPathCompletions:
|
||||
def test_lists_current_directory(self, tmp_path):
|
||||
(tmp_path / "file_a.py").touch()
|
||||
(tmp_path / "file_b.txt").touch()
|
||||
(tmp_path / "subdir").mkdir()
|
||||
|
||||
old_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
try:
|
||||
completions = list(SlashCommandCompleter._path_completions("./"))
|
||||
names = _display_names(completions)
|
||||
assert "file_a.py" in names
|
||||
assert "file_b.txt" in names
|
||||
assert "subdir/" in names
|
||||
finally:
|
||||
os.chdir(old_cwd)
|
||||
|
||||
def test_filters_by_prefix(self, tmp_path):
|
||||
(tmp_path / "alpha.py").touch()
|
||||
(tmp_path / "beta.py").touch()
|
||||
(tmp_path / "alpha_test.py").touch()
|
||||
|
||||
completions = list(SlashCommandCompleter._path_completions(f"{tmp_path}/alpha"))
|
||||
names = _display_names(completions)
|
||||
assert "alpha.py" in names
|
||||
assert "alpha_test.py" in names
|
||||
assert "beta.py" not in names
|
||||
|
||||
def test_directories_have_trailing_slash(self, tmp_path):
|
||||
(tmp_path / "mydir").mkdir()
|
||||
(tmp_path / "myfile.txt").touch()
|
||||
|
||||
completions = list(SlashCommandCompleter._path_completions(f"{tmp_path}/"))
|
||||
names = _display_names(completions)
|
||||
metas = _display_metas(completions)
|
||||
assert "mydir/" in names
|
||||
idx = names.index("mydir/")
|
||||
assert metas[idx] == "dir"
|
||||
|
||||
def test_home_expansion(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
(tmp_path / "testfile.md").touch()
|
||||
|
||||
completions = list(SlashCommandCompleter._path_completions("~/test"))
|
||||
names = _display_names(completions)
|
||||
assert "testfile.md" in names
|
||||
|
||||
def test_nonexistent_dir_returns_empty(self):
|
||||
completions = list(SlashCommandCompleter._path_completions("/nonexistent_dir_xyz/"))
|
||||
assert completions == []
|
||||
|
||||
def test_respects_limit(self, tmp_path):
|
||||
for i in range(50):
|
||||
(tmp_path / f"file_{i:03d}.txt").touch()
|
||||
|
||||
completions = list(SlashCommandCompleter._path_completions(f"{tmp_path}/", limit=10))
|
||||
assert len(completions) == 10
|
||||
|
||||
def test_case_insensitive_prefix(self, tmp_path):
|
||||
(tmp_path / "README.md").touch()
|
||||
|
||||
completions = list(SlashCommandCompleter._path_completions(f"{tmp_path}/read"))
|
||||
names = _display_names(completions)
|
||||
assert "README.md" in names
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""Test the completer produces path completions via the prompt_toolkit API."""
|
||||
|
||||
def test_slash_commands_still_work(self, completer):
|
||||
doc = Document("/hel", cursor_position=4)
|
||||
event = MagicMock()
|
||||
completions = list(completer.get_completions(doc, event))
|
||||
names = _display_names(completions)
|
||||
assert "/help" in names
|
||||
|
||||
def test_path_completion_triggers_on_dot_slash(self, completer, tmp_path):
|
||||
(tmp_path / "test.py").touch()
|
||||
old_cwd = os.getcwd()
|
||||
os.chdir(tmp_path)
|
||||
try:
|
||||
doc = Document("edit ./te", cursor_position=9)
|
||||
event = MagicMock()
|
||||
completions = list(completer.get_completions(doc, event))
|
||||
names = _display_names(completions)
|
||||
assert "test.py" in names
|
||||
finally:
|
||||
os.chdir(old_cwd)
|
||||
|
||||
def test_no_completion_for_plain_words(self, completer):
|
||||
doc = Document("hello world", cursor_position=11)
|
||||
event = MagicMock()
|
||||
completions = list(completer.get_completions(doc, event))
|
||||
assert completions == []
|
||||
|
||||
def test_absolute_path_triggers_completion(self, completer):
|
||||
doc = Document("check /etc/hos", cursor_position=14)
|
||||
event = MagicMock()
|
||||
completions = list(completer.get_completions(doc, event))
|
||||
names = _display_names(completions)
|
||||
# /etc/hosts should exist on Linux
|
||||
assert any("host" in n.lower() for n in names)
|
||||
|
||||
|
||||
class TestFileSizeLabel:
|
||||
def test_bytes(self, tmp_path):
|
||||
f = tmp_path / "small.txt"
|
||||
f.write_text("hi")
|
||||
assert _file_size_label(str(f)) == "2B"
|
||||
|
||||
def test_kilobytes(self, tmp_path):
|
||||
f = tmp_path / "medium.txt"
|
||||
f.write_bytes(b"x" * 2048)
|
||||
assert _file_size_label(str(f)) == "2K"
|
||||
|
||||
def test_megabytes(self, tmp_path):
|
||||
f = tmp_path / "large.bin"
|
||||
f.write_bytes(b"x" * (2 * 1024 * 1024))
|
||||
assert _file_size_label(str(f)) == "2.0M"
|
||||
|
||||
def test_nonexistent(self):
|
||||
assert _file_size_label("/nonexistent_xyz") == ""
|
||||
@@ -115,13 +115,3 @@ class TestConfigYamlRouting:
|
||||
set_config_value("terminal.docker_image", "python:3.12")
|
||||
config = _read_config(_isolated_hermes_home)
|
||||
assert "python:3.12" in config
|
||||
|
||||
def test_terminal_docker_cwd_mount_flag_goes_to_config_and_env(self, _isolated_hermes_home):
|
||||
set_config_value("terminal.docker_mount_cwd_to_workspace", "true")
|
||||
config = _read_config(_isolated_hermes_home)
|
||||
env_content = _read_env(_isolated_hermes_home)
|
||||
assert "docker_mount_cwd_to_workspace: 'true'" in config or "docker_mount_cwd_to_workspace: true" in config
|
||||
assert (
|
||||
"TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE=true" in env_content
|
||||
or "TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE=True" in env_content
|
||||
)
|
||||
|
||||
@@ -75,58 +75,6 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m
|
||||
assert calls["count"] == 1
|
||||
|
||||
|
||||
def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
_clear_provider_env(monkeypatch)
|
||||
|
||||
config = load_config()
|
||||
|
||||
def fake_prompt_choice(question, choices, default=0):
|
||||
if question == "Select your inference provider:":
|
||||
return 3 # Custom endpoint
|
||||
if question == "Configure vision:":
|
||||
return len(choices) - 1 # Skip
|
||||
raise AssertionError(f"Unexpected prompt_choice call: {question}")
|
||||
|
||||
def fake_prompt(message, current=None, **kwargs):
|
||||
if "API base URL" in message:
|
||||
return "http://localhost:8000"
|
||||
if "API key" in message:
|
||||
return "local-key"
|
||||
if "Model name" in message:
|
||||
return "llm"
|
||||
return ""
|
||||
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt)
|
||||
monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False)
|
||||
monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: [])
|
||||
monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: [])
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.probe_api_models",
|
||||
lambda api_key, base_url: {
|
||||
"models": ["llm"],
|
||||
"probed_url": "http://localhost:8000/v1/models",
|
||||
"resolved_base_url": "http://localhost:8000/v1",
|
||||
"suggested_base_url": "http://localhost:8000/v1",
|
||||
"used_fallback": True,
|
||||
},
|
||||
)
|
||||
|
||||
setup_model_provider(config)
|
||||
save_config(config)
|
||||
|
||||
env = _read_env(tmp_path)
|
||||
reloaded = load_config()
|
||||
|
||||
assert env.get("OPENAI_BASE_URL") == "http://localhost:8000/v1"
|
||||
assert env.get("OPENAI_API_KEY") == "local-key"
|
||||
assert reloaded["model"]["provider"] == "custom"
|
||||
assert reloaded["model"]["base_url"] == "http://localhost:8000/v1"
|
||||
assert reloaded["model"]["default"] == "llm"
|
||||
|
||||
|
||||
def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tmp_path, monkeypatch):
|
||||
"""Keep-current should respect config-backed providers, not fall back to OpenRouter."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
from hermes_cli import setup as setup_mod
|
||||
|
||||
|
||||
def test_prompt_choice_uses_curses_helper(monkeypatch):
|
||||
monkeypatch.setattr(setup_mod, "_curses_prompt_choice", lambda question, choices, default=0: 1)
|
||||
|
||||
idx = setup_mod.prompt_choice("Pick one", ["a", "b", "c"], default=0)
|
||||
|
||||
assert idx == 1
|
||||
|
||||
|
||||
def test_prompt_choice_falls_back_to_numbered_input(monkeypatch):
|
||||
monkeypatch.setattr(setup_mod, "_curses_prompt_choice", lambda question, choices, default=0: -1)
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": "2")
|
||||
|
||||
idx = setup_mod.prompt_choice("Pick one", ["a", "b", "c"], default=0)
|
||||
|
||||
assert idx == 1
|
||||
|
||||
|
||||
def test_prompt_checklist_uses_shared_curses_checklist(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.curses_ui.curses_checklist",
|
||||
lambda title, items, selected, cancel_returns=None: {0, 2},
|
||||
)
|
||||
|
||||
selected = setup_mod.prompt_checklist("Pick tools", ["one", "two", "three"], pre_selected=[1])
|
||||
|
||||
assert selected == [0, 2]
|
||||
@@ -1,305 +0,0 @@
|
||||
"""Tests for cmd_update gateway auto-restart — systemd + launchd coverage.
|
||||
|
||||
Ensures ``hermes update`` correctly detects running gateways managed by
|
||||
systemd (Linux) or launchd (macOS) and restarts/informs the user properly,
|
||||
rather than leaving zombie processes or telling users to manually restart
|
||||
when launchd will auto-respawn.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import hermes_cli.gateway as gateway_cli
|
||||
from hermes_cli.main import cmd_update
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_run_side_effect(
|
||||
branch="main",
|
||||
verify_ok=True,
|
||||
commit_count="3",
|
||||
systemd_active=False,
|
||||
launchctl_loaded=False,
|
||||
):
|
||||
"""Build a subprocess.run side_effect that simulates git + service commands."""
|
||||
|
||||
def side_effect(cmd, **kwargs):
|
||||
joined = " ".join(str(c) for c in cmd)
|
||||
|
||||
# git rev-parse --abbrev-ref HEAD
|
||||
if "rev-parse" in joined and "--abbrev-ref" in joined:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout=f"{branch}\n", stderr="")
|
||||
|
||||
# git rev-parse --verify origin/{branch}
|
||||
if "rev-parse" in joined and "--verify" in joined:
|
||||
rc = 0 if verify_ok else 128
|
||||
return subprocess.CompletedProcess(cmd, rc, stdout="", stderr="")
|
||||
|
||||
# git rev-list HEAD..origin/{branch} --count
|
||||
if "rev-list" in joined:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout=f"{commit_count}\n", stderr="")
|
||||
|
||||
# systemctl --user is-active
|
||||
if "systemctl" in joined and "is-active" in joined:
|
||||
if systemd_active:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="active\n", stderr="")
|
||||
return subprocess.CompletedProcess(cmd, 3, stdout="inactive\n", stderr="")
|
||||
|
||||
# systemctl --user restart
|
||||
if "systemctl" in joined and "restart" in joined:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="")
|
||||
|
||||
# launchctl list ai.hermes.gateway
|
||||
if "launchctl" in joined and "list" in joined:
|
||||
if launchctl_loaded:
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="PID\tStatus\tLabel\n123\t0\tai.hermes.gateway\n", stderr="")
|
||||
return subprocess.CompletedProcess(cmd, 113, stdout="", stderr="Could not find service")
|
||||
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="")
|
||||
|
||||
return side_effect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_args():
|
||||
return SimpleNamespace()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Launchd plist includes --replace
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLaunchdPlistReplace:
|
||||
"""The generated launchd plist must include --replace so respawned
|
||||
gateways kill stale instances."""
|
||||
|
||||
def test_plist_contains_replace_flag(self):
|
||||
plist = gateway_cli.generate_launchd_plist()
|
||||
assert "--replace" in plist
|
||||
|
||||
def test_plist_program_arguments_order(self):
|
||||
"""--replace comes after 'run' in the ProgramArguments."""
|
||||
plist = gateway_cli.generate_launchd_plist()
|
||||
lines = [line.strip() for line in plist.splitlines()]
|
||||
# Find 'run' and '--replace' in the string entries
|
||||
string_values = [
|
||||
line.replace("<string>", "").replace("</string>", "")
|
||||
for line in lines
|
||||
if "<string>" in line and "</string>" in line
|
||||
]
|
||||
assert "run" in string_values
|
||||
assert "--replace" in string_values
|
||||
run_idx = string_values.index("run")
|
||||
replace_idx = string_values.index("--replace")
|
||||
assert replace_idx == run_idx + 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cmd_update — macOS launchd detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLaunchdPlistRefresh:
|
||||
"""refresh_launchd_plist_if_needed rewrites stale plists (like systemd's
|
||||
refresh_systemd_unit_if_needed)."""
|
||||
|
||||
def test_refresh_rewrites_stale_plist(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
plist_path.write_text("<plist>old content</plist>")
|
||||
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
calls = []
|
||||
def fake_run(cmd, check=False, **kwargs):
|
||||
calls.append(cmd)
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
|
||||
result = gateway_cli.refresh_launchd_plist_if_needed()
|
||||
|
||||
assert result is True
|
||||
# Plist should now contain the generated content (which includes --replace)
|
||||
assert "--replace" in plist_path.read_text()
|
||||
# Should have unloaded then reloaded
|
||||
assert any("unload" in str(c) for c in calls)
|
||||
assert any("load" in str(c) for c in calls)
|
||||
|
||||
def test_refresh_skips_when_current(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
# Write the current expected content
|
||||
plist_path.write_text(gateway_cli.generate_launchd_plist())
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
gateway_cli.subprocess, "run",
|
||||
lambda cmd, **kw: calls.append(cmd) or SimpleNamespace(returncode=0),
|
||||
)
|
||||
|
||||
result = gateway_cli.refresh_launchd_plist_if_needed()
|
||||
|
||||
assert result is False
|
||||
assert len(calls) == 0 # No launchctl calls needed
|
||||
|
||||
def test_refresh_skips_when_no_plist(self, tmp_path, monkeypatch):
|
||||
plist_path = tmp_path / "nonexistent.plist"
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
result = gateway_cli.refresh_launchd_plist_if_needed()
|
||||
assert result is False
|
||||
|
||||
def test_launchd_start_calls_refresh(self, tmp_path, monkeypatch):
|
||||
"""launchd_start refreshes the plist before starting."""
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
plist_path.write_text("<plist>old</plist>")
|
||||
monkeypatch.setattr(gateway_cli, "get_launchd_plist_path", lambda: plist_path)
|
||||
|
||||
calls = []
|
||||
def fake_run(cmd, check=False, **kwargs):
|
||||
calls.append(cmd)
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr(gateway_cli.subprocess, "run", fake_run)
|
||||
|
||||
gateway_cli.launchd_start()
|
||||
|
||||
# First calls should be refresh (unload/load), then start
|
||||
cmd_strs = [" ".join(c) for c in calls]
|
||||
assert any("unload" in s for s in cmd_strs)
|
||||
assert any("start" in s for s in cmd_strs)
|
||||
|
||||
|
||||
class TestCmdUpdateLaunchdRestart:
|
||||
"""cmd_update correctly detects and handles launchd on macOS."""
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_detects_launchd_and_skips_manual_restart_message(
|
||||
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""When launchd is running the gateway, update should print
|
||||
'auto-restart via launchd' instead of 'Restart it with: hermes gateway run'."""
|
||||
# Create a fake launchd plist so is_macos + plist.exists() passes
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
plist_path.write_text("<plist/>")
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "is_macos", lambda: True,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_launchd_plist_path", lambda: plist_path,
|
||||
)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
launchctl_loaded=True,
|
||||
)
|
||||
|
||||
# Mock get_running_pid to return a PID
|
||||
with patch("gateway.status.get_running_pid", return_value=12345), \
|
||||
patch("gateway.status.remove_pid_file"):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Gateway restarted via launchd" in captured
|
||||
assert "Restart it with: hermes gateway run" not in captured
|
||||
# Verify launchctl stop + start were called (not manual SIGTERM)
|
||||
launchctl_calls = [
|
||||
c for c in mock_run.call_args_list
|
||||
if len(c.args[0]) > 0 and c.args[0][0] == "launchctl"
|
||||
]
|
||||
stop_calls = [c for c in launchctl_calls if "stop" in c.args[0]]
|
||||
start_calls = [c for c in launchctl_calls if "start" in c.args[0]]
|
||||
assert len(stop_calls) >= 1
|
||||
assert len(start_calls) >= 1
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_without_launchd_shows_manual_restart(
|
||||
self, mock_run, _mock_which, mock_args, capsys, tmp_path, monkeypatch,
|
||||
):
|
||||
"""When no service manager is running, update should show the manual restart hint."""
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "is_macos", lambda: True,
|
||||
)
|
||||
plist_path = tmp_path / "ai.hermes.gateway.plist"
|
||||
# plist does NOT exist — no launchd service
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "get_launchd_plist_path", lambda: plist_path,
|
||||
)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
launchctl_loaded=False,
|
||||
)
|
||||
|
||||
with patch("gateway.status.get_running_pid", return_value=12345), \
|
||||
patch("gateway.status.remove_pid_file"), \
|
||||
patch("os.kill"):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Restart it with: hermes gateway run" in captured
|
||||
assert "Gateway restarted via launchd" not in captured
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_with_systemd_still_restarts_via_systemd(
|
||||
self, mock_run, _mock_which, mock_args, capsys, monkeypatch,
|
||||
):
|
||||
"""On Linux with systemd active, update should restart via systemctl."""
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "is_macos", lambda: False,
|
||||
)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
systemd_active=True,
|
||||
)
|
||||
|
||||
with patch("gateway.status.get_running_pid", return_value=12345), \
|
||||
patch("gateway.status.remove_pid_file"), \
|
||||
patch("os.kill"):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Gateway restarted" in captured
|
||||
# Verify systemctl restart was called
|
||||
restart_calls = [
|
||||
c for c in mock_run.call_args_list
|
||||
if "restart" in " ".join(str(a) for a in c.args[0])
|
||||
and "systemctl" in " ".join(str(a) for a in c.args[0])
|
||||
]
|
||||
assert len(restart_calls) == 1
|
||||
|
||||
@patch("shutil.which", return_value=None)
|
||||
@patch("subprocess.run")
|
||||
def test_update_no_gateway_running_skips_restart(
|
||||
self, mock_run, _mock_which, mock_args, capsys, monkeypatch,
|
||||
):
|
||||
"""When no gateway is running, update should skip the restart section entirely."""
|
||||
monkeypatch.setattr(
|
||||
gateway_cli, "is_macos", lambda: False,
|
||||
)
|
||||
|
||||
mock_run.side_effect = _make_run_side_effect(
|
||||
commit_count="3",
|
||||
systemd_active=False,
|
||||
)
|
||||
|
||||
with patch("gateway.status.get_running_pid", return_value=None):
|
||||
cmd_update(mock_args)
|
||||
|
||||
captured = capsys.readouterr().out
|
||||
assert "Stopped gateway" not in captured
|
||||
assert "Gateway restarted" not in captured
|
||||
assert "Gateway restarted via launchd" not in captured
|
||||
+105
-110
@@ -16,131 +16,126 @@ from run_agent import AIAgent, IterationBudget
|
||||
from tools.delegate_tool import _run_single_child
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
|
||||
def main() -> int:
|
||||
set_interrupt(False)
|
||||
set_interrupt(False)
|
||||
|
||||
# Create parent agent (minimal)
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
|
||||
# Create parent agent (minimal)
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
|
||||
|
||||
child_started = threading.Event()
|
||||
result_holder = [None]
|
||||
child_started = threading.Event()
|
||||
result_holder = [None]
|
||||
|
||||
def run_delegate():
|
||||
with patch("run_agent.OpenAI") as MockOpenAI:
|
||||
mock_client = MagicMock()
|
||||
|
||||
def slow_create(**kwargs):
|
||||
time.sleep(3)
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "Done"
|
||||
resp.choices[0].message.tool_calls = None
|
||||
resp.choices[0].message.refusal = None
|
||||
resp.choices[0].finish_reason = "stop"
|
||||
resp.usage.prompt_tokens = 100
|
||||
resp.usage.completion_tokens = 10
|
||||
resp.usage.total_tokens = 110
|
||||
resp.usage.prompt_tokens_details = None
|
||||
return resp
|
||||
def run_delegate():
|
||||
with patch("run_agent.OpenAI") as MockOpenAI:
|
||||
mock_client = MagicMock()
|
||||
|
||||
mock_client.chat.completions.create = slow_create
|
||||
mock_client.close = MagicMock()
|
||||
MockOpenAI.return_value = mock_client
|
||||
def slow_create(**kwargs):
|
||||
time.sleep(3)
|
||||
resp = MagicMock()
|
||||
resp.choices = [MagicMock()]
|
||||
resp.choices[0].message.content = "Done"
|
||||
resp.choices[0].message.tool_calls = None
|
||||
resp.choices[0].message.refusal = None
|
||||
resp.choices[0].finish_reason = "stop"
|
||||
resp.usage.prompt_tokens = 100
|
||||
resp.usage.completion_tokens = 10
|
||||
resp.usage.total_tokens = 110
|
||||
resp.usage.prompt_tokens_details = None
|
||||
return resp
|
||||
|
||||
original_init = AIAgent.__init__
|
||||
mock_client.chat.completions.create = slow_create
|
||||
mock_client.close = MagicMock()
|
||||
MockOpenAI.return_value = mock_client
|
||||
|
||||
def patched_init(self_agent, *a, **kw):
|
||||
original_init(self_agent, *a, **kw)
|
||||
child_started.set()
|
||||
original_init = AIAgent.__init__
|
||||
|
||||
with patch.object(AIAgent, "__init__", patched_init):
|
||||
try:
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Test slow task",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model="test/model",
|
||||
max_iterations=5,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
override_provider="test",
|
||||
override_base_url="http://localhost:1",
|
||||
override_api_key="test",
|
||||
override_api_mode="chat_completions",
|
||||
)
|
||||
result_holder[0] = result
|
||||
except Exception as e:
|
||||
print(f"ERROR in delegate: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
def patched_init(self_agent, *a, **kw):
|
||||
original_init(self_agent, *a, **kw)
|
||||
child_started.set()
|
||||
|
||||
print("Starting agent thread...")
|
||||
agent_thread = threading.Thread(target=run_delegate, daemon=True)
|
||||
agent_thread.start()
|
||||
with patch.object(AIAgent, "__init__", patched_init):
|
||||
try:
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Test slow task",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model="test/model",
|
||||
max_iterations=5,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
override_provider="test",
|
||||
override_base_url="http://localhost:1",
|
||||
override_api_key="test",
|
||||
override_api_mode="chat_completions",
|
||||
)
|
||||
result_holder[0] = result
|
||||
except Exception as e:
|
||||
print(f"ERROR in delegate: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
started = child_started.wait(timeout=10)
|
||||
if not started:
|
||||
print("ERROR: Child never started")
|
||||
set_interrupt(False)
|
||||
return 1
|
||||
|
||||
time.sleep(0.5)
|
||||
print("Starting agent thread...")
|
||||
agent_thread = threading.Thread(target=run_delegate, daemon=True)
|
||||
agent_thread.start()
|
||||
|
||||
print(f"Active children: {len(parent._active_children)}")
|
||||
for i, c in enumerate(parent._active_children):
|
||||
print(f" Child {i}: _interrupt_requested={c._interrupt_requested}")
|
||||
started = child_started.wait(timeout=10)
|
||||
if not started:
|
||||
print("ERROR: Child never started")
|
||||
sys.exit(1)
|
||||
|
||||
t0 = time.monotonic()
|
||||
parent.interrupt("User typed a new message")
|
||||
print("Called parent.interrupt()")
|
||||
time.sleep(0.5)
|
||||
|
||||
for i, c in enumerate(parent._active_children):
|
||||
print(f" Child {i} after interrupt: _interrupt_requested={c._interrupt_requested}")
|
||||
print(f"Global is_interrupted: {is_interrupted()}")
|
||||
print(f"Active children: {len(parent._active_children)}")
|
||||
for i, c in enumerate(parent._active_children):
|
||||
print(f" Child {i}: _interrupt_requested={c._interrupt_requested}")
|
||||
|
||||
agent_thread.join(timeout=10)
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f"Agent thread finished in {elapsed:.2f}s")
|
||||
t0 = time.monotonic()
|
||||
parent.interrupt("User typed a new message")
|
||||
print(f"Called parent.interrupt()")
|
||||
|
||||
result = result_holder[0]
|
||||
if result:
|
||||
print(f"Status: {result['status']}")
|
||||
print(f"Duration: {result['duration_seconds']}s")
|
||||
if elapsed < 2.0:
|
||||
print("✅ PASS: Interrupt detected quickly!")
|
||||
else:
|
||||
print(f"❌ FAIL: Took {elapsed:.2f}s — interrupt was too slow or not detected")
|
||||
for i, c in enumerate(parent._active_children):
|
||||
print(f" Child {i} after interrupt: _interrupt_requested={c._interrupt_requested}")
|
||||
print(f"Global is_interrupted: {is_interrupted()}")
|
||||
|
||||
agent_thread.join(timeout=10)
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f"Agent thread finished in {elapsed:.2f}s")
|
||||
|
||||
result = result_holder[0]
|
||||
if result:
|
||||
print(f"Status: {result['status']}")
|
||||
print(f"Duration: {result['duration_seconds']}s")
|
||||
if elapsed < 2.0:
|
||||
print("✅ PASS: Interrupt detected quickly!")
|
||||
else:
|
||||
print("❌ FAIL: No result!")
|
||||
print(f"❌ FAIL: Took {elapsed:.2f}s — interrupt was too slow or not detected")
|
||||
else:
|
||||
print("❌ FAIL: No result!")
|
||||
|
||||
set_interrupt(False)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
set_interrupt(False)
|
||||
|
||||
@@ -1,480 +0,0 @@
|
||||
"""Tests for Anthropic error handling in the agent retry loop.
|
||||
|
||||
Covers all error paths in run_agent.py's run_conversation() for api_mode=anthropic_messages:
|
||||
- 429 rate limit → retried with backoff
|
||||
- 529 overloaded → retried with backoff
|
||||
- 400 bad request → non-retryable, immediate fail
|
||||
- 401 unauthorized → credential refresh + retry
|
||||
- 500 server error → retried with backoff
|
||||
- "prompt is too long" → context length error triggers compression
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.modules.setdefault("fire", types.SimpleNamespace(Fire=lambda *a, **k: None))
|
||||
sys.modules.setdefault("firecrawl", types.SimpleNamespace(Firecrawl=object))
|
||||
sys.modules.setdefault("fal_client", types.SimpleNamespace())
|
||||
|
||||
import gateway.run as gateway_run
|
||||
import run_agent
|
||||
from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _patch_agent_bootstrap(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
run_agent,
|
||||
"get_tool_definitions",
|
||||
lambda **kwargs: [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": "Run shell commands.",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(run_agent, "check_toolset_requirements", lambda: {})
|
||||
|
||||
|
||||
def _anthropic_response(text: str):
|
||||
"""Simulate an Anthropic messages.create() response object."""
|
||||
return SimpleNamespace(
|
||||
content=[SimpleNamespace(type="text", text=text)],
|
||||
stop_reason="end_turn",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=5),
|
||||
model="claude-sonnet-4-6-20250514",
|
||||
)
|
||||
|
||||
|
||||
class _RateLimitError(Exception):
|
||||
"""Simulates Anthropic 429 rate limit error."""
|
||||
def __init__(self):
|
||||
super().__init__("Error code: 429 - Rate limit exceeded. Please retry after 30s.")
|
||||
self.status_code = 429
|
||||
|
||||
|
||||
class _OverloadedError(Exception):
|
||||
"""Simulates Anthropic 529 overloaded error."""
|
||||
def __init__(self):
|
||||
super().__init__("Error code: 529 - API is temporarily overloaded.")
|
||||
self.status_code = 529
|
||||
|
||||
|
||||
class _BadRequestError(Exception):
|
||||
"""Simulates Anthropic 400 bad request error (non-retryable)."""
|
||||
def __init__(self):
|
||||
super().__init__("Error code: 400 - Invalid model specified.")
|
||||
self.status_code = 400
|
||||
|
||||
|
||||
class _UnauthorizedError(Exception):
|
||||
"""Simulates Anthropic 401 unauthorized error."""
|
||||
def __init__(self):
|
||||
super().__init__("Error code: 401 - Unauthorized. Invalid API key.")
|
||||
self.status_code = 401
|
||||
|
||||
|
||||
class _ServerError(Exception):
|
||||
"""Simulates Anthropic 500 internal server error."""
|
||||
def __init__(self):
|
||||
super().__init__("Error code: 500 - Internal server error.")
|
||||
self.status_code = 500
|
||||
|
||||
|
||||
class _PromptTooLongError(Exception):
|
||||
"""Simulates Anthropic prompt-too-long error (triggers context compression)."""
|
||||
def __init__(self):
|
||||
super().__init__("prompt is too long: 250000 tokens > 200000 maximum")
|
||||
self.status_code = 400
|
||||
|
||||
|
||||
class _FakeAnthropicClient:
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
def _fake_build_anthropic_client(key, base_url=None):
|
||||
return _FakeAnthropicClient()
|
||||
|
||||
|
||||
def _make_agent_cls(error_cls, recover_after=None):
|
||||
"""Create an AIAgent subclass that raises error_cls on API calls.
|
||||
|
||||
If recover_after is set, the agent succeeds after that many failures.
|
||||
"""
|
||||
|
||||
class _Agent(run_agent.AIAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("skip_context_files", True)
|
||||
kwargs.setdefault("skip_memory", True)
|
||||
kwargs.setdefault("max_iterations", 4)
|
||||
super().__init__(*args, **kwargs)
|
||||
self._cleanup_task_resources = lambda task_id: None
|
||||
self._persist_session = lambda messages, history=None: None
|
||||
self._save_trajectory = lambda messages, user_message, completed: None
|
||||
self._save_session_log = lambda messages: None
|
||||
|
||||
def run_conversation(self, user_message, conversation_history=None, task_id=None):
|
||||
calls = {"n": 0}
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
calls["n"] += 1
|
||||
if recover_after is not None and calls["n"] > recover_after:
|
||||
return _anthropic_response("Recovered")
|
||||
raise error_cls()
|
||||
|
||||
self._interruptible_api_call = _fake_api_call
|
||||
return super().run_conversation(
|
||||
user_message, conversation_history=conversation_history, task_id=task_id
|
||||
)
|
||||
|
||||
return _Agent
|
||||
|
||||
|
||||
def _run_with_agent(monkeypatch, agent_cls):
|
||||
"""Run _run_agent through the gateway with the given agent class."""
|
||||
_patch_agent_bootstrap(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.build_anthropic_client", _fake_build_anthropic_client
|
||||
)
|
||||
monkeypatch.setattr(run_agent, "AIAgent", agent_cls)
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
"_resolve_runtime_agent_kwargs",
|
||||
lambda: {
|
||||
"provider": "anthropic",
|
||||
"api_mode": "anthropic_messages",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"api_key": "sk-ant-api03-test-key",
|
||||
},
|
||||
)
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS", "false")
|
||||
|
||||
runner = gateway_run.GatewayRunner.__new__(gateway_run.GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._prefill_messages = []
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._running_agents = {}
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.hooks.loaded_hooks = []
|
||||
runner._session_db = None
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI",
|
||||
chat_type="dm",
|
||||
user_id="test-user-1",
|
||||
)
|
||||
|
||||
return asyncio.run(
|
||||
runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="test-session",
|
||||
session_key="agent:main:local:dm",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_429_rate_limit_is_retried_and_recovers(monkeypatch):
|
||||
"""429 should be retried with backoff. First call fails, second succeeds."""
|
||||
agent_cls = _make_agent_cls(_RateLimitError, recover_after=1)
|
||||
result = _run_with_agent(monkeypatch, agent_cls)
|
||||
assert result["final_response"] == "Recovered"
|
||||
|
||||
|
||||
def test_529_overloaded_is_retried_and_recovers(monkeypatch):
|
||||
"""529 should be retried with backoff. First call fails, second succeeds."""
|
||||
agent_cls = _make_agent_cls(_OverloadedError, recover_after=1)
|
||||
result = _run_with_agent(monkeypatch, agent_cls)
|
||||
assert result["final_response"] == "Recovered"
|
||||
|
||||
|
||||
def test_429_exhausts_all_retries_before_raising(monkeypatch):
|
||||
"""429 must retry max_retries times, not abort on first attempt."""
|
||||
agent_cls = _make_agent_cls(_RateLimitError) # always fails
|
||||
with pytest.raises(_RateLimitError):
|
||||
_run_with_agent(monkeypatch, agent_cls)
|
||||
|
||||
|
||||
def test_400_bad_request_is_non_retryable(monkeypatch):
|
||||
"""400 should fail immediately with only 1 API call (regression guard)."""
|
||||
agent_cls = _make_agent_cls(_BadRequestError)
|
||||
result = _run_with_agent(monkeypatch, agent_cls)
|
||||
assert result["api_calls"] == 1
|
||||
assert "400" in str(result.get("final_response", ""))
|
||||
|
||||
|
||||
def test_500_server_error_is_retried_and_recovers(monkeypatch):
|
||||
"""500 should be retried with backoff. First call fails, second succeeds."""
|
||||
agent_cls = _make_agent_cls(_ServerError, recover_after=1)
|
||||
result = _run_with_agent(monkeypatch, agent_cls)
|
||||
assert result["final_response"] == "Recovered"
|
||||
|
||||
|
||||
def test_401_credential_refresh_recovers(monkeypatch):
|
||||
"""401 should trigger credential refresh and retry once."""
|
||||
_patch_agent_bootstrap(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.build_anthropic_client", _fake_build_anthropic_client
|
||||
)
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS", "false")
|
||||
|
||||
refresh_count = {"n": 0}
|
||||
|
||||
class _Auth401ThenSuccessAgent(run_agent.AIAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("skip_context_files", True)
|
||||
kwargs.setdefault("skip_memory", True)
|
||||
kwargs.setdefault("max_iterations", 4)
|
||||
super().__init__(*args, **kwargs)
|
||||
self._cleanup_task_resources = lambda task_id: None
|
||||
self._persist_session = lambda messages, history=None: None
|
||||
self._save_trajectory = lambda messages, user_message, completed: None
|
||||
self._save_session_log = lambda messages: None
|
||||
|
||||
def _try_refresh_anthropic_client_credentials(self) -> bool:
|
||||
refresh_count["n"] += 1
|
||||
return True # Simulate successful credential refresh
|
||||
|
||||
def run_conversation(self, user_message, conversation_history=None, task_id=None):
|
||||
calls = {"n": 0}
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
raise _UnauthorizedError()
|
||||
return _anthropic_response("Auth refreshed")
|
||||
|
||||
self._interruptible_api_call = _fake_api_call
|
||||
return super().run_conversation(
|
||||
user_message, conversation_history=conversation_history, task_id=task_id
|
||||
)
|
||||
|
||||
monkeypatch.setattr(run_agent, "AIAgent", _Auth401ThenSuccessAgent)
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
"_resolve_runtime_agent_kwargs",
|
||||
lambda: {
|
||||
"provider": "anthropic",
|
||||
"api_mode": "anthropic_messages",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"api_key": "sk-ant-api03-test-key",
|
||||
},
|
||||
)
|
||||
|
||||
runner = gateway_run.GatewayRunner.__new__(gateway_run.GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._prefill_messages = []
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._running_agents = {}
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.hooks.loaded_hooks = []
|
||||
runner._session_db = None
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL, chat_id="cli", chat_name="CLI",
|
||||
chat_type="dm", user_id="test-user-1",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
runner._run_agent(
|
||||
message="hello", context_prompt="", history=[],
|
||||
source=source, session_id="session-401",
|
||||
session_key="agent:main:local:dm",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["final_response"] == "Auth refreshed"
|
||||
assert refresh_count["n"] == 1
|
||||
|
||||
|
||||
def test_401_refresh_fails_is_non_retryable(monkeypatch):
|
||||
"""401 with failed credential refresh should be treated as non-retryable."""
|
||||
_patch_agent_bootstrap(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.build_anthropic_client", _fake_build_anthropic_client
|
||||
)
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS", "false")
|
||||
|
||||
class _Auth401AlwaysFailAgent(run_agent.AIAgent):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("skip_context_files", True)
|
||||
kwargs.setdefault("skip_memory", True)
|
||||
kwargs.setdefault("max_iterations", 4)
|
||||
super().__init__(*args, **kwargs)
|
||||
self._cleanup_task_resources = lambda task_id: None
|
||||
self._persist_session = lambda messages, history=None: None
|
||||
self._save_trajectory = lambda messages, user_message, completed: None
|
||||
self._save_session_log = lambda messages: None
|
||||
|
||||
def _try_refresh_anthropic_client_credentials(self) -> bool:
|
||||
return False # Simulate failed credential refresh
|
||||
|
||||
def run_conversation(self, user_message, conversation_history=None, task_id=None):
|
||||
def _fake_api_call(api_kwargs):
|
||||
raise _UnauthorizedError()
|
||||
|
||||
self._interruptible_api_call = _fake_api_call
|
||||
return super().run_conversation(
|
||||
user_message, conversation_history=conversation_history, task_id=task_id
|
||||
)
|
||||
|
||||
monkeypatch.setattr(run_agent, "AIAgent", _Auth401AlwaysFailAgent)
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
"_resolve_runtime_agent_kwargs",
|
||||
lambda: {
|
||||
"provider": "anthropic",
|
||||
"api_mode": "anthropic_messages",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"api_key": "sk-ant-api03-test-key",
|
||||
},
|
||||
)
|
||||
|
||||
runner = gateway_run.GatewayRunner.__new__(gateway_run.GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._prefill_messages = []
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._running_agents = {}
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.hooks.loaded_hooks = []
|
||||
runner._session_db = None
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL, chat_id="cli", chat_name="CLI",
|
||||
chat_type="dm", user_id="test-user-1",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
runner._run_agent(
|
||||
message="hello", context_prompt="", history=[],
|
||||
source=source, session_id="session-401-fail",
|
||||
session_key="agent:main:local:dm",
|
||||
)
|
||||
)
|
||||
|
||||
# 401 after failed refresh → non-retryable (falls through to is_client_error)
|
||||
assert result["api_calls"] == 1
|
||||
assert "401" in str(result.get("final_response", "")) or "unauthorized" in str(result.get("final_response", "")).lower()
|
||||
|
||||
|
||||
def test_prompt_too_long_triggers_compression(monkeypatch):
|
||||
"""Anthropic 'prompt is too long' error should trigger context compression, not immediate fail."""
|
||||
_patch_agent_bootstrap(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
"agent.anthropic_adapter.build_anthropic_client", _fake_build_anthropic_client
|
||||
)
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS", "false")
|
||||
|
||||
class _PromptTooLongThenSuccessAgent(run_agent.AIAgent):
|
||||
compress_called = 0
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.setdefault("skip_context_files", True)
|
||||
kwargs.setdefault("skip_memory", True)
|
||||
kwargs.setdefault("max_iterations", 4)
|
||||
super().__init__(*args, **kwargs)
|
||||
self._cleanup_task_resources = lambda task_id: None
|
||||
self._persist_session = lambda messages, history=None: None
|
||||
self._save_trajectory = lambda messages, user_message, completed: None
|
||||
self._save_session_log = lambda messages: None
|
||||
|
||||
def _compress_context(self, messages, system_message, approx_tokens=0, task_id=None):
|
||||
type(self).compress_called += 1
|
||||
# Simulate compression by dropping oldest non-system message
|
||||
if len(messages) > 2:
|
||||
compressed = [messages[0]] + messages[2:]
|
||||
else:
|
||||
compressed = messages
|
||||
return compressed, system_message
|
||||
|
||||
def run_conversation(self, user_message, conversation_history=None, task_id=None):
|
||||
calls = {"n": 0}
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
raise _PromptTooLongError()
|
||||
return _anthropic_response("Compressed and recovered")
|
||||
|
||||
self._interruptible_api_call = _fake_api_call
|
||||
return super().run_conversation(
|
||||
user_message, conversation_history=conversation_history, task_id=task_id
|
||||
)
|
||||
|
||||
_PromptTooLongThenSuccessAgent.compress_called = 0
|
||||
monkeypatch.setattr(run_agent, "AIAgent", _PromptTooLongThenSuccessAgent)
|
||||
monkeypatch.setattr(
|
||||
gateway_run,
|
||||
"_resolve_runtime_agent_kwargs",
|
||||
lambda: {
|
||||
"provider": "anthropic",
|
||||
"api_mode": "anthropic_messages",
|
||||
"base_url": "https://api.anthropic.com",
|
||||
"api_key": "sk-ant-api03-test-key",
|
||||
},
|
||||
)
|
||||
|
||||
runner = gateway_run.GatewayRunner.__new__(gateway_run.GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._prefill_messages = []
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._running_agents = {}
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.hooks.loaded_hooks = []
|
||||
runner._session_db = None
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL, chat_id="cli", chat_name="CLI",
|
||||
chat_type="dm", user_id="test-user-1",
|
||||
)
|
||||
|
||||
result = asyncio.run(
|
||||
runner._run_agent(
|
||||
message="hello", context_prompt="", history=[],
|
||||
source=source, session_id="session-prompt-long",
|
||||
session_key="agent:main:local:dm",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["final_response"] == "Compressed and recovered"
|
||||
assert _PromptTooLongThenSuccessAgent.compress_called >= 1
|
||||
@@ -68,22 +68,6 @@ class TestAtomicJsonWrite:
|
||||
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
|
||||
assert len(tmp_files) == 0
|
||||
|
||||
def test_cleans_up_temp_file_on_baseexception(self, tmp_path):
|
||||
class SimulatedAbort(BaseException):
|
||||
pass
|
||||
|
||||
target = tmp_path / "data.json"
|
||||
original = {"preserved": True}
|
||||
target.write_text(json.dumps(original), encoding="utf-8")
|
||||
|
||||
with patch("utils.json.dump", side_effect=SimulatedAbort):
|
||||
with pytest.raises(SimulatedAbort):
|
||||
atomic_json_write(target, {"new": True})
|
||||
|
||||
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
|
||||
assert len(tmp_files) == 0
|
||||
assert json.loads(target.read_text(encoding="utf-8")) == original
|
||||
|
||||
def test_accepts_string_path(self, tmp_path):
|
||||
target = str(tmp_path / "string_path.json")
|
||||
atomic_json_write(target, {"string": True})
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
"""Tests for utils.atomic_yaml_write — crash-safe YAML file writes."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from utils import atomic_yaml_write
|
||||
|
||||
|
||||
class TestAtomicYamlWrite:
|
||||
def test_writes_valid_yaml(self, tmp_path):
|
||||
target = tmp_path / "data.yaml"
|
||||
data = {"key": "value", "nested": {"a": 1}}
|
||||
|
||||
atomic_yaml_write(target, data)
|
||||
|
||||
assert yaml.safe_load(target.read_text(encoding="utf-8")) == data
|
||||
|
||||
def test_cleans_up_temp_file_on_baseexception(self, tmp_path):
|
||||
class SimulatedAbort(BaseException):
|
||||
pass
|
||||
|
||||
target = tmp_path / "data.yaml"
|
||||
original = {"preserved": True}
|
||||
target.write_text(yaml.safe_dump(original), encoding="utf-8")
|
||||
|
||||
with patch("utils.yaml.dump", side_effect=SimulatedAbort):
|
||||
with pytest.raises(SimulatedAbort):
|
||||
atomic_yaml_write(target, {"new": True})
|
||||
|
||||
tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name]
|
||||
assert len(tmp_files) == 0
|
||||
assert yaml.safe_load(target.read_text(encoding="utf-8")) == original
|
||||
|
||||
def test_appends_extra_content(self, tmp_path):
|
||||
target = tmp_path / "data.yaml"
|
||||
|
||||
atomic_yaml_write(target, {"key": "value"}, extra_content="\n# comment\n")
|
||||
|
||||
text = target.read_text(encoding="utf-8")
|
||||
assert "key: value" in text
|
||||
assert "# comment" in text
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Tests for automatic MCP reload when config.yaml mcp_servers section changes."""
|
||||
import time
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _make_cli(tmp_path, mcp_servers=None):
|
||||
"""Create a minimal HermesCLI instance with mocked config."""
|
||||
import cli as cli_mod
|
||||
obj = object.__new__(cli_mod.HermesCLI)
|
||||
obj.config = {"mcp_servers": mcp_servers or {}}
|
||||
obj._agent_running = False
|
||||
obj._last_config_check = 0.0
|
||||
obj._config_mcp_servers = mcp_servers or {}
|
||||
|
||||
cfg_file = tmp_path / "config.yaml"
|
||||
cfg_file.write_text("mcp_servers: {}\n")
|
||||
obj._config_mtime = cfg_file.stat().st_mtime
|
||||
|
||||
obj._reload_mcp = MagicMock()
|
||||
obj._busy_command = MagicMock()
|
||||
obj._busy_command.return_value.__enter__ = MagicMock(return_value=None)
|
||||
obj._busy_command.return_value.__exit__ = MagicMock(return_value=False)
|
||||
obj._slow_command_status = MagicMock(return_value="reloading...")
|
||||
|
||||
return obj, cfg_file
|
||||
|
||||
|
||||
class TestMCPConfigWatch:
|
||||
|
||||
def test_no_change_does_not_reload(self, tmp_path):
|
||||
"""If mtime and mcp_servers unchanged, _reload_mcp is NOT called."""
|
||||
obj, cfg_file = _make_cli(tmp_path)
|
||||
|
||||
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
|
||||
obj._check_config_mcp_changes()
|
||||
|
||||
obj._reload_mcp.assert_not_called()
|
||||
|
||||
def test_mtime_change_with_same_mcp_servers_does_not_reload(self, tmp_path):
|
||||
"""If file mtime changes but mcp_servers is identical, no reload."""
|
||||
import yaml
|
||||
obj, cfg_file = _make_cli(tmp_path, mcp_servers={"fs": {"command": "npx"}})
|
||||
|
||||
# Write same mcp_servers but touch the file
|
||||
cfg_file.write_text(yaml.dump({"mcp_servers": {"fs": {"command": "npx"}}}))
|
||||
# Force mtime to appear changed
|
||||
obj._config_mtime = 0.0
|
||||
|
||||
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
|
||||
obj._check_config_mcp_changes()
|
||||
|
||||
obj._reload_mcp.assert_not_called()
|
||||
|
||||
def test_new_mcp_server_triggers_reload(self, tmp_path):
|
||||
"""Adding a new MCP server to config triggers auto-reload."""
|
||||
import yaml
|
||||
obj, cfg_file = _make_cli(tmp_path, mcp_servers={})
|
||||
|
||||
# Simulate user adding a new MCP server to config.yaml
|
||||
cfg_file.write_text(yaml.dump({"mcp_servers": {"github": {"url": "https://mcp.github.com"}}}))
|
||||
obj._config_mtime = 0.0 # force stale mtime
|
||||
|
||||
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
|
||||
obj._check_config_mcp_changes()
|
||||
|
||||
obj._reload_mcp.assert_called_once()
|
||||
|
||||
def test_removed_mcp_server_triggers_reload(self, tmp_path):
|
||||
"""Removing an MCP server from config triggers auto-reload."""
|
||||
import yaml
|
||||
obj, cfg_file = _make_cli(tmp_path, mcp_servers={"github": {"url": "https://mcp.github.com"}})
|
||||
|
||||
# Simulate user removing the server
|
||||
cfg_file.write_text(yaml.dump({"mcp_servers": {}}))
|
||||
obj._config_mtime = 0.0
|
||||
|
||||
with patch("hermes_cli.config.get_config_path", return_value=cfg_file):
|
||||
obj._check_config_mcp_changes()
|
||||
|
||||
obj._reload_mcp.assert_called_once()
|
||||
|
||||
def test_interval_throttle_skips_check(self, tmp_path):
|
||||
"""If called within CONFIG_WATCH_INTERVAL, stat() is skipped."""
|
||||
obj, cfg_file = _make_cli(tmp_path)
|
||||
obj._last_config_check = time.monotonic() # just checked
|
||||
|
||||
with patch("hermes_cli.config.get_config_path", return_value=cfg_file), \
|
||||
patch.object(Path, "stat") as mock_stat:
|
||||
obj._check_config_mcp_changes()
|
||||
mock_stat.assert_not_called()
|
||||
|
||||
obj._reload_mcp.assert_not_called()
|
||||
|
||||
def test_missing_config_file_does_not_crash(self, tmp_path):
|
||||
"""If config.yaml doesn't exist, _check_config_mcp_changes is a no-op."""
|
||||
obj, cfg_file = _make_cli(tmp_path)
|
||||
missing = tmp_path / "nonexistent.yaml"
|
||||
|
||||
with patch("hermes_cli.config.get_config_path", return_value=missing):
|
||||
obj._check_config_mcp_changes() # should not raise
|
||||
|
||||
obj._reload_mcp.assert_not_called()
|
||||
@@ -64,8 +64,8 @@ class TestModelCommand:
|
||||
cli_obj.process_command("/model gpt-5.4")
|
||||
|
||||
output = capsys.readouterr().out
|
||||
# Auto-detection remaps bare model names to proper OpenRouter slugs
|
||||
assert cli_obj.model == "openai/gpt-5.4"
|
||||
# Model is accepted (with warning) even if not in API listing
|
||||
assert cli_obj.model == "gpt-5.4"
|
||||
|
||||
def test_validation_crash_falls_back_to_save(self, capsys):
|
||||
cli_obj = self._make_cli()
|
||||
|
||||
@@ -162,57 +162,6 @@ def test_runtime_resolution_rebuilds_agent_on_routing_change(monkeypatch):
|
||||
assert shell.api_mode == "codex_responses"
|
||||
|
||||
|
||||
def test_cli_turn_routing_uses_primary_when_disabled(monkeypatch):
|
||||
cli = _import_cli()
|
||||
shell = cli.HermesCLI(model="gpt-5", compact=True, max_turns=1)
|
||||
shell.provider = "openrouter"
|
||||
shell.api_mode = "chat_completions"
|
||||
shell.base_url = "https://openrouter.ai/api/v1"
|
||||
shell.api_key = "sk-primary"
|
||||
shell._smart_model_routing = {"enabled": False}
|
||||
|
||||
result = shell._resolve_turn_agent_config("what time is it in tokyo?")
|
||||
|
||||
assert result["model"] == "gpt-5"
|
||||
assert result["runtime"]["provider"] == "openrouter"
|
||||
assert result["label"] is None
|
||||
|
||||
|
||||
def test_cli_turn_routing_uses_cheap_model_when_simple(monkeypatch):
|
||||
cli = _import_cli()
|
||||
|
||||
def _runtime_resolve(**kwargs):
|
||||
assert kwargs["requested"] == "zai"
|
||||
return {
|
||||
"provider": "zai",
|
||||
"api_mode": "chat_completions",
|
||||
"base_url": "https://open.z.ai/api/v1",
|
||||
"api_key": "cheap-key",
|
||||
"source": "env/config",
|
||||
}
|
||||
|
||||
monkeypatch.setattr("hermes_cli.runtime_provider.resolve_runtime_provider", _runtime_resolve)
|
||||
|
||||
shell = cli.HermesCLI(model="anthropic/claude-sonnet-4", compact=True, max_turns=1)
|
||||
shell.provider = "openrouter"
|
||||
shell.api_mode = "chat_completions"
|
||||
shell.base_url = "https://openrouter.ai/api/v1"
|
||||
shell.api_key = "primary-key"
|
||||
shell._smart_model_routing = {
|
||||
"enabled": True,
|
||||
"cheap_model": {"provider": "zai", "model": "glm-5-air"},
|
||||
"max_simple_chars": 160,
|
||||
"max_simple_words": 28,
|
||||
}
|
||||
|
||||
result = shell._resolve_turn_agent_config("what time is it in tokyo?")
|
||||
|
||||
assert result["model"] == "glm-5-air"
|
||||
assert result["runtime"]["provider"] == "zai"
|
||||
assert result["runtime"]["api_key"] == "cheap-key"
|
||||
assert result["label"] is not None
|
||||
|
||||
|
||||
def test_cli_prefers_config_provider_over_stale_env_override(monkeypatch):
|
||||
cli = _import_cli()
|
||||
|
||||
@@ -387,42 +336,4 @@ def test_cmd_model_falls_back_to_auto_on_invalid_provider(monkeypatch, capsys):
|
||||
|
||||
assert "Warning:" in output
|
||||
assert "falling back to auto provider detection" in output.lower()
|
||||
assert "No change." in output
|
||||
|
||||
|
||||
def test_model_flow_custom_saves_verified_v1_base_url(monkeypatch, capsys):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.get_env_value",
|
||||
lambda key: "" if key in {"OPENAI_BASE_URL", "OPENAI_API_KEY"} else "",
|
||||
)
|
||||
saved_env = {}
|
||||
monkeypatch.setattr("hermes_cli.config.save_env_value", lambda key, value: saved_env.__setitem__(key, value))
|
||||
monkeypatch.setattr("hermes_cli.auth._save_model_choice", lambda model: saved_env.__setitem__("MODEL", model))
|
||||
monkeypatch.setattr("hermes_cli.auth.deactivate_provider", lambda: None)
|
||||
monkeypatch.setattr("hermes_cli.main._save_custom_provider", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.models.probe_api_models",
|
||||
lambda api_key, base_url: {
|
||||
"models": ["llm"],
|
||||
"probed_url": "http://localhost:8000/v1/models",
|
||||
"resolved_base_url": "http://localhost:8000/v1",
|
||||
"suggested_base_url": "http://localhost:8000/v1",
|
||||
"used_fallback": True,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.config.load_config",
|
||||
lambda: {"model": {"default": "", "provider": "custom", "base_url": ""}},
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None)
|
||||
|
||||
answers = iter(["http://localhost:8000", "local-key", "llm"])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers))
|
||||
|
||||
hermes_main._model_flow_custom({})
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Saving the working base URL instead" in output
|
||||
assert saved_env["OPENAI_BASE_URL"] == "http://localhost:8000/v1"
|
||||
assert saved_env["OPENAI_API_KEY"] == "local-key"
|
||||
assert saved_env["MODEL"] == "llm"
|
||||
assert "No change." in output
|
||||
@@ -1,175 +0,0 @@
|
||||
from datetime import datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
|
||||
from cli import HermesCLI
|
||||
|
||||
|
||||
def _make_cli(model: str = "anthropic/claude-sonnet-4-20250514"):
|
||||
cli_obj = HermesCLI.__new__(HermesCLI)
|
||||
cli_obj.model = model
|
||||
cli_obj.session_start = datetime.now() - timedelta(minutes=14, seconds=32)
|
||||
cli_obj.conversation_history = [{"role": "user", "content": "hi"}]
|
||||
cli_obj.agent = None
|
||||
return cli_obj
|
||||
|
||||
|
||||
def _attach_agent(
|
||||
cli_obj,
|
||||
*,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
api_calls: int,
|
||||
context_tokens: int,
|
||||
context_length: int,
|
||||
compressions: int = 0,
|
||||
):
|
||||
cli_obj.agent = SimpleNamespace(
|
||||
model=cli_obj.model,
|
||||
session_prompt_tokens=prompt_tokens,
|
||||
session_completion_tokens=completion_tokens,
|
||||
session_total_tokens=total_tokens,
|
||||
session_api_calls=api_calls,
|
||||
context_compressor=SimpleNamespace(
|
||||
last_prompt_tokens=context_tokens,
|
||||
context_length=context_length,
|
||||
compression_count=compressions,
|
||||
),
|
||||
)
|
||||
return cli_obj
|
||||
|
||||
|
||||
class TestCLIStatusBar:
|
||||
def test_context_style_thresholds(self):
|
||||
cli_obj = _make_cli()
|
||||
|
||||
assert cli_obj._status_bar_context_style(None) == "class:status-bar-dim"
|
||||
assert cli_obj._status_bar_context_style(10) == "class:status-bar-good"
|
||||
assert cli_obj._status_bar_context_style(50) == "class:status-bar-warn"
|
||||
assert cli_obj._status_bar_context_style(81) == "class:status-bar-bad"
|
||||
assert cli_obj._status_bar_context_style(95) == "class:status-bar-critical"
|
||||
|
||||
def test_build_status_bar_text_for_wide_terminal(self):
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=10_230,
|
||||
completion_tokens=2_220,
|
||||
total_tokens=12_450,
|
||||
api_calls=7,
|
||||
context_tokens=12_450,
|
||||
context_length=200_000,
|
||||
)
|
||||
|
||||
text = cli_obj._build_status_bar_text(width=120)
|
||||
|
||||
assert "claude-sonnet-4-20250514" in text
|
||||
assert "12.4K/200K" in text
|
||||
assert "6%" in text
|
||||
assert "$0.06" not in text # cost hidden by default
|
||||
assert "15m" in text
|
||||
|
||||
def test_build_status_bar_text_shows_cost_when_enabled(self):
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=10000,
|
||||
completion_tokens=2400,
|
||||
total_tokens=12400,
|
||||
api_calls=7,
|
||||
context_tokens=12400,
|
||||
context_length=200_000,
|
||||
)
|
||||
cli_obj.show_cost = True
|
||||
|
||||
text = cli_obj._build_status_bar_text(width=120)
|
||||
assert "$" in text # cost is shown when enabled
|
||||
|
||||
def test_build_status_bar_text_collapses_for_narrow_terminal(self):
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=10000,
|
||||
completion_tokens=2400,
|
||||
total_tokens=12400,
|
||||
api_calls=7,
|
||||
context_tokens=12400,
|
||||
context_length=200_000,
|
||||
)
|
||||
|
||||
text = cli_obj._build_status_bar_text(width=60)
|
||||
|
||||
assert "⚕" in text
|
||||
assert "$0.06" not in text # cost hidden by default
|
||||
assert "15m" in text
|
||||
assert "200K" not in text
|
||||
|
||||
def test_build_status_bar_text_handles_missing_agent(self):
|
||||
cli_obj = _make_cli()
|
||||
|
||||
text = cli_obj._build_status_bar_text(width=100)
|
||||
|
||||
assert "⚕" in text
|
||||
assert "claude-sonnet-4-20250514" in text
|
||||
|
||||
|
||||
class TestCLIUsageReport:
|
||||
def test_show_usage_includes_estimated_cost(self, capsys):
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=10_230,
|
||||
completion_tokens=2_220,
|
||||
total_tokens=12_450,
|
||||
api_calls=7,
|
||||
context_tokens=12_450,
|
||||
context_length=200_000,
|
||||
compressions=1,
|
||||
)
|
||||
cli_obj.verbose = False
|
||||
|
||||
cli_obj._show_usage()
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Model:" in output
|
||||
assert "Input cost:" in output
|
||||
assert "Output cost:" in output
|
||||
assert "Total cost:" in output
|
||||
assert "$" in output
|
||||
assert "0.064" in output
|
||||
assert "Session duration:" in output
|
||||
assert "Compressions:" in output
|
||||
|
||||
def test_show_usage_marks_unknown_pricing(self, capsys):
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(model="local/my-custom-model"),
|
||||
prompt_tokens=1_000,
|
||||
completion_tokens=500,
|
||||
total_tokens=1_500,
|
||||
api_calls=1,
|
||||
context_tokens=1_000,
|
||||
context_length=32_000,
|
||||
)
|
||||
cli_obj.verbose = False
|
||||
|
||||
cli_obj._show_usage()
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Total cost:" in output
|
||||
assert "n/a" in output
|
||||
assert "Pricing unknown for local/my-custom-model" in output
|
||||
|
||||
def test_zero_priced_provider_models_stay_unknown(self, capsys):
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(model="glm-5"),
|
||||
prompt_tokens=1_000,
|
||||
completion_tokens=500,
|
||||
total_tokens=1_500,
|
||||
api_calls=1,
|
||||
context_tokens=1_000,
|
||||
context_length=32_000,
|
||||
)
|
||||
cli_obj.verbose = False
|
||||
|
||||
cli_obj._show_usage()
|
||||
output = capsys.readouterr().out
|
||||
|
||||
assert "Total cost:" in output
|
||||
assert "n/a" in output
|
||||
assert "Pricing unknown for glm-5" in output
|
||||
@@ -1,186 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
|
||||
# Load the hyphenated script name dynamically
|
||||
repo_root = Path(__file__).parent.parent
|
||||
script_path = repo_root / "optional-skills" / "security" / "oss-forensics" / "scripts" / "evidence-store.py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location("evidence_store", str(script_path))
|
||||
evidence_store = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(evidence_store)
|
||||
EvidenceStore = evidence_store.EvidenceStore
|
||||
|
||||
|
||||
def test_evidence_store_init(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
assert store.filepath == str(store_file)
|
||||
assert len(store.data["evidence"]) == 0
|
||||
assert "metadata" in store.data
|
||||
assert store.data["metadata"]["version"] == "2.0"
|
||||
assert "chain_of_custody" in store.data
|
||||
|
||||
|
||||
def test_evidence_store_add(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
|
||||
eid = store.add(
|
||||
source="test_source",
|
||||
content="test_content",
|
||||
evidence_type="git",
|
||||
actor="test_actor",
|
||||
notes="test_notes",
|
||||
)
|
||||
|
||||
assert eid == "EV-0001"
|
||||
assert len(store.data["evidence"]) == 1
|
||||
assert store.data["evidence"][0]["content"] == "test_content"
|
||||
assert store.data["evidence"][0]["id"] == "EV-0001"
|
||||
assert store.data["evidence"][0]["actor"] == "test_actor"
|
||||
assert store.data["evidence"][0]["notes"] == "test_notes"
|
||||
# Verify SHA-256 was computed
|
||||
assert store.data["evidence"][0]["content_sha256"] is not None
|
||||
assert len(store.data["evidence"][0]["content_sha256"]) == 64
|
||||
|
||||
|
||||
def test_evidence_store_add_persists(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
store.add(source="s1", content="c1", evidence_type="git")
|
||||
|
||||
# Reload from disk
|
||||
store2 = EvidenceStore(str(store_file))
|
||||
assert len(store2.data["evidence"]) == 1
|
||||
assert store2.data["evidence"][0]["id"] == "EV-0001"
|
||||
|
||||
|
||||
def test_evidence_store_sequential_ids(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
|
||||
eid1 = store.add(source="s1", content="c1", evidence_type="git")
|
||||
eid2 = store.add(source="s2", content="c2", evidence_type="gh_api")
|
||||
eid3 = store.add(source="s3", content="c3", evidence_type="ioc")
|
||||
|
||||
assert eid1 == "EV-0001"
|
||||
assert eid2 == "EV-0002"
|
||||
assert eid3 == "EV-0003"
|
||||
|
||||
|
||||
def test_evidence_store_list(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
|
||||
store.add(source="s1", content="c1", evidence_type="git", actor="a1")
|
||||
store.add(source="s2", content="c2", evidence_type="gh_api", actor="a2")
|
||||
|
||||
all_evidence = store.list_evidence()
|
||||
assert len(all_evidence) == 2
|
||||
|
||||
git_evidence = store.list_evidence(filter_type="git")
|
||||
assert len(git_evidence) == 1
|
||||
assert git_evidence[0]["actor"] == "a1"
|
||||
|
||||
actor_evidence = store.list_evidence(filter_actor="a2")
|
||||
assert len(actor_evidence) == 1
|
||||
assert actor_evidence[0]["type"] == "gh_api"
|
||||
|
||||
|
||||
def test_evidence_store_verify_integrity(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
|
||||
store.add(source="s1", content="c1", evidence_type="git")
|
||||
assert len(store.verify_integrity()) == 0
|
||||
|
||||
# Manually corrupt the content to trigger a hash mismatch
|
||||
store.data["evidence"][0]["content"] = "corrupted_content"
|
||||
issues = store.verify_integrity()
|
||||
assert len(issues) == 1
|
||||
assert issues[0]["id"] == "EV-0001"
|
||||
|
||||
|
||||
def test_evidence_store_query(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
|
||||
store.add(source="github_api", content="malicious activity detected", evidence_type="gh_api")
|
||||
store.add(source="manual", content="clean observation", evidence_type="manual")
|
||||
|
||||
results = store.query("malicious")
|
||||
assert len(results) == 1
|
||||
assert results[0]["source"] == "github_api"
|
||||
|
||||
# Query should be case-insensitive
|
||||
results = store.query("MALICIOUS")
|
||||
assert len(results) == 1
|
||||
|
||||
|
||||
def test_evidence_store_query_searches_multiple_fields(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
|
||||
store.add(source="git_fsck", content="dangling commit abc123", evidence_type="git", actor="attacker")
|
||||
store.add(source="manual", content="clean", evidence_type="manual")
|
||||
|
||||
# Search by source
|
||||
assert len(store.query("fsck")) == 1
|
||||
# Search by actor
|
||||
assert len(store.query("attacker")) == 1
|
||||
# Search returns nothing for non-matching
|
||||
assert len(store.query("nonexistent")) == 0
|
||||
|
||||
|
||||
def test_evidence_store_chain_of_custody(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
|
||||
store.add(source="s1", content="c1", evidence_type="git")
|
||||
store.add(source="s2", content="c2", evidence_type="gh_api")
|
||||
|
||||
chain = store.data["chain_of_custody"]
|
||||
assert len(chain) == 2
|
||||
assert chain[0]["evidence_id"] == "EV-0001"
|
||||
assert chain[0]["action"] == "add"
|
||||
assert chain[1]["evidence_id"] == "EV-0002"
|
||||
|
||||
|
||||
def test_evidence_store_export_markdown(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
|
||||
store.add(source="git_log", content="suspicious commit", evidence_type="git", actor="actor1")
|
||||
|
||||
md = store.export_markdown()
|
||||
assert "# Evidence Registry" in md
|
||||
assert "EV-0001" in md
|
||||
assert "Chain of Custody" in md
|
||||
assert "actor1" in md
|
||||
|
||||
|
||||
def test_evidence_store_summary(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store = EvidenceStore(str(store_file))
|
||||
|
||||
store.add(source="s1", content="c1", evidence_type="git", actor="a1")
|
||||
store.add(source="s2", content="c2", evidence_type="git", actor="a2")
|
||||
store.add(source="s3", content="c3", evidence_type="gh_api", actor="a1")
|
||||
|
||||
s = store.summary()
|
||||
assert s["total"] == 3
|
||||
assert s["by_type"]["git"] == 2
|
||||
assert s["by_type"]["gh_api"] == 1
|
||||
assert "a1" in s["unique_actors"]
|
||||
assert "a2" in s["unique_actors"]
|
||||
|
||||
|
||||
def test_evidence_store_corrupted_file(tmp_path):
|
||||
store_file = tmp_path / "test_evidence.json"
|
||||
store_file.write_text("NOT VALID JSON {{{")
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
EvidenceStore(str(store_file))
|
||||
@@ -206,7 +206,6 @@ class TestHasKnownPricing:
|
||||
def test_unknown_custom_model(self):
|
||||
assert _has_known_pricing("FP16_Hermes_4.5") is False
|
||||
assert _has_known_pricing("my-custom-model") is False
|
||||
assert _has_known_pricing("glm-5") is False
|
||||
assert _has_known_pricing("") is False
|
||||
assert _has_known_pricing(None) is False
|
||||
|
||||
|
||||
+123
-136
@@ -29,6 +29,51 @@ from unittest.mock import MagicMock, patch
|
||||
from run_agent import AIAgent, IterationBudget
|
||||
from tools.interrupt import set_interrupt, is_interrupted
|
||||
|
||||
set_interrupt(False)
|
||||
|
||||
# ─── Create parent agent ───
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
|
||||
|
||||
# Monkey-patch parent.interrupt to log
|
||||
_original_interrupt = AIAgent.interrupt
|
||||
def logged_interrupt(self, message=None):
|
||||
log.info(f"🔴 parent.interrupt() called with: {message!r}")
|
||||
log.info(f" _active_children count: {len(self._active_children)}")
|
||||
_original_interrupt(self, message)
|
||||
log.info(f" After interrupt: _interrupt_requested={self._interrupt_requested}")
|
||||
for i, c in enumerate(self._active_children):
|
||||
log.info(f" Child {i}._interrupt_requested={c._interrupt_requested}")
|
||||
parent.interrupt = lambda msg=None: logged_interrupt(parent, msg)
|
||||
|
||||
# ─── Simulate the exact CLI flow ───
|
||||
interrupt_queue = queue.Queue()
|
||||
child_running = threading.Event()
|
||||
agent_result = [None]
|
||||
|
||||
def make_slow_response(delay=2.0):
|
||||
"""API response that takes a while."""
|
||||
def create(**kwargs):
|
||||
@@ -49,154 +94,96 @@ def make_slow_response(delay=2.0):
|
||||
return create
|
||||
|
||||
|
||||
def main() -> int:
|
||||
set_interrupt(False)
|
||||
def agent_thread_func():
|
||||
"""Simulates the agent_thread in cli.py's chat() method."""
|
||||
log.info("🟢 agent_thread starting")
|
||||
|
||||
# ─── Create parent agent ───
|
||||
parent = AIAgent.__new__(AIAgent)
|
||||
parent._interrupt_requested = False
|
||||
parent._interrupt_message = None
|
||||
parent._active_children = []
|
||||
parent.quiet_mode = True
|
||||
parent.model = "test/model"
|
||||
parent.base_url = "http://localhost:1"
|
||||
parent.api_key = "test"
|
||||
parent.provider = "test"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.platform = "cli"
|
||||
parent.enabled_toolsets = ["terminal", "file"]
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent.max_tokens = None
|
||||
parent.reasoning_config = None
|
||||
parent.prefill_messages = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = 0
|
||||
parent._delegate_spinner = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.iteration_budget = IterationBudget(max_total=100)
|
||||
parent._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1"}
|
||||
with patch("run_agent.OpenAI") as MockOpenAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = make_slow_response(delay=3.0)
|
||||
mock_client.close = MagicMock()
|
||||
MockOpenAI.return_value = mock_client
|
||||
|
||||
# Monkey-patch parent.interrupt to log
|
||||
_original_interrupt = AIAgent.interrupt
|
||||
from tools.delegate_tool import _run_single_child
|
||||
|
||||
def logged_interrupt(self, message=None):
|
||||
log.info(f"🔴 parent.interrupt() called with: {message!r}")
|
||||
log.info(f" _active_children count: {len(self._active_children)}")
|
||||
_original_interrupt(self, message)
|
||||
log.info(f" After interrupt: _interrupt_requested={self._interrupt_requested}")
|
||||
for i, child in enumerate(self._active_children):
|
||||
log.info(f" Child {i}._interrupt_requested={child._interrupt_requested}")
|
||||
# Signal that child is about to start
|
||||
original_init = AIAgent.__init__
|
||||
def patched_init(self_agent, *a, **kw):
|
||||
log.info("🟡 Child AIAgent.__init__ called")
|
||||
original_init(self_agent, *a, **kw)
|
||||
child_running.set()
|
||||
log.info(f"🟡 Child started, parent._active_children = {len(parent._active_children)}")
|
||||
|
||||
parent.interrupt = lambda msg=None: logged_interrupt(parent, msg)
|
||||
with patch.object(AIAgent, "__init__", patched_init):
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Do a slow thing",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model="test/model",
|
||||
max_iterations=3,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
override_provider="test",
|
||||
override_base_url="http://localhost:1",
|
||||
override_api_key="test",
|
||||
override_api_mode="chat_completions",
|
||||
)
|
||||
agent_result[0] = result
|
||||
log.info(f"🟢 agent_thread finished. Result status: {result.get('status')}")
|
||||
|
||||
# ─── Simulate the exact CLI flow ───
|
||||
interrupt_queue = queue.Queue()
|
||||
child_running = threading.Event()
|
||||
agent_result = [None]
|
||||
|
||||
def agent_thread_func():
|
||||
"""Simulates the agent_thread in cli.py's chat() method."""
|
||||
log.info("🟢 agent_thread starting")
|
||||
# ─── Start agent thread (like chat() does) ───
|
||||
agent_thread = threading.Thread(target=agent_thread_func, name="agent_thread", daemon=True)
|
||||
agent_thread.start()
|
||||
|
||||
with patch("run_agent.OpenAI") as MockOpenAI:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = make_slow_response(delay=3.0)
|
||||
mock_client.close = MagicMock()
|
||||
MockOpenAI.return_value = mock_client
|
||||
# ─── Wait for child to start ───
|
||||
if not child_running.wait(timeout=10):
|
||||
print("FAIL: Child never started", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
from tools.delegate_tool import _run_single_child
|
||||
# Give child time to enter its main loop and start API call
|
||||
time.sleep(1.0)
|
||||
|
||||
# Signal that child is about to start
|
||||
original_init = AIAgent.__init__
|
||||
# ─── Simulate user typing a message (like handle_enter does) ───
|
||||
log.info("📝 Simulating user typing 'Hey stop that'")
|
||||
interrupt_queue.put("Hey stop that")
|
||||
|
||||
def patched_init(self_agent, *a, **kw):
|
||||
log.info("🟡 Child AIAgent.__init__ called")
|
||||
original_init(self_agent, *a, **kw)
|
||||
child_running.set()
|
||||
log.info(
|
||||
f"🟡 Child started, parent._active_children = {len(parent._active_children)}"
|
||||
)
|
||||
# ─── Simulate chat() polling loop (like the real chat() method) ───
|
||||
log.info("📡 Starting interrupt queue polling (like chat())")
|
||||
interrupt_msg = None
|
||||
poll_count = 0
|
||||
while agent_thread.is_alive():
|
||||
try:
|
||||
interrupt_msg = interrupt_queue.get(timeout=0.1)
|
||||
if interrupt_msg:
|
||||
log.info(f"📨 Got interrupt message from queue: {interrupt_msg!r}")
|
||||
log.info(f" Calling parent.interrupt()...")
|
||||
parent.interrupt(interrupt_msg)
|
||||
log.info(f" parent.interrupt() returned. Breaking poll loop.")
|
||||
break
|
||||
except queue.Empty:
|
||||
poll_count += 1
|
||||
if poll_count % 20 == 0: # Log every 2s
|
||||
log.info(f" Still polling ({poll_count} iterations)...")
|
||||
|
||||
with patch.object(AIAgent, "__init__", patched_init):
|
||||
result = _run_single_child(
|
||||
task_index=0,
|
||||
goal="Do a slow thing",
|
||||
context=None,
|
||||
toolsets=["terminal"],
|
||||
model="test/model",
|
||||
max_iterations=3,
|
||||
parent_agent=parent,
|
||||
task_count=1,
|
||||
override_provider="test",
|
||||
override_base_url="http://localhost:1",
|
||||
override_api_key="test",
|
||||
override_api_mode="chat_completions",
|
||||
)
|
||||
agent_result[0] = result
|
||||
log.info(f"🟢 agent_thread finished. Result status: {result.get('status')}")
|
||||
# ─── Wait for agent to finish ───
|
||||
log.info("⏳ Waiting for agent_thread to join...")
|
||||
t0 = time.monotonic()
|
||||
agent_thread.join(timeout=10)
|
||||
elapsed = time.monotonic() - t0
|
||||
log.info(f"✅ agent_thread joined after {elapsed:.2f}s")
|
||||
|
||||
# ─── Start agent thread (like chat() does) ───
|
||||
agent_thread = threading.Thread(target=agent_thread_func, name="agent_thread", daemon=True)
|
||||
agent_thread.start()
|
||||
|
||||
# ─── Wait for child to start ───
|
||||
if not child_running.wait(timeout=10):
|
||||
print("FAIL: Child never started", file=sys.stderr)
|
||||
set_interrupt(False)
|
||||
return 1
|
||||
|
||||
# Give child time to enter its main loop and start API call
|
||||
time.sleep(1.0)
|
||||
|
||||
# ─── Simulate user typing a message (like handle_enter does) ───
|
||||
log.info("📝 Simulating user typing 'Hey stop that'")
|
||||
interrupt_queue.put("Hey stop that")
|
||||
|
||||
# ─── Simulate chat() polling loop (like the real chat() method) ───
|
||||
log.info("📡 Starting interrupt queue polling (like chat())")
|
||||
interrupt_msg = None
|
||||
poll_count = 0
|
||||
while agent_thread.is_alive():
|
||||
try:
|
||||
interrupt_msg = interrupt_queue.get(timeout=0.1)
|
||||
if interrupt_msg:
|
||||
log.info(f"📨 Got interrupt message from queue: {interrupt_msg!r}")
|
||||
log.info(" Calling parent.interrupt()...")
|
||||
parent.interrupt(interrupt_msg)
|
||||
log.info(" parent.interrupt() returned. Breaking poll loop.")
|
||||
break
|
||||
except queue.Empty:
|
||||
poll_count += 1
|
||||
if poll_count % 20 == 0: # Log every 2s
|
||||
log.info(f" Still polling ({poll_count} iterations)...")
|
||||
|
||||
# ─── Wait for agent to finish ───
|
||||
log.info("⏳ Waiting for agent_thread to join...")
|
||||
t0 = time.monotonic()
|
||||
agent_thread.join(timeout=10)
|
||||
elapsed = time.monotonic() - t0
|
||||
log.info(f"✅ agent_thread joined after {elapsed:.2f}s")
|
||||
|
||||
# ─── Check results ───
|
||||
result = agent_result[0]
|
||||
if result:
|
||||
log.info(f"Result status: {result['status']}")
|
||||
log.info(f"Result duration: {result['duration_seconds']}s")
|
||||
if result["status"] == "interrupted" and elapsed < 2.0:
|
||||
print("✅ PASS: Interrupt worked correctly!", file=sys.stderr)
|
||||
set_interrupt(False)
|
||||
return 0
|
||||
# ─── Check results ───
|
||||
result = agent_result[0]
|
||||
if result:
|
||||
log.info(f"Result status: {result['status']}")
|
||||
log.info(f"Result duration: {result['duration_seconds']}s")
|
||||
if result["status"] == "interrupted" and elapsed < 2.0:
|
||||
print("✅ PASS: Interrupt worked correctly!", file=sys.stderr)
|
||||
else:
|
||||
print(f"❌ FAIL: status={result['status']}, elapsed={elapsed:.2f}s", file=sys.stderr)
|
||||
set_interrupt(False)
|
||||
return 1
|
||||
|
||||
else:
|
||||
print("❌ FAIL: No result returned", file=sys.stderr)
|
||||
set_interrupt(False)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
set_interrupt(False)
|
||||
|
||||
@@ -59,11 +59,8 @@ def _build_agent(shared_client=None):
|
||||
agent._interrupt_requested = False
|
||||
agent._interrupt_message = None
|
||||
agent._client_lock = threading.RLock()
|
||||
agent._client_kwargs = {"api_key": "***", "base_url": agent.base_url}
|
||||
agent._client_kwargs = {"api_key": "test-key", "base_url": agent.base_url}
|
||||
agent.client = shared_client or FakeSharedClient(lambda **kwargs: {"shared": True})
|
||||
agent.stream_delta_callback = None
|
||||
agent._stream_callback = None
|
||||
agent.reasoning_callback = None
|
||||
return agent
|
||||
|
||||
|
||||
@@ -148,9 +145,8 @@ def test_concurrent_requests_do_not_break_each_other_when_one_client_closes(monk
|
||||
thread_one.join(timeout=5)
|
||||
thread_two.join(timeout=5)
|
||||
|
||||
values = list(results.values())
|
||||
assert sum(isinstance(value, APIConnectionError) for value in values) == 1
|
||||
assert values.count({"ok": "second"}) == 1
|
||||
assert isinstance(results["first"], APIConnectionError)
|
||||
assert results["second"] == {"ok": "second"}
|
||||
assert len(factory.calls) == 2
|
||||
|
||||
|
||||
@@ -176,11 +172,7 @@ def test_streaming_call_recreates_closed_shared_client_before_request(monkeypatc
|
||||
monkeypatch.setattr(run_agent, "OpenAI", factory)
|
||||
|
||||
agent = _build_agent(shared_client=stale_shared)
|
||||
agent.stream_delta_callback = lambda _delta: None
|
||||
# Force chat_completions mode so the streaming path uses
|
||||
# chat.completions.create(stream=True) instead of Codex responses.stream()
|
||||
agent.api_mode = "chat_completions"
|
||||
response = agent._interruptible_streaming_api_call({"model": agent.model, "messages": []})
|
||||
response = agent._streaming_api_call({"model": agent.model, "messages": []}, lambda _delta: None)
|
||||
|
||||
assert response.choices[0].message.content == "Hello world"
|
||||
assert agent.client is replacement_shared
|
||||
|
||||
@@ -1,340 +0,0 @@
|
||||
"""Tests for the Hermes plugin system (hermes_cli.plugins)."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from hermes_cli.plugins import (
|
||||
ENTRY_POINTS_GROUP,
|
||||
VALID_HOOKS,
|
||||
LoadedPlugin,
|
||||
PluginContext,
|
||||
PluginManager,
|
||||
PluginManifest,
|
||||
get_plugin_manager,
|
||||
get_plugin_tool_names,
|
||||
discover_plugins,
|
||||
invoke_hook,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_plugin_dir(base: Path, name: str, *, register_body: str = "pass",
|
||||
manifest_extra: dict | None = None) -> Path:
|
||||
"""Create a minimal plugin directory with plugin.yaml + __init__.py."""
|
||||
plugin_dir = base / name
|
||||
plugin_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
manifest = {"name": name, "version": "0.1.0", "description": f"Test plugin {name}"}
|
||||
if manifest_extra:
|
||||
manifest.update(manifest_extra)
|
||||
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump(manifest))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
f"def register(ctx):\n {register_body}\n"
|
||||
)
|
||||
return plugin_dir
|
||||
|
||||
|
||||
# ── TestPluginDiscovery ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginDiscovery:
|
||||
"""Tests for plugin discovery from directories and entry points."""
|
||||
|
||||
def test_discover_user_plugins(self, tmp_path, monkeypatch):
|
||||
"""Plugins in ~/.hermes/plugins/ are discovered."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "hello_plugin")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "hello_plugin" in mgr._plugins
|
||||
assert mgr._plugins["hello_plugin"].enabled
|
||||
|
||||
def test_discover_project_plugins(self, tmp_path, monkeypatch):
|
||||
"""Plugins in ./.hermes/plugins/ are discovered."""
|
||||
project_dir = tmp_path / "project"
|
||||
project_dir.mkdir()
|
||||
monkeypatch.chdir(project_dir)
|
||||
plugins_dir = project_dir / ".hermes" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "proj_plugin")
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "proj_plugin" in mgr._plugins
|
||||
assert mgr._plugins["proj_plugin"].enabled
|
||||
|
||||
def test_discover_is_idempotent(self, tmp_path, monkeypatch):
|
||||
"""Calling discover_and_load() twice does not duplicate plugins."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "once_plugin")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
mgr.discover_and_load() # second call should no-op
|
||||
|
||||
assert len(mgr._plugins) == 1
|
||||
|
||||
def test_discover_skips_dir_without_manifest(self, tmp_path, monkeypatch):
|
||||
"""Directories without plugin.yaml are silently skipped."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
(plugins_dir / "no_manifest").mkdir(parents=True)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert len(mgr._plugins) == 0
|
||||
|
||||
def test_entry_points_scanned(self, tmp_path, monkeypatch):
|
||||
"""Entry-point based plugins are discovered (mocked)."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
fake_module = types.ModuleType("fake_ep_plugin")
|
||||
fake_module.register = lambda ctx: None # type: ignore[attr-defined]
|
||||
|
||||
fake_ep = MagicMock()
|
||||
fake_ep.name = "ep_plugin"
|
||||
fake_ep.value = "fake_ep_plugin:register"
|
||||
fake_ep.group = ENTRY_POINTS_GROUP
|
||||
fake_ep.load.return_value = fake_module
|
||||
|
||||
def fake_entry_points():
|
||||
result = MagicMock()
|
||||
result.select = MagicMock(return_value=[fake_ep])
|
||||
return result
|
||||
|
||||
with patch("importlib.metadata.entry_points", fake_entry_points):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "ep_plugin" in mgr._plugins
|
||||
|
||||
|
||||
# ── TestPluginLoading ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginLoading:
|
||||
"""Tests for plugin module loading."""
|
||||
|
||||
def test_load_missing_init(self, tmp_path, monkeypatch):
|
||||
"""Plugin dir without __init__.py records an error."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "bad_plugin"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "bad_plugin"}))
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "bad_plugin" in mgr._plugins
|
||||
assert not mgr._plugins["bad_plugin"].enabled
|
||||
assert mgr._plugins["bad_plugin"].error is not None
|
||||
|
||||
def test_load_missing_register_fn(self, tmp_path, monkeypatch):
|
||||
"""Plugin without register() function records an error."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "no_reg"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "no_reg"}))
|
||||
(plugin_dir / "__init__.py").write_text("# no register function\n")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "no_reg" in mgr._plugins
|
||||
assert not mgr._plugins["no_reg"].enabled
|
||||
assert "no register()" in mgr._plugins["no_reg"].error
|
||||
|
||||
def test_load_registers_namespace_module(self, tmp_path, monkeypatch):
|
||||
"""Directory plugins are importable under hermes_plugins.<name>."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "ns_plugin")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
# Clean up any prior namespace module
|
||||
sys.modules.pop("hermes_plugins.ns_plugin", None)
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "hermes_plugins.ns_plugin" in sys.modules
|
||||
|
||||
|
||||
# ── TestPluginHooks ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginHooks:
|
||||
"""Tests for lifecycle hook registration and invocation."""
|
||||
|
||||
def test_register_and_invoke_hook(self, tmp_path, monkeypatch):
|
||||
"""Registered hooks are called on invoke_hook()."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir, "hook_plugin",
|
||||
register_body='ctx.register_hook("pre_tool_call", lambda **kw: None)',
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# Should not raise
|
||||
mgr.invoke_hook("pre_tool_call", tool_name="test", args={}, task_id="t1")
|
||||
|
||||
def test_hook_exception_does_not_propagate(self, tmp_path, monkeypatch):
|
||||
"""A hook callback that raises does NOT crash the caller."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir, "bad_hook",
|
||||
register_body='ctx.register_hook("post_tool_call", lambda **kw: 1/0)',
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
# Should not raise despite 1/0
|
||||
mgr.invoke_hook("post_tool_call", tool_name="x", args={}, result="r", task_id="")
|
||||
|
||||
def test_invalid_hook_name_warns(self, tmp_path, monkeypatch, caplog):
|
||||
"""Registering an unknown hook name logs a warning."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir, "warn_plugin",
|
||||
register_body='ctx.register_hook("on_banana", lambda **kw: None)',
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="hermes_cli.plugins"):
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert any("on_banana" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
# ── TestPluginContext ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginContext:
|
||||
"""Tests for the PluginContext facade."""
|
||||
|
||||
def test_register_tool_adds_to_registry(self, tmp_path, monkeypatch):
|
||||
"""PluginContext.register_tool() puts the tool in the global registry."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "tool_plugin"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "tool_plugin"}))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
'def register(ctx):\n'
|
||||
' ctx.register_tool(\n'
|
||||
' name="plugin_echo",\n'
|
||||
' toolset="plugin_tool_plugin",\n'
|
||||
' schema={"name": "plugin_echo", "description": "Echo", "parameters": {"type": "object", "properties": {}}},\n'
|
||||
' handler=lambda args, **kw: "echo",\n'
|
||||
' )\n'
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "plugin_echo" in mgr._plugin_tool_names
|
||||
|
||||
from tools.registry import registry
|
||||
assert "plugin_echo" in registry._tools
|
||||
|
||||
|
||||
# ── TestPluginToolVisibility ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginToolVisibility:
|
||||
"""Plugin-registered tools appear in get_tool_definitions()."""
|
||||
|
||||
def test_plugin_tools_in_definitions(self, tmp_path, monkeypatch):
|
||||
"""Tools from plugins bypass the toolset filter."""
|
||||
import hermes_cli.plugins as plugins_mod
|
||||
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
plugin_dir = plugins_dir / "vis_plugin"
|
||||
plugin_dir.mkdir(parents=True)
|
||||
(plugin_dir / "plugin.yaml").write_text(yaml.dump({"name": "vis_plugin"}))
|
||||
(plugin_dir / "__init__.py").write_text(
|
||||
'def register(ctx):\n'
|
||||
' ctx.register_tool(\n'
|
||||
' name="vis_tool",\n'
|
||||
' toolset="plugin_vis_plugin",\n'
|
||||
' schema={"name": "vis_tool", "description": "Visible", "parameters": {"type": "object", "properties": {}}},\n'
|
||||
' handler=lambda args, **kw: "ok",\n'
|
||||
' )\n'
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
monkeypatch.setattr(plugins_mod, "_plugin_manager", mgr)
|
||||
|
||||
from model_tools import get_tool_definitions
|
||||
tools = get_tool_definitions(enabled_toolsets=["terminal"], quiet_mode=True)
|
||||
tool_names = [t["function"]["name"] for t in tools]
|
||||
assert "vis_tool" in tool_names
|
||||
|
||||
|
||||
# ── TestPluginManagerList ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPluginManagerList:
|
||||
"""Tests for PluginManager.list_plugins()."""
|
||||
|
||||
def test_list_empty(self):
|
||||
"""Empty manager returns empty list."""
|
||||
mgr = PluginManager()
|
||||
assert mgr.list_plugins() == []
|
||||
|
||||
def test_list_returns_sorted(self, tmp_path, monkeypatch):
|
||||
"""list_plugins() returns results sorted by name."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "zulu")
|
||||
_make_plugin_dir(plugins_dir, "alpha")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
listing = mgr.list_plugins()
|
||||
names = [p["name"] for p in listing]
|
||||
assert names == sorted(names)
|
||||
|
||||
def test_list_with_plugins(self, tmp_path, monkeypatch):
|
||||
"""list_plugins() returns info dicts for each discovered plugin."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(plugins_dir, "alpha")
|
||||
_make_plugin_dir(plugins_dir, "beta")
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
listing = mgr.list_plugins()
|
||||
names = [p["name"] for p in listing]
|
||||
assert "alpha" in names
|
||||
assert "beta" in names
|
||||
for p in listing:
|
||||
assert "enabled" in p
|
||||
assert "tools" in p
|
||||
assert "hooks" in p
|
||||
+16
-84
@@ -612,25 +612,6 @@ class TestBuildApiKwargs:
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["extra_body"]["reasoning"] == {"enabled": False}
|
||||
|
||||
def test_reasoning_not_sent_for_unsupported_openrouter_model(self, agent):
|
||||
agent.model = "minimax/minimax-m2.5"
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert "reasoning" not in kwargs.get("extra_body", {})
|
||||
|
||||
def test_reasoning_sent_for_supported_openrouter_model(self, agent):
|
||||
agent.model = "qwen/qwen3.5-plus-02-15"
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["extra_body"]["reasoning"]["effort"] == "medium"
|
||||
|
||||
def test_reasoning_sent_for_nous_route(self, agent):
|
||||
agent.base_url = "https://inference-api.nousresearch.com/v1"
|
||||
agent.model = "minimax/minimax-m2.5"
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["extra_body"]["reasoning"]["effort"] == "medium"
|
||||
|
||||
def test_max_tokens_injected(self, agent):
|
||||
agent.max_tokens = 4096
|
||||
messages = [{"role": "user", "content": "hi"}]
|
||||
@@ -930,10 +911,8 @@ class TestConcurrentToolExecution:
|
||||
mock_hfc.assert_called_once_with(
|
||||
"web_search", {"q": "test"}, "task-1",
|
||||
enabled_tools=list(agent.valid_tool_names),
|
||||
honcho_manager=None,
|
||||
honcho_session_key=None,
|
||||
)
|
||||
assert result == "result"
|
||||
assert result == "result"
|
||||
|
||||
def test_invoke_tool_handles_agent_level_tools(self, agent):
|
||||
"""_invoke_tool should handle todo tool directly."""
|
||||
@@ -963,19 +942,6 @@ class TestHandleMaxIterations:
|
||||
assert "error" in result.lower()
|
||||
assert "API down" in result
|
||||
|
||||
def test_summary_skips_reasoning_for_unsupported_openrouter_model(self, agent):
|
||||
agent.model = "minimax/minimax-m2.5"
|
||||
resp = _mock_response(content="Summary")
|
||||
agent.client.chat.completions.create.return_value = resp
|
||||
agent._cached_system_prompt = "You are helpful."
|
||||
messages = [{"role": "user", "content": "do stuff"}]
|
||||
|
||||
result = agent._handle_max_iterations(messages, 60)
|
||||
|
||||
assert result == "Summary"
|
||||
kwargs = agent.client.chat.completions.create.call_args.kwargs
|
||||
assert "reasoning" not in kwargs.get("extra_body", {})
|
||||
|
||||
|
||||
class TestRunConversation:
|
||||
"""Tests for the main run_conversation method.
|
||||
@@ -1586,38 +1552,6 @@ class TestSystemPromptStability:
|
||||
should_prefetch = not conversation_history
|
||||
assert should_prefetch is True
|
||||
|
||||
def test_run_conversation_can_skip_honcho_sync_for_synthetic_turns(self, agent):
|
||||
captured = {}
|
||||
|
||||
def _fake_api_call(api_kwargs):
|
||||
captured.update(api_kwargs)
|
||||
return _mock_response(content="done", finish_reason="stop")
|
||||
|
||||
agent._honcho = MagicMock()
|
||||
agent._honcho_session_key = "session-1"
|
||||
agent._honcho_config = SimpleNamespace(
|
||||
ai_peer="hermes",
|
||||
memory_mode="hybrid",
|
||||
write_frequency="async",
|
||||
recall_mode="hybrid",
|
||||
)
|
||||
agent._use_prompt_caching = False
|
||||
|
||||
with (
|
||||
patch.object(agent, "_honcho_sync") as mock_sync,
|
||||
patch.object(agent, "_queue_honcho_prefetch") as mock_prefetch,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
patch.object(agent, "_interruptible_api_call", side_effect=_fake_api_call),
|
||||
):
|
||||
result = agent.run_conversation("synthetic flush turn", sync_honcho=False)
|
||||
|
||||
assert result["completed"] is True
|
||||
assert captured["messages"][-1]["content"] == "synthetic flush turn"
|
||||
mock_sync.assert_not_called()
|
||||
mock_prefetch.assert_not_called()
|
||||
|
||||
|
||||
class TestHonchoActivation:
|
||||
def test_disabled_config_skips_honcho_init(self):
|
||||
@@ -2329,9 +2263,8 @@ class TestStreamingApiCall:
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
callback = MagicMock()
|
||||
agent.stream_delta_callback = callback
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
resp = agent._streaming_api_call({"messages": []}, callback)
|
||||
|
||||
assert resp.choices[0].message.content == "Hello World"
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
@@ -2348,7 +2281,7 @@ class TestStreamingApiCall:
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert len(tc) == 1
|
||||
@@ -2364,7 +2297,7 @@ class TestStreamingApiCall:
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
tc = resp.choices[0].message.tool_calls
|
||||
assert len(tc) == 2
|
||||
@@ -2379,7 +2312,7 @@ class TestStreamingApiCall:
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.choices[0].message.content == "I'll search"
|
||||
assert len(resp.choices[0].message.tool_calls) == 1
|
||||
@@ -2388,7 +2321,7 @@ class TestStreamingApiCall:
|
||||
chunks = [_make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.choices[0].message.content is None
|
||||
assert resp.choices[0].message.tool_calls is None
|
||||
@@ -2400,9 +2333,9 @@ class TestStreamingApiCall:
|
||||
_make_chunk(finish_reason="stop"),
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
agent.stream_delta_callback = MagicMock(side_effect=ValueError("boom"))
|
||||
callback = MagicMock(side_effect=ValueError("boom"))
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
resp = agent._streaming_api_call({"messages": []}, callback)
|
||||
|
||||
assert resp.choices[0].message.content == "Hello World"
|
||||
|
||||
@@ -2413,7 +2346,7 @@ class TestStreamingApiCall:
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.model == "gpt-4o"
|
||||
|
||||
@@ -2421,23 +2354,22 @@ class TestStreamingApiCall:
|
||||
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
agent._interruptible_streaming_api_call({"messages": [], "model": "test"})
|
||||
agent._streaming_api_call({"messages": [], "model": "test"}, MagicMock())
|
||||
|
||||
call_kwargs = agent.client.chat.completions.create.call_args
|
||||
assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True
|
||||
|
||||
def test_api_exception_falls_back_to_non_streaming(self, agent):
|
||||
"""When streaming fails before any deltas, fallback to non-streaming is attempted."""
|
||||
def test_api_exception_propagated(self, agent):
|
||||
agent.client.chat.completions.create.side_effect = ConnectionError("fail")
|
||||
# The fallback also uses the same client, so it'll fail too
|
||||
|
||||
with pytest.raises(ConnectionError, match="fail"):
|
||||
agent._interruptible_streaming_api_call({"messages": []})
|
||||
agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
def test_response_has_uuid_id(self, agent):
|
||||
chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.id.startswith("stream-")
|
||||
assert len(resp.id) > len("stream-")
|
||||
@@ -2451,7 +2383,7 @@ class TestStreamingApiCall:
|
||||
]
|
||||
agent.client.chat.completions.create.return_value = iter(chunks)
|
||||
|
||||
resp = agent._interruptible_streaming_api_call({"messages": []})
|
||||
resp = agent._streaming_api_call({"messages": []}, MagicMock())
|
||||
|
||||
assert resp.choices[0].message.content == "Hello"
|
||||
assert resp.model == "gpt-4"
|
||||
@@ -2507,7 +2439,7 @@ class TestAnthropicInterruptHandler:
|
||||
def test_streaming_has_anthropic_branch(self):
|
||||
"""_streaming_api_call must also handle Anthropic interrupt."""
|
||||
import inspect
|
||||
source = inspect.getsource(AIAgent._interruptible_streaming_api_call)
|
||||
source = inspect.getsource(AIAgent._streaming_api_call)
|
||||
assert "anthropic_messages" in source, \
|
||||
"_streaming_api_call must handle Anthropic interrupt"
|
||||
|
||||
|
||||
@@ -1,571 +0,0 @@
|
||||
"""Tests for streaming token delivery infrastructure.
|
||||
|
||||
Tests the unified streaming API call, delta callbacks, tool-call
|
||||
suppression, provider fallback, and CLI streaming display.
|
||||
"""
|
||||
import json
|
||||
import threading
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_stream_chunk(
|
||||
content=None, tool_calls=None, finish_reason=None,
|
||||
model=None, reasoning_content=None, usage=None,
|
||||
):
|
||||
"""Build a mock streaming chunk matching OpenAI's ChatCompletionChunk shape."""
|
||||
delta = SimpleNamespace(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
reasoning_content=reasoning_content,
|
||||
reasoning=None,
|
||||
)
|
||||
choice = SimpleNamespace(
|
||||
index=0,
|
||||
delta=delta,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
chunk = SimpleNamespace(
|
||||
choices=[choice],
|
||||
model=model,
|
||||
usage=usage,
|
||||
)
|
||||
return chunk
|
||||
|
||||
|
||||
def _make_tool_call_delta(index=0, tc_id=None, name=None, arguments=None):
|
||||
"""Build a mock tool call delta."""
|
||||
func = SimpleNamespace(name=name, arguments=arguments)
|
||||
return SimpleNamespace(index=index, id=tc_id, function=func)
|
||||
|
||||
|
||||
def _make_empty_chunk(model=None, usage=None):
|
||||
"""Build a chunk with no choices (usage-only final chunk)."""
|
||||
return SimpleNamespace(choices=[], model=model, usage=usage)
|
||||
|
||||
|
||||
# ── Test: Streaming Accumulator ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStreamingAccumulator:
|
||||
"""Verify that _interruptible_streaming_api_call accumulates content
|
||||
and tool calls into a response matching the non-streaming shape."""
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_text_only_response(self, mock_close, mock_create):
|
||||
"""Text-only stream produces correct response shape."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="Hello"),
|
||||
_make_stream_chunk(content=" world"),
|
||||
_make_stream_chunk(content="!", finish_reason="stop", model="test-model"),
|
||||
_make_empty_chunk(usage=SimpleNamespace(prompt_tokens=10, completion_tokens=3)),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "Hello world!"
|
||||
assert response.choices[0].message.tool_calls is None
|
||||
assert response.choices[0].finish_reason == "stop"
|
||||
assert response.usage is not None
|
||||
assert response.usage.completion_tokens == 3
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_tool_call_response(self, mock_close, mock_create):
|
||||
"""Tool call stream accumulates ID, name, and arguments."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_123", name="terminal")
|
||||
]),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, arguments='{"command":')
|
||||
]),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, arguments=' "ls"}')
|
||||
]),
|
||||
_make_stream_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
tc = response.choices[0].message.tool_calls
|
||||
assert tc is not None
|
||||
assert len(tc) == 1
|
||||
assert tc[0].id == "call_123"
|
||||
assert tc[0].function.name == "terminal"
|
||||
assert tc[0].function.arguments == '{"command": "ls"}'
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_mixed_content_and_tool_calls(self, mock_close, mock_create):
|
||||
"""Stream with both text and tool calls accumulates both."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="Let me check"),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_456", name="web_search")
|
||||
]),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, arguments='{"query": "test"}')
|
||||
]),
|
||||
_make_stream_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "Let me check"
|
||||
assert len(response.choices[0].message.tool_calls) == 1
|
||||
|
||||
|
||||
# ── Test: Streaming Callbacks ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStreamingCallbacks:
|
||||
"""Verify that delta callbacks fire correctly."""
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_deltas_fire_in_order(self, mock_close, mock_create):
|
||||
"""Callbacks receive text deltas in order."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="a"),
|
||||
_make_stream_chunk(content="b"),
|
||||
_make_stream_chunk(content="c"),
|
||||
_make_stream_chunk(finish_reason="stop"),
|
||||
]
|
||||
|
||||
deltas = []
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: deltas.append(t),
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert deltas == ["a", "b", "c"]
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_on_first_delta_fires_once(self, mock_close, mock_create):
|
||||
"""on_first_delta callback fires exactly once."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="a"),
|
||||
_make_stream_chunk(content="b"),
|
||||
_make_stream_chunk(finish_reason="stop"),
|
||||
]
|
||||
|
||||
first_delta_calls = []
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
agent._interruptible_streaming_api_call(
|
||||
{}, on_first_delta=lambda: first_delta_calls.append(True)
|
||||
)
|
||||
|
||||
assert len(first_delta_calls) == 1
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_tool_only_does_not_fire_callback(self, mock_close, mock_create):
|
||||
"""Tool-call-only stream does not fire the delta callback."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_789", name="terminal")
|
||||
]),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, arguments='{"command": "ls"}')
|
||||
]),
|
||||
_make_stream_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
|
||||
deltas = []
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: deltas.append(t),
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert deltas == []
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_text_suppressed_when_tool_calls_present(self, mock_close, mock_create):
|
||||
"""Text deltas are suppressed when tool calls are also in the stream."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(content="thinking..."),
|
||||
_make_stream_chunk(tool_calls=[
|
||||
_make_tool_call_delta(index=0, tc_id="call_abc", name="read_file")
|
||||
]),
|
||||
_make_stream_chunk(content=" more text"),
|
||||
_make_stream_chunk(finish_reason="tool_calls"),
|
||||
]
|
||||
|
||||
deltas = []
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: deltas.append(t),
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
# Text before tool call IS fired (we don't know yet it will have tools)
|
||||
assert "thinking..." in deltas
|
||||
# Text after tool call is NOT fired
|
||||
assert " more text" not in deltas
|
||||
# But content is still accumulated in the response
|
||||
assert response.choices[0].message.content == "thinking... more text"
|
||||
|
||||
|
||||
# ── Test: Streaming Fallback ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStreamingFallback:
|
||||
"""Verify fallback to non-streaming on ANY streaming error."""
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream):
|
||||
"""'not supported' error triggers fallback to non-streaming."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = Exception(
|
||||
"Streaming is not supported for this model"
|
||||
)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback response",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback response"
|
||||
mock_non_stream.assert_called_once()
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_any_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream):
|
||||
"""ANY streaming error triggers fallback — not just specific messages."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = Exception(
|
||||
"Connection reset by peer"
|
||||
)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback after connection error",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback after connection error"
|
||||
mock_non_stream.assert_called_once()
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_fallback_error_propagates(self, mock_close, mock_create, mock_non_stream):
|
||||
"""When both streaming AND fallback fail, the fallback error propagates."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = Exception("stream broke")
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
mock_non_stream.side_effect = Exception("Rate limit exceeded")
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
with pytest.raises(Exception, match="Rate limit exceeded"):
|
||||
agent._interruptible_streaming_api_call({})
|
||||
|
||||
|
||||
# ── Test: Reasoning Streaming ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestReasoningStreaming:
|
||||
"""Verify reasoning content is accumulated and callback fires."""
|
||||
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_reasoning_callback_fires(self, mock_close, mock_create):
|
||||
"""Reasoning deltas fire the reasoning_callback."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
chunks = [
|
||||
_make_stream_chunk(reasoning_content="Let me think"),
|
||||
_make_stream_chunk(reasoning_content=" about this"),
|
||||
_make_stream_chunk(content="The answer is 42"),
|
||||
_make_stream_chunk(finish_reason="stop"),
|
||||
]
|
||||
|
||||
reasoning_deltas = []
|
||||
text_deltas = []
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.return_value = iter(chunks)
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: text_deltas.append(t),
|
||||
reasoning_callback=lambda t: reasoning_deltas.append(t),
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert reasoning_deltas == ["Let me think", " about this"]
|
||||
assert text_deltas == ["The answer is 42"]
|
||||
assert response.choices[0].message.reasoning_content == "Let me think about this"
|
||||
assert response.choices[0].message.content == "The answer is 42"
|
||||
|
||||
|
||||
# ── Test: _has_stream_consumers ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestHasStreamConsumers:
|
||||
"""Verify _has_stream_consumers() detects registered callbacks."""
|
||||
|
||||
def test_no_consumers(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
assert agent._has_stream_consumers() is False
|
||||
|
||||
def test_delta_callback_set(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: None,
|
||||
)
|
||||
assert agent._has_stream_consumers() is True
|
||||
|
||||
def test_stream_callback_set(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent._stream_callback = lambda t: None
|
||||
assert agent._has_stream_consumers() is True
|
||||
|
||||
|
||||
# ── Test: Codex stream fires callbacks ────────────────────────────────
|
||||
|
||||
|
||||
class TestCodexStreamCallbacks:
|
||||
"""Verify _run_codex_stream fires delta callbacks."""
|
||||
|
||||
def test_codex_text_delta_fires_callback(self):
|
||||
from run_agent import AIAgent
|
||||
|
||||
deltas = []
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
stream_delta_callback=lambda t: deltas.append(t),
|
||||
)
|
||||
agent.api_mode = "codex_responses"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
# Mock the stream context manager
|
||||
mock_event_text = SimpleNamespace(
|
||||
type="response.output_text.delta",
|
||||
delta="Hello from Codex!",
|
||||
)
|
||||
mock_event_done = SimpleNamespace(
|
||||
type="response.completed",
|
||||
delta="",
|
||||
)
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.__enter__ = MagicMock(return_value=mock_stream)
|
||||
mock_stream.__exit__ = MagicMock(return_value=False)
|
||||
mock_stream.__iter__ = MagicMock(return_value=iter([mock_event_text, mock_event_done]))
|
||||
mock_stream.get_final_response.return_value = SimpleNamespace(
|
||||
output=[SimpleNamespace(
|
||||
type="message",
|
||||
content=[SimpleNamespace(type="output_text", text="Hello from Codex!")],
|
||||
)],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.responses.stream.return_value = mock_stream
|
||||
|
||||
response = agent._run_codex_stream({}, client=mock_client)
|
||||
assert "Hello from Codex!" in deltas
|
||||
@@ -1,10 +1,8 @@
|
||||
"""Tests for tools/checkpoint_manager.py — CheckpointManager."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
@@ -145,12 +143,6 @@ class TestTakeCheckpoint:
|
||||
result = mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
assert result is True
|
||||
|
||||
def test_successful_checkpoint_does_not_log_expected_diff_exit(self, mgr, work_dir, caplog):
|
||||
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
|
||||
result = mgr.ensure_checkpoint(str(work_dir), "initial")
|
||||
assert result is True
|
||||
assert not any("diff --cached --quiet" in r.getMessage() for r in caplog.records)
|
||||
|
||||
def test_dedup_same_turn(self, mgr, work_dir):
|
||||
r1 = mgr.ensure_checkpoint(str(work_dir), "first")
|
||||
r2 = mgr.ensure_checkpoint(str(work_dir), "second")
|
||||
@@ -383,26 +375,6 @@ class TestErrorResilience:
|
||||
result = mgr.ensure_checkpoint(str(work_dir), "test")
|
||||
assert result is False
|
||||
|
||||
def test_run_git_allows_expected_nonzero_without_error_log(self, tmp_path, caplog):
|
||||
completed = subprocess.CompletedProcess(
|
||||
args=["git", "diff", "--cached", "--quiet"],
|
||||
returncode=1,
|
||||
stdout="",
|
||||
stderr="",
|
||||
)
|
||||
with patch("tools.checkpoint_manager.subprocess.run", return_value=completed):
|
||||
with caplog.at_level(logging.ERROR, logger="tools.checkpoint_manager"):
|
||||
ok, stdout, stderr = _run_git(
|
||||
["diff", "--cached", "--quiet"],
|
||||
tmp_path / "shadow",
|
||||
str(tmp_path / "work"),
|
||||
allowed_returncodes={1},
|
||||
)
|
||||
assert ok is False
|
||||
assert stdout == ""
|
||||
assert stderr == ""
|
||||
assert not caplog.records
|
||||
|
||||
def test_checkpoint_failure_does_not_raise(self, mgr, work_dir, monkeypatch):
|
||||
"""Checkpoint failures should never raise — they're silently logged."""
|
||||
def broken_run_git(*args, **kwargs):
|
||||
|
||||
@@ -1,31 +1,12 @@
|
||||
import logging
|
||||
from io import StringIO
|
||||
import subprocess
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments import docker as docker_env
|
||||
|
||||
|
||||
def _install_fake_minisweagent(monkeypatch, captured_run_args):
|
||||
class MockInnerDocker:
|
||||
container_id = "fake-container"
|
||||
config = type("Config", (), {"executable": "/usr/bin/docker", "forward_env": [], "env": {}})()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
captured_run_args.extend(kwargs.get("run_args", []))
|
||||
|
||||
minisweagent_mod = types.ModuleType("minisweagent")
|
||||
environments_mod = types.ModuleType("minisweagent.environments")
|
||||
docker_mod = types.ModuleType("minisweagent.environments.docker")
|
||||
docker_mod.DockerEnvironment = MockInnerDocker
|
||||
|
||||
monkeypatch.setitem(sys.modules, "minisweagent", minisweagent_mod)
|
||||
monkeypatch.setitem(sys.modules, "minisweagent.environments", environments_mod)
|
||||
monkeypatch.setitem(sys.modules, "minisweagent.environments.docker", docker_mod)
|
||||
|
||||
|
||||
def _make_dummy_env(**kwargs):
|
||||
"""Helper to construct DockerEnvironment with minimal required args."""
|
||||
return docker_env.DockerEnvironment(
|
||||
@@ -39,8 +20,6 @@ def _make_dummy_env(**kwargs):
|
||||
task_id=kwargs.get("task_id", "test-task"),
|
||||
volumes=kwargs.get("volumes", []),
|
||||
network=kwargs.get("network", True),
|
||||
host_cwd=kwargs.get("host_cwd"),
|
||||
auto_mount_cwd=kwargs.get("auto_mount_cwd", False),
|
||||
)
|
||||
|
||||
|
||||
@@ -109,105 +88,63 @@ def test_ensure_docker_available_uses_resolved_executable(monkeypatch):
|
||||
]
|
||||
|
||||
|
||||
def test_auto_mount_host_cwd_adds_volume(monkeypatch, tmp_path):
|
||||
"""Opt-in docker cwd mounting should bind the host cwd to /workspace."""
|
||||
project_dir = tmp_path / "my-project"
|
||||
project_dir.mkdir()
|
||||
class _FakePopen:
|
||||
def __init__(self, cmd, **kwargs):
|
||||
self.cmd = cmd
|
||||
self.kwargs = kwargs
|
||||
self.stdout = StringIO("")
|
||||
self.stdin = None
|
||||
self.returncode = 0
|
||||
|
||||
def _run_docker_version(*args, **kwargs):
|
||||
return subprocess.CompletedProcess(args[0], 0, stdout="Docker version", stderr="")
|
||||
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
monkeypatch.setattr(docker_env.subprocess, "run", _run_docker_version)
|
||||
|
||||
captured_run_args = []
|
||||
_install_fake_minisweagent(monkeypatch, captured_run_args)
|
||||
|
||||
_make_dummy_env(
|
||||
cwd="/workspace",
|
||||
host_cwd=str(project_dir),
|
||||
auto_mount_cwd=True,
|
||||
)
|
||||
|
||||
run_args_str = " ".join(captured_run_args)
|
||||
assert f"{project_dir}:/workspace" in run_args_str
|
||||
def poll(self):
|
||||
return self.returncode
|
||||
|
||||
|
||||
def test_auto_mount_disabled_by_default(monkeypatch, tmp_path):
|
||||
"""Host cwd should not be mounted unless the caller explicitly opts in."""
|
||||
project_dir = tmp_path / "my-project"
|
||||
project_dir.mkdir()
|
||||
|
||||
def _run_docker_version(*args, **kwargs):
|
||||
return subprocess.CompletedProcess(args[0], 0, stdout="Docker version", stderr="")
|
||||
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
monkeypatch.setattr(docker_env.subprocess, "run", _run_docker_version)
|
||||
|
||||
captured_run_args = []
|
||||
_install_fake_minisweagent(monkeypatch, captured_run_args)
|
||||
|
||||
_make_dummy_env(
|
||||
cwd="/root",
|
||||
host_cwd=str(project_dir),
|
||||
auto_mount_cwd=False,
|
||||
)
|
||||
|
||||
run_args_str = " ".join(captured_run_args)
|
||||
assert f"{project_dir}:/workspace" not in run_args_str
|
||||
def _make_execute_only_env(forward_env=None):
|
||||
env = docker_env.DockerEnvironment.__new__(docker_env.DockerEnvironment)
|
||||
env.cwd = "/root"
|
||||
env.timeout = 60
|
||||
env._forward_env = forward_env or []
|
||||
env._prepare_command = lambda command: (command, None)
|
||||
env._timeout_result = lambda timeout: {"output": f"timed out after {timeout}", "returncode": 124}
|
||||
env._inner = type("Inner", (), {
|
||||
"container_id": "test-container",
|
||||
"config": type("Cfg", (), {"executable": "/usr/bin/docker", "env": {}})(),
|
||||
})()
|
||||
return env
|
||||
|
||||
|
||||
def test_auto_mount_skipped_when_workspace_already_mounted(monkeypatch, tmp_path):
|
||||
"""Explicit user volumes for /workspace should take precedence over cwd mount."""
|
||||
project_dir = tmp_path / "my-project"
|
||||
project_dir.mkdir()
|
||||
other_dir = tmp_path / "other"
|
||||
other_dir.mkdir()
|
||||
def test_execute_uses_hermes_dotenv_for_allowlisted_env(monkeypatch):
|
||||
env = _make_execute_only_env(["GITHUB_TOKEN"])
|
||||
popen_calls = []
|
||||
|
||||
def _run_docker_version(*args, **kwargs):
|
||||
return subprocess.CompletedProcess(args[0], 0, stdout="Docker version", stderr="")
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
monkeypatch.setattr(docker_env.subprocess, "run", _run_docker_version)
|
||||
monkeypatch.delenv("GITHUB_TOKEN", raising=False)
|
||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"})
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
captured_run_args = []
|
||||
_install_fake_minisweagent(monkeypatch, captured_run_args)
|
||||
result = env.execute("echo hi")
|
||||
|
||||
_make_dummy_env(
|
||||
cwd="/workspace",
|
||||
host_cwd=str(project_dir),
|
||||
auto_mount_cwd=True,
|
||||
volumes=[f"{other_dir}:/workspace"],
|
||||
)
|
||||
|
||||
run_args_str = " ".join(captured_run_args)
|
||||
assert f"{other_dir}:/workspace" in run_args_str
|
||||
assert run_args_str.count(":/workspace") == 1
|
||||
assert result["returncode"] == 0
|
||||
assert "GITHUB_TOKEN=value_from_dotenv" in popen_calls[0]
|
||||
|
||||
|
||||
def test_auto_mount_replaces_persistent_workspace_bind(monkeypatch, tmp_path):
|
||||
"""Persistent mode should still prefer the configured host cwd at /workspace."""
|
||||
project_dir = tmp_path / "my-project"
|
||||
project_dir.mkdir()
|
||||
def test_execute_prefers_shell_env_over_hermes_dotenv(monkeypatch):
|
||||
env = _make_execute_only_env(["GITHUB_TOKEN"])
|
||||
popen_calls = []
|
||||
|
||||
def _run_docker_version(*args, **kwargs):
|
||||
return subprocess.CompletedProcess(args[0], 0, stdout="Docker version", stderr="")
|
||||
def _fake_popen(cmd, **kwargs):
|
||||
popen_calls.append(cmd)
|
||||
return _FakePopen(cmd, **kwargs)
|
||||
|
||||
monkeypatch.setattr(docker_env, "find_docker", lambda: "/usr/bin/docker")
|
||||
monkeypatch.setattr(docker_env.subprocess, "run", _run_docker_version)
|
||||
monkeypatch.setenv("GITHUB_TOKEN", "value_from_shell")
|
||||
monkeypatch.setattr(docker_env, "_load_hermes_env_vars", lambda: {"GITHUB_TOKEN": "value_from_dotenv"})
|
||||
monkeypatch.setattr(docker_env.subprocess, "Popen", _fake_popen)
|
||||
|
||||
captured_run_args = []
|
||||
_install_fake_minisweagent(monkeypatch, captured_run_args)
|
||||
|
||||
_make_dummy_env(
|
||||
cwd="/workspace",
|
||||
persistent_filesystem=True,
|
||||
host_cwd=str(project_dir),
|
||||
auto_mount_cwd=True,
|
||||
task_id="test-persistent-auto-mount",
|
||||
)
|
||||
|
||||
run_args_str = " ".join(captured_run_args)
|
||||
assert f"{project_dir}:/workspace" in run_args_str
|
||||
assert "/sandboxes/docker/test-persistent-auto-mount/workspace:/workspace" not in run_args_str
|
||||
env.execute("echo hi")
|
||||
|
||||
assert "GITHUB_TOKEN=value_from_shell" in popen_calls[0]
|
||||
assert "GITHUB_TOKEN=value_from_dotenv" not in popen_calls[0]
|
||||
|
||||
@@ -5,7 +5,6 @@ handling without requiring a running terminal environment.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tools.file_tools import (
|
||||
@@ -88,26 +87,13 @@ class TestWriteFileHandler:
|
||||
mock_ops.write_file.assert_called_once_with("/tmp/out.txt", "hello world!\n")
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_permission_error_returns_error_json_without_error_log(self, mock_get, caplog):
|
||||
def test_exception_returns_error_json(self, mock_get):
|
||||
mock_get.side_effect = PermissionError("read-only filesystem")
|
||||
|
||||
from tools.file_tools import write_file_tool
|
||||
with caplog.at_level(logging.DEBUG, logger="tools.file_tools"):
|
||||
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
|
||||
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
|
||||
assert "error" in result
|
||||
assert "read-only" in result["error"]
|
||||
assert any("write_file expected denial" in r.getMessage() for r in caplog.records)
|
||||
assert not any(r.levelno >= logging.ERROR for r in caplog.records)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_unexpected_exception_still_logs_error(self, mock_get, caplog):
|
||||
mock_get.side_effect = RuntimeError("boom")
|
||||
|
||||
from tools.file_tools import write_file_tool
|
||||
with caplog.at_level(logging.ERROR, logger="tools.file_tools"):
|
||||
result = json.loads(write_file_tool("/tmp/out.txt", "data"))
|
||||
assert result["error"] == "boom"
|
||||
assert any("write_file error" in r.getMessage() for r in caplog.records)
|
||||
|
||||
|
||||
class TestPatchHandler:
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Regression tests for per-call Honcho tool session routing."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from tools import honcho_tools
|
||||
|
||||
|
||||
class TestHonchoToolSessionContext:
|
||||
def setup_method(self):
|
||||
self.orig_manager = honcho_tools._session_manager
|
||||
self.orig_key = honcho_tools._session_key
|
||||
|
||||
def teardown_method(self):
|
||||
honcho_tools._session_manager = self.orig_manager
|
||||
honcho_tools._session_key = self.orig_key
|
||||
|
||||
def test_explicit_call_context_wins_over_module_global_state(self):
|
||||
global_manager = MagicMock()
|
||||
global_manager.get_peer_card.return_value = ["global"]
|
||||
explicit_manager = MagicMock()
|
||||
explicit_manager.get_peer_card.return_value = ["explicit"]
|
||||
|
||||
honcho_tools.set_session_context(global_manager, "global-session")
|
||||
|
||||
result = json.loads(
|
||||
honcho_tools._handle_honcho_profile(
|
||||
{},
|
||||
honcho_manager=explicit_manager,
|
||||
honcho_session_key="explicit-session",
|
||||
)
|
||||
)
|
||||
|
||||
assert result == {"result": ["explicit"]}
|
||||
explicit_manager.get_peer_card.assert_called_once_with("explicit-session")
|
||||
global_manager.get_peer_card.assert_not_called()
|
||||
@@ -26,7 +26,8 @@ def _make_fake_popen(captured: dict):
|
||||
proc = MagicMock()
|
||||
proc.poll.return_value = 0
|
||||
proc.returncode = 0
|
||||
proc.stdout = MagicMock(__iter__=lambda s: iter([]), __next__=lambda s: (_ for _ in ()).throw(StopIteration))
|
||||
proc.stdout = iter([])
|
||||
proc.stdout.close = lambda: None
|
||||
proc.stdin = MagicMock()
|
||||
return proc
|
||||
return fake_popen
|
||||
@@ -85,7 +86,6 @@ class TestProviderEnvBlocklist:
|
||||
"KIMI_API_KEY": "kimi-key",
|
||||
"MINIMAX_API_KEY": "mm-key",
|
||||
"MINIMAX_CN_API_KEY": "mmcn-key",
|
||||
"DEEPSEEK_API_KEY": "deepseek-key",
|
||||
}
|
||||
result_env = _run_with_env(extra_os_env=registry_vars)
|
||||
|
||||
@@ -96,6 +96,7 @@ class TestProviderEnvBlocklist:
|
||||
"""Extra provider vars not in PROVIDER_REGISTRY must also be blocked."""
|
||||
extra_provider_vars = {
|
||||
"GOOGLE_API_KEY": "google-key",
|
||||
"DEEPSEEK_API_KEY": "deepseek-key",
|
||||
"MISTRAL_API_KEY": "mistral-key",
|
||||
"GROQ_API_KEY": "groq-key",
|
||||
"TOGETHER_API_KEY": "together-key",
|
||||
|
||||
@@ -1,152 +0,0 @@
|
||||
"""Tests for the local persistent shell backend."""
|
||||
|
||||
import glob as glob_mod
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.local import LocalEnvironment
|
||||
from tools.environments.persistent_shell import PersistentShellMixin
|
||||
|
||||
|
||||
class TestLocalConfig:
|
||||
def test_local_persistent_default_false(self, monkeypatch):
|
||||
monkeypatch.delenv("TERMINAL_LOCAL_PERSISTENT", raising=False)
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is False
|
||||
|
||||
def test_local_persistent_true(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "true")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is True
|
||||
|
||||
def test_local_persistent_yes(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_LOCAL_PERSISTENT", "yes")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["local_persistent"] is True
|
||||
|
||||
|
||||
class TestMergeOutput:
|
||||
def test_stdout_only(self):
|
||||
assert PersistentShellMixin._merge_output("out", "") == "out"
|
||||
|
||||
def test_stderr_only(self):
|
||||
assert PersistentShellMixin._merge_output("", "err") == "err"
|
||||
|
||||
def test_both(self):
|
||||
assert PersistentShellMixin._merge_output("out", "err") == "out\nerr"
|
||||
|
||||
def test_empty(self):
|
||||
assert PersistentShellMixin._merge_output("", "") == ""
|
||||
|
||||
def test_strips_trailing_newlines(self):
|
||||
assert PersistentShellMixin._merge_output("out\n\n", "err\n") == "out\nerr"
|
||||
|
||||
|
||||
class TestLocalOneShotRegression:
|
||||
def test_echo(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
r = env.execute("echo hello")
|
||||
assert r["returncode"] == 0
|
||||
assert "hello" in r["output"]
|
||||
env.cleanup()
|
||||
|
||||
def test_exit_code(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
r = env.execute("exit 42")
|
||||
assert r["returncode"] == 42
|
||||
env.cleanup()
|
||||
|
||||
def test_state_does_not_persist(self):
|
||||
env = LocalEnvironment(persistent=False)
|
||||
env.execute("export HERMES_ONESHOT_LOCAL=yes")
|
||||
r = env.execute("echo $HERMES_ONESHOT_LOCAL")
|
||||
assert r["output"].strip() == ""
|
||||
env.cleanup()
|
||||
|
||||
|
||||
class TestLocalPersistent:
|
||||
@pytest.fixture
|
||||
def env(self):
|
||||
e = LocalEnvironment(persistent=True)
|
||||
yield e
|
||||
e.cleanup()
|
||||
|
||||
def test_echo(self, env):
|
||||
r = env.execute("echo hello-persistent")
|
||||
assert r["returncode"] == 0
|
||||
assert "hello-persistent" in r["output"]
|
||||
|
||||
def test_env_var_persists(self, env):
|
||||
env.execute("export HERMES_LOCAL_PERSIST_TEST=works")
|
||||
r = env.execute("echo $HERMES_LOCAL_PERSIST_TEST")
|
||||
assert r["output"].strip() == "works"
|
||||
|
||||
def test_cwd_persists(self, env):
|
||||
env.execute("cd /tmp")
|
||||
r = env.execute("pwd")
|
||||
assert r["output"].strip() == "/tmp"
|
||||
|
||||
def test_exit_code(self, env):
|
||||
r = env.execute("(exit 42)")
|
||||
assert r["returncode"] == 42
|
||||
|
||||
def test_stderr(self, env):
|
||||
r = env.execute("echo oops >&2")
|
||||
assert r["returncode"] == 0
|
||||
assert "oops" in r["output"]
|
||||
|
||||
def test_multiline_output(self, env):
|
||||
r = env.execute("echo a; echo b; echo c")
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert lines == ["a", "b", "c"]
|
||||
|
||||
def test_timeout_then_recovery(self, env):
|
||||
r = env.execute("sleep 999", timeout=2)
|
||||
assert r["returncode"] in (124, 130)
|
||||
r = env.execute("echo alive")
|
||||
assert r["returncode"] == 0
|
||||
assert "alive" in r["output"]
|
||||
|
||||
def test_large_output(self, env):
|
||||
r = env.execute("seq 1 1000")
|
||||
assert r["returncode"] == 0
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert len(lines) == 1000
|
||||
assert lines[0] == "1"
|
||||
assert lines[-1] == "1000"
|
||||
|
||||
def test_shell_variable_persists(self, env):
|
||||
env.execute("MY_LOCAL_VAR=hello123")
|
||||
r = env.execute("echo $MY_LOCAL_VAR")
|
||||
assert r["output"].strip() == "hello123"
|
||||
|
||||
def test_cleanup_removes_temp_files(self, env):
|
||||
env.execute("echo warmup")
|
||||
prefix = env._temp_prefix
|
||||
assert len(glob_mod.glob(f"{prefix}-*")) > 0
|
||||
env.cleanup()
|
||||
remaining = glob_mod.glob(f"{prefix}-*")
|
||||
assert remaining == []
|
||||
|
||||
def test_state_does_not_leak_between_instances(self):
|
||||
env1 = LocalEnvironment(persistent=True)
|
||||
env2 = LocalEnvironment(persistent=True)
|
||||
try:
|
||||
env1.execute("export LEAK_TEST=from_env1")
|
||||
r = env2.execute("echo $LEAK_TEST")
|
||||
assert r["output"].strip() == ""
|
||||
finally:
|
||||
env1.cleanup()
|
||||
env2.cleanup()
|
||||
|
||||
def test_special_characters_in_command(self, env):
|
||||
r = env.execute("echo 'hello world'")
|
||||
assert r["output"].strip() == "hello world"
|
||||
|
||||
def test_pipe_command(self, env):
|
||||
r = env.execute("echo hello | tr 'h' 'H'")
|
||||
assert r["output"].strip() == "Hello"
|
||||
|
||||
def test_multiple_commands_semicolon(self, env):
|
||||
r = env.execute("X=42; echo $X")
|
||||
assert r["output"].strip() == "42"
|
||||
@@ -91,8 +91,8 @@ class TestCwdHandling:
|
||||
"/home/ paths should be replaced for modal backend."
|
||||
)
|
||||
|
||||
def test_users_path_replaced_for_docker_by_default(self):
|
||||
"""Docker should keep host paths out of the sandbox unless explicitly enabled."""
|
||||
def test_users_path_replaced_for_docker(self):
|
||||
"""TERMINAL_CWD=/Users/... should be replaced with /root for docker."""
|
||||
with patch.dict(os.environ, {
|
||||
"TERMINAL_ENV": "docker",
|
||||
"TERMINAL_CWD": "/Users/someone/projects",
|
||||
@@ -100,22 +100,8 @@ class TestCwdHandling:
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["cwd"] == "/root", (
|
||||
f"Expected /root, got {config['cwd']}. "
|
||||
"Host paths should be discarded for docker backend by default."
|
||||
"/Users/ paths should be replaced for docker backend."
|
||||
)
|
||||
assert config["host_cwd"] is None
|
||||
assert config["docker_mount_cwd_to_workspace"] is False
|
||||
|
||||
def test_users_path_maps_to_workspace_for_docker_when_enabled(self):
|
||||
"""Docker should map the host cwd into /workspace only when explicitly enabled."""
|
||||
with patch.dict(os.environ, {
|
||||
"TERMINAL_ENV": "docker",
|
||||
"TERMINAL_CWD": "/Users/someone/projects",
|
||||
"TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE": "true",
|
||||
}):
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["cwd"] == "/workspace"
|
||||
assert config["host_cwd"] == "/Users/someone/projects"
|
||||
assert config["docker_mount_cwd_to_workspace"] is True
|
||||
|
||||
def test_windows_path_replaced_for_modal(self):
|
||||
"""TERMINAL_CWD=C:\\Users\\... should be replaced for modal."""
|
||||
@@ -133,27 +119,12 @@ class TestCwdHandling:
|
||||
# Remove TERMINAL_CWD so it uses default
|
||||
env = os.environ.copy()
|
||||
env.pop("TERMINAL_CWD", None)
|
||||
env.pop("TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["cwd"] == "/root", (
|
||||
f"Backend {backend}: expected /root default, got {config['cwd']}"
|
||||
)
|
||||
|
||||
def test_docker_default_cwd_maps_current_directory_when_enabled(self):
|
||||
"""Docker should use /workspace when cwd mounting is explicitly enabled."""
|
||||
with patch("tools.terminal_tool.os.getcwd", return_value="/home/user/project"):
|
||||
with patch.dict(os.environ, {
|
||||
"TERMINAL_ENV": "docker",
|
||||
"TERMINAL_DOCKER_MOUNT_CWD_TO_WORKSPACE": "true",
|
||||
}, clear=False):
|
||||
env = os.environ.copy()
|
||||
env.pop("TERMINAL_CWD", None)
|
||||
with patch.dict(os.environ, env, clear=True):
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["cwd"] == "/workspace"
|
||||
assert config["host_cwd"] == "/home/user/project"
|
||||
|
||||
def test_local_backend_uses_getcwd(self):
|
||||
"""Local backend should use os.getcwd(), not /root."""
|
||||
with patch.dict(os.environ, {"TERMINAL_ENV": "local"}, clear=False):
|
||||
@@ -163,31 +134,6 @@ class TestCwdHandling:
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["cwd"] == os.getcwd()
|
||||
|
||||
def test_create_environment_passes_docker_host_cwd_and_flag(self, monkeypatch):
|
||||
"""Docker host cwd and mount flag should reach DockerEnvironment."""
|
||||
captured = {}
|
||||
sentinel = object()
|
||||
|
||||
def _fake_docker_environment(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return sentinel
|
||||
|
||||
monkeypatch.setattr(_tt_mod, "_DockerEnvironment", _fake_docker_environment)
|
||||
|
||||
env = _tt_mod._create_environment(
|
||||
env_type="docker",
|
||||
image="python:3.11",
|
||||
cwd="/workspace",
|
||||
timeout=60,
|
||||
container_config={"docker_mount_cwd_to_workspace": True},
|
||||
host_cwd="/home/user/project",
|
||||
)
|
||||
|
||||
assert env is sentinel
|
||||
assert captured["cwd"] == "/workspace"
|
||||
assert captured["host_cwd"] == "/home/user/project"
|
||||
assert captured["auto_mount_cwd"] is True
|
||||
|
||||
def test_ssh_preserves_home_paths(self):
|
||||
"""SSH backend should NOT replace /home/ paths (they're valid remotely)."""
|
||||
with patch.dict(os.environ, {
|
||||
|
||||
@@ -30,6 +30,28 @@ class TestParseEnvVar:
|
||||
result = _parse_env_var("TERMINAL_DOCKER_VOLUMES", "[]", json.loads, "valid JSON")
|
||||
assert result == ["/host:/container"]
|
||||
|
||||
def test_get_env_config_parses_docker_forward_env_json(self):
|
||||
with patch.dict("os.environ", {
|
||||
"TERMINAL_ENV": "docker",
|
||||
"TERMINAL_DOCKER_FORWARD_ENV": '["GITHUB_TOKEN", "NPM_TOKEN"]',
|
||||
}, clear=False):
|
||||
config = _tt_mod._get_env_config()
|
||||
assert config["docker_forward_env"] == ["GITHUB_TOKEN", "NPM_TOKEN"]
|
||||
|
||||
def test_create_environment_passes_docker_forward_env(self):
|
||||
fake_env = object()
|
||||
with patch.object(_tt_mod, "_DockerEnvironment", return_value=fake_env) as mock_docker:
|
||||
result = _tt_mod._create_environment(
|
||||
"docker",
|
||||
image="python:3.11",
|
||||
cwd="/root",
|
||||
timeout=180,
|
||||
container_config={"docker_forward_env": ["GITHUB_TOKEN"]},
|
||||
)
|
||||
|
||||
assert result is fake_env
|
||||
assert mock_docker.call_args.kwargs["forward_env"] == ["GITHUB_TOKEN"]
|
||||
|
||||
def test_falls_back_to_default(self):
|
||||
with patch.dict("os.environ", {}, clear=False):
|
||||
# Remove the var if it exists, rely on default
|
||||
|
||||
@@ -232,48 +232,6 @@ class TestCheckFnExceptionHandling:
|
||||
assert any(u["name"] == "crashes" for u in unavailable)
|
||||
|
||||
|
||||
class TestEmojiMetadata:
|
||||
"""Verify per-tool emoji registration and lookup."""
|
||||
|
||||
def test_emoji_stored_on_entry(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(
|
||||
name="t", toolset="s", schema=_make_schema(),
|
||||
handler=_dummy_handler, emoji="🔥",
|
||||
)
|
||||
assert reg._tools["t"].emoji == "🔥"
|
||||
|
||||
def test_get_emoji_returns_registered(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(
|
||||
name="t", toolset="s", schema=_make_schema(),
|
||||
handler=_dummy_handler, emoji="🎯",
|
||||
)
|
||||
assert reg.get_emoji("t") == "🎯"
|
||||
|
||||
def test_get_emoji_returns_default_when_unset(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(
|
||||
name="t", toolset="s", schema=_make_schema(),
|
||||
handler=_dummy_handler,
|
||||
)
|
||||
assert reg.get_emoji("t") == "⚡"
|
||||
assert reg.get_emoji("t", default="🔧") == "🔧"
|
||||
|
||||
def test_get_emoji_returns_default_for_unknown_tool(self):
|
||||
reg = ToolRegistry()
|
||||
assert reg.get_emoji("nonexistent") == "⚡"
|
||||
assert reg.get_emoji("nonexistent", default="❓") == "❓"
|
||||
|
||||
def test_emoji_empty_string_treated_as_unset(self):
|
||||
reg = ToolRegistry()
|
||||
reg.register(
|
||||
name="t", toolset="s", schema=_make_schema(),
|
||||
handler=_dummy_handler, emoji="",
|
||||
)
|
||||
assert reg.get_emoji("t") == "⚡"
|
||||
|
||||
|
||||
class TestSecretCaptureResultContract:
|
||||
def test_secret_request_result_does_not_include_secret_value(self):
|
||||
result = {
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
"""Tests for the SSH remote execution environment backend."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.environments.ssh import SSHEnvironment
|
||||
from tools.environments import ssh as ssh_env
|
||||
|
||||
_SSH_HOST = os.getenv("TERMINAL_SSH_HOST", "")
|
||||
_SSH_USER = os.getenv("TERMINAL_SSH_USER", "")
|
||||
_SSH_PORT = int(os.getenv("TERMINAL_SSH_PORT", "22"))
|
||||
_SSH_KEY = os.getenv("TERMINAL_SSH_KEY", "")
|
||||
|
||||
_has_ssh = bool(_SSH_HOST and _SSH_USER)
|
||||
|
||||
requires_ssh = pytest.mark.skipif(
|
||||
not _has_ssh,
|
||||
reason="TERMINAL_SSH_HOST / TERMINAL_SSH_USER not set",
|
||||
)
|
||||
|
||||
|
||||
def _run(command, task_id="ssh_test", **kwargs):
|
||||
from tools.terminal_tool import terminal_tool
|
||||
return json.loads(terminal_tool(command, task_id=task_id, **kwargs))
|
||||
|
||||
|
||||
def _cleanup(task_id="ssh_test"):
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
cleanup_vm(task_id)
|
||||
|
||||
|
||||
class TestBuildSSHCommand:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_connection(self, monkeypatch):
|
||||
monkeypatch.setattr("tools.environments.ssh.subprocess.run",
|
||||
lambda *a, **k: subprocess.CompletedProcess([], 0))
|
||||
monkeypatch.setattr("tools.environments.ssh.subprocess.Popen",
|
||||
lambda *a, **k: MagicMock(stdout=iter([]),
|
||||
stderr=iter([]),
|
||||
stdin=MagicMock()))
|
||||
monkeypatch.setattr("tools.environments.ssh.time.sleep", lambda _: None)
|
||||
|
||||
def test_base_flags(self):
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
cmd = " ".join(env._build_ssh_command())
|
||||
for flag in ("ControlMaster=auto", "ControlPersist=300",
|
||||
"BatchMode=yes", "StrictHostKeyChecking=accept-new"):
|
||||
assert flag in cmd
|
||||
|
||||
def test_custom_port(self):
|
||||
env = SSHEnvironment(host="h", user="u", port=2222)
|
||||
cmd = env._build_ssh_command()
|
||||
assert "-p" in cmd and "2222" in cmd
|
||||
|
||||
def test_key_path(self):
|
||||
env = SSHEnvironment(host="h", user="u", key_path="/k")
|
||||
cmd = env._build_ssh_command()
|
||||
assert "-i" in cmd and "/k" in cmd
|
||||
|
||||
def test_user_host_suffix(self):
|
||||
env = SSHEnvironment(host="h", user="u")
|
||||
assert env._build_ssh_command()[-1] == "u@h"
|
||||
|
||||
|
||||
class TestTerminalToolConfig:
|
||||
def test_ssh_persistent_default_true(self, monkeypatch):
|
||||
"""SSH persistent defaults to True (via TERMINAL_PERSISTENT_SHELL)."""
|
||||
monkeypatch.delenv("TERMINAL_SSH_PERSISTENT", raising=False)
|
||||
monkeypatch.delenv("TERMINAL_PERSISTENT_SHELL", raising=False)
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["ssh_persistent"] is True
|
||||
|
||||
def test_ssh_persistent_explicit_false(self, monkeypatch):
|
||||
"""Per-backend env var overrides the global default."""
|
||||
monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "false")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["ssh_persistent"] is False
|
||||
|
||||
def test_ssh_persistent_explicit_true(self, monkeypatch):
|
||||
monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["ssh_persistent"] is True
|
||||
|
||||
def test_ssh_persistent_respects_config(self, monkeypatch):
|
||||
"""TERMINAL_PERSISTENT_SHELL=false disables SSH persistent by default."""
|
||||
monkeypatch.delenv("TERMINAL_SSH_PERSISTENT", raising=False)
|
||||
monkeypatch.setenv("TERMINAL_PERSISTENT_SHELL", "false")
|
||||
from tools.terminal_tool import _get_env_config
|
||||
assert _get_env_config()["ssh_persistent"] is False
|
||||
|
||||
|
||||
class TestSSHPreflight:
|
||||
def test_ensure_ssh_available_raises_clear_error_when_missing(self, monkeypatch):
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: None)
|
||||
|
||||
with pytest.raises(RuntimeError, match="SSH is not installed or not in PATH"):
|
||||
ssh_env._ensure_ssh_available()
|
||||
|
||||
def test_ssh_environment_checks_availability_before_connect(self, monkeypatch):
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: None)
|
||||
monkeypatch.setattr(
|
||||
ssh_env.SSHEnvironment,
|
||||
"_establish_connection",
|
||||
lambda self: pytest.fail("_establish_connection should not run when ssh is missing"),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="openssh-client"):
|
||||
ssh_env.SSHEnvironment(host="example.com", user="alice")
|
||||
|
||||
def test_ssh_environment_connects_when_ssh_exists(self, monkeypatch):
|
||||
called = {"count": 0}
|
||||
|
||||
monkeypatch.setattr(ssh_env.shutil, "which", lambda _name: "/usr/bin/ssh")
|
||||
|
||||
def _fake_establish(self):
|
||||
called["count"] += 1
|
||||
|
||||
monkeypatch.setattr(ssh_env.SSHEnvironment, "_establish_connection", _fake_establish)
|
||||
|
||||
env = ssh_env.SSHEnvironment(host="example.com", user="alice")
|
||||
|
||||
assert called["count"] == 1
|
||||
assert env.host == "example.com"
|
||||
assert env.user == "alice"
|
||||
|
||||
|
||||
def _setup_ssh_env(monkeypatch, persistent: bool):
|
||||
monkeypatch.setenv("TERMINAL_ENV", "ssh")
|
||||
monkeypatch.setenv("TERMINAL_SSH_HOST", _SSH_HOST)
|
||||
monkeypatch.setenv("TERMINAL_SSH_USER", _SSH_USER)
|
||||
monkeypatch.setenv("TERMINAL_SSH_PERSISTENT", "true" if persistent else "false")
|
||||
if _SSH_PORT != 22:
|
||||
monkeypatch.setenv("TERMINAL_SSH_PORT", str(_SSH_PORT))
|
||||
if _SSH_KEY:
|
||||
monkeypatch.setenv("TERMINAL_SSH_KEY", _SSH_KEY)
|
||||
|
||||
|
||||
@requires_ssh
|
||||
class TestOneShotSSH:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self, monkeypatch):
|
||||
_setup_ssh_env(monkeypatch, persistent=False)
|
||||
yield
|
||||
_cleanup()
|
||||
|
||||
def test_echo(self):
|
||||
r = _run("echo hello")
|
||||
assert r["exit_code"] == 0
|
||||
assert "hello" in r["output"]
|
||||
|
||||
def test_exit_code(self):
|
||||
r = _run("exit 42")
|
||||
assert r["exit_code"] == 42
|
||||
|
||||
def test_state_does_not_persist(self):
|
||||
_run("export HERMES_ONESHOT_TEST=yes")
|
||||
r = _run("echo $HERMES_ONESHOT_TEST")
|
||||
assert r["output"].strip() == ""
|
||||
|
||||
|
||||
@requires_ssh
|
||||
class TestPersistentSSH:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup(self, monkeypatch):
|
||||
_setup_ssh_env(monkeypatch, persistent=True)
|
||||
yield
|
||||
_cleanup()
|
||||
|
||||
def test_echo(self):
|
||||
r = _run("echo hello-persistent")
|
||||
assert r["exit_code"] == 0
|
||||
assert "hello-persistent" in r["output"]
|
||||
|
||||
def test_env_var_persists(self):
|
||||
_run("export HERMES_PERSIST_TEST=works")
|
||||
r = _run("echo $HERMES_PERSIST_TEST")
|
||||
assert r["output"].strip() == "works"
|
||||
|
||||
def test_cwd_persists(self):
|
||||
_run("cd /tmp")
|
||||
r = _run("pwd")
|
||||
assert r["output"].strip() == "/tmp"
|
||||
|
||||
def test_exit_code(self):
|
||||
r = _run("(exit 42)")
|
||||
assert r["exit_code"] == 42
|
||||
|
||||
def test_stderr(self):
|
||||
r = _run("echo oops >&2")
|
||||
assert r["exit_code"] == 0
|
||||
assert "oops" in r["output"]
|
||||
|
||||
def test_multiline_output(self):
|
||||
r = _run("echo a; echo b; echo c")
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert lines == ["a", "b", "c"]
|
||||
|
||||
def test_timeout_then_recovery(self):
|
||||
r = _run("sleep 999", timeout=2)
|
||||
assert r["exit_code"] == 124
|
||||
r = _run("echo alive")
|
||||
assert r["exit_code"] == 0
|
||||
assert "alive" in r["output"]
|
||||
|
||||
def test_large_output(self):
|
||||
r = _run("seq 1 1000")
|
||||
assert r["exit_code"] == 0
|
||||
lines = r["output"].strip().splitlines()
|
||||
assert len(lines) == 1000
|
||||
assert lines[0] == "1"
|
||||
assert lines[-1] == "1000"
|
||||
@@ -315,23 +315,6 @@ class TestEnsureInstalled:
|
||||
mock_thread.start.assert_called_once()
|
||||
_tirith_mod._resolved_path = None
|
||||
|
||||
@patch("tools.tirith_security._load_security_config")
|
||||
def test_startup_prefetch_can_suppress_install_failure_logs(self, mock_cfg):
|
||||
mock_cfg.return_value = {"tirith_enabled": True, "tirith_path": "tirith",
|
||||
"tirith_timeout": 5, "tirith_fail_open": True}
|
||||
_tirith_mod._resolved_path = None
|
||||
with patch("tools.tirith_security.shutil.which", return_value=None), \
|
||||
patch("tools.tirith_security._hermes_bin_dir", return_value="/nonexistent"), \
|
||||
patch("tools.tirith_security._is_install_failed_on_disk", return_value=False), \
|
||||
patch("tools.tirith_security.threading.Thread") as MockThread:
|
||||
mock_thread = MagicMock()
|
||||
MockThread.return_value = mock_thread
|
||||
result = ensure_installed(log_failures=False)
|
||||
assert result is None
|
||||
assert MockThread.call_args.kwargs["kwargs"] == {"log_failures": False}
|
||||
mock_thread.start.assert_called_once()
|
||||
_tirith_mod._resolved_path = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failed download caches the miss (Finding #1)
|
||||
@@ -533,22 +516,6 @@ class TestCosignVerification:
|
||||
assert path is None
|
||||
assert reason == "cosign_missing"
|
||||
|
||||
@patch("tools.tirith_security.logger.debug")
|
||||
@patch("tools.tirith_security.logger.warning")
|
||||
@patch("tools.tirith_security.shutil.which", return_value=None)
|
||||
@patch("tools.tirith_security._download_file")
|
||||
@patch("tools.tirith_security._detect_target", return_value="aarch64-apple-darwin")
|
||||
def test_install_quiet_mode_downgrades_cosign_missing_log(self, mock_target, mock_dl,
|
||||
mock_which, mock_warning,
|
||||
mock_debug):
|
||||
"""Startup prefetch should not surface cosign-missing as a warning."""
|
||||
from tools.tirith_security import _install_tirith
|
||||
path, reason = _install_tirith(log_failures=False)
|
||||
assert path is None
|
||||
assert reason == "cosign_missing"
|
||||
mock_warning.assert_not_called()
|
||||
mock_debug.assert_called()
|
||||
|
||||
@patch("tools.tirith_security._verify_cosign", return_value=None)
|
||||
@patch("tools.tirith_security.shutil.which", return_value="/usr/local/bin/cosign")
|
||||
@patch("tools.tirith_security._download_file")
|
||||
|
||||
@@ -7,7 +7,6 @@ end-to-end dispatch. All external dependencies are mocked.
|
||||
|
||||
import os
|
||||
import struct
|
||||
import subprocess
|
||||
import wave
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -46,10 +45,7 @@ def sample_ogg(tmp_path):
|
||||
def clean_env(monkeypatch):
|
||||
"""Ensure no real API keys leak into tests."""
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.delenv("HERMES_LOCAL_STT_COMMAND", raising=False)
|
||||
monkeypatch.delenv("HERMES_LOCAL_STT_LANGUAGE", raising=False)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -136,19 +132,6 @@ class TestGetProviderFallbackPriority:
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "local"
|
||||
|
||||
def test_openai_fallback_to_local_command(self, monkeypatch):
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
monkeypatch.setenv(
|
||||
"HERMES_LOCAL_STT_COMMAND",
|
||||
"whisper {input_path} --output_dir {output_dir} --language {language}",
|
||||
)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({"provider": "openai"}) == "local_command"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_groq
|
||||
@@ -296,63 +279,6 @@ class TestTranscribeOpenAIExtended:
|
||||
assert "Permission denied" in result["error"]
|
||||
|
||||
|
||||
class TestTranscribeLocalCommand:
|
||||
def test_auto_detects_local_whisper_binary(self, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_LOCAL_STT_COMMAND", raising=False)
|
||||
monkeypatch.setattr("tools.transcription_tools._find_whisper_binary", lambda: "/opt/homebrew/bin/whisper")
|
||||
|
||||
from tools.transcription_tools import _get_local_command_template
|
||||
|
||||
template = _get_local_command_template()
|
||||
|
||||
assert template is not None
|
||||
assert template.startswith("/opt/homebrew/bin/whisper ")
|
||||
assert "{model}" in template
|
||||
assert "{output_dir}" in template
|
||||
|
||||
def test_command_fallback_with_template(self, monkeypatch, sample_ogg, tmp_path):
|
||||
out_dir = tmp_path / "local-out"
|
||||
out_dir.mkdir()
|
||||
|
||||
monkeypatch.setenv(
|
||||
"HERMES_LOCAL_STT_COMMAND",
|
||||
"whisper {input_path} --model {model} --output_dir {output_dir} --language {language}",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_LOCAL_STT_LANGUAGE", "en")
|
||||
|
||||
def fake_tempdir(prefix=None):
|
||||
class _TempDir:
|
||||
def __enter__(self_inner):
|
||||
return str(out_dir)
|
||||
|
||||
def __exit__(self_inner, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
return _TempDir()
|
||||
|
||||
def fake_run(cmd, *args, **kwargs):
|
||||
if isinstance(cmd, list):
|
||||
output_path = cmd[-1]
|
||||
with open(output_path, "wb") as handle:
|
||||
handle.write(b"RIFF....WAVEfmt ")
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="")
|
||||
|
||||
(out_dir / "test.txt").write_text("hello from local command\n", encoding="utf-8")
|
||||
return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="")
|
||||
|
||||
monkeypatch.setattr("tools.transcription_tools.tempfile.TemporaryDirectory", fake_tempdir)
|
||||
monkeypatch.setattr("tools.transcription_tools._find_ffmpeg_binary", lambda: "/opt/homebrew/bin/ffmpeg")
|
||||
monkeypatch.setattr("tools.transcription_tools.subprocess.run", fake_run)
|
||||
|
||||
from tools.transcription_tools import _transcribe_local_command
|
||||
|
||||
result = _transcribe_local_command(sample_ogg, "base")
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hello from local command"
|
||||
assert result["provider"] == "local_command"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# _transcribe_local — additional tests
|
||||
# ============================================================================
|
||||
@@ -686,29 +612,6 @@ class TestTranscribeAudioDispatch:
|
||||
assert "faster-whisper" in result["error"]
|
||||
assert "GROQ_API_KEY" in result["error"]
|
||||
|
||||
def test_openai_provider_falls_back_to_local_command(self, monkeypatch, sample_ogg):
|
||||
monkeypatch.delenv("VOICE_TOOLS_OPENAI_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
monkeypatch.setenv(
|
||||
"HERMES_LOCAL_STT_COMMAND",
|
||||
"whisper {input_path} --model {model} --output_dir {output_dir} --language {language}",
|
||||
)
|
||||
|
||||
with patch("tools.transcription_tools._load_stt_config", return_value={"provider": "openai"}), \
|
||||
patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True), \
|
||||
patch("tools.transcription_tools._transcribe_local_command", return_value={
|
||||
"success": True,
|
||||
"transcript": "hello from fallback",
|
||||
"provider": "local_command",
|
||||
}) as mock_local_command:
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio(sample_ogg)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["transcript"] == "hello from fallback"
|
||||
mock_local_command.assert_called_once_with(sample_ogg, "base")
|
||||
|
||||
def test_invalid_file_short_circuits(self):
|
||||
from tools.transcription_tools import transcribe_audio
|
||||
result = transcribe_audio("/nonexistent/audio.wav")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user